From 3aff22de00895d95cf80e78d4ba7980f011a2df2 Mon Sep 17 00:00:00 2001 From: krau <71133316+krau@users.noreply.github.com> Date: Mon, 27 Jan 2025 00:20:53 +0800 Subject: [PATCH] feat: enhance hybrid search functionality with improved message handling and result limits --- telegram/handlers/query_artwork.go | 77 ++++++++++++++++++++++++++--- telegram/handlers/search_picture.go | 52 +++++++++---------- 2 files changed, 95 insertions(+), 34 deletions(-) diff --git a/telegram/handlers/query_artwork.go b/telegram/handlers/query_artwork.go index a55c2fb..b6f6d0f 100644 --- a/telegram/handlers/query_artwork.go +++ b/telegram/handlers/query_artwork.go @@ -1,6 +1,7 @@ package handlers import ( + "bytes" "context" "errors" "fmt" @@ -8,6 +9,7 @@ import ( "strconv" "strings" + "github.com/duke-git/lancet/v2/slice" "github.com/krau/ManyACG/adapter" "github.com/krau/ManyACG/common" "github.com/krau/ManyACG/config" @@ -19,6 +21,7 @@ import ( "github.com/mymmrac/telego" "github.com/mymmrac/telego/telegoutil" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" ) @@ -86,7 +89,7 @@ func HybridSearchArtworks(ctx context.Context, bot *telego.Bot, message telego.M } _, _, args := telegoutil.ParseCommand(message.Text) if len(args) == 0 { - utils.ReplyMessage(bot, message, "使用方法: /query <搜索内容> [语义比例]\n语义比例为0-1的浮点数, 应位于参数列表最后, 越大越趋向于基于语义搜索, 若不提供, 使用默认值0.8") + utils.ReplyMessage(bot, message, "使用方法: /hybrid <搜索内容> [语义比例]\n语义比例为0-1的浮点数, 应位于参数列表最后, 越大越趋向于基于语义搜索, 若不提供, 使用默认值0.8") return } var hybridSemanticRatio float64 @@ -102,7 +105,7 @@ func HybridSearchArtworks(ctx context.Context, bot *telego.Bot, message telego.M } queryText = strings.Join(args[:len(args)-1], " ") } - artworks, err := service.HybridSearchArtworks(ctx, queryText, hybridSemanticRatio, 0, 10) + artworks, err := service.HybridSearchArtworks(ctx, queryText, hybridSemanticRatio, 0, 50) if err != nil { common.Logger.Errorf("搜索失败: %s", err) utils.ReplyMessage(bot, message, "搜索失败, 请联系管理员检查搜索引擎设置与状态") @@ -114,7 +117,7 @@ func HybridSearchArtworks(ctx context.Context, bot *telego.Bot, message telego.M } if len(artworks) > 10 { - artworks = artworks[:10] + artworks = slice.Shuffle(artworks)[:10] } inputMedias := make([]telego.InputMedia, 0, len(artworks)) @@ -147,12 +150,53 @@ func SearchSimilarArtworks(ctx context.Context, bot *telego.Bot, message telego. return } if message.ReplyToMessage == nil { - utils.ReplyMessage(bot, message, "请回复一张图片") + utils.ReplyMessage(bot, message, "请回复一条包含图片或作品链接的消息") return } - sourceURL := utils.FindSourceURLForMessage(message.ReplyToMessage) + var sourceURL string + sourceURL = utils.FindSourceURLForMessage(message.ReplyToMessage) + if sourceURL == "" { + if message.ReplyToMessage.Photo != nil || message.ReplyToMessage.Document != nil { + handleGetSourceURLFromPicture := func() (string, error) { + file, err := utils.GetMessagePhotoFile(bot, message.ReplyToMessage) + if err != nil { + return "", err + } + hash, err := common.GetImagePhashFromReader(bytes.NewReader(file)) + if err != nil { + return "", err + } + pictures, err := service.GetPicturesByHashHammingDistance(ctx, hash, 10) + if err != nil { + return "", err + } + if len(pictures) == 0 { + return "", errors.New("not found similar pictures by hash") + } + picture := pictures[0] + artworkID, err := primitive.ObjectIDFromHex(picture.ArtworkID) + if err != nil { + return "", err + } + artwork, err := service.GetArtworkByID(ctx, artworkID) + if err != nil { + return "", err + } + return artwork.SourceURL, nil + } + var err error + sourceURL, err = handleGetSourceURLFromPicture() + if err != nil { + common.Logger.Warnf("获取图片链接失败: %s", err) + utils.ReplyMessage(bot, message, "回复的消息中未找到支持的链接或图片") + return + } + } else { + utils.ReplyMessage(bot, message, "回复的消息中未找到支持的链接") + return + } + } if sourceURL == "" { - utils.ReplyMessage(bot, message, "回复的消息中未找到支持的链接") return } artwork, err := service.GetArtworkByURL(ctx, sourceURL) @@ -161,7 +205,24 @@ func SearchSimilarArtworks(ctx context.Context, bot *telego.Bot, message telego. utils.ReplyMessage(bot, message, "获取作品信息失败") return } - artworks, err := service.SearchSimilarArtworks(ctx, artwork.ID, 0, 10) + _, _, args := telegoutil.ParseCommand(message.Text) + offset := 0 + limit := 50 + if len(args) > 0 { + offset, err = strconv.Atoi(args[0]) + if err != nil || offset < 0 { + utils.ReplyMessage(bot, message, "参数错误: 偏移量应为非负整数") + return + } + } + if len(args) > 1 { + limit, err = strconv.Atoi(args[1]) + if err != nil || limit < 1 || limit > 100 { + utils.ReplyMessage(bot, message, "参数错误: 限制数量应为1-10的整数") + return + } + } + artworks, err := service.SearchSimilarArtworks(ctx, artwork.ID, int64(offset), int64(limit)) if err != nil { common.Logger.Errorf("搜索失败: %s", err) utils.ReplyMessage(bot, message, "搜索失败") @@ -172,7 +233,7 @@ func SearchSimilarArtworks(ctx context.Context, bot *telego.Bot, message telego. return } if len(artworks) > 10 { - artworks = artworks[:10] + artworks = slice.Shuffle(artworks)[:10] } inputMedias := make([]telego.InputMedia, 0, len(artworks)) for _, artwork := range artworks { diff --git a/telegram/handlers/search_picture.go b/telegram/handlers/search_picture.go index 264ca6a..4006bb7 100644 --- a/telegram/handlers/search_picture.go +++ b/telegram/handlers/search_picture.go @@ -139,34 +139,34 @@ func getDBSearchResultText(ctx context.Context, file []byte) (string, bool, erro } channelMessageAvailable := ChannelChatID.ID != 0 || ChannelChatID.Username != "" enableSite := config.Cfg.API.SiteURL != "" - if len(pictures) > 0 { - text := fmt.Sprintf("找到%d张相似的图片\n\n", len(pictures)) - for _, picture := range pictures { - artworkObjectID, err := primitive.ObjectIDFromHex(picture.ArtworkID) - if err != nil { - common.Logger.Errorf("无效的ObjectID: %s", picture.ID) - continue - } - artwork, err := service.GetArtworkByID(ctx, artworkObjectID) - if err != nil { - common.Logger.Errorf("获取作品信息失败: %s", err) - continue - } - text += fmt.Sprintf("[%s\\_%d](%s)\n", - common.EscapeMarkdown(artwork.Title), - picture.Index+1, - common.EscapeMarkdown(artwork.SourceURL), - ) - if channelMessageAvailable && picture.TelegramInfo != nil && picture.TelegramInfo.MessageID != 0 { - text += fmt.Sprintf("[频道消息](%s)\n", utils.GetArtworkPostMessageURL(picture.TelegramInfo.MessageID, ChannelChatID)) - } - if enableSite { - text += fmt.Sprintf("[ManyACG](%s)\n\n", config.Cfg.API.SiteURL+"/artwork/"+artwork.ID) - } + if len(pictures) == 0 { + return "未在数据库中找到相似图片", false, nil + } + text := fmt.Sprintf("找到%d张相似的图片\n\n", len(pictures)) + for _, picture := range pictures { + artworkObjectID, err := primitive.ObjectIDFromHex(picture.ArtworkID) + if err != nil { + common.Logger.Errorf("无效的ObjectID: %s", picture.ID) + continue + } + artwork, err := service.GetArtworkByID(ctx, artworkObjectID) + if err != nil { + common.Logger.Errorf("获取作品信息失败: %s", err) + continue + } + text += fmt.Sprintf("[%s\\_%d](%s)\n", + common.EscapeMarkdown(artwork.Title), + picture.Index+1, + common.EscapeMarkdown(artwork.SourceURL), + ) + if channelMessageAvailable && picture.TelegramInfo != nil && picture.TelegramInfo.MessageID != 0 { + text += fmt.Sprintf("[频道消息](%s)\n", utils.GetArtworkPostMessageURL(picture.TelegramInfo.MessageID, ChannelChatID)) + } + if enableSite { + text += fmt.Sprintf("[ManyACG](%s)\n\n", config.Cfg.API.SiteURL+"/artwork/"+artwork.ID) } - return text, true, nil } - return "未在数据库中找到相似图片", false, nil + return text, true, nil } type ascii2dResult struct {