mirror of
https://github.com/kemko/memes-telegram-integration.git
synced 2026-01-01 15:55:41 +03:00
feat: add store for access token cache
This commit is contained in:
@@ -2,6 +2,8 @@ package memogram
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
|
|
||||||
"github.com/caarlos0/env"
|
"github.com/caarlos0/env"
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -10,6 +12,7 @@ import (
|
|||||||
type Config struct {
|
type Config struct {
|
||||||
ServerAddr string `env:"SERVER_ADDR,required"`
|
ServerAddr string `env:"SERVER_ADDR,required"`
|
||||||
BotToken string `env:"BOT_TOKEN,required"`
|
BotToken string `env:"BOT_TOKEN,required"`
|
||||||
|
Data string `env:"DATA"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func getConfigFromEnv() (*Config, error) {
|
func getConfigFromEnv() (*Config, error) {
|
||||||
@@ -25,5 +28,10 @@ func getConfigFromEnv() (*Config, error) {
|
|||||||
if err := env.Parse(&config); err != nil {
|
if err := env.Parse(&config); err != nil {
|
||||||
return nil, errors.Wrap(err, "invalid configuration")
|
return nil, errors.Wrap(err, "invalid configuration")
|
||||||
}
|
}
|
||||||
|
if config.Data == "" {
|
||||||
|
// Default to `data.txt` if not specified.
|
||||||
|
config.Data = "data.txt"
|
||||||
|
}
|
||||||
|
config.Data = path.Join(".", config.Data)
|
||||||
return &config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
|
|||||||
103
memogram.go
103
memogram.go
@@ -8,11 +8,11 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/go-telegram/bot"
|
"github.com/go-telegram/bot"
|
||||||
"github.com/go-telegram/bot/models"
|
"github.com/go-telegram/bot/models"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/usememos/memogram/store"
|
||||||
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
v1pb "github.com/usememos/memos/proto/gen/api/v1"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
@@ -20,16 +20,11 @@ import (
|
|||||||
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
|
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// userAccessTokenCache is a cache for user access token.
|
|
||||||
// Key is the user id from telegram.
|
|
||||||
// Value is the access token from memos.
|
|
||||||
// TODO: save it to a persistent storage.
|
|
||||||
var userAccessTokenCache sync.Map // map[int64]string
|
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
config *Config
|
|
||||||
client *MemosClient
|
|
||||||
bot *bot.Bot
|
bot *bot.Bot
|
||||||
|
client *MemosClient
|
||||||
|
config *Config
|
||||||
|
store *store.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService() (*Service, error) {
|
func NewService() (*Service, error) {
|
||||||
@@ -38,16 +33,21 @@ func NewService() (*Service, error) {
|
|||||||
return nil, errors.Wrap(err, "failed to get config from env")
|
return nil, errors.Wrap(err, "failed to get config from env")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := grpc.Dial(config.ServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
conn, err := grpc.NewClient(config.ServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to connect to server", slog.Any("err", err))
|
slog.Error("failed to connect to server", slog.Any("err", err))
|
||||||
return nil, errors.Wrap(err, "failed to connect to server")
|
return nil, errors.Wrap(err, "failed to connect to server")
|
||||||
}
|
}
|
||||||
client := NewMemosClient(conn)
|
client := NewMemosClient(conn)
|
||||||
|
|
||||||
|
store := store.NewStore(config.Data)
|
||||||
|
if err := store.Init(); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to init store")
|
||||||
|
}
|
||||||
s := &Service{
|
s := &Service{
|
||||||
config: config,
|
config: config,
|
||||||
client: client,
|
client: client,
|
||||||
|
store: store,
|
||||||
}
|
}
|
||||||
|
|
||||||
opts := []bot.Option{
|
opts := []bot.Option{
|
||||||
@@ -86,7 +86,7 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userID := m.Message.From.ID
|
userID := m.Message.From.ID
|
||||||
if _, ok := userAccessTokenCache.Load(userID); !ok {
|
if _, ok := s.store.GetUserAccessToken(userID); !ok {
|
||||||
b.SendMessage(ctx, &bot.SendMessageParams{
|
b.SendMessage(ctx, &bot.SendMessageParams{
|
||||||
ChatID: m.Message.Chat.ID,
|
ChatID: m.Message.Chat.ID,
|
||||||
Text: "Please start the bot with /start <access_token>",
|
Text: "Please start the bot with /start <access_token>",
|
||||||
@@ -147,8 +147,8 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, _ := userAccessTokenCache.Load(userID)
|
accessToken, _ := s.store.GetUserAccessToken(userID)
|
||||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string))))
|
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken)))
|
||||||
memo, err := s.client.MemoService.CreateMemo(ctx, &v1pb.CreateMemoRequest{
|
memo, err := s.client.MemoService.CreateMemo(ctx, &v1pb.CreateMemoRequest{
|
||||||
Content: content,
|
Content: content,
|
||||||
})
|
})
|
||||||
@@ -164,15 +164,12 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) {
|
|||||||
if message.Document != nil {
|
if message.Document != nil {
|
||||||
s.processFileMessage(ctx, b, m, message.Document.FileID, memo)
|
s.processFileMessage(ctx, b, m, message.Document.FileID, memo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if message.Voice != nil {
|
if message.Voice != nil {
|
||||||
s.processFileMessage(ctx, b, m, message.Voice.FileID, memo)
|
s.processFileMessage(ctx, b, m, message.Voice.FileID, memo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if message.Video != nil {
|
if message.Video != nil {
|
||||||
s.processFileMessage(ctx, b, m, message.Video.FileID, memo)
|
s.processFileMessage(ctx, b, m, message.Video.FileID, memo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(message.Photo) > 0 {
|
if len(message.Photo) > 0 {
|
||||||
photo := message.Photo[len(message.Photo)-1]
|
photo := message.Photo[len(message.Photo)-1]
|
||||||
s.processFileMessage(ctx, b, m, photo.FileID, memo)
|
s.processFileMessage(ctx, b, m, photo.FileID, memo)
|
||||||
@@ -204,7 +201,7 @@ func (s *Service) startHandler(ctx context.Context, b *bot.Bot, m *models.Update
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userAccessTokenCache.Store(userID, accessToken)
|
s.store.SetUserAccessToken(userID, accessToken)
|
||||||
b.SendMessage(ctx, &bot.SendMessageParams{
|
b.SendMessage(ctx, &bot.SendMessageParams{
|
||||||
ChatID: m.Message.Chat.ID,
|
ChatID: m.Message.Chat.ID,
|
||||||
Text: fmt.Sprintf("Hello %s!", user.Nickname),
|
Text: fmt.Sprintf("Hello %s!", user.Nickname),
|
||||||
@@ -214,29 +211,29 @@ func (s *Service) startHandler(ctx context.Context, b *bot.Bot, m *models.Update
|
|||||||
func (s *Service) keyboard(memo *v1pb.Memo) *models.InlineKeyboardMarkup {
|
func (s *Service) keyboard(memo *v1pb.Memo) *models.InlineKeyboardMarkup {
|
||||||
// add inline keyboard to edit memo's visibility or pinned status.
|
// add inline keyboard to edit memo's visibility or pinned status.
|
||||||
return &models.InlineKeyboardMarkup{
|
return &models.InlineKeyboardMarkup{
|
||||||
InlineKeyboard: [][]models.InlineKeyboardButton{
|
InlineKeyboard: [][]models.InlineKeyboardButton{
|
||||||
|
{
|
||||||
{
|
{
|
||||||
{
|
Text: "Public",
|
||||||
Text: "Public",
|
CallbackData: fmt.Sprintf("public %s", memo.Name),
|
||||||
CallbackData: fmt.Sprintf("public %s", memo.Name),
|
},
|
||||||
},
|
{
|
||||||
{
|
Text: "Private",
|
||||||
Text: "Private",
|
CallbackData: fmt.Sprintf("private %s", memo.Name),
|
||||||
CallbackData: fmt.Sprintf("private %s", memo.Name),
|
},
|
||||||
},
|
{
|
||||||
{
|
Text: "Pin",
|
||||||
Text: "Pin",
|
CallbackData: fmt.Sprintf("pin %s", memo.Name),
|
||||||
CallbackData: fmt.Sprintf("pin %s", memo.Name),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *models.Update) {
|
func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *models.Update) {
|
||||||
callbackData := update.CallbackQuery.Data
|
callbackData := update.CallbackQuery.Data
|
||||||
userID := update.CallbackQuery.From.ID
|
userID := update.CallbackQuery.From.ID
|
||||||
accessToken, ok := userAccessTokenCache.Load(userID)
|
accessToken, ok := s.store.GetUserAccessToken(userID)
|
||||||
if !ok {
|
if !ok {
|
||||||
b.AnswerCallbackQuery(ctx, &bot.AnswerCallbackQueryParams{
|
b.AnswerCallbackQuery(ctx, &bot.AnswerCallbackQueryParams{
|
||||||
CallbackQueryID: update.CallbackQuery.ID,
|
CallbackQueryID: update.CallbackQuery.ID,
|
||||||
@@ -246,7 +243,7 @@ func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string))))
|
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken)))
|
||||||
|
|
||||||
parts := strings.Split(callbackData, " ")
|
parts := strings.Split(callbackData, " ")
|
||||||
if len(parts) != 2 {
|
if len(parts) != 2 {
|
||||||
@@ -313,10 +310,10 @@ func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *
|
|||||||
pinnedMarker = ""
|
pinnedMarker = ""
|
||||||
}
|
}
|
||||||
b.EditMessageText(ctx, &bot.EditMessageTextParams{
|
b.EditMessageText(ctx, &bot.EditMessageTextParams{
|
||||||
ChatID: update.CallbackQuery.Message.Message.Chat.ID,
|
ChatID: update.CallbackQuery.Message.Message.Chat.ID,
|
||||||
MessageID: update.CallbackQuery.Message.Message.ID,
|
MessageID: update.CallbackQuery.Message.Message.ID,
|
||||||
Text: fmt.Sprintf("Memo updated as %s with [%s](%s/m/%s) %s", v1pb.Visibility_name[int32(memo.Visibility)], memo.Name, s.config.ServerAddr, memo.Uid, pinnedMarker),
|
Text: fmt.Sprintf("Memo updated as %s with [%s](%s/m/%s) %s", v1pb.Visibility_name[int32(memo.Visibility)], memo.Name, s.config.ServerAddr, memo.Uid, pinnedMarker),
|
||||||
ParseMode: models.ParseModeMarkdown,
|
ParseMode: models.ParseModeMarkdown,
|
||||||
ReplyMarkup: s.keyboard(memo),
|
ReplyMarkup: s.keyboard(memo),
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -332,8 +329,8 @@ func (s *Service) searchHandler(ctx context.Context, b *bot.Bot, m *models.Updat
|
|||||||
|
|
||||||
filterString := "content_search == ['" + searchString + "']"
|
filterString := "content_search == ['" + searchString + "']"
|
||||||
|
|
||||||
accessToken, _ := userAccessTokenCache.Load(userID)
|
accessToken, _ := s.store.GetUserAccessToken(userID)
|
||||||
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string))))
|
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken)))
|
||||||
results, err := s.client.MemoService.ListMemos(ctx, &v1pb.ListMemosRequest{
|
results, err := s.client.MemoService.ListMemos(ctx, &v1pb.ListMemosRequest{
|
||||||
PageSize: 10,
|
PageSize: 10,
|
||||||
Filter: filterString,
|
Filter: filterString,
|
||||||
@@ -360,8 +357,20 @@ func (s *Service) searchHandler(ctx context.Context, b *bot.Bot, m *models.Updat
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return
|
func (s *Service) processFileMessage(ctx context.Context, b *bot.Bot, m *models.Update, fileID string, memo *v1pb.Memo) {
|
||||||
|
file, err := b.GetFile(ctx, &bot.GetFileParams{FileID: fileID})
|
||||||
|
if err != nil {
|
||||||
|
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to get file"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.saveResourceFromFile(ctx, file, memo)
|
||||||
|
if err != nil {
|
||||||
|
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to save resource"))
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) saveResourceFromFile(ctx context.Context, file *models.File, memo *v1pb.Memo) (*v1pb.Resource, error) {
|
func (s *Service) saveResourceFromFile(ctx context.Context, file *models.File, memo *v1pb.Memo) (*v1pb.Resource, error) {
|
||||||
@@ -397,20 +406,6 @@ func (s *Service) saveResourceFromFile(ctx context.Context, file *models.File, m
|
|||||||
return resource, nil
|
return resource, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) processFileMessage(ctx context.Context, b *bot.Bot, m *models.Update, fileID string, memo *v1pb.Memo) {
|
|
||||||
file, err := b.GetFile(ctx, &bot.GetFileParams{FileID: fileID})
|
|
||||||
if err != nil {
|
|
||||||
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to get file"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = s.saveResourceFromFile(ctx, file, memo)
|
|
||||||
if err != nil {
|
|
||||||
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to save resource"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Service) sendError(b *bot.Bot, chatID int64, err error) {
|
func (s *Service) sendError(b *bot.Bot, chatID int64, err error) {
|
||||||
slog.Error("error", slog.Any("err", err))
|
slog.Error("error", slog.Any("err", err))
|
||||||
b.SendMessage(context.Background(), &bot.SendMessageParams{
|
b.SendMessage(context.Background(), &bot.SendMessageParams{
|
||||||
|
|||||||
29
store/store.go
Normal file
29
store/store.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Store struct {
|
||||||
|
Data string
|
||||||
|
|
||||||
|
userAccessTokenCache sync.Map // map[int64]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStore(data string) *Store {
|
||||||
|
return &Store{
|
||||||
|
Data: data,
|
||||||
|
|
||||||
|
userAccessTokenCache: sync.Map{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) Init() error {
|
||||||
|
if err := s.loadUserAccessTokenMapFromFile(); err != nil {
|
||||||
|
return errors.Wrap(err, "failed to load user access token map from file")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
100
store/user.go
Normal file
100
store/user.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetUserAccessToken returns the access token for the user.
|
||||||
|
func (s *Store) GetUserAccessToken(userID int64) (string, bool) {
|
||||||
|
accessToken, ok := s.userAccessTokenCache.Load(userID)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return accessToken.(string), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserAccessToken sets the access token for the user.
|
||||||
|
func (s *Store) SetUserAccessToken(userID int64, accessToken string) {
|
||||||
|
s.userAccessTokenCache.Store(userID, accessToken)
|
||||||
|
if err := s.SaveUserAccessTokenMapToFile(); err != nil {
|
||||||
|
slog.Error("failed to save user access token map to file", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveUserAccessTokenMapToFile saves the user access token map to a data file.
|
||||||
|
func (s *Store) SaveUserAccessTokenMapToFile() error {
|
||||||
|
// Open the file for writing
|
||||||
|
file, err := os.OpenFile(s.Data, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Iterate over the user access token map and write each entry to the file
|
||||||
|
s.userAccessTokenCache.Range(func(key, value interface{}) bool {
|
||||||
|
userID := key.(int64)
|
||||||
|
accessToken := value.(string)
|
||||||
|
line := strconv.FormatInt(userID, 10) + ":" + accessToken + "\n"
|
||||||
|
_, err := file.WriteString(line)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) loadUserAccessTokenMapFromFile() error {
|
||||||
|
// Check if the file exists
|
||||||
|
if _, err := os.Stat(s.Data); os.IsNotExist(err) {
|
||||||
|
// Create the file if it doesn't exist
|
||||||
|
file, err := os.Create(s.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open the file
|
||||||
|
file, err := os.Open(s.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
// Read the file line by line
|
||||||
|
scanner := bufio.NewScanner(file)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
// Parse the line and extract the user ID and access token
|
||||||
|
userID, accessToken := parseLine(line)
|
||||||
|
if userID == 0 || accessToken == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Store the user ID and access token in the cache
|
||||||
|
s.userAccessTokenCache.Store(userID, accessToken)
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLine(line string) (int64, string) {
|
||||||
|
parts := strings.Split(line, ":")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return 0, ""
|
||||||
|
}
|
||||||
|
userIDStr := parts[0]
|
||||||
|
accessToken := parts[1]
|
||||||
|
userID, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, ""
|
||||||
|
}
|
||||||
|
return userID, accessToken
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user