Skip to content

Commit

Permalink
feat(mls): handle wrong epoch error [WPB-1803] (#1817)
Browse files Browse the repository at this point in the history
* feat(mls): [WIP] handle MLS wrong epoch when receiving messages

* test: cover MLSWrongEpochHandler with tests

* Update logic/src/commonTest/kotlin/com/wire/kalium/logic/sync/receiver/conversation/message/MLSWrongEpochHandlerTest.kt

Co-authored-by: Jacob Persson <7156+typfel@users.noreply.github.com>

* refactor: select protocol info directly instead of the whole conversation

---------

Co-authored-by: Jacob Persson <7156+typfel@users.noreply.github.com>
  • Loading branch information
vitorhugods and typfel authored Jun 22, 2023
1 parent f445ee8 commit fddd913
Show file tree
Hide file tree
Showing 14 changed files with 468 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ interface MLSFailure : CoreFailure {

object WrongEpoch : MLSFailure

object ConversationDoesNotSupportMLS : MLSFailure

class Generic(internal val exception: Exception) : MLSFailure {
val rootCause: Throwable get() = exception
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ sealed interface Message {
is MessageContent.ConversationCreated -> mutableMapOf(
typeKey to "conversationCreated"
)

is MessageContent.MLSWrongEpochWarning -> mutableMapOf(
typeKey to "mlsWrongEpochWarning"
)
}

val standardProperties = mapOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ sealed class MessageContent {
val clientId: ClientId? = null
) : Regular()

object MLSWrongEpochWarning : System()

object ClientAction : Signaling()

object CryptoSessionReset : System()
Expand Down Expand Up @@ -276,6 +278,7 @@ fun MessageContent?.getType() = when (this) {
is MessageContent.ConversationCreated -> "ConversationCreated"
is MessageContent.MemberChange.CreationAdded -> "MemberChange.CreationAdded"
is MessageContent.MemberChange.FailedToAdd -> "MemberChange.FailedToAdd"
is MessageContent.MLSWrongEpochWarning -> "MLSWrongEpochWarning"
null -> "Unknown"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class MessageMapperImpl(
MessageEntity.ContentType.HISTORY_LOST -> null
MessageEntity.ContentType.CONVERSATION_MESSAGE_TIMER_CHANGED -> null
MessageEntity.ContentType.CONVERSATION_CREATED -> null
MessageEntity.ContentType.MLS_WRONG_EPOCH_WARNING -> null
}
}

Expand Down Expand Up @@ -319,6 +320,7 @@ class MessageMapperImpl(
is MessageContent.HistoryLost -> MessageEntityContent.HistoryLost
is MessageContent.ConversationMessageTimerChanged -> MessageEntityContent.ConversationMessageTimerChanged(messageTimer)
is MessageContent.ConversationCreated -> MessageEntityContent.ConversationCreated
is MessageContent.MLSWrongEpochWarning -> MessageEntityContent.MLSWrongEpochWarning
}

private fun MessageEntityContent.Regular.toMessageContent(hidden: Boolean): MessageContent.Regular = when (this) {
Expand Down Expand Up @@ -405,6 +407,7 @@ class MessageMapperImpl(
is MessageEntityContent.HistoryLost -> MessageContent.HistoryLost
is MessageEntityContent.ConversationMessageTimerChanged -> MessageContent.ConversationMessageTimerChanged(messageTimer)
is MessageEntityContent.ConversationCreated -> MessageContent.ConversationCreated
is MessageEntityContent.MLSWrongEpochWarning -> MessageContent.MLSWrongEpochWarning
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,6 @@ internal class PersistMessageUseCaseImpl(
is MessageContent.MemberChange.CreationAdded -> false
is MessageContent.MemberChange.FailedToAdd -> false
is MessageContent.ConversationCreated -> false
is MessageContent.MLSWrongEpochWarning -> false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ import com.wire.kalium.logic.sync.receiver.conversation.ReceiptModeUpdateEventHa
import com.wire.kalium.logic.sync.receiver.conversation.ReceiptModeUpdateEventHandlerImpl
import com.wire.kalium.logic.sync.receiver.conversation.RenamedConversationEventHandler
import com.wire.kalium.logic.sync.receiver.conversation.RenamedConversationEventHandlerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandler
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSWrongEpochHandlerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessageHandler
import com.wire.kalium.logic.sync.receiver.conversation.message.ApplicationMessageHandlerImpl
import com.wire.kalium.logic.sync.receiver.conversation.message.MLSMessageUnpacker
Expand Down Expand Up @@ -953,9 +955,17 @@ class UserSessionScope internal constructor(
userId
)

private val mlsWrongEpochHandler: MLSWrongEpochHandler
get() = MLSWrongEpochHandlerImpl(
selfUserId = userId,
persistMessage = persistMessage,
conversationRepository = conversationRepository,
joinExistingMLSConversation = joinExistingMLSConversationUseCase
)

private val newMessageHandler: NewMessageEventHandlerImpl
get() = NewMessageEventHandlerImpl(
proteusUnpacker, mlsUnpacker, applicationMessageHandler
proteusUnpacker, mlsUnpacker, applicationMessageHandler, mlsWrongEpochHandler
)

private val newConversationHandler: NewConversationEventHandler
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Wire
* Copyright (C) 2023 Wire Swiss GmbH
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.wire.kalium.logic.sync.receiver.conversation.message

import com.benasher44.uuid.uuid4
import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.data.conversation.Conversation
import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.message.Message
import com.wire.kalium.logic.data.message.MessageContent
import com.wire.kalium.logic.data.message.PersistMessageUseCase
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.conversation.JoinExistingMLSConversationUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.kaliumLogger

interface MLSWrongEpochHandler {
suspend fun onMLSWrongEpoch(
conversationId: ConversationId,
dateIso: String,
)
}

internal class MLSWrongEpochHandlerImpl(
private val selfUserId: UserId,
private val persistMessage: PersistMessageUseCase,
private val conversationRepository: ConversationRepository,
private val joinExistingMLSConversation: JoinExistingMLSConversationUseCase
) : MLSWrongEpochHandler {

private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) }

override suspend fun onMLSWrongEpoch(
conversationId: ConversationId,
dateIso: String,
) {
logger.i("Handling MLS WrongEpoch result")
conversationRepository.getConversationProtocolInfo(conversationId).flatMap { protocol ->
if (protocol is Conversation.ProtocolInfo.MLS) {
Either.Right(protocol)
} else {
Either.Left(MLSFailure.ConversationDoesNotSupportMLS)
}
}.flatMap { currentProtocol ->
getUpdatedConversationEpoch(conversationId).map { updatedEpoch ->
updatedEpoch != null && updatedEpoch != currentProtocol.epoch
}
}.flatMap { isRejoinNeeded ->
if (isRejoinNeeded) {
joinExistingMLSConversation(conversationId)
} else Either.Right(Unit)
}.flatMap {
insertInfoMessage(conversationId, dateIso)
}
}

private suspend fun getUpdatedConversationEpoch(conversationId: ConversationId): Either<CoreFailure, ULong?> {
return conversationRepository.fetchConversation(conversationId).flatMap {
conversationRepository.getConversationProtocolInfo(conversationId)
}.map { updatedProtocol ->
(updatedProtocol as? Conversation.ProtocolInfo.MLS)?.epoch
}
}

private suspend fun insertInfoMessage(conversationId: ConversationId, dateIso: String): Either<CoreFailure, Unit> {
val mlsEpochWarningMessage = Message.System(
id = uuid4().toString(),
content = MessageContent.MLSWrongEpochWarning,
conversationId = conversationId,
date = dateIso,
senderUserId = selfUserId,
status = Message.Status.READ,
visibility = Message.Visibility.VISIBLE,
senderUserName = null
)
return persistMessage(mlsEpochWarningMessage)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package com.wire.kalium.logic.sync.receiver.conversation.message

import com.wire.kalium.cryptography.exceptions.ProteusException
import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.ProteusFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.event.Event
Expand All @@ -39,7 +40,8 @@ internal interface NewMessageEventHandler {
internal class NewMessageEventHandlerImpl(
private val proteusMessageUnpacker: ProteusMessageUnpacker,
private val mlsMessageUnpacker: MLSMessageUnpacker,
private val applicationMessageHandler: ApplicationMessageHandler
private val applicationMessageHandler: ApplicationMessageHandler,
private val mlsWrongEpochHandler: MLSWrongEpochHandler
) : NewMessageEventHandler {

private val logger by lazy { kaliumLogger.withFeatureId(KaliumLogger.Companion.ApplicationFlow.EVENT_RECEIVER) }
Expand Down Expand Up @@ -86,7 +88,6 @@ internal class NewMessageEventHandlerImpl(
override suspend fun handleNewMLSMessage(event: Event.Conversation.NewMLSMessage) {
mlsMessageUnpacker.unpackMlsMessage(event)
.onFailure {

val logMap = mapOf(
"event" to event.toLogMap(),
"errorInfo" to "$it",
Expand All @@ -95,6 +96,11 @@ internal class NewMessageEventHandlerImpl(

logger.e("Failed to decrypt event: ${logMap.toJsonElement()}")

if (it is MLSFailure.WrongEpoch) {
mlsWrongEpochHandler.onMLSWrongEpoch(event.conversationId, event.timestampIso)
return@onFailure
}

applicationMessageHandler.handleDecryptionError(
eventId = event.id,
conversationId = event.conversationId,
Expand Down
Loading

0 comments on commit fddd913

Please sign in to comment.