diff --git a/src/node/protocol/dkg/round_one.rs b/src/node/protocol/dkg/round_one.rs index 16803f0..4419a75 100644 --- a/src/node/protocol/dkg/round_one.rs +++ b/src/node/protocol/dkg/round_one.rs @@ -154,6 +154,7 @@ mod round_one_package_tests { assert!(res.is_some()); assert_eq!(res.unwrap().get_sender_id(), "local"); } + #[tokio::test] async fn it_should_serialize_and_deserialize_round_one_public_key_package() { let message_id_generator = MessageIdGenerator::new("localhost".to_string()); diff --git a/src/node/protocol/dkg/state.rs b/src/node/protocol/dkg/state.rs index b8cde20..927af2c 100644 --- a/src/node/protocol/dkg/state.rs +++ b/src/node/protocol/dkg/state.rs @@ -48,7 +48,7 @@ impl State { /// Message for state handle to actor communication pub(crate) enum StateMessage { /// Add a received round1 package to state - AddRound1Package(dkg::round1::Package, oneshot::Sender<()>), + AddRound1Package(frost::Identifier, dkg::round1::Package, oneshot::Sender<()>), /// Add a received secret package to state AddSecretPackage(frost::keys::dkg::round1::SecretPackage, oneshot::Sender<()>), } @@ -110,7 +110,7 @@ impl StateHandle { package: dkg::round1::Package, respond_to: oneshot::Sender<()>, ) { - let message = StateMessage::AddRound1Package(package, respond_to); + let message = StateMessage::AddRound1Package(identifier, package, respond_to); let _ = self.sender.send(message).await; } @@ -203,8 +203,7 @@ mod dkg_state_handle_tests { #[tokio::test] async fn test_state_handle_add_round1_package() { let (tx, mut rx) = mpsc::channel(1); - let handle = StateHandle::new(tx); - + let state_handle = StateHandle::new(tx); let identifier = frost::Identifier::derive(b"1").unwrap(); let (_secret_package, package) = frost::keys::dkg::part1(identifier, 3, 2, thread_rng()).unwrap(); @@ -212,12 +211,15 @@ mod dkg_state_handle_tests { let (respond_tx, _respond_rx) = oneshot::channel(); // Send the package - handle + state_handle .add_round1_package(identifier, package.clone(), respond_tx) .await; // Verify the message was received correctly - if let Some(StateMessage::AddRound1Package(received_package, _)) = rx.try_recv().ok() { + if let Some(StateMessage::AddRound1Package(received_identifier, received_package, _)) = + rx.try_recv().ok() + { + assert_eq!(received_identifier, identifier); assert_eq!(received_package, package); } else { panic!("Failed to receive the expected message"); @@ -226,7 +228,7 @@ mod dkg_state_handle_tests { #[tokio::test] async fn test_state_handle_add_secret_package() { let (tx, mut rx) = mpsc::channel(1); - let handle = StateHandle::new(tx); + let state_handle = StateHandle::new(tx); let identifier = frost::Identifier::derive(b"1").unwrap(); let (secret_package, _package) = @@ -235,7 +237,7 @@ mod dkg_state_handle_tests { let (respond_tx, _respond_rx) = oneshot::channel(); // Send the secret package - handle + state_handle .add_secret_package(secret_package.clone(), respond_tx) .await;