diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..c268467 --- /dev/null +++ b/cache.go @@ -0,0 +1,68 @@ +package memogram + +import ( + "sync" + "time" +) + +// Cache is a simple cache implementation +type Cache struct { + sync.RWMutex + items map[string]*CacheItem +} + +type CacheItem struct { + Value interface{} + Expiration time.Time +} + +func NewCache() *Cache { + return &Cache{ + items: make(map[string]*CacheItem), + } +} + +// set adds a key value pair to the cache with a given duration +func (c *Cache) set(key string, value interface{}, duration time.Duration) { + c.Lock() + defer c.Unlock() + c.items[key] = &CacheItem{ + Value: value, + Expiration: time.Now().Add(duration), + } +} + +// get returns a value from the cache if it exists +func (c *Cache) get(key string) (interface{}, bool) { + c.RLock() + defer c.RUnlock() + item, found := c.items[key] + if !found { + return nil, false + } + if time.Now().After(item.Expiration) { + return nil, false + } + return item.Value, true +} + +// deleteExpired deletes all expired key value pairs +func (c *Cache) deleteExpired() { + c.Lock() + defer c.Unlock() + for k, v := range c.items { + if time.Now().After(v.Expiration) { + delete(c.items, k) + } + } +} + +// startGC starts a goroutine to clean expired key value pairs +func (c *Cache) startGC() { + go func() { + for { + <-time.After(5 * time.Minute) + c.deleteExpired() + } + }() +} diff --git a/memogram.go b/memogram.go index 4067193..b9bcfcc 100644 --- a/memogram.go +++ b/memogram.go @@ -8,6 +8,7 @@ import ( "net/http" "path/filepath" "strings" + "time" "github.com/go-telegram/bot" "github.com/go-telegram/bot/models" @@ -25,6 +26,7 @@ type Service struct { client *MemosClient config *Config store *store.Store + cache *Cache } func NewService() (*Service, error) { @@ -48,7 +50,9 @@ func NewService() (*Service, error) { config: config, client: client, store: store, + cache: NewCache(), } + s.cache.startGC() opts := []bot.Option{ bot.WithDefaultHandler(s.handler), @@ -93,6 +97,43 @@ func (s *Service) Start(ctx context.Context) { s.bot.Start(ctx) } +func (s *Service) createMemo(ctx context.Context, content string) (*v1pb.Memo, error) { + memo, err := s.client.MemoService.CreateMemo(ctx, &v1pb.CreateMemoRequest{ + Content: content, + }) + if err != nil { + slog.Error("failed to create memo", slog.Any("err", err)) + return nil, err + } + return memo, nil +} + +func (s *Service) handleMemoCreation(ctx context.Context, m *models.Update, content string) (*v1pb.Memo, error) { + var memo *v1pb.Memo + var err error + + if m.Message.MediaGroupID != "" { + cacheMemo, ok := s.cache.get(m.Message.MediaGroupID) + if !ok { + memo, err = s.createMemo(ctx, content) + if err != nil { + return nil, err + } + + s.cache.set(m.Message.MediaGroupID, memo, 24*time.Hour) + } else { + memo = cacheMemo.(*v1pb.Memo) + } + } else { + memo, err = s.createMemo(ctx, content) + if err != nil { + return nil, err + } + } + + return memo, nil +} + func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) { if strings.HasPrefix(m.Message.Text, "/start ") { s.startHandler(ctx, b, m) @@ -166,11 +207,10 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) { accessToken, _ := s.store.GetUserAccessToken(userID) ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken))) - memo, err := s.client.MemoService.CreateMemo(ctx, &v1pb.CreateMemoRequest{ - Content: content, - }) + + var memo *v1pb.Memo + memo, err := s.handleMemoCreation(ctx, m, content) if err != nil { - slog.Error("failed to create memo", slog.Any("err", err)) b.SendMessage(ctx, &bot.SendMessageParams{ ChatID: m.Message.Chat.ID, Text: "Failed to create memo",