diff --git a/src/lib.rs b/src/lib.rs index c9d7d93..b614ef3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -371,10 +371,9 @@ impl FSChaCha20 { } } -fn gen_key() -> Result { - let mut rnd = rand::thread_rng(); +fn gen_key(rng: &mut impl Rng) -> Result { let mut buffer: Vec = vec![0; 32]; - rnd.fill(&mut buffer[..]); + rng.fill(&mut buffer[..]); let sk = SecretKey::from_slice(&buffer)?; Ok(sk) } @@ -386,8 +385,7 @@ fn new_elligator_swift(sk: SecretKey) -> ElligatorSwift { ElligatorSwift::from_pubkey(pk) } -fn gen_garbage(garbage_len: u32) -> Vec { - let mut rng = rand::thread_rng(); +fn gen_garbage(garbage_len: u32, rng: &mut impl Rng) -> Vec { let buffer: Vec = (0..garbage_len).map(|_| rng.gen()).collect(); buffer } @@ -461,13 +459,37 @@ fn initialize_session_key_material(ikm: &[u8]) -> SessionKeyMaterial { /// # Errors /// /// Fails if their was an error generating the keypair. +#[cfg(feature = "std")] pub fn initialize_v2_handshake( garbage_len: Option, ) -> Result { - let sk = gen_key()?; + let mut rng = rand::thread_rng(); + initialize_v2_handshake_with_rng(garbage_len, &mut rng) +} + +/// Initialize a V2 transport handshake with a peer. The `InitiatorHandshake` contains a message ready to be sent over the wire, +/// and the information necessary for completing ECDH when the peer responds. +/// +/// # Arguments +/// +/// `garbage_len` - The length of the additional garbage to be sent along with the encoded public key. +/// `rng` - supplied Random Number Generator. +/// +/// # Returns +/// +/// A partial handshake. +/// +/// # Errors +/// +/// Fails if their was an error generating the keypair. +pub fn initialize_v2_handshake_with_rng( + garbage_len: Option, + rng: &mut impl Rng, +) -> Result { + let sk = gen_key(rng)?; let es = new_elligator_swift(sk); let garbage_len = garbage_len.unwrap_or(MAX_GARBAGE_LEN); - let garbage = gen_garbage(garbage_len); + let garbage = gen_garbage(garbage_len, rng); let point = EcdhPoint { secret_key: sk, elligator_swift: es, @@ -494,8 +516,31 @@ pub fn initialize_v2_handshake( /// # Errors /// /// Fails if the packet was not prepared properly. +#[cfg(feature = "std")] pub fn receive_v2_handshake( message: Vec, +) -> Result { + let mut rng = rand::thread_rng(); + receive_v2_handshake_with_rng(message, &mut rng) +} + +/// Receive a V2 handshake over the wire. The `ResponderHandshake` contains the message ready to be sent over the wire and a struct for parsing packets. +/// +/// # Arguments +/// +/// `message` - The message received over the wire. +/// `rng` - Supplied Random Number Generator. +/// +/// # Returns +/// +/// A completed handshake containing a `PacketHandler`. +/// +/// # Errors +/// +/// Fails if the packet was not prepared properly. +pub fn receive_v2_handshake_with_rng( + message: Vec, + rng: &mut impl Rng, ) -> Result { let mut network_magic = NETWORK_MAGIC.to_vec(); let mut version_bytes = "version".as_bytes().to_vec(); @@ -508,7 +553,7 @@ pub fn receive_v2_handshake( )) } else { let mut response: Vec = Vec::new(); - let sk = gen_key().map_err(ResponderHandshakeError::ECC)?; + let sk = gen_key(rng).map_err(ResponderHandshakeError::ECC)?; let es = new_elligator_swift(sk); response.extend(&es.to_array()); let elliswift_message = &message[..64]; @@ -519,7 +564,7 @@ pub fn receive_v2_handshake( let session_keys = get_shared_secrets(theirs, es, sk, ElligatorSwiftParty::B); let initiator_garbage = message[64..].to_vec(); let initiator_garbage_len = initiator_garbage.len() as u32; - let response_garbage = gen_garbage(initiator_garbage_len); + let response_garbage = gen_garbage(initiator_garbage_len, rng); if initiator_garbage_len > MAX_GARBAGE_LEN { return Err(ResponderHandshakeError::IncorrectMessage( "Garbage length is too large.".to_string(), @@ -661,7 +706,8 @@ mod tests { #[test] fn test_sec_keygen() { - gen_key().unwrap(); + let mut rng = rand::thread_rng(); + gen_key(&mut rng).unwrap(); } #[test] @@ -777,6 +823,7 @@ mod tests { #[test] fn test_fuzz_packets() { + let mut rng = rand::thread_rng(); let alice = SecretKey::from_str("61062ea5071d800bbfd59e2e8b53d47d194b095ae5a4df04936b49772ef0d4d7") .unwrap(); @@ -792,7 +839,7 @@ mod tests { PacketHandler::new(session_keys.clone(), HandshakeRole::Initiator); let mut bob_packet_handler = PacketHandler::new(session_keys, HandshakeRole::Responder); for _ in 0..REKEY_INTERVAL + 100 { - let message = gen_garbage(4095); + let message = gen_garbage(4095, &mut rng); let enc_packet = alice_packet_handler .prepare_v2_packet(message.clone(), None, false) .unwrap(); @@ -801,7 +848,7 @@ mod tests { .unwrap(); let secret_message = dec_packet.first().unwrap().message.clone().unwrap(); assert_eq!(message, secret_message); - let message = gen_garbage(420); + let message = gen_garbage(420, &mut rng); let enc_packet = bob_packet_handler .prepare_v2_packet(message.clone(), None, false) .unwrap(); @@ -815,6 +862,7 @@ mod tests { #[test] fn test_authenticated_garbage() { + let mut rng = rand::thread_rng(); let alice = SecretKey::from_str("61062ea5071d800bbfd59e2e8b53d47d194b095ae5a4df04936b49772ef0d4d7") .unwrap(); @@ -829,7 +877,7 @@ mod tests { let mut alice_packet_handler = PacketHandler::new(session_keys.clone(), HandshakeRole::Initiator); let mut bob_packet_handler = PacketHandler::new(session_keys, HandshakeRole::Responder); - let auth_garbage = gen_garbage(200); + let auth_garbage = gen_garbage(200, &mut rng); let enc_packet = alice_packet_handler .prepare_v2_packet(Vec::new(), Some(auth_garbage.clone()), false) .unwrap(); @@ -897,6 +945,7 @@ mod tests { #[test] fn test_fuzz_decode_multiple_messages() { + let mut rng = rand::thread_rng(); let handshake_init = initialize_v2_handshake(None).unwrap(); let mut handshake_response = receive_v2_handshake(handshake_init.message.clone()).unwrap(); let alice_completion = @@ -911,7 +960,7 @@ mod tests { let mut bob = handshake_response.packet_handler; let mut message_to_bob = Vec::new(); for _ in 0..REKEY_INTERVAL + 100 { - let message = gen_garbage(420); + let message = gen_garbage(420, &mut rng); let enc_packet = alice .prepare_v2_packet(message.clone(), None, false) .unwrap(); @@ -924,6 +973,7 @@ mod tests { #[test] fn test_vector_1() { + let mut rng = rand::thread_rng(); let alice = SecretKey::from_str("61062ea5071d800bbfd59e2e8b53d47d194b095ae5a4df04936b49772ef0d4d7") .unwrap(); @@ -938,7 +988,7 @@ mod tests { let mut alice_packet_handler = PacketHandler::new(session_keys.clone(), HandshakeRole::Initiator); let mut bob_packet_handler = PacketHandler::new(session_keys, HandshakeRole::Responder); - let first = gen_garbage(100); + let first = gen_garbage(100, &mut rng); let enc = alice_packet_handler .prepare_v2_packet(first.clone(), None, false) .unwrap();