Skip to content

Commit

Permalink
Support adding received round2 packages to state
Browse files Browse the repository at this point in the history
  • Loading branch information
pool2win committed Nov 22, 2024
1 parent 1678962 commit f55dcf0
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 61 deletions.
143 changes: 87 additions & 56 deletions src/node/protocol/dkg/round_two.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl Service<Message> for Package {
match msg {
Message::Unicast(Unicast::DKGRoundTwoPackage(PackageMessage {
sender_id: _,
message: None, // message is None, so we build a new round2 package
message: None, // message is None, so we build a new round2 package and send it
})) => {
match build_round2_packages(sender_id, state.clone()).await {
Ok((round2_secret_package, round2_packages)) => {
Expand All @@ -147,9 +147,25 @@ impl Service<Message> for Package {
}
}
}
Message::Unicast(Unicast::DKGRoundTwoPackage(PackageMessage {
sender_id: from_sender_id,
message: Some(message), // received a message
})) => {
// Received round2 message and save it in state
let identifier = frost::Identifier::derive(from_sender_id.as_bytes()).unwrap();
state
.dkg_state
.add_round2_package(identifier, message)
.await
.unwrap();
Ok(None)
}
_ => {
// Received round2 message
todo!()
log::error!(
"Not a Unicast message {:?}. Should not happen, but we need to match all message types",
msg
);
Ok(None)
}
}
}
Expand Down Expand Up @@ -195,9 +211,8 @@ mod round_two_tests {
use crate::node::protocol::message_id_generator::MessageIdGenerator;
#[mockall_double::double]
use crate::node::reliable_sender::ReliableSenderHandle;
use crate::node::test_helpers::support::build_membership;
use node::{dkg::state::Round1Map, MembershipHandle};
use rand::thread_rng;
use crate::node::test_helpers::support::{build_membership, build_round2_state};
use node::MembershipHandle;

#[tokio::test]
async fn test_build_round2_packages_insufficient_packages() {
Expand All @@ -206,60 +221,12 @@ mod round_two_tests {
membership_handle,
MessageIdGenerator::new("local".to_string()),
);
let result = build_round2_packages("node1".to_string(), state)
let result = build_round2_packages("localhost".to_string(), state)
.await
.unwrap_err();
assert_eq!(result, frost::Error::InvalidSecretShare);
}

async fn build_round2_state(state: node::state::State) -> (node::state::State, Round1Map) {
let rng = thread_rng();
let mut round1_packages = Round1Map::new();

// generate our round1 secret and package
let (secret_package, round1_package) = frost::keys::dkg::part1(
frost::Identifier::derive(b"node1").unwrap(),
3,
2,
rng.clone(),
)
.unwrap();
log::debug!("Secret package {:?}", secret_package);

// add our secret package to state
state
.dkg_state
.add_round1_secret_package(secret_package)
.await
.unwrap();

// Add packages for other nodes
let (_, round1_package2) = frost::keys::dkg::part1(
frost::Identifier::derive(b"node2").unwrap(),
3,
2,
rng.clone(),
)
.unwrap();
round1_packages.insert(
frost::Identifier::derive(b"node2").unwrap(),
round1_package2,
);

let (_, round1_package3) = frost::keys::dkg::part1(
frost::Identifier::derive(b"node3").unwrap(),
3,
2,
rng.clone(),
)
.unwrap();
round1_packages.insert(
frost::Identifier::derive(b"node3").unwrap(),
round1_package3,
);
(state, round1_packages)
}

#[tokio::test]
async fn test_build_round2_packages_should_succeed() {
let membership_handle = MembershipHandle::start("localhost".to_string()).await;
Expand Down Expand Up @@ -289,7 +256,7 @@ mod round_two_tests {
.unwrap();
}

let result = build_round2_packages("node1".to_string(), state).await;
let result = build_round2_packages("localhost".to_string(), state).await;
assert!(result.is_ok());
let (round2_secret, round2_packages) = result.unwrap();
assert_eq!(round2_packages.len(), 2);
Expand Down Expand Up @@ -400,4 +367,68 @@ mod round_two_tests {
.await;
assert!(res.is_err());
}

#[tokio::test]
async fn test_add_received_round2_package() {
let _ = env_logger::try_init();

let membership_handle = MembershipHandle::start("localhost".to_string()).await;

let state = node::state::State::new(
membership_handle.clone(),
MessageIdGenerator::new("localhost".to_string()),
);

for i in 1..3 {
let mut mock_reliable_sender = ReliableSenderHandle::default();
mock_reliable_sender.expect_clone().returning(|| {
let mut mock = ReliableSenderHandle::default();
mock.expect_clone().returning(ReliableSenderHandle::default);
mock.expect_send()
//.times(1)
.returning(|_| futures::future::ok(()).boxed());
mock
});
let _ = membership_handle
.add_member(format!("localhost{}", i), mock_reliable_sender)
.await;
}

let (mut state, round1_packages) = build_round2_state(state).await;

// add all round1 packages to state
for (identifier, package) in round1_packages {
state
.dkg_state
.add_round1_package(identifier, package)
.await
.unwrap();
}

let (round2_secret_package, round2_map) =
build_round2_packages("localhost".to_string(), state.clone())
.await
.unwrap();
let (identifier, message) = round2_map.iter().next().unwrap();

// call the package service to handle received round2 packages
let mut pkg = Package::new("localhost".to_string(), state.clone());
let res = pkg
.call(Message::Unicast(Unicast::DKGRoundTwoPackage(
PackageMessage::new("localhost1".to_string(), Some(message.clone())),
)))
.await;
assert!(res.is_ok());

let received_round2_package = state
.dkg_state
.get_received_round2_packages()
.await
.unwrap()
.keys()
.next()
.unwrap()
.clone();
assert!(round2_map.contains_key(&received_round2_package));
}
}
10 changes: 5 additions & 5 deletions src/node/test_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub(crate) mod support {

// generate our round1 secret and package
let (secret_package, round1_package) = frost::keys::dkg::part1(
frost::Identifier::derive(b"node1").unwrap(),
frost::Identifier::derive(b"localhost").unwrap(),
3,
2,
rng.clone(),
Expand All @@ -66,26 +66,26 @@ pub(crate) mod support {

// Add packages for other nodes
let (_, round1_package2) = frost::keys::dkg::part1(
frost::Identifier::derive(b"node2").unwrap(),
frost::Identifier::derive(b"localhost1").unwrap(),
3,
2,
rng.clone(),
)
.unwrap();
round1_packages.insert(
frost::Identifier::derive(b"node2").unwrap(),
frost::Identifier::derive(b"localhost1").unwrap(),
round1_package2,
);

let (_, round1_package3) = frost::keys::dkg::part1(
frost::Identifier::derive(b"node3").unwrap(),
frost::Identifier::derive(b"localhost2").unwrap(),
3,
2,
rng.clone(),
)
.unwrap();
round1_packages.insert(
frost::Identifier::derive(b"node3").unwrap(),
frost::Identifier::derive(b"localhost2").unwrap(),
round1_package3,
);
(state, round1_packages)
Expand Down

0 comments on commit f55dcf0

Please sign in to comment.