From 5d7dc80cbc4e5076de789a0556e84f29150b4c00 Mon Sep 17 00:00:00 2001 From: Steven Date: Tue, 27 Aug 2024 00:06:39 +0800 Subject: [PATCH] feat: add store for access token cache --- config.go | 8 ++++ memogram.go | 103 +++++++++++++++++++++++-------------------------- store/store.go | 29 ++++++++++++++ store/user.go | 100 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 54 deletions(-) create mode 100644 store/store.go create mode 100644 store/user.go diff --git a/config.go b/config.go index 5e77ca3..14eb4f9 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,8 @@ package memogram import ( "os" + "path" + "github.com/caarlos0/env" "github.com/joho/godotenv" "github.com/pkg/errors" @@ -10,6 +12,7 @@ import ( type Config struct { ServerAddr string `env:"SERVER_ADDR,required"` BotToken string `env:"BOT_TOKEN,required"` + Data string `env:"DATA"` } func getConfigFromEnv() (*Config, error) { @@ -25,5 +28,10 @@ func getConfigFromEnv() (*Config, error) { if err := env.Parse(&config); err != nil { return nil, errors.Wrap(err, "invalid configuration") } + if config.Data == "" { + // Default to `data.txt` if not specified. + config.Data = "data.txt" + } + config.Data = path.Join(".", config.Data) return &config, nil } diff --git a/memogram.go b/memogram.go index f1a1c1d..f28f897 100644 --- a/memogram.go +++ b/memogram.go @@ -8,11 +8,11 @@ import ( "net/http" "path/filepath" "strings" - "sync" "github.com/go-telegram/bot" "github.com/go-telegram/bot/models" "github.com/pkg/errors" + "github.com/usememos/memogram/store" v1pb "github.com/usememos/memos/proto/gen/api/v1" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -20,16 +20,11 @@ import ( fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb" ) -// userAccessTokenCache is a cache for user access token. -// Key is the user id from telegram. -// Value is the access token from memos. -// TODO: save it to a persistent storage. -var userAccessTokenCache sync.Map // map[int64]string - type Service struct { - config *Config - client *MemosClient bot *bot.Bot + client *MemosClient + config *Config + store *store.Store } func NewService() (*Service, error) { @@ -38,16 +33,21 @@ func NewService() (*Service, error) { return nil, errors.Wrap(err, "failed to get config from env") } - conn, err := grpc.Dial(config.ServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, err := grpc.NewClient(config.ServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { slog.Error("failed to connect to server", slog.Any("err", err)) return nil, errors.Wrap(err, "failed to connect to server") } client := NewMemosClient(conn) + store := store.NewStore(config.Data) + if err := store.Init(); err != nil { + return nil, errors.Wrap(err, "failed to init store") + } s := &Service{ config: config, client: client, + store: store, } opts := []bot.Option{ @@ -86,7 +86,7 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) { } userID := m.Message.From.ID - if _, ok := userAccessTokenCache.Load(userID); !ok { + if _, ok := s.store.GetUserAccessToken(userID); !ok { b.SendMessage(ctx, &bot.SendMessageParams{ ChatID: m.Message.Chat.ID, Text: "Please start the bot with /start ", @@ -147,8 +147,8 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) { return } - accessToken, _ := userAccessTokenCache.Load(userID) - ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string)))) + accessToken, _ := s.store.GetUserAccessToken(userID) + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken))) memo, err := s.client.MemoService.CreateMemo(ctx, &v1pb.CreateMemoRequest{ Content: content, }) @@ -164,15 +164,12 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) { if message.Document != nil { s.processFileMessage(ctx, b, m, message.Document.FileID, memo) } - if message.Voice != nil { s.processFileMessage(ctx, b, m, message.Voice.FileID, memo) } - if message.Video != nil { s.processFileMessage(ctx, b, m, message.Video.FileID, memo) } - if len(message.Photo) > 0 { photo := message.Photo[len(message.Photo)-1] s.processFileMessage(ctx, b, m, photo.FileID, memo) @@ -204,7 +201,7 @@ func (s *Service) startHandler(ctx context.Context, b *bot.Bot, m *models.Update return } - userAccessTokenCache.Store(userID, accessToken) + s.store.SetUserAccessToken(userID, accessToken) b.SendMessage(ctx, &bot.SendMessageParams{ ChatID: m.Message.Chat.ID, Text: fmt.Sprintf("Hello %s!", user.Nickname), @@ -214,29 +211,29 @@ func (s *Service) startHandler(ctx context.Context, b *bot.Bot, m *models.Update func (s *Service) keyboard(memo *v1pb.Memo) *models.InlineKeyboardMarkup { // add inline keyboard to edit memo's visibility or pinned status. return &models.InlineKeyboardMarkup{ - InlineKeyboard: [][]models.InlineKeyboardButton{ + InlineKeyboard: [][]models.InlineKeyboardButton{ + { { - { - Text: "Public", - CallbackData: fmt.Sprintf("public %s", memo.Name), - }, - { - Text: "Private", - CallbackData: fmt.Sprintf("private %s", memo.Name), - }, - { - Text: "Pin", - CallbackData: fmt.Sprintf("pin %s", memo.Name), - }, + Text: "Public", + CallbackData: fmt.Sprintf("public %s", memo.Name), + }, + { + Text: "Private", + CallbackData: fmt.Sprintf("private %s", memo.Name), + }, + { + Text: "Pin", + CallbackData: fmt.Sprintf("pin %s", memo.Name), }, }, - } + }, + } } func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *models.Update) { callbackData := update.CallbackQuery.Data userID := update.CallbackQuery.From.ID - accessToken, ok := userAccessTokenCache.Load(userID) + accessToken, ok := s.store.GetUserAccessToken(userID) if !ok { b.AnswerCallbackQuery(ctx, &bot.AnswerCallbackQueryParams{ CallbackQueryID: update.CallbackQuery.ID, @@ -246,7 +243,7 @@ func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update * return } - ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string)))) + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken))) parts := strings.Split(callbackData, " ") if len(parts) != 2 { @@ -313,10 +310,10 @@ func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update * pinnedMarker = "" } b.EditMessageText(ctx, &bot.EditMessageTextParams{ - ChatID: update.CallbackQuery.Message.Message.Chat.ID, - MessageID: update.CallbackQuery.Message.Message.ID, - Text: fmt.Sprintf("Memo updated as %s with [%s](%s/m/%s) %s", v1pb.Visibility_name[int32(memo.Visibility)], memo.Name, s.config.ServerAddr, memo.Uid, pinnedMarker), - ParseMode: models.ParseModeMarkdown, + ChatID: update.CallbackQuery.Message.Message.Chat.ID, + MessageID: update.CallbackQuery.Message.Message.ID, + Text: fmt.Sprintf("Memo updated as %s with [%s](%s/m/%s) %s", v1pb.Visibility_name[int32(memo.Visibility)], memo.Name, s.config.ServerAddr, memo.Uid, pinnedMarker), + ParseMode: models.ParseModeMarkdown, ReplyMarkup: s.keyboard(memo), }) @@ -332,8 +329,8 @@ func (s *Service) searchHandler(ctx context.Context, b *bot.Bot, m *models.Updat filterString := "content_search == ['" + searchString + "']" - accessToken, _ := userAccessTokenCache.Load(userID) - ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string)))) + accessToken, _ := s.store.GetUserAccessToken(userID) + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken))) results, err := s.client.MemoService.ListMemos(ctx, &v1pb.ListMemosRequest{ PageSize: 10, Filter: filterString, @@ -360,8 +357,20 @@ func (s *Service) searchHandler(ctx context.Context, b *bot.Bot, m *models.Updat }) } } +} - return +func (s *Service) processFileMessage(ctx context.Context, b *bot.Bot, m *models.Update, fileID string, memo *v1pb.Memo) { + file, err := b.GetFile(ctx, &bot.GetFileParams{FileID: fileID}) + if err != nil { + s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to get file")) + return + } + + _, err = s.saveResourceFromFile(ctx, file, memo) + if err != nil { + s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to save resource")) + return + } } func (s *Service) saveResourceFromFile(ctx context.Context, file *models.File, memo *v1pb.Memo) (*v1pb.Resource, error) { @@ -397,20 +406,6 @@ func (s *Service) saveResourceFromFile(ctx context.Context, file *models.File, m return resource, nil } -func (s *Service) processFileMessage(ctx context.Context, b *bot.Bot, m *models.Update, fileID string, memo *v1pb.Memo) { - file, err := b.GetFile(ctx, &bot.GetFileParams{FileID: fileID}) - if err != nil { - s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to get file")) - return - } - - _, err = s.saveResourceFromFile(ctx, file, memo) - if err != nil { - s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to save resource")) - return - } -} - func (s *Service) sendError(b *bot.Bot, chatID int64, err error) { slog.Error("error", slog.Any("err", err)) b.SendMessage(context.Background(), &bot.SendMessageParams{ diff --git a/store/store.go b/store/store.go new file mode 100644 index 0000000..251da7e --- /dev/null +++ b/store/store.go @@ -0,0 +1,29 @@ +package store + +import ( + "sync" + + "github.com/pkg/errors" +) + +type Store struct { + Data string + + userAccessTokenCache sync.Map // map[int64]string +} + +func NewStore(data string) *Store { + return &Store{ + Data: data, + + userAccessTokenCache: sync.Map{}, + } +} + +func (s *Store) Init() error { + if err := s.loadUserAccessTokenMapFromFile(); err != nil { + return errors.Wrap(err, "failed to load user access token map from file") + } + + return nil +} diff --git a/store/user.go b/store/user.go new file mode 100644 index 0000000..735062d --- /dev/null +++ b/store/user.go @@ -0,0 +1,100 @@ +package store + +import ( + "bufio" + "log/slog" + "os" + "strconv" + "strings" +) + +// GetUserAccessToken returns the access token for the user. +func (s *Store) GetUserAccessToken(userID int64) (string, bool) { + accessToken, ok := s.userAccessTokenCache.Load(userID) + if !ok { + return "", false + } + return accessToken.(string), true +} + +// SetUserAccessToken sets the access token for the user. +func (s *Store) SetUserAccessToken(userID int64, accessToken string) { + s.userAccessTokenCache.Store(userID, accessToken) + if err := s.SaveUserAccessTokenMapToFile(); err != nil { + slog.Error("failed to save user access token map to file", "error", err) + } +} + +// SaveUserAccessTokenMapToFile saves the user access token map to a data file. +func (s *Store) SaveUserAccessTokenMapToFile() error { + // Open the file for writing + file, err := os.OpenFile(s.Data, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer file.Close() + + // Iterate over the user access token map and write each entry to the file + s.userAccessTokenCache.Range(func(key, value interface{}) bool { + userID := key.(int64) + accessToken := value.(string) + line := strconv.FormatInt(userID, 10) + ":" + accessToken + "\n" + _, err := file.WriteString(line) + if err != nil { + return false + } + return true + }) + + return nil +} + +func (s *Store) loadUserAccessTokenMapFromFile() error { + // Check if the file exists + if _, err := os.Stat(s.Data); os.IsNotExist(err) { + // Create the file if it doesn't exist + file, err := os.Create(s.Data) + if err != nil { + return err + } + defer file.Close() + } + + // Open the file + file, err := os.Open(s.Data) + if err != nil { + return err + } + defer file.Close() + + // Read the file line by line + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + // Parse the line and extract the user ID and access token + userID, accessToken := parseLine(line) + if userID == 0 || accessToken == "" { + continue + } + // Store the user ID and access token in the cache + s.userAccessTokenCache.Store(userID, accessToken) + } + if err := scanner.Err(); err != nil { + return err + } + return nil +} + +func parseLine(line string) (int64, string) { + parts := strings.Split(line, ":") + if len(parts) != 2 { + return 0, "" + } + userIDStr := parts[0] + accessToken := parts[1] + userID, err := strconv.ParseInt(userIDStr, 10, 64) + if err != nil { + return 0, "" + } + return userID, accessToken +}