-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch_picture.go
242 lines (223 loc) · 7.39 KB
/
search_picture.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
package handlers
import (
"bytes"
"context"
"errors"
"fmt"
"path"
"strings"
"github.com/PuerkitoBio/goquery"
"github.com/krau/ManyACG/common"
"github.com/krau/ManyACG/config"
"github.com/krau/ManyACG/service"
"github.com/krau/ManyACG/telegram/utils"
"github.com/mymmrac/telego"
"github.com/mymmrac/telego/telegoutil"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func SearchPicture(ctx context.Context, bot *telego.Bot, message telego.Message) {
if message.ReplyToMessage == nil {
utils.ReplyMessage(bot, message, "请使用该命令回复一条图片消息")
return
}
msg, err := utils.ReplyMessage(bot, message, "少女祈祷中...")
if err != nil {
common.Logger.Errorf("reply message failed: %s", err)
return
}
file, err := utils.GetMessagePhotoFile(bot, message.ReplyToMessage)
if err != nil {
bot.EditMessageText(&telego.EditMessageTextParams{
ChatID: msg.Chat.ChatID(),
MessageID: msg.GetMessageID(),
Text: "获取图片文件失败: " + err.Error(),
})
return
}
text, dbExists, err := getDBSearchResultText(ctx, file)
if err != nil {
common.Logger.Errorf("search in db failed: %s", err)
}
if dbExists {
bot.EditMessageText(&telego.EditMessageTextParams{
ChatID: msg.Chat.ChatID(),
MessageID: msg.GetMessageID(),
Text: text,
ParseMode: telego.ModeMarkdownV2,
})
return
} else {
go bot.EditMessageText(&telego.EditMessageTextParams{
ChatID: msg.Chat.ChatID(),
MessageID: msg.GetMessageID(),
Text: "数据库搜索无结果, 使用 ascii2d 搜索中...",
})
}
ascii2dResults, err := getAscii2dSearchResult(file)
if err != nil {
common.Logger.Errorf("search in ascii2d failed: %s", err)
bot.EditMessageText(&telego.EditMessageTextParams{
ChatID: msg.Chat.ChatID(),
MessageID: msg.GetMessageID(),
Text: "ascii2d 搜索失败",
})
return
}
if len(ascii2dResults) == 0 {
bot.EditMessageText(&telego.EditMessageTextParams{
ChatID: msg.Chat.ChatID(),
MessageID: msg.GetMessageID(),
Text: "没有搜索到相似图片",
})
return
}
text = fmt.Sprintf("在 ascii2d 搜索到%d张相似的图片\n\n", len(ascii2dResults))
for _, result := range ascii2dResults {
text += fmt.Sprintf("[%s](%s)\n\n", common.EscapeMarkdown(result.Name), common.EscapeMarkdown(result.Link))
}
thumbFile, err := common.DownloadWithCache(ctx, ascii2dResults[0].Thumbnail, nil)
if err != nil {
common.Logger.Errorf("download thumbnail failed: %s", err)
bot.EditMessageText(&telego.EditMessageTextParams{
ChatID: msg.Chat.ChatID(),
MessageID: msg.GetMessageID(),
Text: text,
ParseMode: telego.ModeMarkdownV2,
})
} else {
_, err = bot.EditMessageMedia(&telego.EditMessageMediaParams{
ChatID: msg.Chat.ChatID(),
MessageID: msg.GetMessageID(),
Media: telegoutil.MediaPhoto(telegoutil.File(telegoutil.NameReader(bytes.NewReader(thumbFile), path.Base(ascii2dResults[0].Thumbnail)))).
WithCaption(text).
WithParseMode(telego.ModeMarkdownV2),
})
if err != nil {
common.Logger.Errorf("edit message media failed: %s", err)
}
}
}
func SearchPictureCallbackQuery(ctx context.Context, bot *telego.Bot, query telego.CallbackQuery) {
if !query.Message.IsAccessible() {
return
}
message := query.Message.(*telego.Message)
file, err := utils.GetMessagePhotoFile(bot, message)
if err != nil {
bot.AnswerCallbackQuery(telegoutil.CallbackQuery(query.ID).WithText("获取图片文件失败: " + err.Error()).WithShowAlert().WithCacheTime(5))
return
}
text, hasResult, err := getDBSearchResultText(ctx, file)
if err != nil {
bot.AnswerCallbackQuery(telegoutil.CallbackQuery(query.ID).WithText(err.Error()).WithShowAlert().WithCacheTime(5))
return
}
if !hasResult {
go bot.AnswerCallbackQuery(telegoutil.CallbackQuery(query.ID).WithText(text).WithCacheTime(5))
} else {
go bot.AnswerCallbackQuery(telegoutil.CallbackQuery(query.ID).WithText("搜索到相似图片").WithCacheTime(5))
}
utils.ReplyMessageWithMarkdown(bot, *message, text)
}
func getDBSearchResultText(ctx context.Context, file []byte) (string, bool, error) {
hash, err := common.GetImagePhashFromReader(bytes.NewReader(file))
if err != nil {
return "", false, fmt.Errorf("获取图片哈希失败: %w", err)
}
pictures, err := service.GetPicturesByHashHammingDistance(ctx, hash, 10)
if err != nil {
return "", false, fmt.Errorf("搜索图片失败: %w", err)
}
channelMessageAvailable := ChannelChatID.ID != 0 || ChannelChatID.Username != ""
enableSite := config.Cfg.API.SiteURL != ""
if len(pictures) == 0 {
return "未在数据库中找到相似图片", false, nil
}
text := fmt.Sprintf("找到%d张相似的图片\n\n", len(pictures))
for _, picture := range pictures {
artworkObjectID, err := primitive.ObjectIDFromHex(picture.ArtworkID)
if err != nil {
common.Logger.Errorf("无效的ObjectID: %s", picture.ID)
continue
}
artwork, err := service.GetArtworkByID(ctx, artworkObjectID)
if err != nil {
common.Logger.Errorf("获取作品信息失败: %s", err)
continue
}
text += fmt.Sprintf("[%s\\_%d](%s)\n",
common.EscapeMarkdown(artwork.Title),
picture.Index+1,
common.EscapeMarkdown(artwork.SourceURL),
)
if channelMessageAvailable && picture.TelegramInfo != nil && picture.TelegramInfo.MessageID != 0 {
text += fmt.Sprintf("[频道消息](%s)\n", utils.GetArtworkPostMessageURL(picture.TelegramInfo.MessageID, ChannelChatID))
}
if enableSite {
text += fmt.Sprintf("[ManyACG](%s)\n\n", config.Cfg.API.SiteURL+"/artwork/"+artwork.ID)
}
}
return text, true, nil
}
type ascii2dResult struct {
Name string
Link string
Thumbnail string
}
const (
ascii2dAPI = "https://ascii2d.net/search/file"
ascii2dURL = "https://ascii2d.net"
)
func getAscii2dSearchResult(file []byte) ([]*ascii2dResult, error) {
respcolor, err := common.Client.R().SetFileBytes("file", "image.jpg", file).Post(ascii2dAPI)
if err != nil {
return nil, fmt.Errorf("请求 ascii2d 失败: %w", err)
}
if respcolor.IsErrorState() {
return nil, fmt.Errorf("请求 ascii2d 失败: %s", respcolor.Status)
}
bovwUrl := func() string {
if respcolor.Response != nil && respcolor.Response.Request != nil && respcolor.Response.Request.Response != nil && respcolor.Response.Request.Response.Header != nil {
return respcolor.Response.Request.Response.Header.Get("Location")
}
return ""
}()
if bovwUrl == "" {
return nil, errors.New("无法获取 bovw 页面")
}
common.Logger.Debugf("getting ascii2d bovw url: %s", bovwUrl)
respbovw, err := common.Client.R().Get(bovwUrl)
if err != nil {
return nil, fmt.Errorf("请求 ascii2d bovw 页面失败: %w", err)
}
if respbovw.IsErrorState() {
return nil, fmt.Errorf("请求 ascii2d bovw 页面失败: %s", respbovw.Status)
}
results := make([]*ascii2dResult, 0)
doc, err := goquery.NewDocumentFromReader(respbovw.Body)
if err != nil {
return nil, fmt.Errorf("解析 ascii2d 页面失败: %w", err)
}
doc.Find(".row.item-box").Each(func(i int, s *goquery.Selection) {
if i >= 10 {
return
}
detail := s.Find(".detail-box h6")
name := detail.First().Find("a").First().Text()
link, exists := detail.Find("a").First().Attr("href")
if !exists {
return
}
thumbnail, exists := s.Find(".image-box img").Attr("src")
if !exists {
return
}
thumbnail = ascii2dURL + thumbnail
results = append(results, &ascii2dResult{
Name: strings.TrimSpace(name),
Link: link,
Thumbnail: thumbnail,
})
})
return results, nil
}