From fddd91341e6f16072f6103ae08269beae9167c95 Mon Sep 17 00:00:00 2001 From: Vitor Hugo Schwaab Date: Thu, 22 Jun 2023 13:58:07 +0200 Subject: [PATCH] feat(mls): handle wrong epoch error [WPB-1803] (#1817) * feat(mls): [WIP] handle MLS wrong epoch when receiving messages * test: cover MLSWrongEpochHandler with tests * Update logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt Co-authored-by: Jacob Persson <7156+typfel@users.noreply.github.com> * refactor: select protocol info directly instead of the whole conversation --------- Co-authored-by: Jacob Persson <7156+typfel@users.noreply.github.com> --- .../com/wire/kalium/logic/CoreFailure.kt | 2 + .../wire/kalium/logic/data/message/Message.kt | 4 + .../logic/data/message/MessageContent.kt | 3 + .../logic/data/message/MessageMapper.kt | 3 + .../data/message/PersistMessageUseCase.kt | 1 + .../kalium/logic/feature/UserSessionScope.kt | 12 +- .../message/MLSWrongEpochHandler.kt | 98 ++++++++ .../message/NewMessageEventHandler.kt | 10 +- .../message/MLSWrongEpochHandlerTest.kt | 235 ++++++++++++++++++ .../message/NewMessageEventHandlerTest.kt | 41 ++- .../wire/kalium/logic/util/MockativeExt.kt | 53 ++++ .../persistence/dao/message/MessageEntity.kt | 4 +- .../dao/message/MessageInsertExtension.kt | 5 + .../persistence/dao/message/MessageMapper.kt | 2 + 14 files changed, 468 insertions(+), 5 deletions(-) create mode 100644 logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandler.kt create mode 100644 logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt create mode 100644 logic/src/commonTest/kotlin/com/wire/kalium/logic/util/MockativeExt.kt diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt index 6f6d76cbf52..577471db244 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/CoreFailure.kt @@ -106,6 +106,8 @@ interface MLSFailure : CoreFailure { object WrongEpoch : MLSFailure + object ConversationDoesNotSupportMLS : MLSFailure + class Generic(internal val exception: Exception) : MLSFailure { val rootCause: Throwable get() = exception } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/Message.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/Message.kt index 003d91cf6ea..e4fd54b11b0 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/Message.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/Message.kt @@ -281,6 +281,10 @@ sealed interface Message { is MessageContent.ConversationCreated -> mutableMapOf( typeKey to "conversationCreated" ) + + is MessageContent.MLSWrongEpochWarning -> mutableMapOf( + typeKey to "mlsWrongEpochWarning" + ) } val standardProperties = mapOf( diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageContent.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageContent.kt index 3a5b9aad8f3..0a89bf341cd 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageContent.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageContent.kt @@ -232,6 +232,8 @@ sealed class MessageContent { val clientId: ClientId? = null ) : Regular() + object MLSWrongEpochWarning : System() + object ClientAction : Signaling() object CryptoSessionReset : System() @@ -276,6 +278,7 @@ fun MessageContent?.getType() = when (this) { is MessageContent.ConversationCreated -> "ConversationCreated" is MessageContent.MemberChange.CreationAdded -> "MemberChange.CreationAdded" is MessageContent.MemberChange.FailedToAdd -> "MemberChange.FailedToAdd" + is MessageContent.MLSWrongEpochWarning -> "MLSWrongEpochWarning" null -> "Unknown" } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageMapper.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageMapper.kt index fb095a9ecce..e8f3e3884a8 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageMapper.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/MessageMapper.kt @@ -222,6 +222,7 @@ class MessageMapperImpl( MessageEntity.ContentType.HISTORY_LOST -> null MessageEntity.ContentType.CONVERSATION_MESSAGE_TIMER_CHANGED -> null MessageEntity.ContentType.CONVERSATION_CREATED -> null + MessageEntity.ContentType.MLS_WRONG_EPOCH_WARNING -> null } } @@ -319,6 +320,7 @@ class MessageMapperImpl( is MessageContent.HistoryLost -> MessageEntityContent.HistoryLost is MessageContent.ConversationMessageTimerChanged -> MessageEntityContent.ConversationMessageTimerChanged(messageTimer) is MessageContent.ConversationCreated -> MessageEntityContent.ConversationCreated + is MessageContent.MLSWrongEpochWarning -> MessageEntityContent.MLSWrongEpochWarning } private fun MessageEntityContent.Regular.toMessageContent(hidden: Boolean): MessageContent.Regular = when (this) { @@ -405,6 +407,7 @@ class MessageMapperImpl( is MessageEntityContent.HistoryLost -> MessageContent.HistoryLost is MessageEntityContent.ConversationMessageTimerChanged -> MessageContent.ConversationMessageTimerChanged(messageTimer) is MessageEntityContent.ConversationCreated -> MessageContent.ConversationCreated + is MessageEntityContent.MLSWrongEpochWarning -> MessageContent.MLSWrongEpochWarning } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/PersistMessageUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/PersistMessageUseCase.kt index c9b57c185dd..5c2ceecc6b5 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/PersistMessageUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/data/message/PersistMessageUseCase.kt @@ -96,5 +96,6 @@ internal class PersistMessageUseCaseImpl( is MessageContent.MemberChange.CreationAdded -> false is MessageContent.MemberChange.FailedToAdd -> false is MessageContent.ConversationCreated -> false + is MessageContent.MLSWrongEpochWarning -> false } } diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index 1a69d34e6ae..afd43b144b6 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -275,6 +275,8 @@ import com.wire.kalium.logic.sync.receiver.conversation.ReceiptModeUpdateEventHa import com.wire.kalium.logic.sync.receiver.conversation.ReceiptModeUpdateEventHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.RenamedConversationEventHandler import com.wire.kalium.logic.sync.receiver.conversation.RenamedConversationEventHandlerImpl +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandler +import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessageHandler import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessageHandlerImpl import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker @@ -953,9 +955,17 @@ class UserSessionScope internal constructor( userId ) + private val mlsWrongEpochHandler: MLSWrongEpochHandler + get() = MLSWrongEpochHandlerImpl( + selfUserId = userId, + persistMessage = persistMessage, + conversationRepository = conversationRepository, + joinExistingMLSConversation = joinExistingMLSConversationUseCase + ) + private val newMessageHandler: NewMessageEventHandlerImpl get() = NewMessageEventHandlerImpl( - proteusUnpacker, mlsUnpacker, applicationMessageHandler + proteusUnpacker, mlsUnpacker, applicationMessageHandler, mlsWrongEpochHandler ) private val newConversationHandler: NewConversationEventHandler diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandler.kt new file mode 100644 index 00000000000..cc3cb04cfe9 --- /dev/null +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandler.kt @@ -0,0 +1,98 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.sync.receiver.conversation.message + +import com.benasher44.uuid.uuid4 +import com.wire.kalium.logger.KaliumLogger +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.MLSFailure +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.id.ConversationId +import com.wire.kalium.logic.data.message.Message +import com.wire.kalium.logic.data.message.MessageContent +import com.wire.kalium.logic.data.message.PersistMessageUseCase +import com.wire.kalium.logic.data.user.UserId +import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.functional.flatMap +import com.wire.kalium.logic.functional.map +import com.wire.kalium.logic.kaliumLogger + +interface MLSWrongEpochHandler { + suspend fun onMLSWrongEpoch( + conversationId: ConversationId, + dateIso: String, + ) +} + +internal class MLSWrongEpochHandlerImpl( + private val selfUserId: UserId, + private val persistMessage: PersistMessageUseCase, + private val conversationRepository: ConversationRepository, + private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase +) : MLSWrongEpochHandler { + + private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) } + + override suspend fun onMLSWrongEpoch( + conversationId: ConversationId, + dateIso: String, + ) { + logger.i("Handling MLS WrongEpoch result") + conversationRepository.getConversationProtocolInfo(conversationId).flatMap { protocol -> + if (protocol is Conversation.ProtocolInfo.MLS) { + Either.Right(protocol) + } else { + Either.Left(MLSFailure.ConversationDoesNotSupportMLS) + } + }.flatMap { currentProtocol -> + getUpdatedConversationEpoch(conversationId).map { updatedEpoch -> + updatedEpoch != null && updatedEpoch != currentProtocol.epoch + } + }.flatMap { isRejoinNeeded -> + if (isRejoinNeeded) { + joinExistingMLSConversation(conversationId) + } else Either.Right(Unit) + }.flatMap { + insertInfoMessage(conversationId, dateIso) + } + } + + private suspend fun getUpdatedConversationEpoch(conversationId: ConversationId): Either { + return conversationRepository.fetchConversation(conversationId).flatMap { + conversationRepository.getConversationProtocolInfo(conversationId) + }.map { updatedProtocol -> + (updatedProtocol as? Conversation.ProtocolInfo.MLS)?.epoch + } + } + + private suspend fun insertInfoMessage(conversationId: ConversationId, dateIso: String): Either { + val mlsEpochWarningMessage = Message.System( + id = uuid4().toString(), + content = MessageContent.MLSWrongEpochWarning, + conversationId = conversationId, + date = dateIso, + senderUserId = selfUserId, + status = Message.Status.READ, + visibility = Message.Visibility.VISIBLE, + senderUserName = null + ) + return persistMessage(mlsEpochWarningMessage) + } +} diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt index 301a07bb7cd..7220154ce2a 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandler.kt @@ -20,6 +20,7 @@ package com.wire.kalium.logic.sync.receiver.conversation.message import com.wire.kalium.cryptography.exceptions.ProteusException import com.wire.kalium.logger.KaliumLogger +import com.wire.kalium.logic.MLSFailure import com.wire.kalium.logic.ProteusFailure import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.event.Event @@ -39,7 +40,8 @@ internal interface NewMessageEventHandler { internal class NewMessageEventHandlerImpl( private val proteusMessageUnpacker: ProteusMessageUnpacker, private val mlsMessageUnpacker: MLSMessageUnpacker, - private val applicationMessageHandler: ApplicationMessageHandler + private val applicationMessageHandler: ApplicationMessageHandler, + private val mlsWrongEpochHandler: MLSWrongEpochHandler ) : NewMessageEventHandler { private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) } @@ -86,7 +88,6 @@ internal class NewMessageEventHandlerImpl( override suspend fun handleNewMLSMessage(event: Event.Conversation.NewMLSMessage) { mlsMessageUnpacker.unpackMlsMessage(event) .onFailure { - val logMap = mapOf( "event" to event.toLogMap(), "errorInfo" to "$it", @@ -95,6 +96,11 @@ internal class NewMessageEventHandlerImpl( logger.e("Failed to decrypt event: ${logMap.toJsonElement()}") + if (it is MLSFailure.WrongEpoch) { + mlsWrongEpochHandler.onMLSWrongEpoch(event.conversationId, event.timestampIso) + return@onFailure + } + applicationMessageHandler.handleDecryptionError( eventId = event.id, conversationId = event.conversationId, diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt new file mode 100644 index 00000000000..0fdb24d94f1 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt @@ -0,0 +1,235 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.sync.receiver.conversation.message + +import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.StorageFailure +import com.wire.kalium.logic.data.conversation.Conversation +import com.wire.kalium.logic.data.conversation.ConversationRepository +import com.wire.kalium.logic.data.message.MessageContent +import com.wire.kalium.logic.data.message.PersistMessageUseCase +import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase +import com.wire.kalium.logic.framework.TestConversation +import com.wire.kalium.logic.framework.TestUser +import com.wire.kalium.logic.functional.Either +import com.wire.kalium.logic.util.thenReturnSequentially +import io.mockative.Mock +import io.mockative.any +import io.mockative.classOf +import io.mockative.eq +import io.mockative.given +import io.mockative.matching +import io.mockative.mock +import io.mockative.once +import io.mockative.verify +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest +import kotlin.test.Test + +@OptIn(ExperimentalCoroutinesApi::class) +class MLSWrongEpochHandlerTest { + + @Test + fun givenConversationIsNotMLS_whenHandlingEpochFailure_thenShouldNotInsertWarning() = runTest { + val (arrangement, mlsWrongEpochHandler) = Arrangement() + .withProtocolByIdReturningSequence(Either.Right(proteusProtocol)) + .arrange() + + mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") + + verify(arrangement.persistMessageUseCase) + .suspendFunction(arrangement.persistMessageUseCase::invoke) + .with(any()) + .wasNotInvoked() + } + + @Test + fun givenConversationIsNotMLS_whenHandlingEpochFailure_thenShouldNotFetchConversationAgain() = runTest { + val (arrangement, mlsWrongEpochHandler) = Arrangement() + .withProtocolByIdReturningSequence(Either.Right(proteusProtocol)) + .arrange() + + mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::fetchConversation) + .with(any()) + .wasNotInvoked() + } + + @Test + fun givenMLSConversation_whenHandlingEpochFailure_thenShouldFetchConversationAgain() = runTest { + val (arrangement, mlsWrongEpochHandler) = Arrangement() + .withProtocolByIdReturning(Either.Right(mlsProtocol)) + .arrange() + + mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") + + verify(arrangement.conversationRepository) + .suspendFunction(arrangement.conversationRepository::fetchConversation) + .with(eq(conversationId)) + .wasInvoked(exactly = once) + } + + @Test + fun givenUpdatedMLSConversationHasDifferentEpoch_whenHandlingEpochFailure_thenShouldRejoinTheConversation() = runTest { + val (arrangement, mlsWrongEpochHandler) = Arrangement() + .withProtocolByIdReturningSequence( + Either.Right(mlsProtocol), + Either.Right(mlsProtocolWithUpdatedEpoch) + ) + .arrange() + + mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") + + verify(arrangement.joinExistingMLSConversationUseCase) + .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) + .with(eq(conversationId)) + .wasInvoked(exactly = once) + } + + @Test + fun givenUpdatedMLSConversationHasSameEpoch_whenHandlingEpochFailure_thenShouldNotRejoinTheConversation() = runTest { + val (arrangement, mlsWrongEpochHandler) = Arrangement() + .withProtocolByIdReturning(Either.Right(mlsProtocol)) + .arrange() + + mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") + + verify(arrangement.joinExistingMLSConversationUseCase) + .suspendFunction(arrangement.joinExistingMLSConversationUseCase::invoke) + .with(any()) + .wasNotInvoked() + } + + @Test + fun givenRejoiningFails_whenHandlingEpochFailure_thenShouldNotPersistAnyMessage() = runTest { + val (arrangement, mlsWrongEpochHandler) = Arrangement() + .withProtocolByIdReturningSequence( + Either.Right(mlsProtocol), + Either.Right(mlsProtocolWithUpdatedEpoch) + ) + .withJoinExistingConversationReturning(Either.Left(CoreFailure.Unknown(null))) + .arrange() + + mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, "date") + + verify(arrangement.persistMessageUseCase) + .suspendFunction(arrangement.persistMessageUseCase::invoke) + .with(any()) + .wasNotInvoked() + } + + @Test + fun givenConversationIsRejoined_whenHandlingEpochFailure_thenShouldInsertMLSWarningWithCorrectDateAndConversation() = runTest { + val date = "date" + val (arrangement, mlsWrongEpochHandler) = Arrangement() + .withProtocolByIdReturningSequence( + Either.Right(mlsProtocol), + Either.Right(mlsProtocolWithUpdatedEpoch) + ) + .arrange() + + mlsWrongEpochHandler.onMLSWrongEpoch(conversationId, date) + + verify(arrangement.persistMessageUseCase) + .suspendFunction(arrangement.persistMessageUseCase::invoke) + .with( + matching { + it.conversationId == conversationId && + it.content == MessageContent.MLSWrongEpochWarning && + it.date == date + } + ) + .wasInvoked(exactly = once) + } + + private class Arrangement { + + @Mock + val persistMessageUseCase = mock(classOf()) + + @Mock + val conversationRepository = mock(classOf()) + + @Mock + val joinExistingMLSConversationUseCase = mock(classOf()) + + init { + withFetchByIdSucceeding() + withPersistMessageSucceeding() + withJoinExistingConversationSucceeding() + } + + fun withFetchByIdReturning(result: Either) = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::fetchConversation) + .whenInvokedWith(any()) + .thenReturn(result) + } + + fun withFetchByIdSucceeding() = withFetchByIdReturning(Either.Right(Unit)) + + fun withProtocolByIdReturning(result: Either) = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationProtocolInfo) + .whenInvokedWith(any()) + .thenReturn(result) + } + + fun withProtocolByIdReturningSequence(vararg results: Either) = apply { + given(conversationRepository) + .suspendFunction(conversationRepository::getConversationProtocolInfo) + .whenInvokedWith(any()) + .thenReturnSequentially(*results) + } + + fun withPersistMessageReturning(result: Either) = apply { + given(persistMessageUseCase) + .suspendFunction(persistMessageUseCase::invoke) + .whenInvokedWith(any()) + .thenReturn(result) + } + + fun withPersistMessageSucceeding() = withPersistMessageReturning(Either.Right(Unit)) + + fun withJoinExistingConversationReturning(result: Either) = apply { + given(joinExistingMLSConversationUseCase) + .suspendFunction(joinExistingMLSConversationUseCase::invoke) + .whenInvokedWith(any()) + .thenReturn(result) + } + + fun withJoinExistingConversationSucceeding() = withJoinExistingConversationReturning(Either.Right(Unit)) + + fun arrange() = this to MLSWrongEpochHandlerImpl( + TestUser.SELF.id, + persistMessageUseCase, + conversationRepository, + joinExistingMLSConversationUseCase + ) + } + + private companion object { + val conversationId = TestConversation.CONVERSATION.id + val proteusProtocol = Conversation.ProtocolInfo.Proteus + + val mlsProtocol = TestConversation.MLS_CONVERSATION.protocol as Conversation.ProtocolInfo.MLS + val mlsProtocolWithUpdatedEpoch = mlsProtocol.copy(epoch = mlsProtocol.epoch + 1U) + } +} diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt index 10960ed5bac..ad01052793d 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/NewMessageEventHandlerTest.kt @@ -20,6 +20,7 @@ package com.wire.kalium.logic.sync.receiver.conversation.message import com.wire.kalium.cryptography.exceptions.ProteusException import com.wire.kalium.logic.CoreFailure +import com.wire.kalium.logic.MLSFailure import com.wire.kalium.logic.ProteusFailure import com.wire.kalium.logic.framework.TestEvent import com.wire.kalium.logic.functional.Either @@ -114,6 +115,38 @@ class NewMessageEventHandlerTest { .wasInvoked(exactly = once) } + @Test + fun givenMLSEventFailsWithWrongEpoch_whenHandling_shouldCallWrongEpochHandler() = runTest { + val (arrangement, newMessageEventHandler) = Arrangement() + .withMLSUnpackerReturning(Either.Left(MLSFailure.WrongEpoch)) + .arrange() + + val newMessageEvent = TestEvent.newMLSMessageEvent(DateTimeUtil.currentInstant()) + + newMessageEventHandler.handleNewMLSMessage(newMessageEvent) + + verify(arrangement.mlsWrongEpochHandler) + .suspendFunction(arrangement.mlsWrongEpochHandler::onMLSWrongEpoch) + .with(eq(newMessageEvent.conversationId),eq(newMessageEvent.timestampIso)) + .wasInvoked(exactly = once) + } + + @Test + fun givenMLSEventFailsWithWrongEpoch_whenHandling_shouldNotPersistDecryptionErrorMessage() = runTest { + val (arrangement, newMessageEventHandler) = Arrangement() + .withMLSUnpackerReturning(Either.Left(MLSFailure.WrongEpoch)) + .arrange() + + val newMessageEvent = TestEvent.newMLSMessageEvent(DateTimeUtil.currentInstant()) + + newMessageEventHandler.handleNewMLSMessage(newMessageEvent) + + verify(arrangement.applicationMessageHandler) + .suspendFunction(arrangement.applicationMessageHandler::handleDecryptionError) + .with(any()) + .wasNotInvoked() + } + private class Arrangement { @Mock @@ -127,8 +160,14 @@ class NewMessageEventHandlerTest { stubsUnitByDefault = true } + @Mock + val mlsWrongEpochHandler = mock(classOf()) + private val newMessageEventHandler: NewMessageEventHandler = NewMessageEventHandlerImpl( - proteusMessageUnpacker, mlsMessageUnpacker, applicationMessageHandler + proteusMessageUnpacker, + mlsMessageUnpacker, + applicationMessageHandler, + mlsWrongEpochHandler ) fun withProteusUnpackerReturning(result: Either) = apply { diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/MockativeExt.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/MockativeExt.kt new file mode 100644 index 00000000000..90ff0bd93d7 --- /dev/null +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/util/MockativeExt.kt @@ -0,0 +1,53 @@ +/* + * Wire + * Copyright (C) 2023 Wire Swiss GmbH + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package com.wire.kalium.logic.util + +import io.mockative.AnyResultBuilder +import io.mockative.AnySuspendResultBuilder + +/** + * Sets up the mock to return the given results sequentially. That is, + * the first call to the mock will return the first value provided in [results], the second + * call will return the second value, and so on. + */ +fun AnyResultBuilder.thenReturnSequentially(vararg results: R) { + var index = -1 + return thenInvoke { + index += 1 + require(index <= results.lastIndex) { + "Function called more times than expected. No result set for index $index" + } + results[index++] + } +} + +/** + * Sets up the mock to return the given results sequentially. That is, + * the first call to the mock will return the first value provided in [results], the second + * call will return the second value, and so on. + */ +fun AnySuspendResultBuilder.thenReturnSequentially(vararg results: R) { + var index = -1 + return thenInvoke { + index += 1 + require(index <= results.lastIndex) { + "Function called more times than expected. No result set for index $index" + } + results[index] + } +} diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageEntity.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageEntity.kt index bf7512a59d7..b6a7d4f5eb2 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageEntity.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageEntity.kt @@ -189,7 +189,7 @@ sealed class MessageEntity( TEXT, ASSET, KNOCK, MEMBER_CHANGE, MISSED_CALL, RESTRICTED_ASSET, CONVERSATION_RENAMED, UNKNOWN, FAILED_DECRYPTION, REMOVED_FROM_TEAM, CRYPTO_SESSION_RESET, NEW_CONVERSATION_RECEIPT_MODE, CONVERSATION_RECEIPT_MODE_CHANGED, HISTORY_LOST, CONVERSATION_MESSAGE_TIMER_CHANGED, - CONVERSATION_CREATED + CONVERSATION_CREATED, MLS_WRONG_EPOCH_WARNING } enum class MemberChangeType { @@ -293,6 +293,8 @@ sealed class MessageEntityContent { val senderClientId: String?, ) : Regular() + object MLSWrongEpochWarning : System() + data class MemberChange( val memberUserIdList: List, val memberChangeType: MessageEntity.MemberChangeType diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtension.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtension.kt index 5090d033e3f..09edadf8deb 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtension.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageInsertExtension.kt @@ -222,6 +222,10 @@ internal class MessageInsertExtensionImpl( is MessageEntityContent.ConversationCreated -> { /* no-op */ } + + is MessageEntityContent.MLSWrongEpochWarning -> { + /* no-op */ + } } } @@ -317,5 +321,6 @@ internal class MessageInsertExtensionImpl( is MessageEntityContent.HistoryLost -> MessageEntity.ContentType.HISTORY_LOST is MessageEntityContent.ConversationMessageTimerChanged -> MessageEntity.ContentType.CONVERSATION_MESSAGE_TIMER_CHANGED is MessageEntityContent.ConversationCreated -> MessageEntity.ContentType.CONVERSATION_CREATED + is MessageEntityContent.MLSWrongEpochWarning -> MessageEntity.ContentType.MLS_WRONG_EPOCH_WARNING } } diff --git a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageMapper.kt b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageMapper.kt index 75b3e258d25..71f1b9f8643 100644 --- a/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageMapper.kt +++ b/persistence/src/commonMain/kotlin/com/wire/kalium/persistence/dao/message/MessageMapper.kt @@ -185,6 +185,7 @@ object MessageMapper { MessageEntity.ContentType.HISTORY_LOST -> MessagePreviewEntityContent.Unknown MessageEntity.ContentType.CONVERSATION_MESSAGE_TIMER_CHANGED -> MessagePreviewEntityContent.Unknown MessageEntity.ContentType.CONVERSATION_CREATED -> MessagePreviewEntityContent.Unknown + MessageEntity.ContentType.MLS_WRONG_EPOCH_WARNING -> MessagePreviewEntityContent.Unknown } } @@ -493,6 +494,7 @@ object MessageMapper { ) MessageEntity.ContentType.CONVERSATION_CREATED -> MessageEntityContent.ConversationCreated + MessageEntity.ContentType.MLS_WRONG_EPOCH_WARNING -> MessageEntityContent.MLSWrongEpochWarning } return createMessageEntity(