Skip to content

Commit

Permalink
Run dkg state actor properly
Browse files Browse the repository at this point in the history
  • Loading branch information
pool2win committed Nov 18, 2024
1 parent 1a9dc9a commit db761e5
Showing 1 changed file with 44 additions and 41 deletions.
85 changes: 44 additions & 41 deletions src/node/protocol/dkg/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,30 @@ pub(crate) struct Actor {
}

impl Actor {
pub fn start(receiver: mpsc::Receiver<StateMessage>) -> Self {
fn new(receiver: mpsc::Receiver<StateMessage>) -> Self {
Self {
state: State::new(),
receiver,
}
}

pub fn start_new_dkg(&mut self, respond_to: oneshot::Sender<()>) {
self.state.in_progress = true;
let _ = respond_to.send(());
// Add a new method to run the actor
async fn run(&mut self) {
while let Some(message) = self.receiver.recv().await {
match message {
StateMessage::AddRound1Package(identifier, package, respond_to) => {
self.add_round1_package(identifier, package, respond_to);
}
StateMessage::AddSecretPackage(secret_package, respond_to) => {
self.add_secret_package(secret_package, respond_to);
}
}
}
}

pub fn add_round1_package(
fn add_round1_package(
&mut self,
identifier: frost::Identifier,
identifier: Identifier,
package: dkg::round1::Package,
respond_to: oneshot::Sender<()>,
) {
Expand All @@ -82,7 +91,8 @@ impl Actor {
.insert(identifier, package);
let _ = respond_to.send(());
}
pub fn add_secret_package(

fn add_secret_package(
&mut self,
secret_package: frost::keys::dkg::round1::SecretPackage,
respond_to: oneshot::Sender<()>,
Expand All @@ -98,8 +108,16 @@ pub(crate) struct StateHandle {
}

impl StateHandle {
/// Create a new state handle
pub fn new(sender: mpsc::Sender<StateMessage>) -> Self {
/// Create a new state handle and spawn the actor
pub fn new() -> Self {
let (sender, receiver) = mpsc::channel(1);
let mut actor = Actor::new(receiver);

// Spawn the actor task
tokio::spawn(async move {
actor.run().await;
});

Self { sender }
}

Expand Down Expand Up @@ -144,23 +162,21 @@ mod dkg_state_tests {
#[test]
fn test_actor_start() {
let (tx, rx) = mpsc::channel(1);
let mut actor = Actor::start(rx);
let mut actor = Actor::new(rx);
assert_eq!(actor.state.in_progress, false);
}

#[test]
fn test_actor_start_new_dkg() {
#[tokio::test]
async fn test_actor_start_new_dkg() {
let (tx, rx) = mpsc::channel(1);
let mut actor = Actor::start(rx);
let (tx1, rx1) = oneshot::channel();
actor.start_new_dkg(tx1);
assert_eq!(actor.state.in_progress, true);
let mut actor = Actor::new(rx);
assert_eq!(actor.state.in_progress, false);
}

#[test]
fn test_actor_add_round1_package() {
let (_tx, rx) = mpsc::channel(1);
let mut actor = Actor::start(rx);
let mut actor = Actor::new(rx);
let identifier = frost::Identifier::derive(b"1").unwrap();
let rng = thread_rng();

Expand All @@ -175,7 +191,7 @@ mod dkg_state_tests {
#[test]
fn test_actor_add_secret_package() {
let (_tx, rx) = mpsc::channel(1);
let mut actor = Actor::start(rx);
let mut actor = Actor::new(rx);
let identifier = frost::Identifier::derive(b"1").unwrap();
let rng = thread_rng();

Expand All @@ -195,57 +211,44 @@ mod dkg_state_handle_tests {

#[tokio::test]
async fn test_state_handle_new() {
let (tx, _rx) = mpsc::channel(1);
let handle = StateHandle::new(tx);
let handle = StateHandle::new();
assert!(handle.sender.capacity() > 0);
}

#[tokio::test]
async fn test_state_handle_add_round1_package() {
let (tx, mut rx) = mpsc::channel(1);
let state_handle = StateHandle::new(tx);
let state_handle = StateHandle::new();
let identifier = frost::Identifier::derive(b"1").unwrap();
let (_secret_package, package) =
frost::keys::dkg::part1(identifier, 3, 2, thread_rng()).unwrap();

let (respond_tx, _respond_rx) = oneshot::channel();
let (respond_tx, respond_rx) = oneshot::channel();

// Send the package
state_handle
.add_round1_package(identifier, package.clone(), respond_tx)
.await;

// Verify the message was received correctly
if let Some(StateMessage::AddRound1Package(received_identifier, received_package, _)) =
rx.try_recv().ok()
{
assert_eq!(received_identifier, identifier);
assert_eq!(received_package, package);
} else {
panic!("Failed to receive the expected message");
}
// Wait for the response
respond_rx.await.expect("Failed to receive response");
}

#[tokio::test]
async fn test_state_handle_add_secret_package() {
let (tx, mut rx) = mpsc::channel(1);
let state_handle = StateHandle::new(tx);
let state_handle = StateHandle::new();

let identifier = frost::Identifier::derive(b"1").unwrap();
let (secret_package, _package) =
frost::keys::dkg::part1(identifier, 3, 2, thread_rng()).unwrap();

let (respond_tx, _respond_rx) = oneshot::channel();
let (respond_tx, respond_rx) = oneshot::channel();

// Send the secret package
state_handle
.add_secret_package(secret_package.clone(), respond_tx)
.await;

// Verify the message was received correctly
if let Some(StateMessage::AddSecretPackage(received_package, _)) = rx.try_recv().ok() {
assert_eq!(received_package, secret_package);
} else {
panic!("Failed to receive the expected message");
}
// Wait for the response
respond_rx.await.expect("Failed to receive response");
}
}

0 comments on commit db761e5

Please sign in to comment.