diff --git a/ROADMAP.md b/ROADMAP.md index aaba58f..621aece 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -50,7 +50,7 @@ * [x] Message unsend * [x] Message reactions * [x] Message edits - * [ ] Message history + * [x] Message history * [ ] Presence * [x] Typing notifications * [x] Read receipts diff --git a/backfill.go b/backfill.go new file mode 100644 index 0000000..1408987 --- /dev/null +++ b/backfill.go @@ -0,0 +1,633 @@ +// mautrix-meta - A Matrix-Facebook Messenger and Instagram DM puppeting bridge. +// Copyright (C) 2024 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "cmp" + "context" + "crypto/sha256" + "encoding/base64" + "fmt" + "slices" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/variationselector" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-meta/database" + "go.mau.fi/mautrix-meta/messagix/socket" + "go.mau.fi/mautrix-meta/messagix/table" +) + +func (user *User) StopBackfillLoop() { + if fn := user.stopBackfillTask.Swap(nil); fn != nil { + (*fn)() + } +} + +type BackfillCollector struct { + *table.UpsertMessages + Source id.UserID + MaxPages int + Forward bool + LastMessage *database.Message + Task *database.BackfillTask + Done func() +} + +func (user *User) handleBackfillTask(ctx context.Context, task *database.BackfillTask) { + log := zerolog.Ctx(ctx) + log.Debug().Any("task", task).Msg("Got backfill task") + portal := user.bridge.GetExistingPortalByThreadID(task.Key) + task.DispatchedAt = time.Now() + task.CompletedAt = time.Time{} + if !portal.MoreToBackfill { + log.Debug().Int64("portal_id", task.Key.ThreadID).Msg("Nothing more to backfill in portal") + task.Finished = true + task.CompletedAt = time.Now() + if err := task.Upsert(ctx); err != nil { + log.Err(err).Msg("Failed to save backfill task") + } + return + } + if err := task.Upsert(ctx); err != nil { + log.Err(err).Msg("Failed to save backfill task") + } + ok := portal.requestMoreHistory(ctx, user, portal.OldestMessageTS, portal.OldestMessageID) + if !ok { + task.CooldownUntil = time.Now().Add(1 * time.Hour) + if err := task.Upsert(ctx); err != nil { + log.Err(err).Msg("Failed to save backfill task") + } + return + } + backfillDone := make(chan struct{}) + doneCallback := sync.OnceFunc(func() { + close(backfillDone) + }) + portal.backfillCollector = &BackfillCollector{ + UpsertMessages: &table.UpsertMessages{ + Range: &table.LSInsertNewMessageRange{ + ThreadKey: portal.ThreadID, + MinTimestampMsTemplate: portal.OldestMessageTS, + MaxTimestampMsTemplate: portal.OldestMessageTS, + MinMessageId: portal.OldestMessageID, + MaxMessageId: portal.OldestMessageID, + MinTimestampMs: portal.OldestMessageTS, + MaxTimestampMs: portal.OldestMessageTS, + HasMoreBefore: true, + HasMoreAfter: true, + }, + }, + Source: user.MXID, + MaxPages: user.bridge.Config.Bridge.Backfill.Queue.PagesAtOnce, + Forward: false, + Task: task, + Done: doneCallback, + } + select { + case <-backfillDone: + case <-ctx.Done(): + return + } + if !portal.MoreToBackfill { + task.Finished = true + } + task.CompletedAt = time.Now() + if err := task.Upsert(ctx); err != nil { + log.Err(err).Msg("Failed to save backfill task") + } + log.Debug().Any("task", task).Msg("Finished backfill task") +} + +func (user *User) BackfillLoop() { + log := user.log.With().Str("action", "backfill loop").Logger() + defer func() { + log.Debug().Msg("Backfill loop stopped") + }() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + oldFn := user.stopBackfillTask.Swap(&cancel) + if oldFn != nil { + (*oldFn)() + } + ctx = log.WithContext(ctx) + var extraTime time.Duration + log.Debug().Msg("Backfill loop started") + for { + select { + case <-time.After(user.bridge.Config.Bridge.Backfill.Queue.SleepBetweenTasks + extraTime): + case <-ctx.Done(): + return + } + + task, err := user.bridge.DB.BackfillTask.GetNext(ctx, user.MXID) + if err != nil { + log.Err(err).Msg("Failed to get next backfill task") + } else if task != nil { + user.handleBackfillTask(ctx, task) + extraTime = 0 + } else if extraTime < 1*time.Minute { + extraTime += 5 * time.Second + } + if ctx.Err() != nil { + return + } + } +} + +func (portal *Portal) requestMoreHistory(ctx context.Context, user *User, minTimestampMS int64, minMessageID string) bool { + resp, err := user.Client.ExecuteTasks(&socket.FetchMessagesTask{ + ThreadKey: portal.ThreadID, + Direction: 0, + ReferenceTimestampMs: minTimestampMS, + ReferenceMessageId: minMessageID, + SyncGroup: 1, + Cursor: user.Client.SyncManager.GetCursor(1), + }) + zerolog.Ctx(ctx).Trace().Any("resp_data", resp).Msg("Response data for fetching messages") + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to request more history") + return false + } else { + zerolog.Ctx(ctx).Debug(). + Int64("min_timestamp_ms", minTimestampMS). + Str("min_message_id", minMessageID). + Msg("Requested more history") + return true + } +} + +var globalUpsertCounter atomic.Int64 + +func (portal *Portal) handleMetaExistingRange(user *User, rng *table.LSUpdateExistingMessageRange) { + portal.backfillLock.Lock() + defer portal.backfillLock.Unlock() + + log := portal.log.With(). + Str("action", "handle meta existing range"). + Stringer("source_mxid", user.MXID). + Int("global_upsert_counter", int(globalUpsertCounter.Add(1))). + Logger() + logEvt := log.Info(). + Int64("timestamp_ms", rng.TimestampMS). + Bool("bool2", rng.UnknownBool2). + Bool("bool3", rng.UnknownBool3) + if portal.backfillCollector == nil { + logEvt.Msg("Ignoring update existing message range command with no backfill collector") + } else if portal.backfillCollector.Source != user.MXID { + logEvt.Stringer("prev_mxid", portal.backfillCollector.Source). + Msg("Ignoring update existing message range command for another user") + } else if portal.backfillCollector.Range.MinTimestampMs != rng.TimestampMS { + logEvt.Int64("prev_timestamp_ms", portal.backfillCollector.Range.MinTimestampMs). + Msg("Ignoring update existing message range command with different timestamp") + } else { + if len(portal.backfillCollector.Messages) == 0 { + logEvt.Msg("Update existing range marked backfill as done, no messages found") + if portal.backfillCollector.Done != nil { + portal.backfillCollector.Done() + } + portal.MoreToBackfill = false + err := portal.Update(log.WithContext(context.TODO())) + if err != nil { + log.Err(err).Msg("Failed to save portal in database") + } + } else { + logEvt.Msg("Update existing range marked backfill as done, processing collected history now") + if rng.UnknownBool2 && !rng.UnknownBool3 { + portal.backfillCollector.Range.HasMoreBefore = false + } else { + portal.backfillCollector.Range.HasMoreAfter = false + } + portal.handleMessageBatch(log.WithContext(context.TODO()), user, portal.backfillCollector.UpsertMessages, portal.backfillCollector.Forward, portal.backfillCollector.LastMessage, portal.backfillCollector.Done) + } + portal.backfillCollector = nil + } +} + +func (portal *Portal) handleMetaUpsertMessages(user *User, upsert *table.UpsertMessages) { + portal.backfillLock.Lock() + defer portal.backfillLock.Unlock() + + if !portal.bridge.Config.Bridge.Backfill.Enabled { + return + } else if upsert.Range == nil { + portal.log.Warn().Int("message_count", len(upsert.Messages)).Msg("Ignoring upsert messages without range") + return + } + log := portal.log.With(). + Str("action", "handle meta upsert"). + Stringer("source_mxid", user.MXID). + Int("global_upsert_counter", int(globalUpsertCounter.Add(1))). + Logger() + log.Info(). + Int64("min_timestamp_ms", upsert.Range.MinTimestampMs). + Str("min_message_id", upsert.Range.MinMessageId). + Int64("max_timestamp_ms", upsert.Range.MaxTimestampMs). + Str("max_message_id", upsert.Range.MaxMessageId). + Bool("has_more_before", upsert.Range.HasMoreBefore). + Bool("has_more_after", upsert.Range.HasMoreAfter). + Int("message_count", len(upsert.Messages)). + Msg("Received upsert messages") + ctx := log.WithContext(context.TODO()) + + // Check if someone is already collecting messages for backfill + if portal.backfillCollector != nil { + if user.MXID != portal.backfillCollector.Source { + log.Warn().Stringer("prev_mxid", portal.backfillCollector.Source).Msg("Ignoring upsert for another user") + return + } else if upsert.Range.MaxTimestampMs > portal.backfillCollector.Range.MinTimestampMs { + log.Warn(). + Int64("prev_min_timestamp_ms", portal.backfillCollector.Range.MinTimestampMs). + Msg("Ignoring unexpected upsert messages while collecting history") + return + } + if portal.backfillCollector.MaxPages > 0 { + portal.backfillCollector.MaxPages-- + } + portal.backfillCollector.UpsertMessages = portal.backfillCollector.Join(upsert) + pageLimitReached := portal.backfillCollector.MaxPages == 0 + endOfChatReached := !upsert.Range.HasMoreBefore + existingMessagesReached := portal.backfillCollector.LastMessage != nil && portal.backfillCollector.Range.MinTimestampMs <= portal.backfillCollector.LastMessage.Timestamp.UnixMilli() + if portal.backfillCollector.Task != nil { + portal.backfillCollector.Task.PageCount++ + if portal.bridge.Config.Bridge.Backfill.Queue.MaxPages >= 0 && portal.backfillCollector.Task.PageCount >= portal.bridge.Config.Bridge.Backfill.Queue.MaxPages { + log.Debug().Any("task", portal.backfillCollector.Task).Msg("Marking backfill task as finished (reached page limit)") + pageLimitReached = true + } + } + logEvt := log.Debug(). + Bool("page_limit_reached", pageLimitReached). + Bool("end_of_chat_reached", endOfChatReached). + Bool("existing_messages_reached", existingMessagesReached) + if !pageLimitReached && !endOfChatReached && !existingMessagesReached { + logEvt.Msg("Requesting more history as collector still has room") + portal.requestMoreHistory(ctx, user, upsert.Range.MinTimestampMs, upsert.Range.MinMessageId) + return + } + logEvt.Msg("Processing collected history now") + portal.handleMessageBatch(ctx, user, portal.backfillCollector.UpsertMessages, portal.backfillCollector.Forward, portal.backfillCollector.LastMessage, portal.backfillCollector.Done) + portal.backfillCollector = nil + return + } + + // No active collector, check the last bridged message + lastMessage, err := portal.bridge.DB.Message.GetLastByTimestamp(ctx, portal.PortalKey, time.Now().Add(1*time.Minute)) + if err != nil { + log.Err(err).Msg("Failed to get last message to check if upsert batch should be handled") + return + } + if lastMessage == nil { + // Chat is empty, request more history or bridge the one received message immediately depending on history_fetch_count + if portal.bridge.Config.Bridge.Backfill.HistoryFetchPages > 0 { + log.Debug().Msg("Got first historical message in empty chat, requesting more") + portal.backfillCollector = &BackfillCollector{ + UpsertMessages: upsert, + Source: user.MXID, + MaxPages: portal.bridge.Config.Bridge.Backfill.HistoryFetchPages, + Forward: true, + } + portal.requestMoreHistory(ctx, user, upsert.Range.MinTimestampMs, upsert.Range.MinMessageId) + } else { + log.Debug().Msg("Got first historical message in empty chat, bridging it immediately") + portal.handleMessageBatch(ctx, user, upsert, true, nil, nil) + } + } else if upsert.Range.MaxTimestampMs > lastMessage.Timestamp.UnixMilli() && upsert.Range.MaxMessageId != lastMessage.ID { + // Chat is not empty and the upsert contains a newer message than the last bridged one, + // request more history to fill the gap or bridge the received one immediately depending on catchup_fetch_count + if portal.bridge.Config.Bridge.Backfill.CatchupFetchPages > 0 { + log.Debug().Msg("Got upsert of new messages, requesting more") + portal.backfillCollector = &BackfillCollector{ + UpsertMessages: upsert, + Source: user.MXID, + MaxPages: portal.bridge.Config.Bridge.Backfill.CatchupFetchPages, + Forward: true, + LastMessage: lastMessage, + } + portal.requestMoreHistory(ctx, user, upsert.Range.MinTimestampMs, upsert.Range.MinMessageId) + } else { + log.Debug().Msg("Got upsert of new messages, bridging them immediately") + portal.handleMessageBatch(ctx, user, upsert, true, lastMessage, nil) + } + } else { + // Chat is not empty and the upsert doesn't contain new messages (and it's not a part of a backfill collector), ignore it. + log.Debug(). + Int64("last_message_ts", lastMessage.Timestamp.UnixMilli()). + Str("last_message_id", lastMessage.ID). + Int64("upsert_max_ts", upsert.Range.MaxTimestampMs). + Str("upsert_max_id", upsert.Range.MaxMessageId). + Msg("Ignoring unrequested upsert before last message") + queueConfig := portal.bridge.Config.Bridge.Backfill.Queue + if lastMessage == nil && queueConfig.MaxPages != 0 && portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) { + task := portal.bridge.DB.BackfillTask.NewWithValues(portal.PortalKey, user.MXID) + err = task.InsertIfNotExists(ctx) + if err != nil { + log.Err(err).Msg("Failed to ensure backfill task exists") + } + } + } +} + +func (portal *Portal) deterministicEventID(msgID string, partIndex int) id.EventID { + data := fmt.Sprintf("%s/%s", portal.MXID, msgID) + if partIndex != 0 { + data = fmt.Sprintf("%s/%d", data, partIndex) + } + sum := sha256.Sum256([]byte(data)) + return id.EventID(fmt.Sprintf("$%s:%s.com", base64.RawURLEncoding.EncodeToString(sum[:]), portal.bridge.BeeperNetworkName)) +} + +type BackfillPartMetadata struct { + Intent *appservice.IntentAPI + MessageID string + OTID int64 + Sender int64 + PartIndex int + EditCount int64 + Reactions []*table.LSUpsertReaction + InBatchReact *table.LSUpsertReaction +} + +func (portal *Portal) handleMessageBatch(ctx context.Context, source *User, upsert *table.UpsertMessages, forward bool, lastMessage *database.Message, doneCallback func()) { + // The messages are probably already sorted in reverse order (newest to oldest). We want to sort them again to be safe, + // but reverse first to make the sorting algorithm's job easier if it's already sorted. + slices.Reverse(upsert.Messages) + slices.SortFunc(upsert.Messages, func(a, b *table.WrappedMessage) int { + key := cmp.Compare(a.PrimarySortKey, b.PrimarySortKey) + if key == 0 { + key = cmp.Compare(a.SecondarySortKey, b.SecondarySortKey) + } + return key + }) + log := zerolog.Ctx(ctx) + upsert.Messages = slices.CompactFunc(upsert.Messages, func(a, b *table.WrappedMessage) bool { + if a.MessageId == b.MessageId { + log.Debug(). + Str("message_id", a.MessageId). + Bool("attachment_counts_match", len(a.XMAAttachments) == len(b.XMAAttachments) && len(a.BlobAttachments) == len(b.BlobAttachments) && len(a.Stickers) == len(b.Stickers)). + Msg("Backfill batch contained duplicate message") + return true + } + return false + }) + if lastMessage != nil { + // For catchup backfills, delete any messages that are older than the last bridged message. + upsert.Messages = slices.DeleteFunc(upsert.Messages, func(message *table.WrappedMessage) bool { + return message.TimestampMs <= lastMessage.Timestamp.UnixMilli() + }) + } + if portal.OldestMessageTS == 0 || portal.OldestMessageTS > upsert.Range.MinTimestampMs { + portal.OldestMessageTS = upsert.Range.MinTimestampMs + portal.OldestMessageID = upsert.Range.MinMessageId + portal.MoreToBackfill = upsert.Range.HasMoreBefore + err := portal.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save oldest message ID/timestamp in database") + } else { + log.Debug(). + Int64("oldest_message_ts", portal.OldestMessageTS). + Str("oldest_message_id", portal.OldestMessageID). + Msg("Saved oldest message ID/timestamp in database") + } + } + if len(upsert.Messages) == 0 { + log.Warn().Msg("Got empty batch of historical messages") + return + } + log.Info(). + Int64("oldest_message_ts", upsert.Messages[0].TimestampMs). + Str("oldest_message_id", upsert.Messages[0].MessageId). + Int64("newest_message_ts", upsert.Messages[len(upsert.Messages)-1].TimestampMs). + Str("newest_message_id", upsert.Messages[len(upsert.Messages)-1].MessageId). + Int("message_count", len(upsert.Messages)). + Bool("has_more_before", upsert.Range.HasMoreBefore). + Msg("Handling batch of historical messages") + if lastMessage == nil && (upsert.Messages[0].TimestampMs != upsert.Range.MinTimestampMs || upsert.Messages[0].MessageId != upsert.Range.MinMessageId) { + log.Warn(). + Int64("min_timestamp_ms", upsert.Range.MinTimestampMs). + Str("min_message_id", upsert.Range.MinMessageId). + Int64("first_message_ts", upsert.Messages[0].TimestampMs). + Str("first_message_id", upsert.Messages[0].MessageId). + Msg("First message in batch doesn't match range") + } + if !forward { + go func() { + if doneCallback != nil { + defer doneCallback() + } + portal.convertAndSendBackfill(ctx, source, upsert.Messages, upsert.MarkRead, forward) + }() + } else { + if doneCallback != nil { + defer doneCallback() + } + portal.convertAndSendBackfill(ctx, source, upsert.Messages, upsert.MarkRead, forward) + queueConfig := portal.bridge.Config.Bridge.Backfill.Queue + if lastMessage == nil && queueConfig.MaxPages != 0 && portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) { + task := portal.bridge.DB.BackfillTask.NewWithValues(portal.PortalKey, source.MXID) + err := task.Upsert(ctx) + if err != nil { + log.Err(err).Msg("Failed to save backfill task after initial backfill") + } else { + log.Debug().Msg("Saved backfill task after initial backfill") + } + } + } +} + +func (portal *Portal) convertAndSendBackfill(ctx context.Context, source *User, messages []*table.WrappedMessage, markRead, forward bool) { + log := zerolog.Ctx(ctx) + events := make([]*event.Event, 0, len(messages)) + metas := make([]*BackfillPartMetadata, 0, len(messages)) + ctx = context.WithValue(ctx, msgconvContextKeyClient, source.Client) + if forward { + ctx = context.WithValue(ctx, msgconvContextKeyBackfill, backfillTypeForward) + } else { + ctx = context.WithValue(ctx, msgconvContextKeyBackfill, backfillTypeHistorical) + } + sendReactionsInBatch := portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) + for _, msg := range messages { + intent := portal.bridge.GetPuppetByID(msg.SenderId).IntentFor(portal) + if intent == nil { + log.Warn().Int64("sender_id", msg.SenderId).Msg("Failed to get intent for sender") + continue + } + ctx := context.WithValue(ctx, msgconvContextKeyIntent, intent) + converted := portal.MsgConv.ToMatrix(ctx, msg) + if portal.bridge.Config.Bridge.CaptionInMessage { + converted.MergeCaption() + } + if len(converted.Parts) == 0 { + log.Warn().Str("message_id", msg.MessageId).Msg("Message was empty after conversion") + continue + } + var reactionsToSendSeparately []*table.LSUpsertReaction + if !sendReactionsInBatch { + reactionsToSendSeparately = msg.Reactions + } + for i, part := range converted.Parts { + content := &event.Content{ + Parsed: part.Content, + Raw: part.Extra, + } + evtType, err := portal.encrypt(ctx, intent, content, part.Type) + if err != nil { + log.Err(err).Str("message_id", msg.MessageId).Int("part_index", i).Msg("Failed to encrypt event") + continue + } + + events = append(events, &event.Event{ + Sender: intent.UserID, + Type: evtType, + Timestamp: msg.TimestampMs, + ID: portal.deterministicEventID(msg.MessageId, i), + RoomID: portal.MXID, + Content: *content, + }) + otid, _ := strconv.ParseInt(msg.OfflineThreadingId, 10, 64) + metas = append(metas, &BackfillPartMetadata{ + Intent: intent, + MessageID: msg.MessageId, + OTID: otid, + Sender: msg.SenderId, + PartIndex: i, + EditCount: msg.EditCount, + Reactions: reactionsToSendSeparately, + }) + reactionsToSendSeparately = nil + } + if sendReactionsInBatch { + reactionTargetEventID := portal.deterministicEventID(msg.MessageId, 0) + for _, react := range msg.Reactions { + reactSender := portal.bridge.GetPuppetByID(react.ActorId) + events = append(events, &event.Event{ + Sender: reactSender.IntentFor(portal).UserID, + Type: event.EventReaction, + Timestamp: react.TimestampMs, + RoomID: portal.MXID, + Content: event.Content{ + Parsed: &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: reactionTargetEventID, + Key: variationselector.Add(react.Reaction), + }, + }, + }, + }) + metas = append(metas, &BackfillPartMetadata{ + MessageID: msg.MessageId, + InBatchReact: react, + }) + } + } + } + if portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) { + log.Info().Int("event_count", len(events)).Msg("Sending events to Matrix using Beeper batch sending") + portal.sendBackfillBeeper(ctx, source, events, metas, markRead, forward) + } else { + log.Info().Int("event_count", len(events)).Msg("Sending events to Matrix one by one") + portal.sendBackfillLegacy(ctx, source, events, metas, markRead) + } + zerolog.Ctx(ctx).Info().Msg("Finished sending backfill batch") +} + +func (portal *Portal) sendBackfillLegacy(ctx context.Context, source *User, events []*event.Event, metas []*BackfillPartMetadata, markRead bool) { + var lastEventID id.EventID + for i, evt := range events { + resp, err := portal.sendMatrixEvent(ctx, metas[i].Intent, evt.Type, evt.Content.Parsed, evt.Content.Raw, evt.Timestamp) + if err != nil { + zerolog.Ctx(ctx).Err(err).Int("evt_index", i).Msg("Failed to send event") + } else { + portal.storeMessageInDB(ctx, resp.EventID, metas[i].MessageID, metas[i].OTID, metas[i].Sender, time.UnixMilli(evt.Timestamp), metas[i].PartIndex) + lastEventID = resp.EventID + } + for _, react := range metas[i].Reactions { + portal.handleMetaReaction(react) + } + } + if markRead && lastEventID != "" { + puppet := portal.bridge.GetPuppetByCustomMXID(source.MXID) + if puppet != nil { + err := portal.SendReadReceipt(ctx, puppet, lastEventID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to send read receipt after backfill") + } + } + } +} + +func (portal *Portal) sendBackfillBeeper(ctx context.Context, source *User, events []*event.Event, metas []*BackfillPartMetadata, markRead, forward bool) { + var markReadBy id.UserID + if markRead { + markReadBy = source.MXID + } + resp, err := portal.MainIntent().BeeperBatchSend(ctx, portal.MXID, &mautrix.ReqBeeperBatchSend{ + Forward: forward, + SendNotification: forward && !markRead, + MarkReadBy: markReadBy, + Events: events, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill batch") + return + } else if len(resp.EventIDs) != len(metas) { + zerolog.Ctx(ctx).Error(). + Int("event_count", len(events)). + Int("meta_count", len(metas)). + Msg("Got wrong number of event IDs for backfill batch") + return + } + dbMessages := make([]*database.Message, 0, len(events)) + dbReactions := make([]*database.Reaction, 0) + for i, evtID := range resp.EventIDs { + meta := metas[i] + if meta.InBatchReact != nil { + dbReactions = append(dbReactions, &database.Reaction{ + MessageID: meta.MessageID, + Sender: meta.InBatchReact.ActorId, + Emoji: meta.InBatchReact.Reaction, + MXID: evtID, + }) + } else { + dbMessages = append(dbMessages, &database.Message{ + ID: meta.MessageID, + PartIndex: meta.PartIndex, + Sender: meta.Sender, + OTID: meta.OTID, + MXID: evtID, + Timestamp: time.UnixMilli(events[i].Timestamp), + EditCount: meta.EditCount, + }) + } + } + err = portal.bridge.DB.Message.BulkInsert(ctx, portal.PortalKey, portal.MXID, dbMessages) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save backfill batch messages to database") + } + err = portal.bridge.DB.Reaction.BulkInsert(ctx, portal.PortalKey, portal.MXID, dbReactions) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save backfill batch reactions to database") + } +} diff --git a/config/bridge.go b/config/bridge.go index 593f9ec..b97e2d9 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -55,7 +55,21 @@ type BridgeConfig struct { Deadline time.Duration `yaml:"-"` } `yaml:"message_handling_timeout"` - CommandPrefix string `yaml:"command_prefix"` + CommandPrefix string `yaml:"command_prefix"` + + Backfill struct { + Enabled bool `yaml:"enabled"` + InboxFetchPages int `yaml:"inbox_fetch_pages"` + HistoryFetchPages int `yaml:"history_fetch_pages"` + CatchupFetchPages int `yaml:"catchup_fetch_pages"` + Queue struct { + PagesAtOnce int `yaml:"pages_at_once"` + MaxPages int `yaml:"max_pages"` + SleepBetweenTasks time.Duration `yaml:"sleep_between_tasks"` + DontFetchXMA bool `yaml:"dont_fetch_xma"` + } `yaml:"queue"` + } `yaml:"backfill"` + ManagementRoomText bridgeconfig.ManagementRoomTexts `yaml:"management_room_text"` Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"` diff --git a/config/upgrade.go b/config/upgrade.go index 71caf3f..058f8cd 100644 --- a/config/upgrade.go +++ b/config/upgrade.go @@ -96,6 +96,14 @@ func DoUpgrade(helper *up.Helper) { helper.Copy(up.Bool, "bridge", "double_puppet_allow_discovery") helper.Copy(up.Map, "bridge", "login_shared_secret_map") helper.Copy(up.Str, "bridge", "command_prefix") + helper.Copy(up.Bool, "bridge", "backfill", "enabled") + helper.Copy(up.Int, "bridge", "backfill", "inbox_fetch_pages") + helper.Copy(up.Int, "bridge", "backfill", "history_fetch_pages") + helper.Copy(up.Int, "bridge", "backfill", "catchup_fetch_pages") + helper.Copy(up.Int, "bridge", "backfill", "queue", "pages_at_once") + helper.Copy(up.Int, "bridge", "backfill", "queue", "max_pages") + helper.Copy(up.Str, "bridge", "backfill", "queue", "sleep_between_tasks") + helper.Copy(up.Bool, "bridge", "backfill", "queue", "dont_fetch_xma") helper.Copy(up.Str, "bridge", "management_room_text", "welcome") helper.Copy(up.Str, "bridge", "management_room_text", "welcome_connected") helper.Copy(up.Str, "bridge", "management_room_text", "welcome_unconnected") @@ -153,6 +161,7 @@ var SpacedBlocks = [][]string{ {"bridge"}, {"bridge", "personal_filtering_spaces"}, {"bridge", "command_prefix"}, + {"bridge", "backfill"}, {"bridge", "management_room_text"}, {"bridge", "encryption"}, {"bridge", "provisioning"}, diff --git a/database/backfilltask.go b/database/backfilltask.go new file mode 100644 index 0000000..f13579f --- /dev/null +++ b/database/backfilltask.go @@ -0,0 +1,134 @@ +// mautrix-meta - A Matrix-Facebook Messenger and Instagram DM puppeting bridge. +// Copyright (C) 2024 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package database + +import ( + "context" + "time" + + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/id" +) + +const ( + putBackfillTask = ` + INSERT INTO backfill_task ( + portal_id, portal_receiver, user_mxid, priority, page_count, finished, + dispatched_at, completed_at, cooldown_until + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (portal_id, portal_receiver, user_mxid) DO UPDATE + SET priority=excluded.priority, page_count=excluded.page_count, finished=excluded.finished, + dispatched_at=excluded.dispatched_at, completed_at=excluded.completed_at, cooldown_until=excluded.cooldown_until + ` + insertBackfillTaskIfNotExists = ` + INSERT INTO backfill_task ( + portal_id, portal_receiver, user_mxid, priority, page_count, finished, + dispatched_at, completed_at, cooldown_until + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (portal_id, portal_receiver, user_mxid) DO NOTHING + ` + getNextBackfillTask = ` + SELECT portal_id, portal_receiver, user_mxid, priority, page_count, finished, dispatched_at, completed_at, cooldown_until + FROM backfill_task + WHERE user_mxid=$1 AND finished=false AND cooldown_until<$2 AND (dispatched_at<$3 OR completed_at<>0) + ORDER BY priority DESC, completed_at, dispatched_at LIMIT 1 + ` +) + +type BackfillTaskQuery struct { + *dbutil.QueryHelper[*BackfillTask] +} + +type BackfillTask struct { + qh *dbutil.QueryHelper[*BackfillTask] + + Key PortalKey + UserMXID id.UserID + Priority int + PageCount int + Finished bool + DispatchedAt time.Time + CompletedAt time.Time + CooldownUntil time.Time +} + +func newBackfillTask(qh *dbutil.QueryHelper[*BackfillTask]) *BackfillTask { + return &BackfillTask{qh: qh} +} + +func (btq *BackfillTaskQuery) NewWithValues(portalKey PortalKey, userID id.UserID) *BackfillTask { + return &BackfillTask{ + qh: btq.QueryHelper, + + Key: portalKey, + UserMXID: userID, + DispatchedAt: time.Now(), + CompletedAt: time.Now(), + } +} + +func (btq *BackfillTaskQuery) GetNext(ctx context.Context, userID id.UserID) (*BackfillTask, error) { + return btq.QueryOne(ctx, getNextBackfillTask, userID, time.Now().UnixMilli(), time.Now().Add(-1*time.Hour).UnixMilli()) +} + +func (task *BackfillTask) Scan(row dbutil.Scannable) (*BackfillTask, error) { + var dispatchedAt, completedAt, cooldownUntil int64 + err := row.Scan(&task.Key.ThreadID, &task.Key.Receiver, &task.UserMXID, &task.Priority, &task.PageCount, &task.Finished, &dispatchedAt, &completedAt, &cooldownUntil) + if err != nil { + return nil, err + } + task.DispatchedAt = timeFromUnixMilli(dispatchedAt) + task.CompletedAt = timeFromUnixMilli(completedAt) + task.CooldownUntil = timeFromUnixMilli(cooldownUntil) + return task, nil +} + +func timeFromUnixMilli(unix int64) time.Time { + if unix == 0 { + return time.Time{} + } + return time.UnixMilli(unix) +} + +func unixMilliOrZero(time time.Time) int64 { + if time.IsZero() { + return 0 + } + return time.UnixMilli() +} + +func (task *BackfillTask) sqlVariables() []any { + return []any{ + task.Key.ThreadID, + task.Key.Receiver, + task.UserMXID, + task.Priority, + task.PageCount, + task.Finished, + unixMilliOrZero(task.DispatchedAt), + unixMilliOrZero(task.CompletedAt), + unixMilliOrZero(task.CooldownUntil), + } +} + +func (task *BackfillTask) Upsert(ctx context.Context) error { + return task.qh.Exec(ctx, putBackfillTask, task.sqlVariables()...) +} + +func (task *BackfillTask) InsertIfNotExists(ctx context.Context) error { + return task.qh.Exec(ctx, insertBackfillTaskIfNotExists, task.sqlVariables()...) +} diff --git a/database/database.go b/database/database.go index d1f64be..fbffafc 100644 --- a/database/database.go +++ b/database/database.go @@ -29,21 +29,23 @@ import ( type Database struct { *dbutil.Database - User *UserQuery - Portal *PortalQuery - Puppet *PuppetQuery - Message *MessageQuery - Reaction *ReactionQuery + User *UserQuery + Portal *PortalQuery + Puppet *PuppetQuery + Message *MessageQuery + Reaction *ReactionQuery + BackfillTask *BackfillTaskQuery } func New(db *dbutil.Database) *Database { db.UpgradeTable = upgrades.Table return &Database{ - Database: db, - User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)}, - Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)}, - Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)}, - Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)}, - Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)}, + Database: db, + User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)}, + Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)}, + Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)}, + Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)}, + Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)}, + BackfillTask: &BackfillTaskQuery{dbutil.MakeQueryHelper(db, newBackfillTask)}, } } diff --git a/database/message.go b/database/message.go index 35f2465..2d61658 100644 --- a/database/message.go +++ b/database/message.go @@ -20,6 +20,8 @@ import ( "context" "database/sql" "errors" + "fmt" + "strings" "time" "go.mau.fi/util/dbutil" @@ -57,7 +59,9 @@ const ( INSERT INTO message (id, part_index, thread_id, thread_receiver, msg_sender, otid, mxid, mx_room, timestamp, edit_count) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ` - deleteMessageQuery = ` + insertQueryValuePlaceholder = `($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)` + bulkInsertPlaceholderTemplate = `($%d, $%d, $1, $2, $%d, $%d, $%d, $3, $%d, $%d)` + deleteMessageQuery = ` DELETE FROM message WHERE id=$1 AND thread_receiver=$2 AND part_index=$3 ` @@ -66,6 +70,12 @@ const ( ` ) +func init() { + if strings.ReplaceAll(insertMessageQuery, insertQueryValuePlaceholder, "meow") == insertMessageQuery { + panic("Bulk insert query placeholder not found") + } +} + type MessageQuery struct { *dbutil.QueryHelper[*Message] } @@ -119,6 +129,60 @@ func (mq *MessageQuery) FindEditTargetPortal(ctx context.Context, id string, rec return } +type bulkInserter[T any] interface { + GetDB() *dbutil.Database + BulkInsertChunk(context.Context, PortalKey, id.RoomID, []T) error +} + +const BulkInsertChunkSize = 100 + +func doBulkInsert[T any](q bulkInserter[T], ctx context.Context, thread PortalKey, roomID id.RoomID, entries []T) error { + if len(entries) == 0 { + return nil + } + return q.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + for i := 0; i < len(entries); i += BulkInsertChunkSize { + messageChunk := entries[i:] + if len(messageChunk) > BulkInsertChunkSize { + messageChunk = messageChunk[:BulkInsertChunkSize] + } + err := q.BulkInsertChunk(ctx, thread, roomID, messageChunk) + if err != nil { + return err + } + } + return nil + }) +} + +func (mq *MessageQuery) BulkInsert(ctx context.Context, thread PortalKey, roomID id.RoomID, messages []*Message) error { + return doBulkInsert[*Message](mq, ctx, thread, roomID, messages) +} + +func (mq *MessageQuery) BulkInsertChunk(ctx context.Context, thread PortalKey, roomID id.RoomID, messages []*Message) error { + if len(messages) == 0 { + return nil + } + placeholders := make([]string, len(messages)) + values := make([]any, 3+len(messages)*7) + values[0] = thread.ThreadID + values[1] = thread.Receiver + values[2] = roomID + for i, msg := range messages { + baseIndex := 3 + i*7 + placeholders[i] = fmt.Sprintf(bulkInsertPlaceholderTemplate, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7) + values[baseIndex] = msg.ID + values[baseIndex+1] = msg.PartIndex + values[baseIndex+2] = msg.Sender + values[baseIndex+3] = msg.OTID + values[baseIndex+4] = msg.MXID + values[baseIndex+5] = msg.Timestamp.UnixMilli() + values[baseIndex+6] = msg.EditCount + } + query := strings.ReplaceAll(insertMessageQuery, insertQueryValuePlaceholder, strings.Join(placeholders, ",")) + return mq.Exec(ctx, query, values...) +} + func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) { var timestamp int64 err := row.Scan( diff --git a/database/portal.go b/database/portal.go index 663e5c4..3906753 100644 --- a/database/portal.go +++ b/database/portal.go @@ -30,7 +30,7 @@ const ( portalBaseSelect = ` SELECT thread_id, receiver, thread_type, mxid, name, avatar_id, avatar_url, name_set, avatar_set, - encrypted, relay_user_id + encrypted, relay_user_id, oldest_message_id, oldest_message_ts, more_to_backfill FROM portal ` getPortalByMXIDQuery = portalBaseSelect + `WHERE mxid=$1` @@ -47,14 +47,14 @@ const ( INSERT INTO portal ( thread_id, receiver, thread_type, mxid, name, avatar_id, avatar_url, name_set, avatar_set, - encrypted, relay_user_id - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + encrypted, relay_user_id, oldest_message_id, oldest_message_ts, more_to_backfill + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ` updatePortalQuery = ` UPDATE portal SET thread_type=$3, mxid=$4, name=$5, avatar_id=$6, avatar_url=$7, name_set=$8, avatar_set=$9, - encrypted=$10, relay_user_id=$11 + encrypted=$10, relay_user_id=$11, oldest_message_id=$12, oldest_message_ts=$13, more_to_backfill=$14 WHERE thread_id=$1 AND receiver=$2 ` deletePortalQuery = `DELETE FROM portal WHERE thread_id=$1 AND receiver=$2` @@ -82,6 +82,10 @@ type Portal struct { AvatarSet bool Encrypted bool RelayUserID id.UserID + + OldestMessageID string + OldestMessageTS int64 + MoreToBackfill bool } func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal { @@ -138,6 +142,9 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { &p.AvatarSet, &p.Encrypted, &p.RelayUserID, + &p.OldestMessageID, + &p.OldestMessageTS, + &p.MoreToBackfill, ) if err != nil { return nil, err @@ -159,6 +166,9 @@ func (p *Portal) sqlVariables() []any { p.AvatarSet, p.Encrypted, p.RelayUserID, + p.OldestMessageID, + p.OldestMessageTS, + p.MoreToBackfill, } } diff --git a/database/reaction.go b/database/reaction.go index 3b10bee..719406b 100644 --- a/database/reaction.go +++ b/database/reaction.go @@ -18,6 +18,8 @@ package database import ( "context" + "fmt" + "strings" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/id" @@ -35,7 +37,14 @@ const ( INSERT INTO reaction (message_id, thread_id, thread_receiver, reaction_sender, emoji, mxid, mx_room) VALUES ($1, $2, $3, $4, $5, $6, $7) ` - updateReactionQuery = ` + bulkInsertReactionQuery = ` + INSERT INTO reaction (message_id, thread_id, thread_receiver, reaction_sender, emoji, mxid, mx_room) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (message_id, thread_receiver, reaction_sender) DO UPDATE SET mxid=excluded.mxid, emoji=excluded.emoji + ` + bulkInsertReactionQueryValuePlaceholder = `($1, $2, $3, $4, $5, $6, $7)` + bulkInsertReactionPlaceholderTemplate = `($%d, $1, $2, $%d, $%d, $%d, $3)` + updateReactionQuery = ` UPDATE reaction SET mxid=$1, emoji=$2 WHERE message_id=$3 AND thread_receiver=$4 AND reaction_sender=$5 @@ -45,6 +54,12 @@ const ( ` ) +func init() { + if strings.ReplaceAll(bulkInsertReactionQuery, bulkInsertReactionQueryValuePlaceholder, "meow") == bulkInsertReactionQuery { + panic("Bulk insert query placeholder not found") + } +} + type ReactionQuery struct { *dbutil.QueryHelper[*Reaction] } @@ -75,6 +90,31 @@ func (rq *ReactionQuery) GetByID(ctx context.Context, msgID string, threadReceiv return rq.QueryOne(ctx, getReactionByIDQuery, msgID, threadReceiver, reactionSender) } +func (rq *ReactionQuery) BulkInsert(ctx context.Context, thread PortalKey, roomID id.RoomID, reactions []*Reaction) error { + return doBulkInsert[*Reaction](rq, ctx, thread, roomID, reactions) +} + +func (rq *ReactionQuery) BulkInsertChunk(ctx context.Context, thread PortalKey, roomID id.RoomID, reactions []*Reaction) error { + if len(reactions) == 0 { + return nil + } + placeholders := make([]string, len(reactions)) + values := make([]any, 3+len(reactions)*4) + values[0] = thread.ThreadID + values[1] = thread.Receiver + values[2] = roomID + for i, react := range reactions { + baseIndex := 3 + i*4 + placeholders[i] = fmt.Sprintf(bulkInsertReactionPlaceholderTemplate, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4) + values[baseIndex] = react.MessageID + values[baseIndex+1] = react.Sender + values[baseIndex+2] = react.Emoji + values[baseIndex+3] = react.MXID + } + query := strings.ReplaceAll(bulkInsertReactionQuery, bulkInsertReactionQueryValuePlaceholder, strings.Join(placeholders, ",")) + return rq.Exec(ctx, query, values...) +} + func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { return dbutil.ValueOrErr(r, row.Scan( &r.MessageID, &r.ThreadID, &r.ThreadReceiver, &r.Sender, &r.Emoji, &r.MXID, &r.RoomID, diff --git a/database/upgrades/00-latest.sql b/database/upgrades/00-latest.sql index 8e5c71f..4fa6057 100644 --- a/database/upgrades/00-latest.sql +++ b/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v2: Latest revision +-- v0 -> v3: Latest revision CREATE TABLE portal ( thread_id BIGINT NOT NULL, @@ -15,6 +15,10 @@ CREATE TABLE portal ( encrypted BOOLEAN NOT NULL DEFAULT false, relay_user_id TEXT NOT NULL, + oldest_message_id TEXT NOT NULL, + oldest_message_ts BIGINT NOT NULL, + more_to_backfill BOOLEAN NOT NULL, + PRIMARY KEY (thread_id, receiver), CONSTRAINT portal_mxid_unique UNIQUE(mxid) ); @@ -41,6 +45,8 @@ CREATE TABLE "user" ( meta_id BIGINT, cookies jsonb, + inbox_fetched BOOLEAN NOT NULL, + management_room TEXT, space_room TEXT, @@ -48,10 +54,11 @@ CREATE TABLE "user" ( ); CREATE TABLE user_portal ( - user_mxid TEXT, - portal_thread_id BIGINT, - portal_receiver BIGINT, - in_space BOOLEAN NOT NULL DEFAULT false, + user_mxid TEXT NOT NULL, + portal_thread_id BIGINT NOT NULL, + portal_receiver BIGINT NOT NULL, + + in_space BOOLEAN NOT NULL DEFAULT false, PRIMARY KEY (user_mxid, portal_thread_id, portal_receiver), CONSTRAINT user_portal_user_fkey FOREIGN KEY (user_mxid) @@ -60,6 +67,25 @@ CREATE TABLE user_portal ( REFERENCES portal(thread_id, receiver) ON UPDATE CASCADE ON DELETE CASCADE ); +CREATE TABLE backfill_task ( + portal_id BIGINT NOT NULL, + portal_receiver BIGINT NOT NULL, + user_mxid TEXT NOT NULL, + + priority INTEGER NOT NULL, + page_count INTEGER NOT NULL, + finished BOOLEAN NOT NULL, + dispatched_at BIGINT NOT NULL, + completed_at BIGINT NOT NULL, + cooldown_until BIGINT NOT NULL, + + PRIMARY KEY (portal_id, portal_receiver, user_mxid), + CONSTRAINT backfill_task_user_fkey FOREIGN KEY (user_mxid) + REFERENCES "user" (mxid) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT backfill_task_portal_fkey FOREIGN KEY (portal_id, portal_receiver) + REFERENCES portal (thread_id, receiver) ON UPDATE CASCADE ON DELETE CASCADE +); + CREATE TABLE message ( id TEXT NOT NULL, part_index INTEGER NOT NULL, diff --git a/database/upgrades/03-backfill-queue.sql b/database/upgrades/03-backfill-queue.sql new file mode 100644 index 0000000..f053c03 --- /dev/null +++ b/database/upgrades/03-backfill-queue.sql @@ -0,0 +1,39 @@ +-- v3: Add backfill queue +ALTER TABLE portal ADD COLUMN oldest_message_id TEXT NOT NULL DEFAULT ''; +ALTER TABLE portal ADD COLUMN oldest_message_ts BIGINT NOT NULL DEFAULT 0; +ALTER TABLE portal ADD COLUMN more_to_backfill BOOL NOT NULL DEFAULT true; +UPDATE portal SET (oldest_message_id, oldest_message_ts) = ( + SELECT id, timestamp + FROM message + WHERE thread_id = portal.thread_id + AND thread_receiver = portal.receiver + ORDER BY timestamp ASC + LIMIT 1 +); +-- only: postgres for next 3 lines +ALTER TABLE portal ALTER COLUMN oldest_message_id DROP DEFAULT; +ALTER TABLE portal ALTER COLUMN oldest_message_ts DROP DEFAULT; +ALTER TABLE portal ALTER COLUMN more_to_backfill DROP DEFAULT; + +CREATE TABLE backfill_task ( + portal_id BIGINT NOT NULL, + portal_receiver BIGINT NOT NULL, + user_mxid TEXT NOT NULL, + + priority INTEGER NOT NULL, + page_count INTEGER NOT NULL, + finished BOOLEAN NOT NULL, + dispatched_at BIGINT NOT NULL, + completed_at BIGINT NOT NULL, + cooldown_until BIGINT NOT NULL, + + PRIMARY KEY (portal_id, portal_receiver, user_mxid), + CONSTRAINT backfill_task_user_fkey FOREIGN KEY (user_mxid) + REFERENCES "user" (mxid) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT backfill_task_portal_fkey FOREIGN KEY (portal_id, portal_receiver) + REFERENCES portal (thread_id, receiver) ON UPDATE CASCADE ON DELETE CASCADE +); + +ALTER TABLE "user" ADD COLUMN inbox_fetched BOOLEAN NOT NULL DEFAULT false; +-- only: postgres +ALTER TABLE "user" ALTER COLUMN inbox_fetched DROP DEFAULT; diff --git a/database/user.go b/database/user.go index 2dc970b..08bb88c 100644 --- a/database/user.go +++ b/database/user.go @@ -28,11 +28,11 @@ import ( ) const ( - getUserByMXIDQuery = `SELECT mxid, meta_id, cookies, management_room, space_room FROM "user" WHERE mxid=$1` - getUserByMetaIDQuery = `SELECT mxid, meta_id, cookies, management_room, space_room FROM "user" WHERE meta_id=$1` - getAllLoggedInUsersQuery = `SELECT mxid, meta_id, cookies, management_room, space_room FROM "user" WHERE cookies IS NOT NULL` - insertUserQuery = `INSERT INTO "user" (mxid, meta_id, cookies, management_room, space_room) VALUES ($1, $2, $3, $4, $5)` - updateUserQuery = `UPDATE "user" SET meta_id=$2, cookies=$3, management_room=$4, space_room=$5 WHERE mxid=$1` + getUserByMXIDQuery = `SELECT mxid, meta_id, cookies, inbox_fetched, management_room, space_room FROM "user" WHERE mxid=$1` + getUserByMetaIDQuery = `SELECT mxid, meta_id, cookies, inbox_fetched, management_room, space_room FROM "user" WHERE meta_id=$1` + getAllLoggedInUsersQuery = `SELECT mxid, meta_id, cookies, inbox_fetched, management_room, space_room FROM "user" WHERE cookies IS NOT NULL` + insertUserQuery = `INSERT INTO "user" (mxid, meta_id, cookies, inbox_fetched, management_room, space_room) VALUES ($1, $2, $3, $4, $5, $6)` + updateUserQuery = `UPDATE "user" SET meta_id=$2, cookies=$3, inbox_fetched=$4, management_room=$5, space_room=$6 WHERE mxid=$1` ) type UserQuery struct { @@ -45,6 +45,7 @@ type User struct { MXID id.UserID MetaID int64 Cookies cookies.Cookies + InboxFetched bool ManagementRoom id.RoomID SpaceRoom id.RoomID @@ -73,7 +74,7 @@ func (uq *UserQuery) GetAllLoggedIn(ctx context.Context) ([]*User, error) { } func (u *User) sqlVariables() []any { - return []any{u.MXID, dbutil.NumPtr(u.MetaID), dbutil.JSON{Data: u.Cookies}, dbutil.StrPtr(u.ManagementRoom), dbutil.StrPtr(u.SpaceRoom)} + return []any{u.MXID, dbutil.NumPtr(u.MetaID), dbutil.JSON{Data: u.Cookies}, u.InboxFetched, dbutil.StrPtr(u.ManagementRoom), dbutil.StrPtr(u.SpaceRoom)} } func (u *User) Insert(ctx context.Context) error { @@ -94,6 +95,7 @@ func (u *User) Scan(row dbutil.Scannable) (*User, error) { &u.MXID, &metaID, &dbutil.JSON{Data: scannedCookies}, + &u.InboxFetched, &managementRoom, &spaceRoom, ) diff --git a/example-config.yaml b/example-config.yaml index e62ba7b..1200d49 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -147,6 +147,32 @@ bridge: # If set to "default", will be determined based on meta -> mode, "!ig" for instagram and "!fb" for facebook command_prefix: default + backfill: + # If disabled, old messages will never be bridged. + enabled: true + # By default, Meta sends info about approximately 20 recent threads. If this is set to something else than 0, + # the bridge will request more threads on first login, until it reaches the specified number of pages + # or the end of the inbox. + inbox_fetch_pages: 0 + # By default, Meta only sends one old message per thread. If this is set to a something else than 0, + # the bridge will delay handling the one automatically received message and request more messages to backfill. + # One page usually contains 20 messages. This can technically be set to -1 to fetch all messages, + # but that will block bridging messages until the entire backfill is completed. + history_fetch_pages: 0 + # Same as above, but for catchup backfills (i.e. when the bridge is restarted). + catchup_fetch_pages: 5 + # Backfill queue settings. Only relevant for Beeper, because standard Matrix servers + # don't support inserting messages into room history. + queue: + # How many pages of messages to request in one go (without sleeping between requests)? + pages_at_once: 5 + # Maximum number of pages to fetch. -1 to fetch all pages until the start of the chat. + max_pages: -1 + # How long to sleep after fetching a bunch of pages ("bunch" defined by pages_at_once). + sleep_between_tasks: 20s + # Disable fetching XMA media (reels, stories, etc) when backfilling. + dont_fetch_xma: true + # Messages sent upon joining a management room. # Markdown is supported. The defaults are listed below. management_room_text: diff --git a/messagix/table/messages.go b/messagix/table/messages.go index b248dd8..e985bd9 100644 --- a/messagix/table/messages.go +++ b/messagix/table/messages.go @@ -206,8 +206,8 @@ type LSInsertNewMessageRange struct { MaxTimestampMsTemplate int64 `index:"2" json:",omitempty"` MinMessageId string `index:"3" json:",omitempty"` MaxMessageId string `index:"4" json:",omitempty"` - MaxTimestampMs int64 `index:"5" json:",omitempty"` - MinTimestampMs int64 `index:"6" json:",omitempty"` + MinTimestampMs int64 `index:"5" json:",omitempty"` + MaxTimestampMs int64 `index:"6" json:",omitempty"` HasMoreBefore bool `index:"7" json:",omitempty"` HasMoreAfter bool `index:"8" json:",omitempty"` Unknown interface{} `index:"9" json:",omitempty"` @@ -215,6 +215,20 @@ type LSInsertNewMessageRange struct { Unrecognized map[int]any `json:",omitempty"` } +type LSUpdateExistingMessageRange struct { + ThreadKey int64 `index:"0" json:",omitempty"` + TimestampMS int64 `index:"1" json:",omitempty"` + + UnknownBool2 bool `index:"2" json:",omitempty"` + UnknownBool3 bool `index:"3" json:",omitempty"` + + // if bool 2 && !3 then clear "has more after" else clear "has more before" +} + +func (ls *LSUpdateExistingMessageRange) GetThreadKey() int64 { + return ls.ThreadKey +} + type LSDeleteExistingMessageRanges struct { ConsistentThreadFbid int64 `index:"0" json:",omitempty"` diff --git a/messagix/table/table.go b/messagix/table/table.go index a718059..2d64d72 100644 --- a/messagix/table/table.go +++ b/messagix/table/table.go @@ -29,6 +29,7 @@ type LSTable struct { LSSetMessageDisplayedContentTypes []*LSSetMessageDisplayedContentTypes `json:",omitempty"` LSUpdateReadReceipt []*LSUpdateReadReceipt `json:",omitempty"` LSInsertNewMessageRange []*LSInsertNewMessageRange `json:",omitempty"` + LSUpdateExistingMessageRange []*LSUpdateExistingMessageRange `json:",omitempty"` LSDeleteExistingMessageRanges []*LSDeleteExistingMessageRanges `json:",omitempty"` LSUpsertSequenceId []*LSUpsertSequenceId `json:",omitempty"` LSVerifyContactRowExists []*LSVerifyContactRowExists `json:",omitempty"` @@ -193,6 +194,7 @@ var SPTable = map[string]string{ "truncateTablesForSyncGroup": "LSTruncateTablesForSyncGroup", "insertXmaAttachment": "LSInsertXmaAttachment", "insertNewMessageRange": "LSInsertNewMessageRange", + "updateExistingMessageRange": "LSUpdateExistingMessageRange", "threadsRangesQuery": "LSThreadsRangesQuery", "updateThreadSnippetFromLastMessage": "LSUpdateThreadSnippetFromLastMessage", "upsertInboxThreadsRange": "LSUpsertInboxThreadsRange", diff --git a/messagix/table/wrappedmessage.go b/messagix/table/wrappedmessage.go index d8114ff..e2c0708 100644 --- a/messagix/table/wrappedmessage.go +++ b/messagix/table/wrappedmessage.go @@ -1,21 +1,93 @@ package table import ( + "slices" + badGlobalLog "github.com/rs/zerolog/log" ) -func (table *LSTable) WrapMessages() []*WrappedMessage { - messages := make([]*WrappedMessage, len(table.LSInsertMessage)+len(table.LSUpsertMessage)) +type UpsertMessages struct { + Range *LSInsertNewMessageRange + Messages []*WrappedMessage + MarkRead bool +} + +func (um *UpsertMessages) Join(other *UpsertMessages) *UpsertMessages { + if um == nil { + return other + } else if other == nil { + return um + } + um.Messages = append(um.Messages, other.Messages...) + um.Range.HasMoreBefore = other.Range.HasMoreBefore + um.Range.MinTimestampMsTemplate = other.Range.MinTimestampMsTemplate + um.Range.MinTimestampMs = other.Range.MinTimestampMs + um.Range.MinMessageId = other.Range.MinMessageId + return um +} + +func (um *UpsertMessages) GetThreadKey() int64 { + if um.Range != nil { + return um.Range.ThreadKey + } else if len(um.Messages) > 0 { + return um.Messages[0].ThreadKey + } + return 0 +} + +func (table *LSTable) WrapMessages() (upsert map[int64]*UpsertMessages, insert []*WrappedMessage) { messageMap := make(map[string]*WrappedMessage, len(table.LSInsertMessage)+len(table.LSUpsertMessage)) - for i, msg := range table.LSUpsertMessage { - messages[i] = &WrappedMessage{LSInsertMessage: msg.ToInsert(), IsUpsert: true} - messageMap[msg.MessageId] = messages[i] + + upsert = make(map[int64]*UpsertMessages, len(table.LSUpsertMessage)) + for _, rng := range table.LSInsertNewMessageRange { + upsert[rng.ThreadKey] = &UpsertMessages{Range: rng} + } + + // TODO are there other places that might have read receipts for upserts than these two? + for _, read := range table.LSMarkThreadRead { + upsertMsg, ok := upsert[read.ThreadKey] + if ok { + upsertMsg.MarkRead = upsertMsg.Range.MaxTimestampMs <= read.LastReadWatermarkTimestampMs + } + } + for _, thread := range table.LSDeleteThenInsertThread { + upsertMsg, ok := upsert[thread.ThreadKey] + if ok { + upsertMsg.MarkRead = upsertMsg.Range.MaxTimestampMs <= thread.LastReadWatermarkTimestampMs + } + } + + for _, msg := range table.LSUpsertMessage { + wrapped := &WrappedMessage{LSInsertMessage: msg.ToInsert(), IsUpsert: true} + chatUpsert, ok := upsert[msg.ThreadKey] + if !ok { + badGlobalLog.Warn(). + Int64("thread_id", msg.ThreadKey). + Msg("Got upsert message for thread without corresponding message range") + upsert[msg.ThreadKey] = &UpsertMessages{Messages: []*WrappedMessage{wrapped}} + } else { + chatUpsert.Messages = append(chatUpsert.Messages, wrapped) + } + messageMap[msg.MessageId] = wrapped + } + if len(table.LSUpsertMessage) > 0 { + // For upserted messages, add reactions to the upsert data, and delete them + // from the main list to avoid handling them as new reactions. + table.LSUpsertReaction = slices.DeleteFunc(table.LSUpsertReaction, func(reaction *LSUpsertReaction) bool { + wrapped, ok := messageMap[reaction.MessageId] + if ok && wrapped.IsUpsert { + wrapped.Reactions = append(wrapped.Reactions, reaction) + return true + } + return false + }) } - iOffset := len(table.LSUpsertMessage) + insert = make([]*WrappedMessage, len(table.LSInsertMessage)) for i, msg := range table.LSInsertMessage { - messages[iOffset+i] = &WrappedMessage{LSInsertMessage: msg} - messageMap[msg.MessageId] = messages[iOffset+i] + insert[i] = &WrappedMessage{LSInsertMessage: msg} + messageMap[msg.MessageId] = insert[i] } + for _, blob := range table.LSInsertBlobAttachment { msg, ok := messageMap[blob.MessageId] if ok { @@ -54,7 +126,7 @@ func (table *LSTable) WrapMessages() []*WrappedMessage { Msg("Got sticker attachment in table without corresponding message") } } - return messages + return } type WrappedMessage struct { @@ -63,6 +135,7 @@ type WrappedMessage struct { BlobAttachments []*LSInsertBlobAttachment XMAAttachments []*WrappedXMA Stickers []*LSInsertStickerAttachment + Reactions []*LSUpsertReaction } type WrappedXMA struct { diff --git a/msgconv/from-meta.go b/msgconv/from-meta.go index 35c6461..9338f52 100644 --- a/msgconv/from-meta.go +++ b/msgconv/from-meta.go @@ -117,7 +117,7 @@ func (mc *MessageConverter) ToMatrix(ctx context.Context, msg *table.WrappedMess }, }) } - replyTo, sender := mc.GetMatrixReply(ctx, msg.ReplySourceId) + replyTo, sender := mc.GetMatrixReply(ctx, msg.ReplySourceId, msg.ReplyToUserId) for _, part := range cm.Parts { if part.Content.Mentions == nil { part.Content.Mentions = &event.Mentions{} @@ -240,7 +240,7 @@ var reelActionURLRegex = regexp.MustCompile(`^/stories/direct/(\d+)_(\d+)$`) func (mc *MessageConverter) fetchFullXMA(ctx context.Context, att *table.WrappedXMA, minimalConverted *ConvertedMessagePart) *ConvertedMessagePart { ig := mc.GetClient(ctx).Instagram if att.CTA == nil || ig == nil { - return nil + return minimalConverted } log := zerolog.Ctx(ctx) switch { @@ -248,6 +248,9 @@ func (mc *MessageConverter) fetchFullXMA(ctx context.Context, att *table.Wrapped log.Trace().Any("cta_data", att.CTA).Msg("Fetching XMA media from CTA data") externalURL := fmt.Sprintf("https://www.instagram.com/p/%s/", strings.TrimPrefix(att.CTA.NativeUrl, "instagram://media/?shortcode=")) minimalConverted.Extra["external_url"] = externalURL + if !mc.ShouldFetchXMA(ctx) { + return minimalConverted + } resp, err := ig.FetchMedia(strconv.FormatInt(att.CTA.TargetId, 10), att.CTA.NativeUrl) if err != nil { @@ -269,6 +272,10 @@ func (mc *MessageConverter) fetchFullXMA(ctx context.Context, att *table.Wrapped log.Trace().Any("cta_data", att.CTA).Msg("Fetching XMA story from CTA data") externalURL := fmt.Sprintf("https://www.instagram.com%s", att.CTA.ActionUrl) minimalConverted.Extra["external_url"] = externalURL + if !mc.ShouldFetchXMA(ctx) { + return minimalConverted + } + if match := reelActionURLRegex.FindStringSubmatch(att.CTA.ActionUrl); len(match) != 3 { log.Warn().Str("action_url", att.CTA.ActionUrl).Msg("Failed to parse story action URL") } else if resp, err := ig.FetchReel([]string{match[2]}, match[1]); err != nil { diff --git a/msgconv/media.go b/msgconv/media.go index 558582b..6b4db29 100644 --- a/msgconv/media.go +++ b/msgconv/media.go @@ -20,9 +20,11 @@ import ( "context" "fmt" "io" + "net" "net/http" "net/url" "path" + "time" "github.com/rs/zerolog" "maunium.net/go/mautrix" @@ -31,7 +33,15 @@ import ( "go.mau.fi/mautrix-meta/messagix" ) -var avatarHTTPClient http.Client +var mediaHTTPClient = http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{Timeout: 5 * time.Second}).DialContext, + TLSHandshakeTimeout: 5 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: 60 * time.Second, +} var MediaReferer string func DownloadMedia(ctx context.Context, url string) ([]byte, error) { @@ -48,7 +58,7 @@ func DownloadMedia(ctx context.Context, url string) ([]byte, error) { req.Header.Set("User-Agent", messagix.UserAgent) req.Header.Add("sec-ch-ua", messagix.SecCHUserAgent) req.Header.Add("sec-ch-ua-platform", messagix.SecCHPlatform) - resp, err := avatarHTTPClient.Do(req) + resp, err := mediaHTTPClient.Do(req) defer func() { if resp != nil && resp.Body != nil { _ = resp.Body.Close() diff --git a/msgconv/msgconv.go b/msgconv/msgconv.go index 3a7eee3..6f1a54a 100644 --- a/msgconv/msgconv.go +++ b/msgconv/msgconv.go @@ -30,12 +30,12 @@ import ( type PortalMethods interface { UploadMatrixMedia(ctx context.Context, data []byte, fileName, contentType string) (id.ContentURIString, error) DownloadMatrixMedia(ctx context.Context, uri id.ContentURIString) ([]byte, error) - GetMatrixReply(ctx context.Context, messageID string) (replyTo id.EventID, replyTargetSender id.UserID) + GetMatrixReply(ctx context.Context, messageID string, replyToUser int64) (replyTo id.EventID, replyTargetSender id.UserID) GetMetaReply(ctx context.Context, content *event.MessageEventContent) *socket.ReplyMetaData GetUserMXID(ctx context.Context, userID int64) id.UserID + ShouldFetchXMA(ctx context.Context) bool GetClient(ctx context.Context) *messagix.Client - GetData(ctx context.Context) *database.Portal } diff --git a/portal.go b/portal.go index 75ca56d..9799eec 100644 --- a/portal.go +++ b/portal.go @@ -199,6 +199,9 @@ type Portal struct { pendingMessages map[int64]id.EventID pendingMessagesLock sync.Mutex + backfillLock sync.Mutex + backfillCollector *BackfillCollector + fetchAttempted atomic.Bool relayUser *User @@ -770,8 +773,21 @@ type msgconvContextKey int const ( msgconvContextKeyIntent msgconvContextKey = iota msgconvContextKeyClient + msgconvContextKeyBackfill ) +type backfillType int + +const ( + backfillTypeForward backfillType = iota + 1 + backfillTypeHistorical +) + +func (portal *Portal) ShouldFetchXMA(ctx context.Context) bool { + xmaDisabled := ctx.Value(msgconvContextKeyBackfill) == backfillTypeHistorical && portal.bridge.Config.Bridge.Backfill.Queue.DontFetchXMA + return !xmaDisabled +} + func (portal *Portal) UploadMatrixMedia(ctx context.Context, data []byte, fileName, contentType string) (id.ContentURIString, error) { intent := ctx.Value(msgconvContextKeyIntent).(*appservice.IntentAPI) req := mautrix.ReqUploadMedia{ @@ -810,7 +826,7 @@ func (portal *Portal) GetClient(ctx context.Context) *messagix.Client { return ctx.Value(msgconvContextKeyClient).(*messagix.Client) } -func (portal *Portal) GetMatrixReply(ctx context.Context, replyToID string) (replyTo id.EventID, replyTargetSender id.UserID) { +func (portal *Portal) GetMatrixReply(ctx context.Context, replyToID string, replyToUser int64) (replyTo id.EventID, replyTargetSender id.UserID) { if replyToID == "" { return } @@ -820,15 +836,27 @@ func (portal *Portal) GetMatrixReply(ctx context.Context, replyToID string) (rep if message, err := portal.bridge.DB.Message.GetByID(ctx, replyToID, 0, portal.Receiver); err != nil { log.Err(err).Msg("Failed to get reply target message from database") } else if message == nil { - log.Warn().Msg("Reply target message not found") + if ctx.Value(msgconvContextKeyBackfill) != nil && portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { + replyTo = portal.deterministicEventID(replyToID, 0) + } else { + log.Warn().Msg("Reply target message not found") + return + } } else { replyTo = message.MXID - targetUser := portal.bridge.GetUserByMetaID(message.Sender) - if targetUser != nil { - replyTargetSender = targetUser.MXID - } else { - replyTargetSender = portal.bridge.FormatPuppetMXID(message.Sender) + if message.Sender != replyToUser { + log.Warn(). + Int64("message_sender", message.Sender). + Int64("reply_to_user", replyToUser). + Msg("Mismatching reply to user and found message sender") } + replyToUser = message.Sender + } + targetUser := portal.bridge.GetUserByMetaID(replyToUser) + if targetUser != nil { + replyTargetSender = targetUser.MXID + } else { + replyTargetSender = portal.bridge.FormatPuppetMXID(replyToUser) } return } @@ -869,6 +897,10 @@ func (portal *Portal) handleMetaMessage(portalMessage portalMetaMessage) { switch typedEvt := portalMessage.evt.(type) { case *table.WrappedMessage: portal.handleMetaInsertMessage(portalMessage.user, typedEvt) + case *table.UpsertMessages: + portal.handleMetaUpsertMessages(portalMessage.user, typedEvt) + case *table.LSUpdateExistingMessageRange: + portal.handleMetaExistingRange(portalMessage.user, typedEvt) case *table.LSEditMessage: portal.handleMetaEditMessage(typedEvt) case *table.LSDeleteMessage: @@ -1154,20 +1186,20 @@ type customReadMarkers struct { FullyReadExtra customReadReceipt `json:"com.beeper.fully_read.extra"` } -func (portal *Portal) SendReadReceipt(ctx context.Context, sender *Puppet, msg *database.Message) error { +func (portal *Portal) SendReadReceipt(ctx context.Context, sender *Puppet, eventID id.EventID) error { intent := sender.IntentFor(portal) if intent.IsCustomPuppet { extra := customReadReceipt{DoublePuppetSource: portal.bridge.Name} return intent.SetReadMarkers(ctx, portal.MXID, &customReadMarkers{ ReqSetReadMarkers: mautrix.ReqSetReadMarkers{ - Read: msg.MXID, - FullyRead: msg.MXID, + Read: eventID, + FullyRead: eventID, }, ReadExtra: extra, FullyReadExtra: extra, }) } else { - return intent.MarkRead(ctx, portal.MXID, msg.MXID) + return intent.MarkRead(ctx, portal.MXID, eventID) } } @@ -1189,7 +1221,7 @@ func (portal *Portal) handleMetaReadReceipt(read *table.LSUpdateReadReceipt) { log.Err(err).Msg("Failed to get message to mark as read") } else if message == nil { log.Warn().Msg("No message found to mark as read") - } else if err = portal.SendReadReceipt(ctx, sender, message); err != nil { + } else if err = portal.SendReadReceipt(ctx, sender, message.MXID); err != nil { log.Err(err).Stringer("event_id", message.MXID).Msg("Failed to send read receipt") } else { log.Debug().Stringer("event_id", message.MXID).Msg("Sent read receipt to Matrix") diff --git a/user.go b/user.go index 6cbac65..5cc4d64 100644 --- a/user.go +++ b/user.go @@ -23,8 +23,10 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "github.com/rs/zerolog" + "golang.org/x/exp/maps" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge" @@ -171,6 +173,10 @@ type User struct { spaceMembershipChecked bool spaceCreateLock sync.Mutex + + stopBackfillTask atomic.Pointer[context.CancelFunc] + + InboxPagesFetched int } var ( @@ -391,6 +397,7 @@ func (user *User) Connect() { }) } } + func (user *User) Login(ctx context.Context, cookies cookies.Cookies) error { user.Lock() defer user.Unlock() @@ -460,6 +467,7 @@ func (user *User) handleTable(tbl *table.LSTable) { user.bridge.GetPuppetByID(contact.ContactId).UpdateInfo(ctx, contact) } for _, thread := range tbl.LSDeleteThenInsertThread { + // TODO handle last read watermark in here? portal := user.GetPortalByThreadID(thread.ThreadKey, thread.ThreadType) portal.UpdateInfo(ctx, thread) if portal.MXID == "" { @@ -527,7 +535,10 @@ func (user *User) handleTable(tbl *table.LSTable) { log.Warn().Int64("thread_id", thread.ThreadKey).Msg("Portal doesn't exist in verifyThreadExists, but fetch was already attempted") } } - handlePortalEvents(user, tbl.WrapMessages()) + upsert, insert := tbl.WrapMessages() + handlePortalEvents(user, maps.Values(upsert)) + handlePortalEvents(user, tbl.LSUpdateExistingMessageRange) + handlePortalEvents(user, insert) for _, msg := range tbl.LSEditMessage { user.handleEditEvent(ctx, msg) } @@ -540,6 +551,46 @@ func (user *User) handleTable(tbl *table.LSTable) { handlePortalEvents(user, tbl.LSDeleteThenInsertMessage) handlePortalEvents(user, tbl.LSUpsertReaction) handlePortalEvents(user, tbl.LSDeleteReaction) + user.requestMoreInbox(ctx, tbl.LSUpsertInboxThreadsRange) +} + +func (user *User) requestMoreInbox(ctx context.Context, itrs []*table.LSUpsertInboxThreadsRange) { + maxInboxPages := user.bridge.Config.Bridge.Backfill.InboxFetchPages + if len(itrs) == 0 || user.InboxFetched || maxInboxPages == 0 { + return + } + log := zerolog.Ctx(ctx) + itr := itrs[0] + user.InboxPagesFetched++ + reachedPageLimit := maxInboxPages > 0 && user.InboxPagesFetched > maxInboxPages + logEvt := log.Debug(). + Int("fetched_pages", user.InboxPagesFetched). + Bool("has_more_before", itr.HasMoreBefore). + Bool("reached_page_limit", reachedPageLimit). + Int64("min_thread_key", itr.MinThreadKey). + Int64("min_last_activity_timestamp_ms", itr.MinLastActivityTimestampMs) + if !itr.HasMoreBefore || reachedPageLimit { + logEvt.Msg("Finished fetching threads") + user.InboxFetched = true + err := user.Update(ctx) + if err != nil { + log.Err(err).Msg("Failed to save user after marking inbox as fetched") + } + } else { + logEvt.Msg("Requesting more threads") + resp, err := user.Client.ExecuteTasks(&socket.FetchThreadsTask{ + ReferenceThreadKey: itr.MinThreadKey, + ReferenceActivityTimestamp: itr.MinLastActivityTimestampMs, + Cursor: user.Client.SyncManager.GetCursor(1), + SyncGroup: 1, + }) + log.Trace().Any("resp", resp).Msg("Fetch threads response data") + if err != nil { + log.Err(err).Msg("Failed to fetch more threads") + } else { + log.Debug().Msg("Sent more threads request") + } + } } type ThreadKeyable interface { @@ -636,6 +687,7 @@ func (user *User) eventHandler(rawEvt any) { user.BridgeState.Send(status.BridgeState{StateEvent: status.StateConnected}) user.tryAutomaticDoublePuppeting() user.handleTable(evt.Table) + go user.BackfillLoop() case *messagix.Event_SocketError: user.BridgeState.Send(status.BridgeState{StateEvent: status.StateTransientDisconnect, Message: evt.Err.Error()}) case *messagix.Event_Reconnected: @@ -646,6 +698,7 @@ func (user *User) eventHandler(rawEvt any) { stateEvt = status.StateBadCredentials } user.BridgeState.Send(status.BridgeState{StateEvent: stateEvt, Message: evt.Err.Error()}) + user.StopBackfillLoop() default: user.log.Warn().Type("event_type", evt).Msg("Unrecognized event type from messagix") } @@ -662,22 +715,25 @@ func (user *User) GetPortalByThreadID(threadID int64, threadType table.ThreadTyp }, threadType) } -func (user *User) Disconnect() error { - user.Lock() - defer user.Unlock() +func (user *User) unlockedDisconnect() { if user.Client != nil { user.Client.Disconnect() } + user.StopBackfillLoop() + user.Client = nil +} + +func (user *User) Disconnect() error { + user.Lock() + defer user.Unlock() + user.unlockedDisconnect() return nil } func (user *User) DeleteSession() { user.Lock() defer user.Unlock() - if user.Client != nil { - user.Client.Disconnect() - } - user.Client = nil + user.unlockedDisconnect() user.Cookies = nil user.MetaID = 0 doublePuppet := user.bridge.GetPuppetByCustomMXID(user.MXID)