diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index b5ca9af8..2e16f1a6 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -17,12 +17,14 @@ import ( type MockVerificationCallbacks interface { GetRequestedVerifications() map[id.UserID][]id.VerificationTransactionID GetScanQRCodeTransactions() []id.VerificationTransactionID + GetVerificationsReadyTransactions() []id.VerificationTransactionID GetQRCodeShown(id.VerificationTransactionID) *verificationhelper.QRCode } type baseVerificationCallbacks struct { scanQRCodeTransactions []id.VerificationTransactionID verificationsRequested map[id.UserID][]id.VerificationTransactionID + verificationsReady []id.VerificationTransactionID qrCodesShown map[id.VerificationTransactionID]*verificationhelper.QRCode qrCodesScanned map[id.VerificationTransactionID]struct{} doneTransactions map[id.VerificationTransactionID]struct{} @@ -33,6 +35,7 @@ type baseVerificationCallbacks struct { } var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil) +var _ MockVerificationCallbacks = (*baseVerificationCallbacks)(nil) func newBaseVerificationCallbacks() *baseVerificationCallbacks { return &baseVerificationCallbacks{ @@ -55,6 +58,10 @@ func (c *baseVerificationCallbacks) GetScanQRCodeTransactions() []id.Verificatio return c.scanQRCodeTransactions } +func (c *baseVerificationCallbacks) GetVerificationsReadyTransactions() []id.VerificationTransactionID { + return c.verificationsReady +} + func (c *baseVerificationCallbacks) GetQRCodeShown(txnID id.VerificationTransactionID) *verificationhelper.QRCode { return c.qrCodesShown[txnID] } @@ -85,6 +92,10 @@ func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, t c.verificationsRequested[from] = append(c.verificationsRequested[from], txnID) } +func (c *baseVerificationCallbacks) VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID) { + c.verificationsReady = append(c.verificationsReady, txnID) +} + func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) { c.verificationCancellation[txnID] = &event.VerificationCancelEventContent{ Code: code, diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 92d4de23..fcabe312 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -33,6 +33,10 @@ type RequiredCallbacks interface { // from another device. VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) + // VerificationReady is called when a verification request has been + // accepted by both parties. + VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID) + // VerificationCancelled is called when the verification is cancelled. VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) @@ -76,6 +80,7 @@ type VerificationHelper struct { // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) + verificationReady func(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID) verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) @@ -107,6 +112,7 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor panic("callbacks must implement RequiredCallbacks") } else { helper.verificationRequested = c.VerificationRequested + helper.verificationReady = c.VerificationReady helper.verificationCancelledCallback = c.VerificationCancelled helper.verificationDone = c.VerificationDone } @@ -420,6 +426,7 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V return err } txn.VerificationState = VerificationStateReady + vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID) if vh.scanQRCode != nil && slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true @@ -709,6 +716,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif Stringer("their_device_id", txn.TheirDeviceID). Any("their_supported_methods", txn.TheirSupportedMethods). Msg("Received verification ready event") + vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID) // If we sent this verification request, send cancellations to all of the // other devices. diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index 31bc7d6e..d50999be 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -322,6 +322,10 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) + // Ensure that the receiving device get a notification about the + // transaction being ready. + assert.Contains(t, tc.receivingCallbacks.GetVerificationsReadyTransactions(), txnID) + _, sendingIsQRCallbacks := tc.sendingCallbacks.(*showQRCodeVerificationCallbacks) _, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks) sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks @@ -357,6 +361,10 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // device. ts.dispatchToDevice(t, ctx, sendingClient) + // Ensure that the sending device got a notification about the + // transaction being ready. + assert.Contains(t, tc.sendingCallbacks.GetVerificationsReadyTransactions(), txnID) + // Ensure that if the sending device should show a QR code that it // has the correct content. if tc.receivingSupportsScan && sendingCanShowQR {