diff --git a/src/node.rs b/src/node.rs index 654c05a..b21ab0f 100644 --- a/src/node.rs +++ b/src/node.rs @@ -125,12 +125,15 @@ impl Node { ) { log::debug!("Starting... {}", self.bind_address); let node_id = self.get_node_id().clone(); - let state = self.state.clone(); - let (round_one_tx, round_one_rx) = mpsc::channel::<()>(1); - self.state.round_one_tx = Some(round_one_tx.clone()); - let (round_two_tx, round_two_rx) = mpsc::channel::<()>(1); - self.state.round_two_tx = Some(round_two_tx.clone()); let echo_broadcast_handle = self.echo_broadcast_handle.clone(); + + // We can send message on the channel from both sending and receiving tasks + let (round_one_tx, round_one_rx) = mpsc::channel::<()>(2); + self.state.round_one_tx = Some(round_one_tx); + let (round_two_tx, round_two_rx) = mpsc::channel::<()>(2); + self.state.round_two_tx = Some(round_two_tx); + + let state = self.state.clone(); tokio::spawn(async move { dkg::trigger::run_dkg_trigger( 15000, diff --git a/src/node/protocol/dkg/round_one.rs b/src/node/protocol/dkg/round_one.rs index fce0f13..426712d 100644 --- a/src/node/protocol/dkg/round_one.rs +++ b/src/node/protocol/dkg/round_one.rs @@ -45,7 +45,7 @@ impl PackageMessage { /// Builds a round one package using the frost-secp256k1 crate async fn build_round1_package( sender_id: String, - state: crate::node::state::State, + state: &crate::node::state::State, ) -> Result { let (max_signers, min_signers) = get_max_min_signers(&state).await; @@ -123,8 +123,19 @@ impl Service for Package { _message_id, ) => { log::debug!("Build round one package"); - let response = build_round1_package(this_sender_id, state).await?; + let response = build_round1_package(this_sender_id, &state).await?; log::info!("Sending round one package {:?}", response); + let finished = state + .dkg_state + .get_received_round1_packages() + .await + .unwrap() + .len() + == state.dkg_state.get_expected_members().await.unwrap(); + if finished { + log::debug!("Round one finished, sending signal"); + let _ = state.round_one_tx.unwrap().send(()).await; + } Ok(Some(response)) } Message::Broadcast( @@ -207,7 +218,7 @@ mod round_one_package_tests { let membership_handle = build_membership(3).await; let state = State::new(membership_handle, message_id_generator).await; - let round1_package = build_round1_package("local".into(), state).await.unwrap(); + let round1_package = build_round1_package("local".into(), &state).await.unwrap(); // Extract the public key package from the NetworkMessage if let Message::Broadcast(BroadcastProtocol::DKGRoundOnePackage(pkg_msg), _message_id) = @@ -232,7 +243,7 @@ mod round_one_package_tests { let state_clone = state.clone(); // First create a round1 package that we'll pretend came from another node - let round1_package = build_round1_package("remote".into(), state).await.unwrap(); + let round1_package = build_round1_package("remote".into(), &state).await.unwrap(); // Create our local package service let mut pkg = Package::new("local".into(), state_clone); diff --git a/src/node/protocol/dkg/round_two.rs b/src/node/protocol/dkg/round_two.rs index f46097d..4e84319 100644 --- a/src/node/protocol/dkg/round_two.rs +++ b/src/node/protocol/dkg/round_two.rs @@ -141,6 +141,17 @@ impl Service for Package { log::error!("Failed to send round2 packages: {:?}", e); return Err(e.into()); } + let finished = state + .dkg_state + .get_received_round2_packages() + .await + .unwrap() + .len() + == state.dkg_state.get_expected_members().await.unwrap(); + if finished { + log::debug!("Round two finished on send, sending signal"); + state.round_two_tx.unwrap().send(()).await?; + } log::debug!("Sent round2 packages"); Ok(None) } @@ -155,6 +166,11 @@ impl Service for Package { message: Some(message), // received a message })) => { // Received round2 message and save it in state + log::debug!( + "Received round two message from {} \n {:?}", + from_sender_id, + message + ); let identifier = frost::Identifier::derive(from_sender_id.as_bytes()).unwrap(); let finished = state .dkg_state @@ -162,8 +178,8 @@ impl Service for Package { .await .unwrap(); if finished { - log::debug!("Round two finished, sending signal"); - let _ = state.round_two_tx.unwrap().send(()).await; + log::debug!("Round two finished on receive, sending signal"); + state.round_two_tx.unwrap().send(()).await?; } Ok(None) } diff --git a/src/node/protocol/dkg/trigger.rs b/src/node/protocol/dkg/trigger.rs index e5e590f..9973519 100644 --- a/src/node/protocol/dkg/trigger.rs +++ b/src/node/protocol/dkg/trigger.rs @@ -131,10 +131,10 @@ pub(crate) async fn trigger_dkg( // TODO Improve this to allow round1 to finish as soon as all other parties have sent their round1 message // This will mean moving the timeout into round1 service - // Wait for round1 to finish, give it 5 seconds - if round1_future.await.is_err() { - log::error!("Error running round 1"); - return Err("Error running round 1".into()); + // Start round1 + if let Err(e) = round1_future.await { + log::error!("Error running round 1: {:?}", e); + return Err("Error running round 1: failed with error".into()); } round_one_rx.recv().await.unwrap(); log::info!("Round 1 finished"); @@ -149,9 +149,9 @@ pub(crate) async fn trigger_dkg( ); // start round2 - if round2_future.await.is_err() { - log::error!("Error running round 2"); - return Err("Error running round 2".into()); + if let Err(e) = round2_future.await { + log::error!("Error running round 2: {:?}", e); + return Err("Error running round 2: failed with error".into()); } round_two_rx.recv().await.unwrap(); log::info!("Round 2 finished");