diff --git a/memogram.go b/memogram.go index 2955a5a..8a9c565 100644 --- a/memogram.go +++ b/memogram.go @@ -3,7 +3,10 @@ package memogram import ( "context" "fmt" + "io" "log/slog" + "net/http" + "path/filepath" "strings" "sync" @@ -25,6 +28,7 @@ var userAccessTokenCache sync.Map // map[int64]string type Service struct { config *Config client *MemosClient + bot *bot.Bot } func NewService() (*Service, error) { @@ -40,17 +44,9 @@ func NewService() (*Service, error) { } client := NewMemosClient(conn) - return &Service{ - config, - client, - }, nil -} - -func (s *Service) Start(ctx context.Context) { - config, err := getConfigFromEnv() - if err != nil { - slog.Error("failed to get config from env", slog.Any("err", err)) - return + s := &Service{ + config: config, + client: client, } opts := []bot.Option{ @@ -59,12 +55,16 @@ func (s *Service) Start(ctx context.Context) { b, err := bot.New(config.BotToken, opts...) if err != nil { - slog.Error("failed to create bot", slog.Any("err", err)) - return + return nil, errors.Wrap(err, "failed to create bot") } + s.bot = b - slog.Info("memogram started") - b.Start(ctx) + return s, nil +} + +func (s *Service) Start(ctx context.Context) { + slog.Info("Memogram started") + s.bot.Start(ctx) } func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) { @@ -85,7 +85,8 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) { message := m.Message // TODO: handle message.Entities to get markdown text. text := message.Text - if text == "" { + hasResource := message.Document != nil || len(message.Photo) > 0 + if text == "" && !hasResource { b.SendMessage(ctx, &bot.SendMessageParams{ ChatID: m.Message.Chat.ID, Text: "Please input memo content", @@ -107,6 +108,34 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) { return } + if message.Document != nil { + file, err := b.GetFile(ctx, &bot.GetFileParams{FileID: message.Document.FileID}) + if err != nil { + s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to get file")) + return + } + + _, err = s.saveResourceFromFile(ctx, file, memo) + if err != nil { + s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to save resource")) + return + } + } + if len(message.Photo) > 0 { + photo := message.Photo[len(message.Photo)-1] + file, err := b.GetFile(ctx, &bot.GetFileParams{FileID: photo.FileID}) + if err != nil { + s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to get file")) + return + } + + _, err = s.saveResourceFromFile(ctx, file, memo) + if err != nil { + s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to save resource")) + return + } + } + b.SendMessage(ctx, &bot.SendMessageParams{ ChatID: m.Message.Chat.ID, Text: fmt.Sprintf("Memo created with %s", memo.Name), @@ -133,3 +162,44 @@ func (s *Service) startHandler(ctx context.Context, b *bot.Bot, m *models.Update Text: fmt.Sprintf("Hello %s!", user.Nickname), }) } + +func (s *Service) saveResourceFromFile(ctx context.Context, file *models.File, memo *v1pb.Memo) (*v1pb.Resource, error) { + fileLink := s.bot.FileDownloadLink(file) + response, err := http.Get(fileLink) + if err != nil { + return nil, errors.Wrap(err, "failed to download file") + } + defer response.Body.Close() + + bytes, err := io.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read file") + } + contentType, err := getContentType(fileLink) + if err != nil { + return nil, errors.Wrap(err, "failed to get content type") + } + + resource, err := s.client.ResourceService.CreateResource(ctx, &v1pb.CreateResourceRequest{ + Resource: &v1pb.Resource{ + Filename: filepath.Base(file.FilePath), + Type: contentType, + Size: file.FileSize, + Content: bytes, + Memo: &memo.Name, + }, + }) + if err != nil { + return nil, errors.Wrap(err, "failed to create resource") + } + + return resource, nil +} + +func (s *Service) sendError(b *bot.Bot, chatID int64, err error) { + slog.Error("error", slog.Any("err", err)) + b.SendMessage(context.Background(), &bot.SendMessageParams{ + ChatID: chatID, + Text: fmt.Sprintf("Error: %s", err.Error()), + }) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..57aa0a3 --- /dev/null +++ b/util.go @@ -0,0 +1,41 @@ +package memogram + +import ( + "io" + "mime" + "net/http" + "net/url" + "path" +) + +func getContentType(imageURL string) (string, error) { + resp, err := http.Get(imageURL) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Check if the server provided a Content-Type header. + contentType := resp.Header.Get("Content-Type") + if contentType != "" && contentType != "application/octet-stream" { + return contentType, nil + } + + // Read a few bytes from the body to detect the content type. + buffer := make([]byte, 512) + _, err = io.ReadFull(resp.Body, buffer) + if err != nil && err != io.EOF { + return "", err + } + + // Use the DetectContentType function to get the content type. + contentType = http.DetectContentType(buffer) + if contentType == "application/octet-stream" { + // Try to infer content type from URL if detection fails. + parsedURL, err := url.Parse(imageURL) + if err == nil { + contentType = mime.TypeByExtension(path.Ext(parsedURL.Path)) + } + } + return contentType, nil +}