Skip to content

Commit bfd1591

Browse files
committed
feat: add tagger integration and enable automatic tagging for artworks
1 parent 1832c59 commit bfd1591

File tree

10 files changed

+296
-6
lines changed

10 files changed

+296
-6
lines changed

common/common.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import (
88

99
func Init() {
1010
initHttpClient()
11-
initResendClient()
1211
initLogger()
12+
if config.Cfg.Auth.Resend.APIKey != "" {
13+
initResendClient()
14+
}
1315
if searchCfg := config.Cfg.Search; searchCfg.Enable {
1416
switch searchCfg.Engine {
1517
case "meilisearch":
@@ -19,4 +21,7 @@ func Init() {
1921
os.Exit(1)
2022
}
2123
}
24+
if taggerCfg := config.Cfg.Tagger; taggerCfg.Enable {
25+
initTaggerClient()
26+
}
2227
}

common/tagger_client.go

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package common
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"os"
8+
"time"
9+
10+
"github.com/imroc/req/v3"
11+
"github.com/krau/ManyACG/config"
12+
)
13+
14+
type taggerClient struct {
15+
Client *req.Client
16+
host string
17+
token string
18+
timeout time.Duration
19+
}
20+
21+
func (c *taggerClient) Health() (string, error) {
22+
var health struct {
23+
Status string `json:"status"`
24+
}
25+
resp, err := c.Client.R().Get("/health")
26+
if err != nil {
27+
return "", err
28+
}
29+
if err := json.Unmarshal(resp.Bytes(), &health); err != nil {
30+
return "", err
31+
}
32+
return health.Status, nil
33+
}
34+
35+
type taggerPredictResponse struct {
36+
PredictedTags []string `json:"predicted_tags"`
37+
Scores map[string]float64 `json:"scores"`
38+
}
39+
40+
func (c *taggerClient) Predict(ctx context.Context, file []byte) (*taggerPredictResponse, error) {
41+
resp, err := c.Client.R().SetContext(ctx).SetFileBytes("file", "image", file).Post("/predict")
42+
if err != nil {
43+
return nil, err
44+
}
45+
if resp.IsErrorState() {
46+
return nil, fmt.Errorf("tagger predict failed: %s", resp.Status)
47+
}
48+
var predict taggerPredictResponse
49+
if err := json.Unmarshal(resp.Bytes(), &predict); err != nil {
50+
return nil, err
51+
}
52+
return &predict, nil
53+
}
54+
55+
var TaggerClient *taggerClient
56+
57+
func initTaggerClient() {
58+
if config.Cfg.Tagger.Host == "" || config.Cfg.Tagger.Token == "" {
59+
Logger.Fatalf("Tagger configuration is incomplete")
60+
os.Exit(1)
61+
}
62+
client := req.C().
63+
SetCommonBearerAuthToken(config.Cfg.Tagger.Token).
64+
SetBaseURL(config.Cfg.Tagger.Host).
65+
SetTimeout(time.Duration(config.Cfg.Tagger.Timeout) * time.Second).
66+
SetUserAgent("ManyACG/" + Version)
67+
TaggerClient = &taggerClient{
68+
Client: client,
69+
host: config.Cfg.Tagger.Host,
70+
token: config.Cfg.Tagger.Token,
71+
timeout: time.Duration(config.Cfg.Tagger.Timeout) * time.Second,
72+
}
73+
if status, err := TaggerClient.Health(); err != nil {
74+
Logger.Fatalf("Tagger health check failed: %s", err)
75+
os.Exit(1)
76+
} else {
77+
Logger.Infof("Tagger health check: %s", status)
78+
}
79+
}

config/tagger.go

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package config
2+
3+
type taggerConfig struct {
4+
Enable bool `toml:"enable" mapstructure:"enable" json:"enable" yaml:"enable"`
5+
Host string `toml:"host" mapstructure:"host" json:"host" yaml:"host"`
6+
Token string `toml:"token" mapstructure:"token" json:"token" yaml:"token"`
7+
Timeout int `toml:"timeout" mapstructure:"timeout" json:"timeout" yaml:"timeout"`
8+
TagNew bool `toml:"tagnew" mapstructure:"tagnew" json:"tagnew" yaml:"tagnew"`
9+
}

config/viper.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
type Config struct {
1313
Debug bool `toml:"debug" mapstructure:"debug" json:"debug" yaml:"debug"`
1414
WSRVURL string `toml:"wsrv_url" mapstructure:"wsrv_url" json:"wsrv_url" yaml:"wsrv_url"`
15-
Search searchConfig `toml:"search" mapstructure:"search" json:"search" yaml:"search"`
1615
API apiConfig `toml:"api" mapstructure:"api" json:"api" yaml:"api"`
1716
Auth authConfig `toml:"auth" mapstructure:"auth" json:"auth" yaml:"auth"`
1817
Fetcher fetcherConfig `toml:"fetcher" mapstructure:"fetcher" json:"fetcher" yaml:"fetcher"`
@@ -21,6 +20,8 @@ type Config struct {
2120
Storage storageConfigs `toml:"storage" mapstructure:"storage" json:"storage" yaml:"storage"`
2221
Telegram telegramConfig `toml:"telegram" mapstructure:"telegram" json:"telegram" yaml:"telegram"`
2322
Database databaseConfig `toml:"database" mapstructure:"database" json:"database" yaml:"database"`
23+
Search searchConfig `toml:"search" mapstructure:"search" json:"search" yaml:"search"`
24+
Tagger taggerConfig `toml:"tagger" mapstructure:"tagger" json:"tagger" yaml:"tagger"`
2425
}
2526

2627
type fetcherConfig struct {

service/change_stream.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ func (m *artworkSyncManager) ProcessArtworkUpdateEvent(event bson.M) {
112112
}
113113
task, err := common.MeilisearchClient.Index(config.Cfg.Search.MeiliSearch.Index).UpdateDocuments(artworkJSON)
114114
if err != nil {
115-
common.Logger.Errorf("add artwork to meilisearch error: %s", err)
115+
common.Logger.Errorf("update artwork to meilisearch error: %s", err)
116116
return
117117
}
118-
common.Logger.Debugf("commited add artwork task to meilisearch: %d", task.TaskUID)
118+
common.Logger.Debugf("commited update artwork task to meilisearch: %d", task.TaskUID)
119119
}
120120

121121
func (m *artworkSyncManager) ProcessArtworkDeleteEvent(event bson.M) {

service/service.go

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ func InitService() {
1212
if config.Cfg.Search.Enable {
1313
go syncArtworkToSearchEngine()
1414
}
15+
if config.Cfg.Tagger.Enable {
16+
go listenPredictArtworkTagsTask()
17+
}
1518
}
1619

1720
type Service struct{}

service/tags.go

+63
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ import (
66
"fmt"
77

88
"github.com/duke-git/lancet/v2/slice"
9+
"github.com/krau/ManyACG/common"
910
"github.com/krau/ManyACG/dao"
1011
"github.com/krau/ManyACG/errs"
12+
"github.com/krau/ManyACG/storage"
1113
"github.com/krau/ManyACG/types"
1214
"go.mongodb.org/mongo-driver/bson"
1315
"go.mongodb.org/mongo-driver/bson/primitive"
@@ -167,3 +169,64 @@ func AddTagAliasByID(ctx context.Context, tagID primitive.ObjectID, alias ...str
167169
}
168170
return result.(*types.TagModel), nil
169171
}
172+
173+
func PredictArtworkTagsByIDAndUpdate(ctx context.Context, artworkID primitive.ObjectID) error {
174+
if common.TaggerClient == nil {
175+
return errors.New("tagger not available")
176+
}
177+
artwork, err := GetArtworkByID(ctx, artworkID)
178+
if err != nil {
179+
return err
180+
}
181+
predictedTags := make([]string, 0)
182+
for _, picture := range artwork.Pictures {
183+
var pictureFile []byte
184+
if picture.StorageInfo.Regular != nil {
185+
pictureFile, err = storage.GetFile(ctx, picture.StorageInfo.Regular)
186+
} else if picture.StorageInfo.Original != nil {
187+
pictureFile, err = storage.GetFile(ctx, picture.StorageInfo.Original)
188+
} else {
189+
pictureFile, err = common.DownloadWithCache(ctx, picture.Original, nil)
190+
}
191+
if err != nil {
192+
return err
193+
}
194+
common.Logger.Debugf("predict picture %s", picture.Original)
195+
result, err := common.TaggerClient.Predict(ctx, pictureFile)
196+
if err != nil {
197+
common.Logger.Errorf("predict picture %s error: %s", picture.Original, err)
198+
continue
199+
}
200+
if len(result.PredictedTags) == 0 {
201+
continue
202+
}
203+
predictedTags = slice.Union(predictedTags, result.PredictedTags)
204+
}
205+
newTags := slice.Compact(slice.Union(artwork.Tags, predictedTags))
206+
if err := UpdateArtworkTagsByURL(ctx, artwork.SourceURL, newTags); err != nil {
207+
return err
208+
}
209+
return nil
210+
}
211+
212+
type predictArtworkTagsTask struct {
213+
ArtworkID primitive.ObjectID
214+
Ctx context.Context
215+
}
216+
217+
var predictArtworkTagsTaskChan = make(chan *predictArtworkTagsTask)
218+
219+
func AddPredictArtworkTagTask(ctx context.Context, artworkID primitive.ObjectID) {
220+
predictArtworkTagsTaskChan <- &predictArtworkTagsTask{
221+
ArtworkID: artworkID,
222+
Ctx: ctx,
223+
}
224+
}
225+
226+
func listenPredictArtworkTagsTask() {
227+
for task := range predictArtworkTagsTaskChan {
228+
if err := PredictArtworkTagsByIDAndUpdate(task.Ctx, task.ArtworkID); err != nil {
229+
common.Logger.Errorf("predict artwork %s tags error: %s", task.ArtworkID.Hex(), err)
230+
}
231+
}
232+
}

telegram/handlers/admin_edit_artwork.go

+119
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ package handlers
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"strconv"
78
"strings"
89

10+
"github.com/duke-git/lancet/v2/slice"
911
"github.com/krau/ManyACG/common"
1012
"github.com/krau/ManyACG/service"
1113
"github.com/krau/ManyACG/sources"
14+
"github.com/krau/ManyACG/storage"
1215
"github.com/krau/ManyACG/telegram/utils"
1316
"github.com/krau/ManyACG/types"
1417

@@ -324,3 +327,119 @@ func ReCaptionArtwork(ctx context.Context, bot *telego.Bot, message telego.Messa
324327
})
325328
utils.ReplyMessage(bot, message, "已重新生成作品描述")
326329
}
330+
331+
func AutoTaggingArtwork(ctx context.Context, bot *telego.Bot, message telego.Message) {
332+
if !CheckPermissionInGroup(ctx, message, types.PermissionEditArtwork) {
333+
utils.ReplyMessage(bot, message, "你没有编辑作品的权限")
334+
return
335+
}
336+
if common.TaggerClient == nil {
337+
utils.ReplyMessage(bot, message, "Tagger is not available")
338+
}
339+
var sourceURL string
340+
var findUrlInArgs bool
341+
if message.ReplyToMessage != nil {
342+
sourceURL = utils.FindSourceURLForMessage(message.ReplyToMessage)
343+
} else {
344+
sourceURL = sources.FindSourceURL(message.Text)
345+
findUrlInArgs = true
346+
}
347+
if sourceURL == "" {
348+
utils.ReplyMessage(bot, message, "请回复一条消息, 或者指定作品链接")
349+
return
350+
}
351+
352+
artwork, err := service.GetArtworkByURL(ctx, sourceURL)
353+
if err != nil {
354+
utils.ReplyMessage(bot, message, "获取作品信息失败: "+err.Error())
355+
return
356+
}
357+
selectAllPictures := true
358+
pictureIndex := 1
359+
_, _, args := telegoutil.ParseCommand(message.Text)
360+
if len(args) > func() int {
361+
if findUrlInArgs {
362+
return 1
363+
}
364+
return 0
365+
}() {
366+
selectAllPictures = false
367+
pictureIndex, err := strconv.Atoi(args[len(args)-1])
368+
if err != nil {
369+
utils.ReplyMessage(bot, message, "图片序号错误")
370+
return
371+
}
372+
if pictureIndex < 1 || pictureIndex > len(artwork.Pictures) {
373+
utils.ReplyMessage(bot, message, "图片序号超出范围")
374+
return
375+
}
376+
}
377+
pictures := make([]*types.Picture, 0)
378+
if selectAllPictures {
379+
pictures = artwork.Pictures
380+
} else {
381+
picture := artwork.Pictures[pictureIndex-1]
382+
pictures[0] = picture
383+
}
384+
msg, err := utils.ReplyMessage(bot, message, "正在请求...")
385+
if err != nil {
386+
common.Logger.Errorf("Reply message failed: %s", err)
387+
return
388+
}
389+
for i, picture := range pictures {
390+
var file []byte
391+
file, err = storage.GetFile(ctx, func() *types.StorageDetail {
392+
if picture.StorageInfo.Regular != nil {
393+
return picture.StorageInfo.Regular
394+
} else {
395+
return picture.StorageInfo.Original
396+
}
397+
}())
398+
if err != nil {
399+
file, err = common.DownloadWithCache(ctx, picture.Original, nil)
400+
}
401+
if err != nil {
402+
common.Logger.Errorf("Download picture %s failed: %s", picture.Original, err)
403+
continue
404+
}
405+
common.Logger.Debugf("Predicting tags for %s", picture.Original)
406+
predict, err := common.TaggerClient.Predict(ctx, file)
407+
if err != nil {
408+
common.Logger.Errorf("Predict tags failed: %s", err)
409+
utils.ReplyMessage(bot, message, "Predict tags failed")
410+
return
411+
}
412+
if len(predict.PredictedTags) == 0 {
413+
utils.ReplyMessage(bot, message, "No tags predicted")
414+
return
415+
}
416+
newTags := slice.Union(artwork.Tags, predict.PredictedTags)
417+
if err := service.UpdateArtworkTagsByURL(ctx, artwork.SourceURL, newTags); err != nil {
418+
utils.ReplyMessage(bot, message, "更新作品标签失败: "+err.Error())
419+
return
420+
}
421+
artwork, err = service.GetArtworkByURL(ctx, artwork.SourceURL)
422+
if err != nil {
423+
utils.ReplyMessage(bot, message, "获取更新后的作品信息失败: "+err.Error())
424+
return
425+
}
426+
if artwork.Pictures[0].TelegramInfo.MessageID != 0 {
427+
bot.EditMessageCaption(&telego.EditMessageCaptionParams{
428+
ChatID: ChannelChatID,
429+
MessageID: artwork.Pictures[0].TelegramInfo.MessageID,
430+
Caption: utils.GetArtworkHTMLCaption(artwork),
431+
ParseMode: telego.ModeHTML,
432+
})
433+
}
434+
bot.EditMessageText(&telego.EditMessageTextParams{
435+
ChatID: msg.Chat.ChatID(),
436+
MessageID: msg.MessageID,
437+
Text: fmt.Sprintf("选择的第 %d 张图片预测标签成功", i+1),
438+
})
439+
}
440+
bot.EditMessageText(&telego.EditMessageTextParams{
441+
ChatID: msg.Chat.ChatID(),
442+
MessageID: msg.MessageID,
443+
Text: "更新作品标签成功",
444+
})
445+
}

telegram/handlers/handlers.go

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func RegisterHandlers(hg *telegohandler.HandlerGroup) {
4646
mg.HandleMessageCtx(AddTagAlias, telegohandler.CommandEqual("tagalias"))
4747
mg.HandleMessageCtx(DumpArtworkInfo, telegohandler.CommandEqual("dump"))
4848
mg.HandleMessageCtx(ReCaptionArtwork, telegohandler.CommandEqual("recaption"))
49+
mg.HandleMessageCtx(AutoTaggingArtwork, telegohandler.CommandEqual("autotag"))
4950

5051
hg.HandleCallbackQueryCtx(PostArtworkCallbackQuery, telegohandler.CallbackDataContains("post_artwork"))
5152
hg.HandleCallbackQueryCtx(SearchPictureCallbackQuery, telegohandler.CallbackDataPrefix("search_picture"))

telegram/utils/post_artwork.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,18 @@ func PostAndCreateArtwork(ctx context.Context, artwork *types.Artwork, bot *tele
248248
}
249249

250250
func afterCreate(ctx context.Context, artwork *types.Artwork, bot *telego.Bot, fromID int64) {
251-
for _, picture := range artwork.Pictures {
252-
service.AddProcessPictureTask(ctx, picture)
251+
go func() {
252+
for _, picture := range artwork.Pictures {
253+
service.AddProcessPictureTask(ctx, picture)
254+
}
255+
}()
256+
if config.Cfg.Tagger.TagNew {
257+
objectID, err := primitive.ObjectIDFromHex(artwork.ID)
258+
if err != nil {
259+
common.Logger.Fatalf("invalid ObjectID: %s", artwork.ID)
260+
return
261+
}
262+
go service.AddPredictArtworkTagTask(ctx, objectID)
253263
}
254264
go checkDuplicate(ctx, artwork, bot, fromID)
255265
go prettyPostedArtworkTagCaption(ctx, artwork, bot)

0 commit comments

Comments
 (0)