mirror of
https://github.com/kemko/memes-telegram-integration.git
synced 2026-01-01 15:55:41 +03:00
feat: add store for access token cache
This commit is contained in:
@@ -2,6 +2,8 @@ package memogram
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"github.com/caarlos0/env"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/pkg/errors"
|
||||
@@ -10,6 +12,7 @@ import (
|
||||
type Config struct {
|
||||
ServerAddr string `env:"SERVER_ADDR,required"`
|
||||
BotToken string `env:"BOT_TOKEN,required"`
|
||||
Data string `env:"DATA"`
|
||||
}
|
||||
|
||||
func getConfigFromEnv() (*Config, error) {
|
||||
@@ -25,5 +28,10 @@ func getConfigFromEnv() (*Config, error) {
|
||||
if err := env.Parse(&config); err != nil {
|
||||
return nil, errors.Wrap(err, "invalid configuration")
|
||||
}
|
||||
if config.Data == "" {
|
||||
// Default to `data.txt` if not specified.
|
||||
config.Data = "data.txt"
|
||||
}
|
||||
config.Data = path.Join(".", config.Data)
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
103
memogram.go
103
memogram.go
@@ -8,11 +8,11 @@ import (
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/go-telegram/bot"
|
||||
"github.com/go-telegram/bot/models"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/memogram/store"
|
||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@@ -20,16 +20,11 @@ import (
|
||||
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
)
|
||||
|
||||
// userAccessTokenCache is a cache for user access token.
|
||||
// Key is the user id from telegram.
|
||||
// Value is the access token from memos.
|
||||
// TODO: save it to a persistent storage.
|
||||
var userAccessTokenCache sync.Map // map[int64]string
|
||||
|
||||
type Service struct {
|
||||
config *Config
|
||||
client *MemosClient
|
||||
bot *bot.Bot
|
||||
client *MemosClient
|
||||
config *Config
|
||||
store *store.Store
|
||||
}
|
||||
|
||||
func NewService() (*Service, error) {
|
||||
@@ -38,16 +33,21 @@ func NewService() (*Service, error) {
|
||||
return nil, errors.Wrap(err, "failed to get config from env")
|
||||
}
|
||||
|
||||
conn, err := grpc.Dial(config.ServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
conn, err := grpc.NewClient(config.ServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
slog.Error("failed to connect to server", slog.Any("err", err))
|
||||
return nil, errors.Wrap(err, "failed to connect to server")
|
||||
}
|
||||
client := NewMemosClient(conn)
|
||||
|
||||
store := store.NewStore(config.Data)
|
||||
if err := store.Init(); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to init store")
|
||||
}
|
||||
s := &Service{
|
||||
config: config,
|
||||
client: client,
|
||||
store: store,
|
||||
}
|
||||
|
||||
opts := []bot.Option{
|
||||
@@ -86,7 +86,7 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) {
|
||||
}
|
||||
|
||||
userID := m.Message.From.ID
|
||||
if _, ok := userAccessTokenCache.Load(userID); !ok {
|
||||
if _, ok := s.store.GetUserAccessToken(userID); !ok {
|
||||
b.SendMessage(ctx, &bot.SendMessageParams{
|
||||
ChatID: m.Message.Chat.ID,
|
||||
Text: "Please start the bot with /start <access_token>",
|
||||
@@ -147,8 +147,8 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) {
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, _ := userAccessTokenCache.Load(userID)
|
||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string))))
|
||||
accessToken, _ := s.store.GetUserAccessToken(userID)
|
||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken)))
|
||||
memo, err := s.client.MemoService.CreateMemo(ctx, &v1pb.CreateMemoRequest{
|
||||
Content: content,
|
||||
})
|
||||
@@ -164,15 +164,12 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) {
|
||||
if message.Document != nil {
|
||||
s.processFileMessage(ctx, b, m, message.Document.FileID, memo)
|
||||
}
|
||||
|
||||
if message.Voice != nil {
|
||||
s.processFileMessage(ctx, b, m, message.Voice.FileID, memo)
|
||||
}
|
||||
|
||||
if message.Video != nil {
|
||||
s.processFileMessage(ctx, b, m, message.Video.FileID, memo)
|
||||
}
|
||||
|
||||
if len(message.Photo) > 0 {
|
||||
photo := message.Photo[len(message.Photo)-1]
|
||||
s.processFileMessage(ctx, b, m, photo.FileID, memo)
|
||||
@@ -204,7 +201,7 @@ func (s *Service) startHandler(ctx context.Context, b *bot.Bot, m *models.Update
|
||||
return
|
||||
}
|
||||
|
||||
userAccessTokenCache.Store(userID, accessToken)
|
||||
s.store.SetUserAccessToken(userID, accessToken)
|
||||
b.SendMessage(ctx, &bot.SendMessageParams{
|
||||
ChatID: m.Message.Chat.ID,
|
||||
Text: fmt.Sprintf("Hello %s!", user.Nickname),
|
||||
@@ -214,29 +211,29 @@ func (s *Service) startHandler(ctx context.Context, b *bot.Bot, m *models.Update
|
||||
func (s *Service) keyboard(memo *v1pb.Memo) *models.InlineKeyboardMarkup {
|
||||
// add inline keyboard to edit memo's visibility or pinned status.
|
||||
return &models.InlineKeyboardMarkup{
|
||||
InlineKeyboard: [][]models.InlineKeyboardButton{
|
||||
InlineKeyboard: [][]models.InlineKeyboardButton{
|
||||
{
|
||||
{
|
||||
{
|
||||
Text: "Public",
|
||||
CallbackData: fmt.Sprintf("public %s", memo.Name),
|
||||
},
|
||||
{
|
||||
Text: "Private",
|
||||
CallbackData: fmt.Sprintf("private %s", memo.Name),
|
||||
},
|
||||
{
|
||||
Text: "Pin",
|
||||
CallbackData: fmt.Sprintf("pin %s", memo.Name),
|
||||
},
|
||||
Text: "Public",
|
||||
CallbackData: fmt.Sprintf("public %s", memo.Name),
|
||||
},
|
||||
{
|
||||
Text: "Private",
|
||||
CallbackData: fmt.Sprintf("private %s", memo.Name),
|
||||
},
|
||||
{
|
||||
Text: "Pin",
|
||||
CallbackData: fmt.Sprintf("pin %s", memo.Name),
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *models.Update) {
|
||||
callbackData := update.CallbackQuery.Data
|
||||
userID := update.CallbackQuery.From.ID
|
||||
accessToken, ok := userAccessTokenCache.Load(userID)
|
||||
accessToken, ok := s.store.GetUserAccessToken(userID)
|
||||
if !ok {
|
||||
b.AnswerCallbackQuery(ctx, &bot.AnswerCallbackQueryParams{
|
||||
CallbackQueryID: update.CallbackQuery.ID,
|
||||
@@ -246,7 +243,7 @@ func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *
|
||||
return
|
||||
}
|
||||
|
||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string))))
|
||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken)))
|
||||
|
||||
parts := strings.Split(callbackData, " ")
|
||||
if len(parts) != 2 {
|
||||
@@ -313,10 +310,10 @@ func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *
|
||||
pinnedMarker = ""
|
||||
}
|
||||
b.EditMessageText(ctx, &bot.EditMessageTextParams{
|
||||
ChatID: update.CallbackQuery.Message.Message.Chat.ID,
|
||||
MessageID: update.CallbackQuery.Message.Message.ID,
|
||||
Text: fmt.Sprintf("Memo updated as %s with [%s](%s/m/%s) %s", v1pb.Visibility_name[int32(memo.Visibility)], memo.Name, s.config.ServerAddr, memo.Uid, pinnedMarker),
|
||||
ParseMode: models.ParseModeMarkdown,
|
||||
ChatID: update.CallbackQuery.Message.Message.Chat.ID,
|
||||
MessageID: update.CallbackQuery.Message.Message.ID,
|
||||
Text: fmt.Sprintf("Memo updated as %s with [%s](%s/m/%s) %s", v1pb.Visibility_name[int32(memo.Visibility)], memo.Name, s.config.ServerAddr, memo.Uid, pinnedMarker),
|
||||
ParseMode: models.ParseModeMarkdown,
|
||||
ReplyMarkup: s.keyboard(memo),
|
||||
})
|
||||
|
||||
@@ -332,8 +329,8 @@ func (s *Service) searchHandler(ctx context.Context, b *bot.Bot, m *models.Updat
|
||||
|
||||
filterString := "content_search == ['" + searchString + "']"
|
||||
|
||||
accessToken, _ := userAccessTokenCache.Load(userID)
|
||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string))))
|
||||
accessToken, _ := s.store.GetUserAccessToken(userID)
|
||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken)))
|
||||
results, err := s.client.MemoService.ListMemos(ctx, &v1pb.ListMemosRequest{
|
||||
PageSize: 10,
|
||||
Filter: filterString,
|
||||
@@ -360,8 +357,20 @@ func (s *Service) searchHandler(ctx context.Context, b *bot.Bot, m *models.Updat
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
func (s *Service) processFileMessage(ctx context.Context, b *bot.Bot, m *models.Update, fileID string, memo *v1pb.Memo) {
|
||||
file, err := b.GetFile(ctx, &bot.GetFileParams{FileID: fileID})
|
||||
if err != nil {
|
||||
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to get file"))
|
||||
return
|
||||
}
|
||||
|
||||
_, err = s.saveResourceFromFile(ctx, file, memo)
|
||||
if err != nil {
|
||||
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to save resource"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) saveResourceFromFile(ctx context.Context, file *models.File, memo *v1pb.Memo) (*v1pb.Resource, error) {
|
||||
@@ -397,20 +406,6 @@ func (s *Service) saveResourceFromFile(ctx context.Context, file *models.File, m
|
||||
return resource, nil
|
||||
}
|
||||
|
||||
func (s *Service) processFileMessage(ctx context.Context, b *bot.Bot, m *models.Update, fileID string, memo *v1pb.Memo) {
|
||||
file, err := b.GetFile(ctx, &bot.GetFileParams{FileID: fileID})
|
||||
if err != nil {
|
||||
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to get file"))
|
||||
return
|
||||
}
|
||||
|
||||
_, err = s.saveResourceFromFile(ctx, file, memo)
|
||||
if err != nil {
|
||||
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to save resource"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) sendError(b *bot.Bot, chatID int64, err error) {
|
||||
slog.Error("error", slog.Any("err", err))
|
||||
b.SendMessage(context.Background(), &bot.SendMessageParams{
|
||||
|
||||
29
store/store.go
Normal file
29
store/store.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
Data string
|
||||
|
||||
userAccessTokenCache sync.Map // map[int64]string
|
||||
}
|
||||
|
||||
func NewStore(data string) *Store {
|
||||
return &Store{
|
||||
Data: data,
|
||||
|
||||
userAccessTokenCache: sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) Init() error {
|
||||
if err := s.loadUserAccessTokenMapFromFile(); err != nil {
|
||||
return errors.Wrap(err, "failed to load user access token map from file")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
100
store/user.go
Normal file
100
store/user.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetUserAccessToken returns the access token for the user.
|
||||
func (s *Store) GetUserAccessToken(userID int64) (string, bool) {
|
||||
accessToken, ok := s.userAccessTokenCache.Load(userID)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return accessToken.(string), true
|
||||
}
|
||||
|
||||
// SetUserAccessToken sets the access token for the user.
|
||||
func (s *Store) SetUserAccessToken(userID int64, accessToken string) {
|
||||
s.userAccessTokenCache.Store(userID, accessToken)
|
||||
if err := s.SaveUserAccessTokenMapToFile(); err != nil {
|
||||
slog.Error("failed to save user access token map to file", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SaveUserAccessTokenMapToFile saves the user access token map to a data file.
|
||||
func (s *Store) SaveUserAccessTokenMapToFile() error {
|
||||
// Open the file for writing
|
||||
file, err := os.OpenFile(s.Data, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Iterate over the user access token map and write each entry to the file
|
||||
s.userAccessTokenCache.Range(func(key, value interface{}) bool {
|
||||
userID := key.(int64)
|
||||
accessToken := value.(string)
|
||||
line := strconv.FormatInt(userID, 10) + ":" + accessToken + "\n"
|
||||
_, err := file.WriteString(line)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) loadUserAccessTokenMapFromFile() error {
|
||||
// Check if the file exists
|
||||
if _, err := os.Stat(s.Data); os.IsNotExist(err) {
|
||||
// Create the file if it doesn't exist
|
||||
file, err := os.Create(s.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
}
|
||||
|
||||
// Open the file
|
||||
file, err := os.Open(s.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Read the file line by line
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
// Parse the line and extract the user ID and access token
|
||||
userID, accessToken := parseLine(line)
|
||||
if userID == 0 || accessToken == "" {
|
||||
continue
|
||||
}
|
||||
// Store the user ID and access token in the cache
|
||||
s.userAccessTokenCache.Store(userID, accessToken)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseLine(line string) (int64, string) {
|
||||
parts := strings.Split(line, ":")
|
||||
if len(parts) != 2 {
|
||||
return 0, ""
|
||||
}
|
||||
userIDStr := parts[0]
|
||||
accessToken := parts[1]
|
||||
userID, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
return 0, ""
|
||||
}
|
||||
return userID, accessToken
|
||||
}
|
||||
Reference in New Issue
Block a user