Skip to content

Commit

Permalink
feat: enhance hybrid search functionality with improved message handl…
Browse files Browse the repository at this point in the history
…ing and result limits
  • Loading branch information
krau committed Jan 26, 2025
1 parent 7a1a2fd commit 3aff22d
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 34 deletions.
77 changes: 69 additions & 8 deletions telegram/handlers/query_artwork.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package handlers

import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
"strconv"
"strings"

"github.com/duke-git/lancet/v2/slice"
"github.com/krau/ManyACG/adapter"
"github.com/krau/ManyACG/common"
"github.com/krau/ManyACG/config"
Expand All @@ -19,6 +21,7 @@ import (

"github.com/mymmrac/telego"
"github.com/mymmrac/telego/telegoutil"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
)

Expand Down Expand Up @@ -86,7 +89,7 @@ func HybridSearchArtworks(ctx context.Context, bot *telego.Bot, message telego.M
}
_, _, args := telegoutil.ParseCommand(message.Text)
if len(args) == 0 {
utils.ReplyMessage(bot, message, "使用方法: /query <搜索内容> [语义比例]\n语义比例为0-1的浮点数, 应位于参数列表最后, 越大越趋向于基于语义搜索, 若不提供, 使用默认值0.8")
utils.ReplyMessage(bot, message, "使用方法: /hybrid <搜索内容> [语义比例]\n语义比例为0-1的浮点数, 应位于参数列表最后, 越大越趋向于基于语义搜索, 若不提供, 使用默认值0.8")
return
}
var hybridSemanticRatio float64
Expand All @@ -102,7 +105,7 @@ func HybridSearchArtworks(ctx context.Context, bot *telego.Bot, message telego.M
}
queryText = strings.Join(args[:len(args)-1], " ")
}
artworks, err := service.HybridSearchArtworks(ctx, queryText, hybridSemanticRatio, 0, 10)
artworks, err := service.HybridSearchArtworks(ctx, queryText, hybridSemanticRatio, 0, 50)
if err != nil {
common.Logger.Errorf("搜索失败: %s", err)
utils.ReplyMessage(bot, message, "搜索失败, 请联系管理员检查搜索引擎设置与状态")
Expand All @@ -114,7 +117,7 @@ func HybridSearchArtworks(ctx context.Context, bot *telego.Bot, message telego.M
}

if len(artworks) > 10 {
artworks = artworks[:10]
artworks = slice.Shuffle(artworks)[:10]
}

inputMedias := make([]telego.InputMedia, 0, len(artworks))
Expand Down Expand Up @@ -147,12 +150,53 @@ func SearchSimilarArtworks(ctx context.Context, bot *telego.Bot, message telego.
return
}
if message.ReplyToMessage == nil {
utils.ReplyMessage(bot, message, "请回复一张图片")
utils.ReplyMessage(bot, message, "请回复一条包含图片或作品链接的消息")
return
}
sourceURL := utils.FindSourceURLForMessage(message.ReplyToMessage)
var sourceURL string
sourceURL = utils.FindSourceURLForMessage(message.ReplyToMessage)
if sourceURL == "" {
if message.ReplyToMessage.Photo != nil || message.ReplyToMessage.Document != nil {
handleGetSourceURLFromPicture := func() (string, error) {
file, err := utils.GetMessagePhotoFile(bot, message.ReplyToMessage)
if err != nil {
return "", err
}
hash, err := common.GetImagePhashFromReader(bytes.NewReader(file))
if err != nil {
return "", err
}
pictures, err := service.GetPicturesByHashHammingDistance(ctx, hash, 10)
if err != nil {
return "", err
}
if len(pictures) == 0 {
return "", errors.New("not found similar pictures by hash")
}
picture := pictures[0]
artworkID, err := primitive.ObjectIDFromHex(picture.ArtworkID)
if err != nil {
return "", err
}
artwork, err := service.GetArtworkByID(ctx, artworkID)
if err != nil {
return "", err
}
return artwork.SourceURL, nil
}
var err error
sourceURL, err = handleGetSourceURLFromPicture()
if err != nil {
common.Logger.Warnf("获取图片链接失败: %s", err)
utils.ReplyMessage(bot, message, "回复的消息中未找到支持的链接或图片")
return
}
} else {
utils.ReplyMessage(bot, message, "回复的消息中未找到支持的链接")
return
}
}
if sourceURL == "" {
utils.ReplyMessage(bot, message, "回复的消息中未找到支持的链接")
return
}
artwork, err := service.GetArtworkByURL(ctx, sourceURL)
Expand All @@ -161,7 +205,24 @@ func SearchSimilarArtworks(ctx context.Context, bot *telego.Bot, message telego.
utils.ReplyMessage(bot, message, "获取作品信息失败")
return
}
artworks, err := service.SearchSimilarArtworks(ctx, artwork.ID, 0, 10)
_, _, args := telegoutil.ParseCommand(message.Text)
offset := 0
limit := 50
if len(args) > 0 {
offset, err = strconv.Atoi(args[0])
if err != nil || offset < 0 {
utils.ReplyMessage(bot, message, "参数错误: 偏移量应为非负整数")
return
}
}
if len(args) > 1 {
limit, err = strconv.Atoi(args[1])
if err != nil || limit < 1 || limit > 100 {
utils.ReplyMessage(bot, message, "参数错误: 限制数量应为1-10的整数")
return
}
}
artworks, err := service.SearchSimilarArtworks(ctx, artwork.ID, int64(offset), int64(limit))
if err != nil {
common.Logger.Errorf("搜索失败: %s", err)
utils.ReplyMessage(bot, message, "搜索失败")
Expand All @@ -172,7 +233,7 @@ func SearchSimilarArtworks(ctx context.Context, bot *telego.Bot, message telego.
return
}
if len(artworks) > 10 {
artworks = artworks[:10]
artworks = slice.Shuffle(artworks)[:10]
}
inputMedias := make([]telego.InputMedia, 0, len(artworks))
for _, artwork := range artworks {
Expand Down
52 changes: 26 additions & 26 deletions telegram/handlers/search_picture.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,34 +139,34 @@ func getDBSearchResultText(ctx context.Context, file []byte) (string, bool, erro
}
channelMessageAvailable := ChannelChatID.ID != 0 || ChannelChatID.Username != ""
enableSite := config.Cfg.API.SiteURL != ""
if len(pictures) > 0 {
text := fmt.Sprintf("找到%d张相似的图片\n\n", len(pictures))
for _, picture := range pictures {
artworkObjectID, err := primitive.ObjectIDFromHex(picture.ArtworkID)
if err != nil {
common.Logger.Errorf("无效的ObjectID: %s", picture.ID)
continue
}
artwork, err := service.GetArtworkByID(ctx, artworkObjectID)
if err != nil {
common.Logger.Errorf("获取作品信息失败: %s", err)
continue
}
text += fmt.Sprintf("[%s\\_%d](%s)\n",
common.EscapeMarkdown(artwork.Title),
picture.Index+1,
common.EscapeMarkdown(artwork.SourceURL),
)
if channelMessageAvailable && picture.TelegramInfo != nil && picture.TelegramInfo.MessageID != 0 {
text += fmt.Sprintf("[频道消息](%s)\n", utils.GetArtworkPostMessageURL(picture.TelegramInfo.MessageID, ChannelChatID))
}
if enableSite {
text += fmt.Sprintf("[ManyACG](%s)\n\n", config.Cfg.API.SiteURL+"/artwork/"+artwork.ID)
}
if len(pictures) == 0 {
return "未在数据库中找到相似图片", false, nil
}
text := fmt.Sprintf("找到%d张相似的图片\n\n", len(pictures))
for _, picture := range pictures {
artworkObjectID, err := primitive.ObjectIDFromHex(picture.ArtworkID)
if err != nil {
common.Logger.Errorf("无效的ObjectID: %s", picture.ID)
continue
}
artwork, err := service.GetArtworkByID(ctx, artworkObjectID)
if err != nil {
common.Logger.Errorf("获取作品信息失败: %s", err)
continue
}
text += fmt.Sprintf("[%s\\_%d](%s)\n",
common.EscapeMarkdown(artwork.Title),
picture.Index+1,
common.EscapeMarkdown(artwork.SourceURL),
)
if channelMessageAvailable && picture.TelegramInfo != nil && picture.TelegramInfo.MessageID != 0 {
text += fmt.Sprintf("[频道消息](%s)\n", utils.GetArtworkPostMessageURL(picture.TelegramInfo.MessageID, ChannelChatID))
}
if enableSite {
text += fmt.Sprintf("[ManyACG](%s)\n\n", config.Cfg.API.SiteURL+"/artwork/"+artwork.ID)
}
return text, true, nil
}
return "未在数据库中找到相似图片", false, nil
return text, true, nil
}

type ascii2dResult struct {
Expand Down

0 comments on commit 3aff22d

Please sign in to comment.