diff --git a/src/node.rs b/src/node.rs index cf2d955..e15a3e9 100644 --- a/src/node.rs +++ b/src/node.rs @@ -361,7 +361,7 @@ impl Node { ) { tokio::spawn(async move { let protocol_service = - Protocol::new(node_id.clone(), state, reliable_sender_handle.clone()); + Protocol::new(node_id.clone(), state, Some(reliable_sender_handle.clone())); let reliable_sender_service = ReliableSend::new(protocol_service, reliable_sender_handle); let timeout_layer = @@ -390,7 +390,7 @@ impl Node { let protocol_service = Protocol::new( node_id.clone(), state.clone(), - reliable_sender_handle.clone(), + Some(reliable_sender_handle.clone()), ); let echo_broadcast_service = EchoBroadcast::new(protocol_service, echo_broadcast_handle, state, node_id); diff --git a/src/node/echo_broadcast/service.rs b/src/node/echo_broadcast/service.rs index a40c203..3014be4 100644 --- a/src/node/echo_broadcast/service.rs +++ b/src/node/echo_broadcast/service.rs @@ -155,8 +155,11 @@ mod echo_broadcast_service_tests { } .into(); - let handshake_service = - Protocol::new("localhost".to_string(), state.clone(), mock_reliable_sender); + let handshake_service = Protocol::new( + "localhost".to_string(), + state.clone(), + Some(mock_reliable_sender), + ); let echo_broadcast_service = EchoBroadcast::new( handshake_service, echo_bcast_handle, diff --git a/src/node/protocol.rs b/src/node/protocol.rs index 96f9c69..b63f0e6 100644 --- a/src/node/protocol.rs +++ b/src/node/protocol.rs @@ -152,11 +152,11 @@ impl From for Message { pub struct Protocol { node_id: String, state: State, - peer_sender: ReliableSenderHandle, + peer_sender: Option, } impl Protocol { - pub fn new(node_id: String, state: State, peer_sender: ReliableSenderHandle) -> Self { + pub fn new(node_id: String, state: State, peer_sender: Option) -> Self { Protocol { node_id, state, @@ -230,7 +230,7 @@ mod protocol_tests { let membership_handle = MembershipHandle::start("localhost".to_string()).await; let state = State::new(membership_handle, message_id_generator).await; - let p = Protocol::new("local".into(), state, reliable_sender_handle); + let p = Protocol::new("local".into(), state, Some(reliable_sender_handle)); let m = p.oneshot(PingMessage::default().into()).await; assert!(m.unwrap().is_some()); } diff --git a/src/node/protocol/dkg/trigger.rs b/src/node/protocol/dkg/trigger.rs index 986897a..8e29464 100644 --- a/src/node/protocol/dkg/trigger.rs +++ b/src/node/protocol/dkg/trigger.rs @@ -112,7 +112,7 @@ pub(crate) async fn trigger_dkg( reliable_sender_handle: ReliableSenderHandle, ) -> Result<(KeyPackage, PublicKeyPackage), BoxError> { let protocol_service: Protocol = - Protocol::new(node_id.clone(), state.clone(), reliable_sender_handle); + Protocol::new(node_id.clone(), state.clone(), Some(reliable_sender_handle)); let round1_future = build_round1_future( node_id.clone(), diff --git a/src/node/protocol/handshake.rs b/src/node/protocol/handshake.rs index 555dce5..1b4b19a 100644 --- a/src/node/protocol/handshake.rs +++ b/src/node/protocol/handshake.rs @@ -47,11 +47,11 @@ impl HandshakeMessage { pub struct Handshake { node_id: String, state: State, - peer_sender: ReliableSenderHandle, + peer_sender: Option, } impl Handshake { - pub fn new(node_id: String, state: State, peer_sender: ReliableSenderHandle) -> Self { + pub fn new(node_id: String, state: State, peer_sender: Option) -> Self { Handshake { node_id, state, @@ -86,7 +86,7 @@ impl Service for Handshake { })) => match message.as_str() { "helo" => { if membership_handle - .add_member(sender_id, peer_sender) + .add_member(sender_id, peer_sender.unwrap()) .await .is_err() { @@ -102,7 +102,7 @@ impl Service for Handshake { } "oleh" => { if membership_handle - .add_member(sender_id, peer_sender) + .add_member(sender_id, peer_sender.unwrap()) .await .is_err() { @@ -154,7 +154,7 @@ mod handshake_tests { .expect_clone() .returning(ReliableSenderHandle::default); - let mut p = Handshake::new("local".to_string(), state, reliable_sender_handle); + let mut p = Handshake::new("local".to_string(), state, Some(reliable_sender_handle)); let res = p .ready() @@ -185,7 +185,7 @@ mod handshake_tests { .expect_clone() .returning(ReliableSenderHandle::default); - let mut p = Handshake::new("local".to_string(), state, reliable_sender_handle); + let mut p = Handshake::new("local".to_string(), state, Some(reliable_sender_handle)); let res = p .ready() @@ -224,7 +224,7 @@ mod handshake_tests { mock }); - let mut p = Handshake::new("local".to_string(), state, reliable_sender_handle); + let mut p = Handshake::new("local".to_string(), state, Some(reliable_sender_handle)); let res = p .ready() diff --git a/src/node/protocol/init.rs b/src/node/protocol/init.rs index 94e916c..548db20 100644 --- a/src/node/protocol/init.rs +++ b/src/node/protocol/init.rs @@ -32,7 +32,7 @@ pub(crate) async fn initialize_handshake( reliable_sender_handle: ReliableSenderHandle, delivery_timeout: u64, ) { - let protocol_service = Protocol::new(node_id, state, reliable_sender_handle.clone()); + let protocol_service = Protocol::new(node_id, state, Some(reliable_sender_handle.clone())); let reliable_sender_service = ReliableSend::new(protocol_service, reliable_sender_handle); let timeout_layer = TimeoutLayer::new(Duration::from_millis(delivery_timeout)); let _ = timeout_layer @@ -50,7 +50,7 @@ pub(crate) async fn send_membership( delivery_time: u64, ) { log::info!("Sending membership information"); - let protocol_service = Protocol::new(node_id.clone(), state, sender.clone()); + let protocol_service = Protocol::new(node_id.clone(), state, Some(sender.clone())); let reliable_sender_service = ReliableSend::new(protocol_service, sender); let timeout_layer = tower::timeout::TimeoutLayer::new(tokio::time::Duration::from_millis(delivery_time));