diff --git a/Cargo.toml b/Cargo.toml index 21985cbc..9af67434 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,8 @@ [workspace] -members = [ - "tap_core", - "tap_aggregator", - "tap_integration_tests" -] +resolver = "2" +members = ["tap_core", "tap_aggregator", "tap_integration_tests"] [workspace.package] -version="0.1.0" +version = "0.1.0" edition = "2021" license = "Apache-2.0" diff --git a/tap_aggregator/src/aggregator.rs b/tap_aggregator/src/aggregator.rs index 22afa6e2..7479cd9c 100644 --- a/tap_aggregator/src/aggregator.rs +++ b/tap_aggregator/src/aggregator.rs @@ -14,7 +14,7 @@ use tap_core::{ receipt_aggregate_voucher::ReceiptAggregateVoucher, tap_receipt::Receipt, }; -pub async fn check_and_aggregate_receipts( +pub fn check_and_aggregate_receipts( domain_separator: &Eip712Domain, receipts: &[EIP712SignedMessage], previous_rav: Option>, @@ -69,7 +69,7 @@ pub async fn check_and_aggregate_receipts( let rav = ReceiptAggregateVoucher::aggregate_receipts(allocation_id, receipts, previous_rav)?; // Sign the rav and return - Ok(EIP712SignedMessage::new(domain_separator, rav, wallet).await?) + Ok(EIP712SignedMessage::new(domain_separator, rav, wallet)?) } fn check_signature_is_from_one_of_addresses( @@ -170,8 +170,8 @@ mod tests { } #[rstest] - #[tokio::test] - async fn check_signatures_unique_fail( + #[test] + fn check_signatures_unique_fail( keys: (LocalWallet, Address), allocation_ids: Vec
, domain_separator: Eip712Domain, @@ -183,7 +183,6 @@ mod tests { Receipt::new(allocation_ids[0], 42).unwrap(), &keys.0, ) - .await .unwrap(); receipts.push(receipt.clone()); receipts.push(receipt); @@ -193,41 +192,36 @@ mod tests { } #[rstest] - #[tokio::test] - async fn check_signatures_unique_ok( + #[test] + fn check_signatures_unique_ok( keys: (LocalWallet, Address), allocation_ids: Vec
, domain_separator: Eip712Domain, ) { // Create 2 different receipts - let mut receipts = Vec::new(); - receipts.push( + let receipts = vec![ EIP712SignedMessage::new( &domain_separator, Receipt::new(allocation_ids[0], 42).unwrap(), &keys.0, ) - .await .unwrap(), - ); - receipts.push( EIP712SignedMessage::new( &domain_separator, Receipt::new(allocation_ids[0], 43).unwrap(), &keys.0, ) - .await .unwrap(), - ); + ]; let res = aggregator::check_signatures_unique(&receipts); assert!(res.is_ok()); } #[rstest] - #[tokio::test] + #[test] /// Test that a receipt with a timestamp greater then the rav timestamp passes - async fn check_receipt_timestamps( + fn check_receipt_timestamps( keys: (LocalWallet, Address), allocation_ids: Vec
, domain_separator: Eip712Domain, @@ -247,7 +241,6 @@ mod tests { }, &keys.0, ) - .await .unwrap(), ); } @@ -262,7 +255,6 @@ mod tests { }, &keys.0, ) - .await .unwrap(); assert!(aggregator::check_receipt_timestamps(&receipts, Some(&rav)).is_ok()); @@ -277,7 +269,6 @@ mod tests { }, &keys.0, ) - .await .unwrap(); assert!(aggregator::check_receipt_timestamps(&receipts, Some(&rav)).is_err()); @@ -292,48 +283,39 @@ mod tests { }, &keys.0, ) - .await .unwrap(); assert!(aggregator::check_receipt_timestamps(&receipts, Some(&rav)).is_err()); } #[rstest] - #[tokio::test] + #[test] /// Test check_allocation_id with 2 receipts that have the correct allocation id /// and 1 receipt that has the wrong allocation id - async fn check_allocation_id_fail( + fn check_allocation_id_fail( keys: (LocalWallet, Address), allocation_ids: Vec
, domain_separator: Eip712Domain, ) { - let mut receipts = Vec::new(); - receipts.push( + let receipts = vec![ EIP712SignedMessage::new( &domain_separator, Receipt::new(allocation_ids[0], 42).unwrap(), &keys.0, ) - .await .unwrap(), - ); - receipts.push( EIP712SignedMessage::new( &domain_separator, Receipt::new(allocation_ids[0], 43).unwrap(), &keys.0, ) - .await .unwrap(), - ); - receipts.push( EIP712SignedMessage::new( &domain_separator, Receipt::new(allocation_ids[1], 44).unwrap(), &keys.0, ) - .await .unwrap(), - ); + ]; let res = aggregator::check_allocation_id(&receipts, allocation_ids[0]); @@ -341,41 +323,33 @@ mod tests { } #[rstest] - #[tokio::test] + #[test] /// Test check_allocation_id with 3 receipts that have the correct allocation id - async fn check_allocation_id_ok( + fn check_allocation_id_ok( keys: (LocalWallet, Address), allocation_ids: Vec
, domain_separator: Eip712Domain, ) { - let mut receipts = Vec::new(); - receipts.push( + let receipts = vec![ EIP712SignedMessage::new( &domain_separator, Receipt::new(allocation_ids[0], 42).unwrap(), &keys.0, ) - .await .unwrap(), - ); - receipts.push( EIP712SignedMessage::new( &domain_separator, Receipt::new(allocation_ids[0], 43).unwrap(), &keys.0, ) - .await .unwrap(), - ); - receipts.push( EIP712SignedMessage::new( &domain_separator, Receipt::new(allocation_ids[0], 44).unwrap(), &keys.0, ) - .await .unwrap(), - ); + ]; let res = aggregator::check_allocation_id(&receipts, allocation_ids[0]); diff --git a/tap_aggregator/src/server.rs b/tap_aggregator/src/server.rs index aea28468..6d0d1a6e 100644 --- a/tap_aggregator/src/server.rs +++ b/tap_aggregator/src/server.rs @@ -7,11 +7,7 @@ use alloy_primitives::Address; use alloy_sol_types::Eip712Domain; use anyhow::Result; use ethers_signers::LocalWallet; -use jsonrpsee::{ - proc_macros::rpc, - server::ServerBuilder, - {core::async_trait, server::ServerHandle}, -}; +use jsonrpsee::{proc_macros::rpc, server::ServerBuilder, server::ServerHandle}; use lazy_static::lazy_static; use prometheus::{register_counter, register_int_counter, Counter, IntCounter}; @@ -82,12 +78,12 @@ lazy_static! { pub trait Rpc { /// Returns the versions of the TAP JSON-RPC API implemented by this server. #[method(name = "api_versions")] - async fn api_versions(&self) -> JsonRpcResult; + fn api_versions(&self) -> JsonRpcResult; /// Aggregates the given receipts into a receipt aggregate voucher. /// Returns an error if the user expected API version is not supported. #[method(name = "aggregate_receipts")] - async fn aggregate_receipts( + fn aggregate_receipts( &self, api_version: String, receipts: Vec>, @@ -131,7 +127,7 @@ fn check_api_version_deprecation(api_version: &TapRpcApiVersion) -> Option, @@ -156,16 +152,13 @@ async fn aggregate_receipts_( } let res = match api_version { - TapRpcApiVersion::V0_0 => { - check_and_aggregate_receipts( - domain_separator, - &receipts, - previous_rav, - wallet, - accepted_addresses, - ) - .await - } + TapRpcApiVersion::V0_0 => check_and_aggregate_receipts( + domain_separator, + &receipts, + previous_rav, + wallet, + accepted_addresses, + ), }; // Handle aggregation error @@ -179,13 +172,12 @@ async fn aggregate_receipts_( } } -#[async_trait] impl RpcServer for RpcImpl { - async fn api_versions(&self) -> JsonRpcResult { + fn api_versions(&self) -> JsonRpcResult { Ok(JsonRpcResponse::ok(tap_rpc_api_versions_info())) } - async fn aggregate_receipts( + fn aggregate_receipts( &self, api_version: String, receipts: Vec>, @@ -202,9 +194,7 @@ impl RpcServer for RpcImpl { &self.domain_separator, receipts, previous_rav, - ) - .await - { + ) { Ok(res) => { TOTAL_GRT_AGGREGATED.inc_by(receipts_grt as f64); TOTAL_AGGREGATED_RECEIPTS.inc_by(receipts_count); @@ -410,7 +400,6 @@ mod tests { Receipt::new(allocation_ids[0], value).unwrap(), &all_wallets.choose(&mut rng).unwrap().wallet, ) - .await .unwrap(), ); } @@ -492,7 +481,6 @@ mod tests { Receipt::new(allocation_ids[0], value).unwrap(), &all_wallets.choose(&mut rng).unwrap().wallet, ) - .await .unwrap(), ); } @@ -509,7 +497,6 @@ mod tests { prev_rav, &all_wallets.choose(&mut rng).unwrap().wallet, ) - .await .unwrap(); // Create new RAV from last half of receipts and prev_rav through the JSON-RPC server @@ -569,7 +556,6 @@ mod tests { Receipt::new(allocation_ids[0], 42).unwrap(), &keys_main.wallet, ) - .await .unwrap()]; // Skipping receipts validation in this test, aggregate_receipts assumes receipts are valid. @@ -665,7 +651,6 @@ mod tests { Receipt::new(allocation_ids[0], u128::MAX / 1000).unwrap(), &keys_main.wallet, ) - .await .unwrap(), ); } diff --git a/tap_core/Cargo.toml b/tap_core/Cargo.toml index bed741d2..3b9cece3 100644 --- a/tap_core/Cargo.toml +++ b/tap_core/Cargo.toml @@ -7,7 +7,7 @@ description = "Core Timeline Aggregation Protocol library: a fast, efficient and [dependencies] rand_core = "0.6.4" -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive", "rc"] } rand = "0.8.5" thiserror = "1.0.38" ethereum-types = { version = "0.14.1" } @@ -24,10 +24,11 @@ strum = "0.24.1" strum_macros = "0.24.3" async-trait = "0.1.72" tokio = { version = "1.29.1", features = ["macros", "rt-multi-thread"] } +typetag = "0.2.14" +futures = "0.3.17" [dev-dependencies] criterion = { version = "0.5", features = ["async_std"] } -futures = "0.3.17" [features] diff --git a/tap_core/benches/timeline_aggretion_protocol_benchmark.rs b/tap_core/benches/timeline_aggretion_protocol_benchmark.rs index 7bcc36e8..37cbc183 100644 --- a/tap_core/benches/timeline_aggretion_protocol_benchmark.rs +++ b/tap_core/benches/timeline_aggretion_protocol_benchmark.rs @@ -12,7 +12,6 @@ use std::str::FromStr; use alloy_primitives::Address; use alloy_sol_types::Eip712Domain; -use criterion::async_executor::AsyncStdExecutor; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use ethers::signers::{LocalWallet, Signer, Wallet}; use ethers_core::k256::ecdsa::SigningKey; @@ -22,9 +21,8 @@ use tap_core::{ eip_712_signed_message::EIP712SignedMessage, receipt_aggregate_voucher::ReceiptAggregateVoucher, tap_receipt::Receipt, }; -use tokio::runtime::Runtime; -pub async fn create_and_sign_receipt( +pub fn create_and_sign_receipt( domain_separator: &Eip712Domain, allocation_id: Address, value: u128, @@ -35,15 +33,12 @@ pub async fn create_and_sign_receipt( Receipt::new(allocation_id, value).unwrap(), wallet, ) - .await .unwrap() } pub fn criterion_benchmark(c: &mut Criterion) { let domain_seperator = tap_eip712_domain(1, Address::from([0x11u8; 20])); - let async_runtime = Runtime::new().unwrap(); - let wallet = LocalWallet::new(&mut OsRng); let address: [u8; 20] = wallet.address().into(); let address: Address = address.into(); @@ -53,7 +48,7 @@ pub fn criterion_benchmark(c: &mut Criterion) { let value = 12345u128; c.bench_function("Create Receipt", |b| { - b.to_async(AsyncStdExecutor).iter(|| { + b.iter(|| { create_and_sign_receipt( black_box(&domain_seperator), black_box(allocation_id), @@ -63,12 +58,7 @@ pub fn criterion_benchmark(c: &mut Criterion) { }) }); - let receipt = async_runtime.block_on(create_and_sign_receipt( - &domain_seperator, - allocation_id, - value, - &wallet, - )); + let receipt = create_and_sign_receipt(&domain_seperator, allocation_id, value, &wallet); c.bench_function("Validate Receipt", |b| { b.iter(|| { @@ -82,14 +72,7 @@ pub fn criterion_benchmark(c: &mut Criterion) { for log_number_of_receipts in 10..30 { let receipts = (0..2 ^ log_number_of_receipts) - .map(|_| { - async_runtime.block_on(create_and_sign_receipt( - &domain_seperator, - allocation_id, - value, - &wallet, - )) - }) + .map(|_| create_and_sign_receipt(&domain_seperator, allocation_id, value, &wallet)) .collect::>(); rav_group.bench_function( @@ -105,14 +88,12 @@ pub fn criterion_benchmark(c: &mut Criterion) { }, ); - let signed_rav = async_runtime - .block_on(EIP712SignedMessage::new( - &domain_seperator, - ReceiptAggregateVoucher::aggregate_receipts(allocation_id, &receipts, None) - .unwrap(), - &wallet, - )) - .unwrap(); + let signed_rav = EIP712SignedMessage::new( + &domain_seperator, + ReceiptAggregateVoucher::aggregate_receipts(allocation_id, &receipts, None).unwrap(), + &wallet, + ) + .unwrap(); rav_group.bench_function( &format!("Validate RAV w/ 2^{} receipt's", log_number_of_receipts), diff --git a/tap_core/src/adapters/mock.rs b/tap_core/src/adapters/mock.rs index 85c63db2..d5900997 100644 --- a/tap_core/src/adapters/mock.rs +++ b/tap_core/src/adapters/mock.rs @@ -1,9 +1,4 @@ // Copyright 2023-, Semiotic AI, Inc. // SPDX-License-Identifier: Apache-2.0 -pub mod auditor_executor_mock; -pub mod escrow_adapter_mock; pub mod executor_mock; -pub mod rav_storage_adapter_mock; -pub mod receipt_checks_adapter_mock; -pub mod receipt_storage_adapter_mock; diff --git a/tap_core/src/adapters/mock/auditor_executor_mock.rs b/tap_core/src/adapters/mock/auditor_executor_mock.rs deleted file mode 100644 index 49fa8eff..00000000 --- a/tap_core/src/adapters/mock/auditor_executor_mock.rs +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright 2023-, Semiotic AI, Inc. -// SPDX-License-Identifier: Apache-2.0 - -use super::{escrow_adapter_mock::AdpaterErrorMock, receipt_checks_adapter_mock::AdapterErrorMock}; -use crate::adapters::escrow_adapter::EscrowAdapter; -use crate::adapters::receipt_checks_adapter::ReceiptChecksAdapter; -use crate::eip_712_signed_message::EIP712SignedMessage; -use crate::tap_receipt::{Receipt, ReceivedReceipt}; -use alloy_primitives::Address; -use async_trait::async_trait; -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; -use tokio::sync::RwLock; - -#[derive(Clone)] -pub struct AuditorExecutorMock { - receipt_storage: Arc>>, - - sender_escrow_storage: Arc>>, - - query_appraisals: Arc>>, - allocation_ids: Arc>>, - sender_ids: Arc>>, -} - -impl AuditorExecutorMock { - pub fn new( - receipt_storage: Arc>>, - sender_escrow_storage: Arc>>, - query_appraisals: Arc>>, - allocation_ids: Arc>>, - sender_ids: Arc>>, - ) -> Self { - AuditorExecutorMock { - receipt_storage, - sender_escrow_storage, - allocation_ids, - sender_ids, - query_appraisals, - } - } -} - -impl AuditorExecutorMock { - pub async fn escrow(&self, sender_id: Address) -> Result { - let sender_escrow_storage = self.sender_escrow_storage.read().await; - if let Some(escrow) = sender_escrow_storage.get(&sender_id) { - return Ok(*escrow); - } - Err(AdpaterErrorMock::AdapterError { - error: "No escrow exists for provided sender ID.".to_owned(), - }) - } - - pub async fn increase_escrow(&mut self, sender_id: Address, value: u128) { - let mut sender_escrow_storage = self.sender_escrow_storage.write().await; - - if let Some(current_value) = sender_escrow_storage.get(&sender_id) { - let mut sender_escrow_storage = self.sender_escrow_storage.write().await; - sender_escrow_storage.insert(sender_id, current_value + value); - } else { - sender_escrow_storage.insert(sender_id, value); - } - } - - pub async fn reduce_escrow( - &self, - sender_id: Address, - value: u128, - ) -> Result<(), AdpaterErrorMock> { - let mut sender_escrow_storage = self.sender_escrow_storage.write().await; - - if let Some(current_value) = sender_escrow_storage.get(&sender_id) { - let checked_new_value = current_value.checked_sub(value); - if let Some(new_value) = checked_new_value { - sender_escrow_storage.insert(sender_id, new_value); - return Ok(()); - } - } - Err(AdpaterErrorMock::AdapterError { - error: "Provided value is greater than existing escrow.".to_owned(), - }) - } -} - -#[async_trait] -impl EscrowAdapter for AuditorExecutorMock { - type AdapterError = AdpaterErrorMock; - async fn get_available_escrow(&self, sender_id: Address) -> Result { - self.escrow(sender_id).await - } - async fn subtract_escrow( - &self, - sender_id: Address, - value: u128, - ) -> Result<(), Self::AdapterError> { - self.reduce_escrow(sender_id, value).await - } -} - -#[async_trait] -impl ReceiptChecksAdapter for AuditorExecutorMock { - type AdapterError = AdapterErrorMock; - - async fn is_unique( - &self, - receipt: &EIP712SignedMessage, - receipt_id: u64, - ) -> Result { - let receipt_storage = self.receipt_storage.read().await; - Ok(receipt_storage - .iter() - .all(|(stored_receipt_id, stored_receipt)| { - (stored_receipt.signed_receipt().message != receipt.message) - || *stored_receipt_id == receipt_id - })) - } - - async fn is_valid_allocation_id( - &self, - allocation_id: Address, - ) -> Result { - let allocation_ids = self.allocation_ids.read().await; - Ok(allocation_ids.contains(&allocation_id)) - } - - async fn is_valid_value(&self, value: u128, query_id: u64) -> Result { - let query_appraisals = self.query_appraisals.read().await; - let appraised_value = query_appraisals.get(&query_id).unwrap(); - - if value != *appraised_value { - return Ok(false); - } - Ok(true) - } - - async fn is_valid_sender_id(&self, sender_id: Address) -> Result { - let sender_ids = self.sender_ids.read().await; - Ok(sender_ids.contains(&sender_id)) - } -} diff --git a/tap_core/src/adapters/mock/escrow_adapter_mock.rs b/tap_core/src/adapters/mock/escrow_adapter_mock.rs deleted file mode 100644 index b52517eb..00000000 --- a/tap_core/src/adapters/mock/escrow_adapter_mock.rs +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2023-, Semiotic AI, Inc. -// SPDX-License-Identifier: Apache-2.0 - -use std::{collections::HashMap, sync::Arc}; - -use alloy_primitives::Address; -use async_trait::async_trait; -use tokio::sync::RwLock; - -use crate::adapters::escrow_adapter::EscrowAdapter; - -pub struct EscrowAdapterMock { - sender_escrow_storage: Arc>>, -} - -use thiserror::Error; -#[derive(Debug, Error)] -pub enum AdpaterErrorMock { - #[error("something went wrong: {error}")] - AdapterError { error: String }, -} - -impl EscrowAdapterMock { - pub fn new(sender_escrow_storage: Arc>>) -> Self { - EscrowAdapterMock { - sender_escrow_storage, - } - } - pub async fn escrow(&self, sender_id: Address) -> Result { - let sender_escrow_storage = self.sender_escrow_storage.read().await; - if let Some(escrow) = sender_escrow_storage.get(&sender_id) { - return Ok(*escrow); - } - Err(AdpaterErrorMock::AdapterError { - error: "No escrow exists for provided sender ID.".to_owned(), - }) - } - - pub async fn increase_escrow(&mut self, sender_id: Address, value: u128) { - let mut sender_escrow_storage = self.sender_escrow_storage.write().await; - - if let Some(current_value) = sender_escrow_storage.get(&sender_id) { - let mut sender_escrow_storage = self.sender_escrow_storage.write().await; - sender_escrow_storage.insert(sender_id, current_value + value); - } else { - sender_escrow_storage.insert(sender_id, value); - } - } - - pub async fn reduce_escrow( - &self, - sender_id: Address, - value: u128, - ) -> Result<(), AdpaterErrorMock> { - let mut sender_escrow_storage = self.sender_escrow_storage.write().await; - - if let Some(current_value) = sender_escrow_storage.get(&sender_id) { - let checked_new_value = current_value.checked_sub(value); - if let Some(new_value) = checked_new_value { - sender_escrow_storage.insert(sender_id, new_value); - return Ok(()); - } - } - Err(AdpaterErrorMock::AdapterError { - error: "Provided value is greater than existing escrow.".to_owned(), - }) - } -} - -#[async_trait] -impl EscrowAdapter for EscrowAdapterMock { - type AdapterError = AdpaterErrorMock; - async fn get_available_escrow(&self, sender_id: Address) -> Result { - self.escrow(sender_id).await - } - async fn subtract_escrow( - &self, - sender_id: Address, - value: u128, - ) -> Result<(), Self::AdapterError> { - self.reduce_escrow(sender_id, value).await - } -} diff --git a/tap_core/src/adapters/mock/executor_mock.rs b/tap_core/src/adapters/mock/executor_mock.rs index 76898659..28ee8194 100644 --- a/tap_core/src/adapters/mock/executor_mock.rs +++ b/tap_core/src/adapters/mock/executor_mock.rs @@ -1,14 +1,12 @@ // Copyright 2023-, Semiotic AI, Inc. // SPDX-License-Identifier: Apache-2.0 -use super::{escrow_adapter_mock::AdpaterErrorMock, receipt_checks_adapter_mock::AdapterErrorMock}; use crate::adapters::escrow_adapter::EscrowAdapter; -use crate::adapters::receipt_checks_adapter::ReceiptChecksAdapter; use crate::adapters::receipt_storage_adapter::{ safe_truncate_receipts, ReceiptRead, ReceiptStore, StoredReceipt, }; -use crate::eip_712_signed_message::EIP712SignedMessage; -use crate::tap_receipt::{Receipt, ReceivedReceipt}; +use crate::checks::TimestampCheck; +use crate::tap_receipt::ReceivedReceipt; use crate::{ adapters::rav_storage_adapter::{RAVRead, RAVStore}, tap_manager::SignedRAV, @@ -16,67 +14,125 @@ use crate::{ use alloy_primitives::Address; use async_trait::async_trait; use std::ops::RangeBounds; -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; -use tokio::sync::RwLock; +use std::sync::RwLock; +use std::{collections::HashMap, sync::Arc}; pub type EscrowStorage = Arc>>; pub type QueryAppraisals = Arc>>; +pub type ReceiptStorage = Arc>>; +pub type RAVStorage = Arc>>; + +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum AdapterErrorMock { + #[error("something went wrong: {error}")] + AdapterError { error: String }, +} #[derive(Clone)] pub struct ExecutorMock { /// local RAV store with rwlocks to allow sharing with other compenents as needed - rav_storage: Arc>>, - receipt_storage: Arc>>, + rav_storage: RAVStorage, + receipt_storage: ReceiptStorage, unique_id: Arc>, sender_escrow_storage: EscrowStorage, - query_appraisals: QueryAppraisals, - allocation_ids: Arc>>, - sender_ids: Arc>>, + timestamp_check: Arc, } impl ExecutorMock { pub fn new( - rav_storage: Arc>>, - receipt_storage: Arc>>, - sender_escrow_storage: Arc>>, - query_appraisals: Arc>>, - allocation_ids: Arc>>, - sender_ids: Arc>>, + rav_storage: RAVStorage, + receipt_storage: ReceiptStorage, + sender_escrow_storage: EscrowStorage, + timestamp_check: Arc, ) -> Self { ExecutorMock { rav_storage, receipt_storage, unique_id: Arc::new(RwLock::new(0)), sender_escrow_storage, - allocation_ids, - sender_ids, - query_appraisals, + timestamp_check, } } + + pub async fn retrieve_receipt_by_id( + &self, + receipt_id: u64, + ) -> Result { + let receipt_storage = self.receipt_storage.read().unwrap(); + + receipt_storage + .get(&receipt_id) + .cloned() + .ok_or(AdapterErrorMock::AdapterError { + error: "No receipt found with ID".to_owned(), + }) + } + + pub async fn retrieve_receipts_by_timestamp( + &self, + timestamp_ns: u64, + ) -> Result, AdapterErrorMock> { + let receipt_storage = self.receipt_storage.read().unwrap(); + Ok(receipt_storage + .iter() + .filter(|(_, rx_receipt)| { + rx_receipt.signed_receipt().message.timestamp_ns == timestamp_ns + }) + .map(|(&id, rx_receipt)| (id, rx_receipt.clone())) + .collect()) + } + + pub async fn retrieve_receipts_upto_timestamp( + &self, + timestamp_ns: u64, + ) -> Result, AdapterErrorMock> { + self.retrieve_receipts_in_timestamp_range(..=timestamp_ns, None) + .await + } + + pub async fn remove_receipt_by_id(&mut self, receipt_id: u64) -> Result<(), AdapterErrorMock> { + let mut receipt_storage = self.receipt_storage.write().unwrap(); + receipt_storage + .remove(&receipt_id) + .map(|_| ()) + .ok_or(AdapterErrorMock::AdapterError { + error: "No receipt found with ID".to_owned(), + }) + } + pub async fn remove_receipts_by_ids( + &mut self, + receipt_ids: &[u64], + ) -> Result<(), AdapterErrorMock> { + for receipt_id in receipt_ids { + self.remove_receipt_by_id(*receipt_id).await?; + } + Ok(()) + } } #[async_trait] impl RAVStore for ExecutorMock { - type AdapterError = AdpaterErrorMock; + type AdapterError = AdapterErrorMock; async fn update_last_rav(&self, rav: SignedRAV) -> Result<(), Self::AdapterError> { - let mut rav_storage = self.rav_storage.write().await; + let mut rav_storage = self.rav_storage.write().unwrap(); + let timestamp = rav.message.timestampNs; *rav_storage = Some(rav); + self.timestamp_check.update_min_timestamp_ns(timestamp); Ok(()) } } #[async_trait] impl RAVRead for ExecutorMock { - type AdapterError = AdpaterErrorMock; + type AdapterError = AdapterErrorMock; async fn last_rav(&self) -> Result, Self::AdapterError> { - Ok(self.rav_storage.read().await.clone()) + Ok(self.rav_storage.read().unwrap().clone()) } } @@ -84,9 +140,9 @@ impl RAVRead for ExecutorMock { impl ReceiptStore for ExecutorMock { type AdapterError = AdapterErrorMock; async fn store_receipt(&self, receipt: ReceivedReceipt) -> Result { - let mut id_pointer = self.unique_id.write().await; + let mut id_pointer = self.unique_id.write().unwrap(); let id_previous = *id_pointer; - let mut receipt_storage = self.receipt_storage.write().await; + let mut receipt_storage = self.receipt_storage.write().unwrap(); receipt_storage.insert(*id_pointer, receipt); *id_pointer += 1; Ok(id_previous) @@ -96,7 +152,7 @@ impl ReceiptStore for ExecutorMock { receipt_id: u64, receipt: ReceivedReceipt, ) -> Result<(), Self::AdapterError> { - let mut receipt_storage = self.receipt_storage.write().await; + let mut receipt_storage = self.receipt_storage.write().unwrap(); if !receipt_storage.contains_key(&receipt_id) { return Err(AdapterErrorMock::AdapterError { @@ -105,14 +161,14 @@ impl ReceiptStore for ExecutorMock { }; receipt_storage.insert(receipt_id, receipt); - *self.unique_id.write().await += 1; + *self.unique_id.write().unwrap() += 1; Ok(()) } async fn remove_receipts_in_timestamp_range + std::marker::Send>( &self, timestamp_ns: R, ) -> Result<(), Self::AdapterError> { - let mut receipt_storage = self.receipt_storage.write().await; + let mut receipt_storage = self.receipt_storage.write().unwrap(); receipt_storage.retain(|_, rx_receipt| { !timestamp_ns.contains(&rx_receipt.signed_receipt().message.timestamp_ns) }); @@ -128,7 +184,7 @@ impl ReceiptRead for ExecutorMock { timestamp_range_ns: R, limit: Option, ) -> Result, Self::AdapterError> { - let receipt_storage = self.receipt_storage.read().await; + let receipt_storage = self.receipt_storage.read().unwrap(); let mut receipts_in_range: Vec<(u64, ReceivedReceipt)> = receipt_storage .iter() .filter(|(_, rx_receipt)| { @@ -145,33 +201,29 @@ impl ReceiptRead for ExecutorMock { } impl ExecutorMock { - pub async fn escrow(&self, sender_id: Address) -> Result { - let sender_escrow_storage = self.sender_escrow_storage.read().await; + pub fn escrow(&self, sender_id: Address) -> Result { + let sender_escrow_storage = self.sender_escrow_storage.read().unwrap(); if let Some(escrow) = sender_escrow_storage.get(&sender_id) { return Ok(*escrow); } - Err(AdpaterErrorMock::AdapterError { + Err(AdapterErrorMock::AdapterError { error: "No escrow exists for provided sender ID.".to_owned(), }) } - pub async fn increase_escrow(&mut self, sender_id: Address, value: u128) { - let mut sender_escrow_storage = self.sender_escrow_storage.write().await; + pub fn increase_escrow(&mut self, sender_id: Address, value: u128) { + let mut sender_escrow_storage = self.sender_escrow_storage.write().unwrap(); if let Some(current_value) = sender_escrow_storage.get(&sender_id) { - let mut sender_escrow_storage = self.sender_escrow_storage.write().await; + let mut sender_escrow_storage = self.sender_escrow_storage.write().unwrap(); sender_escrow_storage.insert(sender_id, current_value + value); } else { sender_escrow_storage.insert(sender_id, value); } } - pub async fn reduce_escrow( - &self, - sender_id: Address, - value: u128, - ) -> Result<(), AdpaterErrorMock> { - let mut sender_escrow_storage = self.sender_escrow_storage.write().await; + pub fn reduce_escrow(&self, sender_id: Address, value: u128) -> Result<(), AdapterErrorMock> { + let mut sender_escrow_storage = self.sender_escrow_storage.write().unwrap(); if let Some(current_value) = sender_escrow_storage.get(&sender_id) { let checked_new_value = current_value.checked_sub(value); @@ -180,7 +232,7 @@ impl ExecutorMock { return Ok(()); } } - Err(AdpaterErrorMock::AdapterError { + Err(AdapterErrorMock::AdapterError { error: "Provided value is greater than existing escrow.".to_owned(), }) } @@ -188,57 +240,15 @@ impl ExecutorMock { #[async_trait] impl EscrowAdapter for ExecutorMock { - type AdapterError = AdpaterErrorMock; + type AdapterError = AdapterErrorMock; async fn get_available_escrow(&self, sender_id: Address) -> Result { - self.escrow(sender_id).await + self.escrow(sender_id) } async fn subtract_escrow( &self, sender_id: Address, value: u128, ) -> Result<(), Self::AdapterError> { - self.reduce_escrow(sender_id, value).await - } -} - -#[async_trait] -impl ReceiptChecksAdapter for ExecutorMock { - type AdapterError = AdapterErrorMock; - - async fn is_unique( - &self, - receipt: &EIP712SignedMessage, - receipt_id: u64, - ) -> Result { - let receipt_storage = self.receipt_storage.read().await; - Ok(receipt_storage - .iter() - .all(|(stored_receipt_id, stored_receipt)| { - (stored_receipt.signed_receipt().message != receipt.message) - || *stored_receipt_id == receipt_id - })) - } - - async fn is_valid_allocation_id( - &self, - allocation_id: Address, - ) -> Result { - let allocation_ids = self.allocation_ids.read().await; - Ok(allocation_ids.contains(&allocation_id)) - } - - async fn is_valid_value(&self, value: u128, query_id: u64) -> Result { - let query_appraisals = self.query_appraisals.read().await; - let appraised_value = query_appraisals.get(&query_id).unwrap(); - - if value != *appraised_value { - return Ok(false); - } - Ok(true) - } - - async fn is_valid_sender_id(&self, sender_id: Address) -> Result { - let sender_ids = self.sender_ids.read().await; - Ok(sender_ids.contains(&sender_id)) + self.reduce_escrow(sender_id, value) } } diff --git a/tap_core/src/adapters/mock/rav_storage_adapter_mock.rs b/tap_core/src/adapters/mock/rav_storage_adapter_mock.rs deleted file mode 100644 index 603af06d..00000000 --- a/tap_core/src/adapters/mock/rav_storage_adapter_mock.rs +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2023-, Semiotic AI, Inc. -// SPDX-License-Identifier: Apache-2.0 - -use std::sync::Arc; - -use async_trait::async_trait; -use thiserror::Error; -use tokio::sync::RwLock; - -use crate::{ - adapters::rav_storage_adapter::{RAVRead, RAVStore}, - tap_manager::SignedRAV, -}; - -/// `RAVStorageAdapterMock` is a mock implementation of the `RAVStorageAdapter` trait. -/// -/// It serves two main purposes: -/// -/// 1. **Unit Testing**: The `RAVStorageAdapterMock` is primarily intended to be used for unit tests, -/// providing a way to simulate the behavior of a real `RAVStorageAdapter` without requiring a real -/// implementation. By using a mock implementation, you can create predictable behaviors and -/// responses, enabling isolated and focused testing of the logic that depends on the `RAVStorageAdapter` trait. -/// -/// 2. **Example Implementation**: New users of the `RAVStorageAdapter` trait can look to -/// `RAVStorageAdapterMock` as a basic example of how to implement the trait. -/// -/// Note: This mock implementation is not suitable for production use. Its methods simply manipulate a -/// local `RwLock>`, and it provides no real error handling. -/// -/// # Usage -/// -/// To use `RAVStorageAdapterMock`, first create an `Arc>>`, then pass it to -/// `RAVStorageAdapterMock::new()`. Now, it can be used anywhere a `RAVStorageAdapter` is required. -/// -/// ```rust -/// use std::sync::{Arc}; -/// use tokio::sync::RwLock; -/// use tap_core::{tap_manager::SignedRAV, adapters::rav_storage_adapter_mock::RAVStorageAdapterMock}; -/// -/// let rav_storage: Arc>> = Arc::new(RwLock::new(None)); -/// let adapter = RAVStorageAdapterMock::new(rav_storage); -/// ``` -pub struct RAVStorageAdapterMock { - /// local RAV store with rwlocks to allow sharing with other compenents as needed - rav_storage: Arc>>, -} - -impl RAVStorageAdapterMock { - pub fn new(rav_storage: Arc>>) -> Self { - RAVStorageAdapterMock { rav_storage } - } -} - -#[derive(Debug, Error)] -pub enum AdpaterErrorMock { - #[error("something went wrong: {error}")] - AdapterError { error: String }, -} - -#[async_trait] -impl RAVStore for RAVStorageAdapterMock { - type AdapterError = AdpaterErrorMock; - - async fn update_last_rav(&self, rav: SignedRAV) -> Result<(), Self::AdapterError> { - let mut rav_storage = self.rav_storage.write().await; - *rav_storage = Some(rav); - Ok(()) - } -} - -#[async_trait] -impl RAVRead for RAVStorageAdapterMock { - type AdapterError = AdpaterErrorMock; - - async fn last_rav(&self) -> Result, Self::AdapterError> { - Ok(self.rav_storage.read().await.clone()) - } -} diff --git a/tap_core/src/adapters/mock/receipt_checks_adapter_mock.rs b/tap_core/src/adapters/mock/receipt_checks_adapter_mock.rs deleted file mode 100644 index 3dd0f4d7..00000000 --- a/tap_core/src/adapters/mock/receipt_checks_adapter_mock.rs +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2023-, Semiotic AI, Inc. -// SPDX-License-Identifier: Apache-2.0 - -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; - -use alloy_primitives::Address; -use async_trait::async_trait; -use thiserror::Error; -use tokio::sync::RwLock; - -use crate::{ - adapters::receipt_checks_adapter::ReceiptChecksAdapter, - eip_712_signed_message::EIP712SignedMessage, - tap_receipt::{Receipt, ReceiptError, ReceivedReceipt}, -}; - -pub struct ReceiptChecksAdapterMock { - receipt_storage: Arc>>, - query_appraisals: Arc>>, - allocation_ids: Arc>>, - sender_ids: Arc>>, -} - -#[derive(Debug, Error)] -pub enum AdapterErrorMock { - #[error("something went wrong: {error}")] - AdapterError { error: String }, -} - -impl From for ReceiptError { - fn from(val: AdapterErrorMock) -> Self { - ReceiptError::CheckFailedToComplete { - source_error_message: val.to_string(), - } - } -} - -impl ReceiptChecksAdapterMock { - pub fn new( - receipt_storage: Arc>>, - query_appraisals: Arc>>, - allocation_ids: Arc>>, - sender_ids: Arc>>, - ) -> Self { - Self { - receipt_storage, - query_appraisals, - allocation_ids, - sender_ids, - } - } -} - -#[async_trait] -impl ReceiptChecksAdapter for ReceiptChecksAdapterMock { - type AdapterError = AdapterErrorMock; - - async fn is_unique( - &self, - receipt: &EIP712SignedMessage, - receipt_id: u64, - ) -> Result { - let receipt_storage = self.receipt_storage.read().await; - Ok(receipt_storage - .iter() - .all(|(stored_receipt_id, stored_receipt)| { - (stored_receipt.signed_receipt().message != receipt.message) - || *stored_receipt_id == receipt_id - })) - } - - async fn is_valid_allocation_id( - &self, - allocation_id: Address, - ) -> Result { - let allocation_ids = self.allocation_ids.read().await; - Ok(allocation_ids.contains(&allocation_id)) - } - - async fn is_valid_value(&self, value: u128, query_id: u64) -> Result { - let query_appraisals = self.query_appraisals.read().await; - let appraised_value = query_appraisals.get(&query_id).unwrap(); - - if value != *appraised_value { - return Ok(false); - } - Ok(true) - } - - async fn is_valid_sender_id(&self, sender_id: Address) -> Result { - let sender_ids = self.sender_ids.read().await; - Ok(sender_ids.contains(&sender_id)) - } -} diff --git a/tap_core/src/adapters/mock/receipt_storage_adapter_mock.rs b/tap_core/src/adapters/mock/receipt_storage_adapter_mock.rs deleted file mode 100644 index af444c51..00000000 --- a/tap_core/src/adapters/mock/receipt_storage_adapter_mock.rs +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2023-, Semiotic AI, Inc. -// SPDX-License-Identifier: Apache-2.0 - -use std::{collections::HashMap, ops::RangeBounds, sync::Arc}; - -use async_trait::async_trait; -use tokio::sync::RwLock; - -use crate::{ - adapters::receipt_storage_adapter::{ - safe_truncate_receipts, ReceiptRead, ReceiptStore, StoredReceipt, - }, - tap_receipt::ReceivedReceipt, -}; - -pub struct ReceiptStorageAdapterMock { - receipt_storage: Arc>>, - unique_id: RwLock, -} - -impl ReceiptStorageAdapterMock { - pub fn new(receipt_storage: Arc>>) -> Self { - Self { - receipt_storage, - unique_id: RwLock::new(0u64), - } - } - pub async fn retrieve_receipt_by_id( - &self, - receipt_id: u64, - ) -> Result { - let receipt_storage = self.receipt_storage.read().await; - - receipt_storage - .get(&receipt_id) - .cloned() - .ok_or(AdapterErrorMock::AdapterError { - error: "No receipt found with ID".to_owned(), - }) - } - pub async fn retrieve_receipts_by_timestamp( - &self, - timestamp_ns: u64, - ) -> Result, AdapterErrorMock> { - let receipt_storage = self.receipt_storage.read().await; - Ok(receipt_storage - .iter() - .filter(|(_, rx_receipt)| { - rx_receipt.signed_receipt().message.timestamp_ns == timestamp_ns - }) - .map(|(&id, rx_receipt)| (id, rx_receipt.clone())) - .collect()) - } - pub async fn retrieve_receipts_upto_timestamp( - &self, - timestamp_ns: u64, - ) -> Result, AdapterErrorMock> { - self.retrieve_receipts_in_timestamp_range(..=timestamp_ns, None) - .await - } - pub async fn remove_receipt_by_id(&mut self, receipt_id: u64) -> Result<(), AdapterErrorMock> { - let mut receipt_storage = self.receipt_storage.write().await; - receipt_storage - .remove(&receipt_id) - .map(|_| ()) - .ok_or(AdapterErrorMock::AdapterError { - error: "No receipt found with ID".to_owned(), - }) - } - pub async fn remove_receipts_by_ids( - &mut self, - receipt_ids: &[u64], - ) -> Result<(), AdapterErrorMock> { - for receipt_id in receipt_ids { - self.remove_receipt_by_id(*receipt_id).await?; - } - Ok(()) - } -} - -use thiserror::Error; -#[derive(Debug, Error)] -pub enum AdapterErrorMock { - #[error("something went wrong: {error}")] - AdapterError { error: String }, -} - -#[async_trait] -impl ReceiptStore for ReceiptStorageAdapterMock { - type AdapterError = AdapterErrorMock; - async fn store_receipt(&self, receipt: ReceivedReceipt) -> Result { - let mut id_pointer = self.unique_id.write().await; - let id_previous = *id_pointer; - let mut receipt_storage = self.receipt_storage.write().await; - receipt_storage.insert(*id_pointer, receipt); - *id_pointer += 1; - Ok(id_previous) - } - async fn update_receipt_by_id( - &self, - receipt_id: u64, - receipt: ReceivedReceipt, - ) -> Result<(), Self::AdapterError> { - let mut receipt_storage = self.receipt_storage.write().await; - - if !receipt_storage.contains_key(&receipt_id) { - return Err(AdapterErrorMock::AdapterError { - error: "Invalid receipt_id".to_owned(), - }); - }; - - receipt_storage.insert(receipt_id, receipt); - *self.unique_id.write().await += 1; - Ok(()) - } - async fn remove_receipts_in_timestamp_range + std::marker::Send>( - &self, - timestamp_ns: R, - ) -> Result<(), Self::AdapterError> { - let mut receipt_storage = self.receipt_storage.write().await; - receipt_storage.retain(|_, rx_receipt| { - !timestamp_ns.contains(&rx_receipt.signed_receipt().message.timestamp_ns) - }); - Ok(()) - } -} - -#[async_trait] -impl ReceiptRead for ReceiptStorageAdapterMock { - type AdapterError = AdapterErrorMock; - async fn retrieve_receipts_in_timestamp_range + std::marker::Send>( - &self, - timestamp_range_ns: R, - limit: Option, - ) -> Result, Self::AdapterError> { - let receipt_storage = self.receipt_storage.read().await; - let mut receipts_in_range: Vec<(u64, ReceivedReceipt)> = receipt_storage - .iter() - .filter(|(_, rx_receipt)| { - timestamp_range_ns.contains(&rx_receipt.signed_receipt().message.timestamp_ns) - }) - .map(|(&id, rx_receipt)| (id, rx_receipt.clone())) - .collect(); - - if limit.is_some_and(|limit| receipts_in_range.len() > limit as usize) { - safe_truncate_receipts(&mut receipts_in_range, limit.unwrap()); - - Ok(receipts_in_range.into_iter().map(|r| r.into()).collect()) - } else { - Ok(receipts_in_range.into_iter().map(|r| r.into()).collect()) - } - } -} diff --git a/tap_core/src/adapters/mod.rs b/tap_core/src/adapters/mod.rs index 96ab5cea..6b045e4e 100644 --- a/tap_core/src/adapters/mod.rs +++ b/tap_core/src/adapters/mod.rs @@ -17,7 +17,6 @@ pub mod escrow_adapter; pub mod rav_storage_adapter; -pub mod receipt_checks_adapter; pub mod receipt_storage_adapter; #[cfg(feature = "mock")] diff --git a/tap_core/src/adapters/receipt_checks_adapter.rs b/tap_core/src/adapters/receipt_checks_adapter.rs deleted file mode 100644 index ed4dbd7f..00000000 --- a/tap_core/src/adapters/receipt_checks_adapter.rs +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2023-, Semiotic AI, Inc. -// SPDX-License-Identifier: Apache-2.0 - -use crate::{eip_712_signed_message::EIP712SignedMessage, tap_receipt::Receipt}; -use alloy_primitives::Address; -use async_trait::async_trait; - -/// `ReceiptChecksAdapter` defines a trait for adapters to handle checks related to TAP receipts. -/// -/// This trait is designed to be implemented by users of this library who want to -/// customize the checks done on TAP receipts. This includes ensuring the receipt is unique, -/// verifying the allocation ID, the value and the sender ID. -/// -/// # Usage -/// -/// The `is_unique` method should be used to check if the given receipt is unique in the system. -/// -/// The `is_valid_allocation_id` method should verify if the allocation ID is valid. -/// -/// The `is_valid_value` method should confirm the value of the receipt is valid for the given query ID. -/// -/// The `is_valid_sender_id` method should confirm the sender ID is valid. -/// -/// This trait is utilized by [crate::tap_manager], which relies on these -/// operations for managing TAP receipts. -/// -/// # Example -/// -/// For example code see [crate::adapters::receipt_checks_adapter_mock] - -#[async_trait] -pub trait ReceiptChecksAdapter { - /// Defines the user-specified error type. - /// - /// This error type should implement the `Error` and `Debug` traits from the standard library. - /// Errors of this type are returned to the user when an operation fails. - type AdapterError: std::error::Error + std::fmt::Debug + Send + Sync + 'static; - - /// Checks if the given receipt is unique in the system. - /// - /// This method should be implemented to verify the uniqueness of a given receipt in your system. Keep in mind that - /// the receipt likely will be in storage when this check is performed so the receipt id should be used to check - /// for uniqueness. - async fn is_unique( - &self, - receipt: &EIP712SignedMessage, - receipt_id: u64, - ) -> Result; - - /// Verifies if the allocation ID is valid. - /// - /// This method should be implemented to validate the given allocation ID is a valid allocation for the indexer. Valid is defined as - /// an allocation ID that is owned by the indexer and still available for redeeming. - async fn is_valid_allocation_id( - &self, - allocation_id: Address, - ) -> Result; - - /// Confirms the value of the receipt is valid for the given query ID. - /// - /// This method should be implemented to confirm the validity of the given value for a specific query ID. - async fn is_valid_value(&self, value: u128, query_id: u64) -> Result; - - /// Confirms the sender ID is valid. - /// - /// This method should be implemented to validate the given sender ID is one associated with a sender the indexer considers valid. - /// The provided sender ID is the address of the sender that is recovered from the signature of the receipt. - async fn is_valid_sender_id(&self, sender_id: Address) -> Result; -} diff --git a/tap_core/src/adapters/test/escrow_adapter_test.rs b/tap_core/src/adapters/test/escrow_adapter_test.rs index 1b35b9be..18c28fd5 100644 --- a/tap_core/src/adapters/test/escrow_adapter_test.rs +++ b/tap_core/src/adapters/test/escrow_adapter_test.rs @@ -3,20 +3,37 @@ #[cfg(test)] mod escrow_adapter_unit_test { - use std::{collections::HashMap, sync::Arc}; + use std::{ + collections::HashMap, + sync::{Arc, RwLock}, + }; use ethers::signers::{coins_bip39::English, LocalWallet, MnemonicBuilder, Signer}; use rstest::*; - use tokio::sync::RwLock; - use crate::adapters::{escrow_adapter::EscrowAdapter, escrow_adapter_mock::EscrowAdapterMock}; + use crate::{ + adapters::{escrow_adapter::EscrowAdapter, executor_mock::ExecutorMock}, + checks::TimestampCheck, + }; - #[rstest] - #[tokio::test] - async fn escrow_adapter_test() { + #[fixture] + fn executor() -> ExecutorMock { let escrow_storage = Arc::new(RwLock::new(HashMap::new())); - let mut escrow_adapter = EscrowAdapterMock::new(escrow_storage); + let rav_storage = Arc::new(RwLock::new(None)); + let receipt_storage = Arc::new(RwLock::new(HashMap::new())); + + let timestamp_check = Arc::new(TimestampCheck::new(0)); + ExecutorMock::new( + rav_storage, + receipt_storage.clone(), + escrow_storage.clone(), + timestamp_check, + ) + } + #[rstest] + #[tokio::test] + async fn escrow_adapter_test(mut executor: ExecutorMock) { let wallet: LocalWallet = MnemonicBuilder::::default() .phrase("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") .build() @@ -35,42 +52,31 @@ mod escrow_adapter_unit_test { let initial_value = 500u128; - escrow_adapter - .increase_escrow(sender_id, initial_value) - .await; + executor.increase_escrow(sender_id, initial_value); // Check that sender exists and has valid value through adapter - assert!(escrow_adapter.get_available_escrow(sender_id).await.is_ok()); + assert!(executor.get_available_escrow(sender_id).await.is_ok()); assert_eq!( - escrow_adapter - .get_available_escrow(sender_id) - .await - .unwrap(), + executor.get_available_escrow(sender_id).await.unwrap(), initial_value ); // Check that subtracting is valid for valid sender, and results in expected value - assert!(escrow_adapter + assert!(executor .subtract_escrow(sender_id, initial_value) .await .is_ok()); - assert!(escrow_adapter.get_available_escrow(sender_id).await.is_ok()); - assert_eq!( - escrow_adapter - .get_available_escrow(sender_id) - .await - .unwrap(), - 0 - ); + assert!(executor.get_available_escrow(sender_id).await.is_ok()); + assert_eq!(executor.get_available_escrow(sender_id).await.unwrap(), 0); // Check that subtracting to negative escrow results in err - assert!(escrow_adapter + assert!(executor .subtract_escrow(sender_id, initial_value) .await .is_err()); // Check that accessing non initialized sender results in err - assert!(escrow_adapter + assert!(executor .get_available_escrow(invalid_sender_id) .await .is_err()); diff --git a/tap_core/src/adapters/test/mod.rs b/tap_core/src/adapters/test/mod.rs index 318c6216..33fb14c0 100644 --- a/tap_core/src/adapters/test/mod.rs +++ b/tap_core/src/adapters/test/mod.rs @@ -3,5 +3,4 @@ pub mod escrow_adapter_test; pub mod rav_storage_adapter_test; -pub mod receipt_checks_adapter_test; pub mod receipt_storage_adapter_test; diff --git a/tap_core/src/adapters/test/rav_storage_adapter_test.rs b/tap_core/src/adapters/test/rav_storage_adapter_test.rs index ce5c324f..3b51999b 100644 --- a/tap_core/src/adapters/test/rav_storage_adapter_test.rs +++ b/tap_core/src/adapters/test/rav_storage_adapter_test.rs @@ -3,6 +3,8 @@ #[cfg(test)] mod rav_storage_adapter_unit_test { + use std::collections::HashMap; + use std::sync::RwLock; use std::{str::FromStr, sync::Arc}; use alloy_primitives::Address; @@ -10,16 +12,17 @@ mod rav_storage_adapter_unit_test { use ethers::signers::coins_bip39::English; use ethers::signers::{LocalWallet, MnemonicBuilder}; use rstest::*; - use tokio::sync::RwLock; - use crate::adapters::rav_storage_adapter::RAVRead; - use crate::adapters::{ - rav_storage_adapter::RAVStore, rav_storage_adapter_mock::RAVStorageAdapterMock, - }; - use crate::tap_eip712_domain; + use crate::checks::TimestampCheck; use crate::{ + adapters::{ + executor_mock::ExecutorMock, + rav_storage_adapter::{RAVRead, RAVStore}, + }, eip_712_signed_message::EIP712SignedMessage, - receipt_aggregate_voucher::ReceiptAggregateVoucher, tap_receipt::Receipt, + receipt_aggregate_voucher::ReceiptAggregateVoucher, + tap_eip712_domain, + tap_receipt::Receipt, }; #[fixture] @@ -27,12 +30,24 @@ mod rav_storage_adapter_unit_test { tap_eip712_domain(1, Address::from([0x11u8; 20])) } - #[rstest] - #[tokio::test] - async fn rav_storage_adapter_test(domain_separator: Eip712Domain) { + #[fixture] + fn executor() -> ExecutorMock { + let escrow_storage = Arc::new(RwLock::new(HashMap::new())); let rav_storage = Arc::new(RwLock::new(None)); - let rav_storage_adapter = RAVStorageAdapterMock::new(rav_storage); + let receipt_storage = Arc::new(RwLock::new(HashMap::new())); + + let timestamp_check = Arc::new(TimestampCheck::new(0)); + ExecutorMock::new( + rav_storage, + receipt_storage.clone(), + escrow_storage.clone(), + timestamp_check, + ) + } + #[rstest] + #[tokio::test] + async fn rav_storage_adapter_test(domain_separator: Eip712Domain, executor: ExecutorMock) { let wallet: LocalWallet = MnemonicBuilder::::default() .phrase("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") .build() @@ -53,7 +68,6 @@ mod rav_storage_adapter_unit_test { Receipt::new(allocation_id, value).unwrap(), &wallet, ) - .await .unwrap(), ); } @@ -63,16 +77,12 @@ mod rav_storage_adapter_unit_test { ReceiptAggregateVoucher::aggregate_receipts(allocation_id, &receipts, None).unwrap(), &wallet, ) - .await .unwrap(); - rav_storage_adapter - .update_last_rav(signed_rav.clone()) - .await - .unwrap(); + executor.update_last_rav(signed_rav.clone()).await.unwrap(); // Retreive rav - let retrieved_rav = rav_storage_adapter.last_rav().await; + let retrieved_rav = executor.last_rav().await; assert!(retrieved_rav.unwrap().unwrap() == signed_rav); // Testing the last rav update... @@ -86,7 +96,6 @@ mod rav_storage_adapter_unit_test { Receipt::new(allocation_id, value).unwrap(), &wallet, ) - .await .unwrap(), ); } @@ -96,17 +105,13 @@ mod rav_storage_adapter_unit_test { ReceiptAggregateVoucher::aggregate_receipts(allocation_id, &receipts, None).unwrap(), &wallet, ) - .await .unwrap(); // Update the last rav - rav_storage_adapter - .update_last_rav(signed_rav.clone()) - .await - .unwrap(); + executor.update_last_rav(signed_rav.clone()).await.unwrap(); // Retreive rav - let retrieved_rav = rav_storage_adapter.last_rav().await; + let retrieved_rav = executor.last_rav().await; assert!(retrieved_rav.unwrap().unwrap() == signed_rav); } } diff --git a/tap_core/src/adapters/test/receipt_checks_adapter_test.rs b/tap_core/src/adapters/test/receipt_checks_adapter_test.rs deleted file mode 100644 index 0d25f771..00000000 --- a/tap_core/src/adapters/test/receipt_checks_adapter_test.rs +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2023-, Semiotic AI, Inc. -// SPDX-License-Identifier: Apache-2.0 - -#[cfg(test)] -mod receipt_checks_adapter_unit_test { - use std::{ - collections::{HashMap, HashSet}, - str::FromStr, - sync::Arc, - }; - - use alloy_primitives::Address; - use alloy_sol_types::Eip712Domain; - use ethers::signers::{coins_bip39::English, LocalWallet, MnemonicBuilder}; - use futures::{stream, StreamExt}; - use rstest::*; - use tokio::sync::RwLock; - - use crate::{ - adapters::{ - receipt_checks_adapter::ReceiptChecksAdapter, - receipt_checks_adapter_mock::ReceiptChecksAdapterMock, - }, - eip_712_signed_message::EIP712SignedMessage, - tap_eip712_domain, - tap_receipt::{get_full_list_of_checks, Receipt, ReceivedReceipt}, - }; - - #[fixture] - fn domain_separator() -> Eip712Domain { - tap_eip712_domain(1, Address::from([0x11u8; 20])) - } - - #[rstest] - #[tokio::test] - async fn receipt_checks_adapter_test(domain_separator: Eip712Domain) { - let sender_ids = [ - Address::from_str("0xfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfb").unwrap(), - Address::from_str("0xfafafafafafafafafafafafafafafafafafafafa").unwrap(), - Address::from_str("0xadadadadadadadadadadadadadadadadadadadad").unwrap(), - ]; - let sender_ids_set = Arc::new(RwLock::new(HashSet::from(sender_ids))); - - let allocation_ids = [ - Address::from_str("0xabababababababababababababababababababab").unwrap(), - Address::from_str("0xbabababababababababababababababababababa").unwrap(), - Address::from_str("0xdfdfdfdfdfdfdfdfdfdfdfdfdfdfdfdfdfdfdfdf").unwrap(), - ]; - let allocation_ids_set = Arc::new(RwLock::new(HashSet::from(allocation_ids))); - - let wallet: LocalWallet = MnemonicBuilder::::default() - .phrase("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") - .build() - .unwrap(); - let value = 100u128; - - let receipts: HashMap = stream::iter(0..10) - .then(|id| { - let wallet = wallet.clone(); - let domain_separator = &domain_separator; - async move { - ( - id, - ReceivedReceipt::new( - EIP712SignedMessage::new( - domain_separator, - Receipt::new(allocation_ids[0], value).unwrap(), - &wallet, - ) - .await - .unwrap(), - id, - &get_full_list_of_checks(), - ), - ) - } - }) - .collect::>() - .await; - let receipt_storage = Arc::new(RwLock::new(receipts)); - - let query_appraisals = (0..11).map(|id| (id, 100u128)).collect::>(); - - let query_appraisals_storage = Arc::new(RwLock::new(query_appraisals)); - - let receipt_checks_adapter = ReceiptChecksAdapterMock::new( - Arc::clone(&receipt_storage), - query_appraisals_storage, - allocation_ids_set, - sender_ids_set, - ); - - let new_receipt = ( - 10u64, - ReceivedReceipt::new( - EIP712SignedMessage::new( - &domain_separator, - Receipt::new(allocation_ids[0], value).unwrap(), - &wallet, - ) - .await - .unwrap(), - 10u64, - &get_full_list_of_checks(), - ), - ); - - let unique_receipt_id = 0u64; - receipt_storage - .write() - .await - .insert(unique_receipt_id, new_receipt.1.clone()); - - assert!(receipt_checks_adapter - .is_unique(new_receipt.1.signed_receipt(), unique_receipt_id) - .await - .unwrap()); - assert!(receipt_checks_adapter - .is_valid_allocation_id(new_receipt.1.signed_receipt().message.allocation_id) - .await - .unwrap()); - // TODO: Add check when sender_id is available from received receipt (issue: #56) - // assert!(receipt_checks_adapter.is_valid_sender_id(sender_id)); - assert!(receipt_checks_adapter - .is_valid_value( - new_receipt.1.signed_receipt().message.value, - new_receipt.1.query_id() - ) - .await - .unwrap()); - } -} diff --git a/tap_core/src/adapters/test/receipt_storage_adapter_test.rs b/tap_core/src/adapters/test/receipt_storage_adapter_test.rs index 26beeaf4..b6158edf 100644 --- a/tap_core/src/adapters/test/receipt_storage_adapter_test.rs +++ b/tap_core/src/adapters/test/receipt_storage_adapter_test.rs @@ -5,38 +5,67 @@ mod receipt_storage_adapter_unit_test { use rand::seq::SliceRandom; use rand::thread_rng; - use std::collections::HashMap; + use std::collections::{HashMap, HashSet}; use std::str::FromStr; - use std::sync::Arc; + use std::sync::{Arc, RwLock}; + use crate::checks::TimestampCheck; + use crate::{ + adapters::{executor_mock::ExecutorMock, receipt_storage_adapter::ReceiptStore}, + checks::{mock::get_full_list_of_checks, ReceiptCheck}, + eip_712_signed_message::EIP712SignedMessage, + tap_eip712_domain, + tap_receipt::{Receipt, ReceivedReceipt}, + }; use alloy_primitives::Address; use alloy_sol_types::Eip712Domain; use ethers::signers::coins_bip39::English; use ethers::signers::{LocalWallet, MnemonicBuilder}; use rstest::*; - use tokio::sync::RwLock; - - use crate::adapters::{ - receipt_storage_adapter::ReceiptStore, - receipt_storage_adapter_mock::ReceiptStorageAdapterMock, - }; - use crate::tap_eip712_domain; - use crate::tap_receipt::ReceivedReceipt; - use crate::{ - eip_712_signed_message::EIP712SignedMessage, tap_receipt::get_full_list_of_checks, - tap_receipt::Receipt, - }; #[fixture] fn domain_separator() -> Eip712Domain { tap_eip712_domain(1, Address::from([0x11u8; 20])) } + struct ExecutorFixture { + executor: ExecutorMock, + checks: Vec, + } + + #[fixture] + fn executor_mock(domain_separator: Eip712Domain) -> ExecutorFixture { + let escrow_storage = Arc::new(RwLock::new(HashMap::new())); + let rav_storage = Arc::new(RwLock::new(None)); + let query_appraisals = Arc::new(RwLock::new(HashMap::new())); + let receipt_storage = Arc::new(RwLock::new(HashMap::new())); + + let timestamp_check = Arc::new(TimestampCheck::new(0)); + let executor = ExecutorMock::new( + rav_storage, + receipt_storage.clone(), + escrow_storage.clone(), + timestamp_check.clone(), + ); + let mut checks = get_full_list_of_checks( + domain_separator, + HashSet::new(), + Arc::new(RwLock::new(HashSet::new())), + receipt_storage, + query_appraisals.clone(), + ); + checks.push(timestamp_check); + + ExecutorFixture { executor, checks } + } + #[rstest] #[tokio::test] - async fn receipt_adapter_test(domain_separator: Eip712Domain) { - let receipt_storage = Arc::new(RwLock::new(HashMap::new())); - let mut receipt_adapter = ReceiptStorageAdapterMock::new(receipt_storage); + async fn receipt_adapter_test(domain_separator: Eip712Domain, executor_mock: ExecutorFixture) { + let ExecutorFixture { + mut executor, + checks, + } = executor_mock; let wallet: LocalWallet = MnemonicBuilder::::default() .phrase("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") @@ -55,50 +84,42 @@ mod receipt_storage_adapter_unit_test { Receipt::new(allocation_id, value).unwrap(), &wallet, ) - .await .unwrap(), query_id, - &get_full_list_of_checks(), + &checks, ); - let receipt_store_result = receipt_adapter.store_receipt(received_receipt).await; + let receipt_store_result = executor.store_receipt(received_receipt).await; assert!(receipt_store_result.is_ok()); let receipt_id = receipt_store_result.unwrap(); // Retreive receipt with id expected to be valid - assert!(receipt_adapter - .retrieve_receipt_by_id(receipt_id) - .await - .is_ok()); + assert!(executor.retrieve_receipt_by_id(receipt_id).await.is_ok()); // Retreive receipt with arbitrary id expected to be invalid - assert!(receipt_adapter.retrieve_receipt_by_id(999).await.is_err()); + assert!(executor.retrieve_receipt_by_id(999).await.is_err()); // Remove receipt with id expected to be valid - assert!(receipt_adapter - .remove_receipt_by_id(receipt_id) - .await - .is_ok()); + assert!(executor.remove_receipt_by_id(receipt_id).await.is_ok()); // Remove receipt with arbitrary id expected to be invalid - assert!(receipt_adapter.remove_receipt_by_id(999).await.is_err()); + assert!(executor.remove_receipt_by_id(999).await.is_err()); // Retreive receipt that was removed previously - assert!(receipt_adapter - .retrieve_receipt_by_id(receipt_id) - .await - .is_err()); + assert!(executor.retrieve_receipt_by_id(receipt_id).await.is_err()); // Remove receipt that was removed previously - assert!(receipt_adapter - .remove_receipt_by_id(receipt_id) - .await - .is_err()); + assert!(executor.remove_receipt_by_id(receipt_id).await.is_err()); } #[rstest] #[tokio::test] - async fn multi_receipt_adapter_test(domain_separator: Eip712Domain) { - let receipt_storage = Arc::new(RwLock::new(HashMap::new())); - let mut receipt_adapter = ReceiptStorageAdapterMock::new(receipt_storage); + async fn multi_receipt_adapter_test( + domain_separator: Eip712Domain, + executor_mock: ExecutorFixture, + ) { + let ExecutorFixture { + mut executor, + checks, + } = executor_mock; let wallet: LocalWallet = MnemonicBuilder::::default() .phrase("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") @@ -117,17 +138,16 @@ mod receipt_storage_adapter_unit_test { Receipt::new(allocation_id, value).unwrap(), &wallet, ) - .await .unwrap(), query_id as u64, - &get_full_list_of_checks(), + &checks, )); } let mut receipt_ids = Vec::new(); let mut receipt_timestamps = Vec::new(); for received_receipt in received_receipts { receipt_ids.push( - receipt_adapter + executor .store_receipt(received_receipt.clone()) .await .unwrap(), @@ -136,23 +156,23 @@ mod receipt_storage_adapter_unit_test { } // Retreive receipts with timestamp - assert!(receipt_adapter + assert!(executor .retrieve_receipts_by_timestamp(receipt_timestamps[0]) .await .is_ok()); - assert!(!receipt_adapter + assert!(!executor .retrieve_receipts_by_timestamp(receipt_timestamps[0]) .await .unwrap() .is_empty()); // Retreive receipts before timestamp - assert!(receipt_adapter + assert!(executor .retrieve_receipts_upto_timestamp(receipt_timestamps[3]) .await .is_ok()); assert!( - receipt_adapter + executor .retrieve_receipts_upto_timestamp(receipt_timestamps[3]) .await .unwrap() @@ -161,21 +181,18 @@ mod receipt_storage_adapter_unit_test { ); // Remove all receipts with one call - assert!(receipt_adapter + assert!(executor .remove_receipts_by_ids(receipt_ids.as_slice()) .await .is_ok()); // Removal should no longer be valid - assert!(receipt_adapter + assert!(executor .remove_receipts_by_ids(receipt_ids.as_slice()) .await .is_err()); // Retrieval should be invalid for receipt_id in receipt_ids { - assert!(receipt_adapter - .retrieve_receipt_by_id(receipt_id) - .await - .is_err()); + assert!(executor.retrieve_receipt_by_id(receipt_id).await.is_err()); } } @@ -185,9 +202,10 @@ mod receipt_storage_adapter_unit_test { #[case(vec![1, 2, 3, 3, 4, 5], 3, vec![1, 2])] #[case(vec![1, 2, 3, 4, 4, 4], 3, vec![1, 2, 3])] #[case(vec![1, 1, 1, 1, 2, 3], 3, vec![])] - #[tokio::test] - async fn safe_truncate_receipts_test( + #[test] + fn safe_truncate_receipts_test( domain_separator: Eip712Domain, + executor_mock: ExecutorFixture, #[case] input: Vec, #[case] limit: u64, #[case] expected: Vec, @@ -196,6 +214,7 @@ mod receipt_storage_adapter_unit_test { .phrase("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") .build() .unwrap(); + let checks = executor_mock.checks; // Vec of (id, receipt) let mut receipts_orig: Vec<(u64, ReceivedReceipt)> = Vec::new(); @@ -215,10 +234,9 @@ mod receipt_storage_adapter_unit_test { }, &wallet, ) - .await .unwrap(), i as u64, // Will use that to check the IDs - &get_full_list_of_checks(), + &checks, ), )); } diff --git a/tap_core/src/checks/mod.rs b/tap_core/src/checks/mod.rs new file mode 100644 index 00000000..709708ea --- /dev/null +++ b/tap_core/src/checks/mod.rs @@ -0,0 +1,250 @@ +// Copyright 2023-, Semiotic AI, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use crate::tap_receipt::{Checking, ReceiptError, ReceiptResult, ReceiptWithState}; +use serde::{Deserialize, Serialize}; +use std::sync::{Arc, RwLock}; + +pub type ReceiptCheck = Arc; + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub enum CheckingChecks { + Pending(ReceiptCheck), + Executed(ReceiptResult<()>), +} + +impl CheckingChecks { + pub fn new(check: ReceiptCheck) -> Self { + Self::Pending(check) + } + + pub async fn execute(self, receipt: &ReceiptWithState) -> Self { + match self { + Self::Pending(check) => { + let result = check.check(receipt).await; + Self::Executed(result) + } + Self::Executed(_) => self, + } + } + + pub fn is_failed(&self) -> bool { + matches!(self, Self::Executed(Err(_))) + } + + pub fn is_pending(&self) -> bool { + matches!(self, Self::Pending(_)) + } + + pub fn is_complete(&self) -> bool { + matches!(self, Self::Executed(_)) + } +} + +#[async_trait::async_trait] +#[typetag::serde(tag = "type")] +pub trait Check: std::fmt::Debug + Send + Sync { + async fn check(&self, receipt: &ReceiptWithState) -> ReceiptResult<()>; + + async fn check_batch(&self, receipts: &[ReceiptWithState]) -> Vec> { + let mut results = Vec::new(); + for receipt in receipts { + let result = self.check(receipt).await; + results.push(result); + } + results + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TimestampCheck { + #[serde(skip)] + min_timestamp_ns: RwLock, +} + +impl TimestampCheck { + pub fn new(min_timestamp_ns: u64) -> Self { + Self { + min_timestamp_ns: RwLock::new(min_timestamp_ns), + } + } + /// Updates the minimum timestamp that will be accepted for a receipt (exclusive). + pub fn update_min_timestamp_ns(&self, min_timestamp_ns: u64) { + *self.min_timestamp_ns.write().unwrap() = min_timestamp_ns; + } +} + +#[async_trait::async_trait] +#[typetag::serde] +impl Check for TimestampCheck { + async fn check(&self, receipt: &ReceiptWithState) -> ReceiptResult<()> { + let min_timestamp_ns = *self.min_timestamp_ns.read().unwrap(); + let signed_receipt = receipt.signed_receipt(); + if signed_receipt.message.timestamp_ns <= min_timestamp_ns { + return Err(ReceiptError::InvalidTimestamp { + received_timestamp: signed_receipt.message.timestamp_ns, + timestamp_min: min_timestamp_ns, + }); + } + Ok(()) + } +} + +#[cfg(feature = "mock")] +pub mod mock { + + use super::*; + use crate::tap_receipt::ReceivedReceipt; + use alloy_primitives::Address; + use alloy_sol_types::Eip712Domain; + use std::{ + collections::{HashMap, HashSet}, + fmt::Debug, + }; + + pub fn get_full_list_of_checks( + domain_separator: Eip712Domain, + valid_signers: HashSet
, + allocation_ids: Arc>>, + receipt_storage: Arc>>, + query_appraisals: Arc>>, + ) -> Vec { + vec![ + Arc::new(UniqueCheck { receipt_storage }), + Arc::new(ValueCheck { query_appraisals }), + Arc::new(AllocationIdCheck { allocation_ids }), + Arc::new(SignatureCheck { + domain_separator, + valid_signers, + }), + ] + } + + #[derive(Serialize, Deserialize)] + struct UniqueCheck { + #[serde(skip)] + receipt_storage: Arc>>, + } + impl Debug for UniqueCheck { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "UniqueCheck") + } + } + + #[async_trait::async_trait] + #[typetag::serde] + impl Check for UniqueCheck { + async fn check(&self, receipt: &ReceiptWithState) -> ReceiptResult<()> { + let receipt_storage = self.receipt_storage.read().unwrap(); + // let receipt_id = receipt. + let unique = receipt_storage + .iter() + .all(|(_stored_receipt_id, stored_receipt)| { + stored_receipt.signed_receipt().message != receipt.signed_receipt().message + || stored_receipt.query_id() == receipt.query_id + }); + + unique.then_some(()).ok_or(ReceiptError::NonUniqueReceipt) + } + + async fn check_batch( + &self, + receipts: &[ReceiptWithState], + ) -> Vec> { + let mut signatures: HashSet = HashSet::new(); + let mut results = Vec::new(); + + for received_receipt in receipts { + let signature = received_receipt.signed_receipt.signature; + if signatures.insert(signature) { + results.push(Ok(())); + } else { + results.push(Err(ReceiptError::NonUniqueReceipt)); + } + } + results + } + } + + #[derive(Debug, Serialize, Deserialize)] + struct ValueCheck { + #[serde(skip)] + query_appraisals: Arc>>, + } + + #[async_trait::async_trait] + #[typetag::serde] + impl Check for ValueCheck { + async fn check(&self, receipt: &ReceiptWithState) -> ReceiptResult<()> { + let query_id = receipt.query_id; + let value = receipt.signed_receipt().message.value; + let query_appraisals = self.query_appraisals.read().unwrap(); + let appraised_value = + query_appraisals + .get(&query_id) + .ok_or(ReceiptError::CheckFailedToComplete { + source_error_message: "Could not find query_appraisals".into(), + })?; + + if value != *appraised_value { + Err(ReceiptError::InvalidValue { + received_value: value, + }) + } else { + Ok(()) + } + } + } + + #[derive(Debug, Serialize, Deserialize)] + struct AllocationIdCheck { + #[serde(skip)] + allocation_ids: Arc>>, + } + + #[async_trait::async_trait] + #[typetag::serde] + impl Check for AllocationIdCheck { + async fn check(&self, receipt: &ReceiptWithState) -> ReceiptResult<()> { + let received_allocation_id = receipt.signed_receipt().message.allocation_id; + if self + .allocation_ids + .read() + .unwrap() + .contains(&received_allocation_id) + { + Ok(()) + } else { + Err(ReceiptError::InvalidAllocationID { + received_allocation_id, + }) + } + } + } + + #[derive(Debug, Serialize, Deserialize)] + struct SignatureCheck { + domain_separator: Eip712Domain, + valid_signers: HashSet
, + } + + #[async_trait::async_trait] + #[typetag::serde] + impl Check for SignatureCheck { + async fn check(&self, receipt: &ReceiptWithState) -> ReceiptResult<()> { + let recovered_address = receipt + .signed_receipt() + .recover_signer(&self.domain_separator) + .map_err(|e| ReceiptError::InvalidSignature { + source_error_message: e.to_string(), + })?; + if !self.valid_signers.contains(&recovered_address) { + Err(ReceiptError::InvalidSignature { + source_error_message: "Invalid signer".to_string(), + }) + } else { + Ok(()) + } + } + } +} diff --git a/tap_core/src/eip_712_signed_message.rs b/tap_core/src/eip_712_signed_message.rs index 4d050ab0..b68c4aa2 100644 --- a/tap_core/src/eip_712_signed_message.rs +++ b/tap_core/src/eip_712_signed_message.rs @@ -21,7 +21,7 @@ pub struct EIP712SignedMessage { impl EIP712SignedMessage { /// creates signed message with signed EIP712 hash of `message` using `signing_wallet` - pub async fn new( + pub fn new( domain_separator: &Eip712Domain, message: M, signing_wallet: &LocalWallet, diff --git a/tap_core/src/lib.rs b/tap_core/src/lib.rs index 42fcc0bb..fd61aabc 100644 --- a/tap_core/src/lib.rs +++ b/tap_core/src/lib.rs @@ -12,6 +12,7 @@ use alloy_sol_types::eip712_domain; use thiserror::Error; pub mod adapters; +pub mod checks; pub mod eip_712_signed_message; mod error; pub mod receipt_aggregate_voucher; @@ -87,8 +88,8 @@ mod tap_tests { #[rstest] #[case::basic_rav_test (vec![45,56,34,23])] #[case::rav_from_zero_valued_receipts (vec![0,0,0,0])] - #[tokio::test] - async fn signed_rav_is_valid_with_no_previous_rav( + #[test] + fn signed_rav_is_valid_with_no_previous_rav( keys: (LocalWallet, Address), allocation_ids: Vec
, domain_separator: Eip712Domain, @@ -103,7 +104,6 @@ mod tap_tests { Receipt::new(allocation_ids[0], value).unwrap(), &keys.0, ) - .await .unwrap(), ); } @@ -112,17 +112,15 @@ mod tap_tests { let rav = ReceiptAggregateVoucher::aggregate_receipts(allocation_ids[0], &receipts, None) .unwrap(); - let signed_rav = EIP712SignedMessage::new(&domain_separator, rav, &keys.0) - .await - .unwrap(); + let signed_rav = EIP712SignedMessage::new(&domain_separator, rav, &keys.0).unwrap(); assert!(signed_rav.recover_signer(&domain_separator).unwrap() == keys.1); } #[rstest] #[case::basic_rav_test(vec![45,56,34,23])] #[case::rav_from_zero_valued_receipts(vec![0,0,0,0])] - #[tokio::test] - async fn signed_rav_is_valid_with_previous_rav( + #[test] + fn signed_rav_is_valid_with_previous_rav( keys: (LocalWallet, Address), allocation_ids: Vec
, domain_separator: Eip712Domain, @@ -137,7 +135,6 @@ mod tap_tests { Receipt::new(allocation_ids[0], value).unwrap(), &keys.0, ) - .await .unwrap(), ); } @@ -149,9 +146,8 @@ mod tap_tests { None, ) .unwrap(); - let signed_prev_rav = EIP712SignedMessage::new(&domain_separator, prev_rav, &keys.0) - .await - .unwrap(); + let signed_prev_rav = + EIP712SignedMessage::new(&domain_separator, prev_rav, &keys.0).unwrap(); // Create new RAV from last half of receipts and prev_rav let rav = ReceiptAggregateVoucher::aggregate_receipts( @@ -160,16 +156,14 @@ mod tap_tests { Some(signed_prev_rav), ) .unwrap(); - let signed_rav = EIP712SignedMessage::new(&domain_separator, rav, &keys.0) - .await - .unwrap(); + let signed_rav = EIP712SignedMessage::new(&domain_separator, rav, &keys.0).unwrap(); assert!(signed_rav.recover_signer(&domain_separator).unwrap() == keys.1); } #[rstest] - #[tokio::test] - async fn verify_signature( + #[test] + fn verify_signature( keys: (LocalWallet, Address), allocation_ids: Vec
, domain_separator: Eip712Domain, @@ -179,7 +173,6 @@ mod tap_tests { Receipt::new(allocation_ids[0], 42).unwrap(), &keys.0, ) - .await .unwrap(); assert!(signed_message.verify(&domain_separator, keys.1).is_ok()); diff --git a/tap_core/src/tap_manager/manager.rs b/tap_core/src/tap_manager/manager.rs index 78ac2412..1f959a63 100644 --- a/tap_core/src/tap_manager/manager.rs +++ b/tap_core/src/tap_manager/manager.rs @@ -1,20 +1,22 @@ // Copyright 2023-, Semiotic AI, Inc. // SPDX-License-Identifier: Apache-2.0 +use alloy_primitives::Address; use alloy_sol_types::Eip712Domain; +use futures::Future; use super::{RAVRequest, SignedRAV, SignedReceipt}; use crate::{ adapters::{ escrow_adapter::EscrowAdapter, rav_storage_adapter::{RAVRead, RAVStore}, - receipt_checks_adapter::ReceiptChecksAdapter, receipt_storage_adapter::{ReceiptRead, ReceiptStore}, }, + checks::ReceiptCheck, receipt_aggregate_voucher::ReceiptAggregateVoucher, tap_receipt::{ - CategorizedReceiptsWithState, Failed, ReceiptAuditor, ReceiptCheck, ReceiptWithId, - ReceiptWithState, ReceivedReceipt, Reserved, + CategorizedReceiptsWithState, Failed, ReceiptAuditor, ReceiptWithId, ReceiptWithState, + ReceivedReceipt, Reserved, }, Error, }; @@ -41,13 +43,8 @@ where domain_separator: Eip712Domain, executor: E, required_checks: Vec, - starting_min_timestamp_ns: u64, ) -> Self { - let receipt_auditor = ReceiptAuditor::new( - domain_separator, - executor.clone(), - starting_min_timestamp_ns, - ); + let receipt_auditor = ReceiptAuditor::new(domain_separator, executor.clone()); Self { executor, required_checks, @@ -58,7 +55,7 @@ where impl Manager where - E: RAVStore + ReceiptChecksAdapter, + E: RAVStore, { /// Verify `signed_rav` matches all values on `expected_rav`, and that `signed_rav` has a valid signer. /// @@ -66,13 +63,18 @@ where /// /// Returns [`Error::AdapterError`] if there are any errors while storing RAV /// - pub async fn verify_and_store_rav( + pub async fn verify_and_store_rav( &self, expected_rav: ReceiptAggregateVoucher, signed_rav: SignedRAV, - ) -> std::result::Result<(), Error> { + verify_signer: F, + ) -> std::result::Result<(), Error> + where + F: FnOnce(Address) -> Fut, + Fut: Future>, + { self.receipt_auditor - .check_rav_signature(&signed_rav) + .check_rav_signature(&signed_rav, verify_signer) .await?; if signed_rav.message != expected_rav { @@ -111,7 +113,7 @@ where impl Manager where - E: ReceiptRead + EscrowAdapter + ReceiptChecksAdapter, + E: ReceiptRead + EscrowAdapter, { async fn collect_receipts( &self, @@ -148,14 +150,12 @@ where mut reserved_receipts, } = received_receipts.into(); - for received_receipt in checking_receipts { + for received_receipt in checking_receipts.into_iter() { let ReceiptWithId { receipt, - receipt_id, + receipt_id: _, } = received_receipt; - let receipt = receipt - .finalize_receipt_checks(receipt_id, &self.receipt_auditor) - .await; + let receipt = receipt.finalize_receipt_checks().await; match receipt { Ok(checked) => awaiting_reserve_receipts.push(checked), @@ -178,7 +178,7 @@ where impl Manager where - E: ReceiptRead + RAVRead + EscrowAdapter + ReceiptChecksAdapter, + E: ReceiptRead + RAVRead + EscrowAdapter, { /// Completes remaining checks on all receipts up to (current time - `timestamp_buffer_ns`). Returns them in /// two lists (valid receipts and invalid receipts) along with the expected RAV that should be received @@ -207,9 +207,6 @@ where let expected_rav = Self::generate_expected_rav(&valid_receipts, previous_rav.clone())?; - self.receipt_auditor - .update_min_timestamp_ns(expected_rav.timestampNs) - .await; let valid_receipts = valid_receipts .into_iter() .map(|rx_receipt| rx_receipt.signed_receipt) @@ -274,7 +271,7 @@ where impl Manager where - E: ReceiptStore + EscrowAdapter + ReceiptChecksAdapter, + E: ReceiptStore + EscrowAdapter, { /// Runs `initial_checks` on `signed_receipt` for initial verification, then stores received receipt. /// The provided `query_id` will be used as a key when chaecking query appraisal. @@ -309,7 +306,13 @@ where if let ReceivedReceipt::Checking(received_receipt) = &mut received_receipt { received_receipt - .perform_checks(initial_checks, receipt_id, &self.receipt_auditor) + .perform_checks( + initial_checks + .iter() + .map(|check| check.typetag_name()) + .collect::>() + .as_slice(), + ) .await; } diff --git a/tap_core/src/tap_manager/rav_request.rs b/tap_core/src/tap_manager/rav_request.rs index bcf395e4..4b59b5f8 100644 --- a/tap_core/src/tap_manager/rav_request.rs +++ b/tap_core/src/tap_manager/rav_request.rs @@ -10,7 +10,7 @@ use crate::{ }; #[derive(Debug, Serialize, Deserialize, Clone)] - +#[serde(bound(deserialize = "'de: 'static"))] pub struct RAVRequest { pub valid_receipts: Vec, pub previous_rav: Option, diff --git a/tap_core/src/tap_manager/test/manager_test.rs b/tap_core/src/tap_manager/test/manager_test.rs index dc04882c..c231559d 100644 --- a/tap_core/src/tap_manager/test/manager_test.rs +++ b/tap_core/src/tap_manager/test/manager_test.rs @@ -1,502 +1,482 @@ // Copyright 2023-, Semiotic AI, Inc. // SPDX-License-Identifier: Apache-2.0 -#[cfg(test)] -#[allow(clippy::too_many_arguments)] -mod manager_unit_test { - use std::{ - collections::{HashMap, HashSet}, - str::FromStr, - sync::Arc, - }; - - use alloy_primitives::Address; - use alloy_sol_types::Eip712Domain; - use ethers::signers::{coins_bip39::English, LocalWallet, MnemonicBuilder, Signer}; - use rstest::*; - use tokio::sync::RwLock; - - use super::super::Manager; - use crate::{ - adapters::{ - escrow_adapter_mock::EscrowAdapterMock, - executor_mock::{EscrowStorage, ExecutorMock, QueryAppraisals}, - receipt_checks_adapter_mock::ReceiptChecksAdapterMock, - receipt_storage_adapter::ReceiptRead, - }, - eip_712_signed_message::EIP712SignedMessage, - get_current_timestamp_u64_ns, tap_eip712_domain, - tap_receipt::{get_full_list_of_checks, Receipt, ReceiptCheck}, - }; - - #[fixture] - fn keys() -> (LocalWallet, Address) { - let wallet: LocalWallet = MnemonicBuilder::::default() - .phrase("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") - .build() - .unwrap(); - // Alloy library does not have feature parity with ethers library (yet) This workaround is needed to get the address - // to convert to an alloy Address. This will not be needed when the alloy library has wallet support. - let address: [u8; 20] = wallet.address().into(); - - (wallet, address.into()) - } - - #[fixture] - fn allocation_ids() -> Vec
{ - vec![ - Address::from_str("0xabababababababababababababababababababab").unwrap(), - Address::from_str("0xdeaddeaddeaddeaddeaddeaddeaddeaddeaddead").unwrap(), - Address::from_str("0xbeefbeefbeefbeefbeefbeefbeefbeefbeefbeef").unwrap(), - Address::from_str("0x1234567890abcdef1234567890abcdef12345678").unwrap(), - ] - } - - #[fixture] - fn sender_ids() -> Vec
{ - vec![ - Address::from_str("0xfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfb").unwrap(), - Address::from_str("0xfafafafafafafafafafafafafafafafafafafafa").unwrap(), - Address::from_str("0xadadadadadadadadadadadadadadadadadadadad").unwrap(), - keys().1, - ] - } - - #[fixture] - fn domain_separator() -> Eip712Domain { - tap_eip712_domain(1, Address::from([0x11u8; 20])) - } +use std::{ + collections::HashMap, + ops::Range, + str::FromStr, + sync::{Arc, RwLock}, +}; + +use alloy_primitives::Address; +use alloy_sol_types::Eip712Domain; +use ethers::signers::{coins_bip39::English, LocalWallet, MnemonicBuilder, Signer}; +use rstest::*; + +use super::super::Manager; +use crate::{ + adapters::{ + executor_mock::{EscrowStorage, ExecutorMock, QueryAppraisals}, + receipt_storage_adapter::ReceiptRead, + }, + checks::{mock::get_full_list_of_checks, ReceiptCheck, TimestampCheck}, + eip_712_signed_message::EIP712SignedMessage, + get_current_timestamp_u64_ns, tap_eip712_domain, + tap_receipt::Receipt, +}; + +const LENGTH_OF_CHECKS: usize = 4; + +#[fixture] +fn keys() -> (LocalWallet, Address) { + let wallet: LocalWallet = MnemonicBuilder::::default() + .phrase("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") + .build() + .unwrap(); + // Alloy library does not have feature parity with ethers library (yet) This workaround is needed to get the address + // to convert to an alloy Address. This will not be needed when the alloy library has wallet support. + let address: [u8; 20] = wallet.address().into(); - #[fixture] - fn executor_mock() -> (ExecutorMock, EscrowStorage, QueryAppraisals) { - let rav_storage = Arc::new(RwLock::new(None)); - let receipt_storage = Arc::new(RwLock::new(HashMap::new())); - - let sender_escrow_storage = Arc::new(RwLock::new(HashMap::new())); - - let allocation_ids_set = Arc::new(RwLock::new(HashSet::from_iter(allocation_ids()))); - let sender_ids_set = Arc::new(RwLock::new(HashSet::from_iter(sender_ids()))); - let query_appraisal_storage = Arc::new(RwLock::new(HashMap::new())); - - ( - ExecutorMock::new( - rav_storage, - receipt_storage, - sender_escrow_storage.clone(), - query_appraisal_storage.clone(), - allocation_ids_set, - sender_ids_set, - ), - sender_escrow_storage, - query_appraisal_storage, - ) - } + (wallet, address.into()) +} - #[fixture] - fn escrow_adapters() -> (EscrowAdapterMock, EscrowStorage) { - let sender_escrow_storage = Arc::new(RwLock::new(HashMap::new())); - let escrow_adapter = EscrowAdapterMock::new(Arc::clone(&sender_escrow_storage)); - (escrow_adapter, sender_escrow_storage) - } +#[fixture] +fn allocation_ids() -> Vec
{ + vec![ + Address::from_str("0xabababababababababababababababababababab").unwrap(), + Address::from_str("0xdeaddeaddeaddeaddeaddeaddeaddeaddeaddead").unwrap(), + Address::from_str("0xbeefbeefbeefbeefbeefbeefbeefbeefbeefbeef").unwrap(), + Address::from_str("0x1234567890abcdef1234567890abcdef12345678").unwrap(), + ] +} - #[fixture] - fn receipt_adapters() -> (ReceiptChecksAdapterMock, Arc>>) { - let receipt_storage = Arc::new(RwLock::new(HashMap::new())); +#[fixture] +fn sender_ids() -> Vec
{ + vec![ + Address::from_str("0xfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfb").unwrap(), + Address::from_str("0xfafafafafafafafafafafafafafafafafafafafa").unwrap(), + Address::from_str("0xadadadadadadadadadadadadadadadadadadadad").unwrap(), + keys().1, + ] +} - let allocation_ids_set = Arc::new(RwLock::new(HashSet::from_iter(allocation_ids()))); - let sender_ids_set = Arc::new(RwLock::new(HashSet::from_iter(sender_ids()))); - let query_appraisal_storage = Arc::new(RwLock::new(HashMap::new())); +#[fixture] +fn domain_separator() -> Eip712Domain { + tap_eip712_domain(1, Address::from([0x11u8; 20])) +} - let receipt_checks_adapter = ReceiptChecksAdapterMock::new( - Arc::clone(&receipt_storage), - Arc::clone(&query_appraisal_storage), - Arc::clone(&allocation_ids_set), - Arc::clone(&sender_ids_set), - ); +struct ExecutorFixture { + executor: ExecutorMock, + escrow_storage: EscrowStorage, + query_appraisals: QueryAppraisals, + checks: Vec, +} - (receipt_checks_adapter, query_appraisal_storage) +#[fixture] +fn executor_mock( + domain_separator: Eip712Domain, + allocation_ids: Vec
, + sender_ids: Vec
, +) -> ExecutorFixture { + let escrow_storage = Arc::new(RwLock::new(HashMap::new())); + let rav_storage = Arc::new(RwLock::new(None)); + let query_appraisals = Arc::new(RwLock::new(HashMap::new())); + let receipt_storage = Arc::new(RwLock::new(HashMap::new())); + let timestamp_check = Arc::new(TimestampCheck::new(0)); + let executor = ExecutorMock::new( + rav_storage, + receipt_storage.clone(), + escrow_storage.clone(), + timestamp_check.clone(), + ); + + let mut checks = get_full_list_of_checks( + domain_separator, + sender_ids.iter().cloned().collect(), + Arc::new(RwLock::new(allocation_ids.iter().cloned().collect())), + receipt_storage, + query_appraisals.clone(), + ); + checks.push(timestamp_check); + + ExecutorFixture { + executor, + escrow_storage, + query_appraisals, + checks, } +} - #[rstest] - #[case::full_checks(get_full_list_of_checks())] - #[case::partial_checks(vec![ReceiptCheck::CheckSignature])] - #[case::no_checks(Vec::::new())] - #[tokio::test] - async fn manager_verify_and_store_varying_initial_checks( - executor_mock: (ExecutorMock, EscrowStorage, QueryAppraisals), - keys: (LocalWallet, Address), - allocation_ids: Vec
, - domain_separator: Eip712Domain, - #[case] initial_checks: Vec, - ) { - let (executor, escrow_storage, query_appraisal_storage) = executor_mock; - // give receipt 5 second variance for min start time - let starting_min_timestamp = get_current_timestamp_u64_ns().unwrap() - 500000000; - - let manager = Manager::new( - domain_separator.clone(), - executor, - get_full_list_of_checks(), - starting_min_timestamp, - ); +#[rstest] +#[case::full_checks(0..LENGTH_OF_CHECKS)] +#[case::partial_checks(0..2)] +#[case::no_checks(0..0)] +#[tokio::test] +async fn manager_verify_and_store_varying_initial_checks( + keys: (LocalWallet, Address), + allocation_ids: Vec
, + domain_separator: Eip712Domain, + #[case] range: Range, + executor_mock: ExecutorFixture, +) { + let ExecutorFixture { + executor, + checks, + query_appraisals, + escrow_storage, + .. + } = executor_mock; + // give receipt 5 second variance for min start time + let manager = Manager::new(domain_separator.clone(), executor, checks.clone()); + + let query_id = 1; + let value = 20u128; + let signed_receipt = EIP712SignedMessage::new( + &domain_separator, + Receipt::new(allocation_ids[0], value).unwrap(), + &keys.0, + ) + .unwrap(); + query_appraisals.write().unwrap().insert(query_id, value); + escrow_storage.write().unwrap().insert(keys.1, 999999); + + assert!(manager + .verify_and_store_receipt(signed_receipt, query_id, &checks[range]) + .await + .is_ok()); +} - let query_id = 1; +#[rstest] +#[case::full_checks(0..LENGTH_OF_CHECKS)] +#[case::partial_checks(0..2)] +#[case::no_checks(0..0)] +#[tokio::test] +async fn manager_create_rav_request_all_valid_receipts( + keys: (LocalWallet, Address), + allocation_ids: Vec
, + domain_separator: Eip712Domain, + #[case] range: Range, + executor_mock: ExecutorFixture, +) { + let ExecutorFixture { + executor, + checks, + query_appraisals, + escrow_storage, + .. + } = executor_mock; + let initial_checks = &checks[range]; + + let manager = Manager::new(domain_separator.clone(), executor, checks.clone()); + escrow_storage.write().unwrap().insert(keys.1, 999999); + + let mut stored_signed_receipts = Vec::new(); + for query_id in 0..10 { let value = 20u128; let signed_receipt = EIP712SignedMessage::new( &domain_separator, Receipt::new(allocation_ids[0], value).unwrap(), &keys.0, ) - .await .unwrap(); - query_appraisal_storage - .write() - .await - .insert(query_id, value); - escrow_storage.write().await.insert(keys.1, 999999); - + stored_signed_receipts.push(signed_receipt.clone()); + query_appraisals.write().unwrap().insert(query_id, value); assert!(manager - .verify_and_store_receipt(signed_receipt, query_id, initial_checks.as_slice()) + .verify_and_store_receipt(signed_receipt, query_id, initial_checks) .await .is_ok()); } - - #[rstest] - #[case::full_checks(get_full_list_of_checks())] - #[case::partial_checks(vec![ReceiptCheck::CheckSignature])] - #[case::no_checks(Vec::::new())] - #[tokio::test] - async fn manager_create_rav_request_all_valid_receipts( - executor_mock: (ExecutorMock, EscrowStorage, QueryAppraisals), - keys: (LocalWallet, Address), - allocation_ids: Vec
, - domain_separator: Eip712Domain, - #[case] initial_checks: Vec, - ) { - let (executor, escrow_storage, query_appraisal_storage) = executor_mock; - // give receipt 5 second variance for min start time - let starting_min_timestamp = get_current_timestamp_u64_ns().unwrap() - 500000000; - - let manager = Manager::new( - domain_separator.clone(), - executor, - get_full_list_of_checks(), - starting_min_timestamp, - ); - escrow_storage.write().await.insert(keys.1, 999999); - - let mut stored_signed_receipts = Vec::new(); - for query_id in 0..10 { - let value = 20u128; - let signed_receipt = EIP712SignedMessage::new( - &domain_separator, - Receipt::new(allocation_ids[0], value).unwrap(), - &keys.0, - ) - .await + let rav_request_result = manager.create_rav_request(0, None).await; + println!("{:?}", rav_request_result); + assert!(rav_request_result.is_ok()); + + let rav_request = rav_request_result.unwrap(); + // all passing + assert_eq!( + rav_request.valid_receipts.len(), + stored_signed_receipts.len() + ); + // no failing + assert_eq!(rav_request.invalid_receipts.len(), 0); + + let signed_rav = + EIP712SignedMessage::new(&domain_separator, rav_request.expected_rav.clone(), &keys.0) .unwrap(); - stored_signed_receipts.push(signed_receipt.clone()); - query_appraisal_storage - .write() - .await - .insert(query_id, value); - assert!(manager - .verify_and_store_receipt(signed_receipt, query_id, initial_checks.as_slice()) - .await - .is_ok()); - } - let rav_request_result = manager.create_rav_request(0, None).await; - assert!(rav_request_result.is_ok()); - - let rav_request = rav_request_result.unwrap(); - // all passing - assert_eq!( - rav_request.valid_receipts.len(), - stored_signed_receipts.len() - ); - // no failing - assert_eq!(rav_request.invalid_receipts.len(), 0); + assert!(manager + .verify_and_store_rav( + rav_request.expected_rav, + signed_rav, + |address: Address| async move { Ok(keys.1 == address) } + ) + .await + .is_ok()); +} - let signed_rav = - EIP712SignedMessage::new(&domain_separator, rav_request.expected_rav.clone(), &keys.0) - .await - .unwrap(); +#[rstest] +#[case::full_checks(0..LENGTH_OF_CHECKS)] +#[case::partial_checks(0..2)] +#[case::no_checks(0..0)] +#[tokio::test] +async fn manager_create_multiple_rav_requests_all_valid_receipts( + keys: (LocalWallet, Address), + allocation_ids: Vec
, + domain_separator: Eip712Domain, + #[case] range: Range, + executor_mock: ExecutorFixture, +) { + let ExecutorFixture { + executor, + checks, + query_appraisals, + escrow_storage, + .. + } = executor_mock; + let initial_checks = &checks[range]; + // give receipt 5 second variance for min start time + + let manager = Manager::new(domain_separator.clone(), executor, checks.clone()); + + escrow_storage.write().unwrap().insert(keys.1, 999999); + + let mut stored_signed_receipts = Vec::new(); + let mut expected_accumulated_value = 0; + for query_id in 0..10 { + let value = 20u128; + let signed_receipt = EIP712SignedMessage::new( + &domain_separator, + Receipt::new(allocation_ids[0], value).unwrap(), + &keys.0, + ) + .unwrap(); + stored_signed_receipts.push(signed_receipt.clone()); + query_appraisals.write().unwrap().insert(query_id, value); assert!(manager - .verify_and_store_rav(rav_request.expected_rav, signed_rav) + .verify_and_store_receipt(signed_receipt, query_id, initial_checks) .await .is_ok()); + expected_accumulated_value += value; } - - #[rstest] - #[case::full_checks(get_full_list_of_checks())] - #[case::partial_checks(vec![ReceiptCheck::CheckSignature])] - #[case::no_checks(Vec::::new())] - #[tokio::test] - async fn manager_create_multiple_rav_requests_all_valid_receipts( - executor_mock: (ExecutorMock, EscrowStorage, QueryAppraisals), - keys: (LocalWallet, Address), - allocation_ids: Vec
, - domain_separator: Eip712Domain, - #[case] initial_checks: Vec, - ) { - let (executor, escrow_storage, query_appraisal_storage) = executor_mock; - // give receipt 5 second variance for min start time - let starting_min_timestamp = get_current_timestamp_u64_ns().unwrap() - 500000000; - - let manager = Manager::new( - domain_separator.clone(), - executor, - get_full_list_of_checks(), - starting_min_timestamp, - ); - - escrow_storage.write().await.insert(keys.1, 999999); - - let mut stored_signed_receipts = Vec::new(); - let mut expected_accumulated_value = 0; - for query_id in 0..10 { - let value = 20u128; - let signed_receipt = EIP712SignedMessage::new( - &domain_separator, - Receipt::new(allocation_ids[0], value).unwrap(), - &keys.0, - ) - .await + let rav_request_result = manager.create_rav_request(0, None).await; + assert!(rav_request_result.is_ok()); + + let rav_request = rav_request_result.unwrap(); + // all receipts passing + assert_eq!( + rav_request.valid_receipts.len(), + stored_signed_receipts.len() + ); + // no receipts failing + assert_eq!(rav_request.invalid_receipts.len(), 0); + // accumulated value is correct + assert_eq!( + rav_request.expected_rav.valueAggregate, + expected_accumulated_value + ); + // no previous rav + assert!(rav_request.previous_rav.is_none()); + + let signed_rav = + EIP712SignedMessage::new(&domain_separator, rav_request.expected_rav.clone(), &keys.0) .unwrap(); - stored_signed_receipts.push(signed_receipt.clone()); - query_appraisal_storage - .write() - .await - .insert(query_id, value); - assert!(manager - .verify_and_store_receipt(signed_receipt, query_id, initial_checks.as_slice()) - .await - .is_ok()); - expected_accumulated_value += value; - } - let rav_request_result = manager.create_rav_request(0, None).await; - assert!(rav_request_result.is_ok()); - - let rav_request = rav_request_result.unwrap(); - // all receipts passing - assert_eq!( - rav_request.valid_receipts.len(), - stored_signed_receipts.len() - ); - // no receipts failing - assert_eq!(rav_request.invalid_receipts.len(), 0); - // accumulated value is correct - assert_eq!( - rav_request.expected_rav.valueAggregate, - expected_accumulated_value - ); - // no previous rav - assert!(rav_request.previous_rav.is_none()); + assert!(manager + .verify_and_store_rav( + rav_request.expected_rav, + signed_rav, + |address: Address| async move { Ok(keys.1 == address) } + ) + .await + .is_ok()); - let signed_rav = - EIP712SignedMessage::new(&domain_separator, rav_request.expected_rav.clone(), &keys.0) - .await - .unwrap(); + stored_signed_receipts.clear(); + for query_id in 10..20 { + let value = 20u128; + let signed_receipt = EIP712SignedMessage::new( + &domain_separator, + Receipt::new(allocation_ids[0], value).unwrap(), + &keys.0, + ) + .unwrap(); + stored_signed_receipts.push(signed_receipt.clone()); + query_appraisals.write().unwrap().insert(query_id, value); assert!(manager - .verify_and_store_rav(rav_request.expected_rav, signed_rav) + .verify_and_store_receipt(signed_receipt, query_id, initial_checks) .await .is_ok()); - - stored_signed_receipts.clear(); - for query_id in 10..20 { - let value = 20u128; - let signed_receipt = EIP712SignedMessage::new( - &domain_separator, - Receipt::new(allocation_ids[0], value).unwrap(), - &keys.0, - ) - .await + expected_accumulated_value += value; + } + let rav_request_result = manager.create_rav_request(0, None).await; + assert!(rav_request_result.is_ok()); + + let rav_request = rav_request_result.unwrap(); + // all receipts passing + assert_eq!( + rav_request.valid_receipts.len(), + stored_signed_receipts.len() + ); + // no receipts failing + assert_eq!(rav_request.invalid_receipts.len(), 0); + // accumulated value is correct + assert_eq!( + rav_request.expected_rav.valueAggregate, + expected_accumulated_value + ); + // Verify there is a previous rav + assert!(rav_request.previous_rav.is_some()); + + let signed_rav = + EIP712SignedMessage::new(&domain_separator, rav_request.expected_rav.clone(), &keys.0) .unwrap(); - stored_signed_receipts.push(signed_receipt.clone()); - query_appraisal_storage - .write() - .await - .insert(query_id, value); - assert!(manager - .verify_and_store_receipt(signed_receipt, query_id, initial_checks.as_slice()) - .await - .is_ok()); - expected_accumulated_value += value; - } - let rav_request_result = manager.create_rav_request(0, None).await; - assert!(rav_request_result.is_ok()); - - let rav_request = rav_request_result.unwrap(); - // all receipts passing - assert_eq!( - rav_request.valid_receipts.len(), - stored_signed_receipts.len() - ); - // no receipts failing - assert_eq!(rav_request.invalid_receipts.len(), 0); - // accumulated value is correct - assert_eq!( - rav_request.expected_rav.valueAggregate, - expected_accumulated_value - ); - // Verify there is a previous rav - assert!(rav_request.previous_rav.is_some()); + assert!(manager + .verify_and_store_rav( + rav_request.expected_rav, + signed_rav, + |address: Address| async move { Ok(keys.1 == address) } + ) + .await + .is_ok()); +} - let signed_rav = - EIP712SignedMessage::new(&domain_separator, rav_request.expected_rav.clone(), &keys.0) - .await - .unwrap(); +#[rstest] +#[tokio::test] +async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_timestamps( + keys: (LocalWallet, Address), + allocation_ids: Vec
, + domain_separator: Eip712Domain, + #[values(0..0, 0..2, 0..LENGTH_OF_CHECKS)] range: Range, + #[values(true, false)] remove_old_receipts: bool, + executor_mock: ExecutorFixture, +) { + let ExecutorFixture { + executor, + checks, + query_appraisals, + escrow_storage, + .. + } = executor_mock; + let initial_checks = &checks[range]; + // give receipt 5 second variance for min start time + let starting_min_timestamp = get_current_timestamp_u64_ns().unwrap() - 500000000; + + let manager = Manager::new(domain_separator.clone(), executor, checks.clone()); + + escrow_storage.write().unwrap().insert(keys.1, 999999); + + let mut stored_signed_receipts = Vec::new(); + let mut expected_accumulated_value = 0; + for query_id in 0..10 { + let value = 20u128; + let mut receipt = Receipt::new(allocation_ids[0], value).unwrap(); + receipt.timestamp_ns = starting_min_timestamp + query_id + 1; + let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &keys.0).unwrap(); + stored_signed_receipts.push(signed_receipt.clone()); + query_appraisals.write().unwrap().insert(query_id, value); assert!(manager - .verify_and_store_rav(rav_request.expected_rav, signed_rav) + .verify_and_store_receipt(signed_receipt, query_id, initial_checks) .await .is_ok()); + expected_accumulated_value += value; } - #[rstest] - #[tokio::test] - async fn manager_create_multiple_rav_requests_all_valid_receipts_consecutive_timestamps( - executor_mock: (ExecutorMock, EscrowStorage, QueryAppraisals), - keys: (LocalWallet, Address), - allocation_ids: Vec
, - domain_separator: Eip712Domain, - #[values(get_full_list_of_checks(), vec![ReceiptCheck::CheckSignature], Vec::::new())] - initial_checks: Vec, - #[values(true, false)] remove_old_receipts: bool, - ) { - let (executor, escrow_storage, query_appraisal_storage) = executor_mock; - // give receipt 5 second variance for min start time - let starting_min_timestamp = get_current_timestamp_u64_ns().unwrap() - 500000000; - - let manager = Manager::new( - domain_separator.clone(), - executor, - get_full_list_of_checks(), - starting_min_timestamp, - ); - - escrow_storage.write().await.insert(keys.1, 999999); - - let mut stored_signed_receipts = Vec::new(); - let mut expected_accumulated_value = 0; - for query_id in 0..10 { - let value = 20u128; - let mut receipt = Receipt::new(allocation_ids[0], value).unwrap(); - receipt.timestamp_ns = starting_min_timestamp + query_id + 1; - let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &keys.0) - .await - .unwrap(); - stored_signed_receipts.push(signed_receipt.clone()); - query_appraisal_storage - .write() - .await - .insert(query_id, value); - assert!(manager - .verify_and_store_receipt(signed_receipt, query_id, initial_checks.as_slice()) - .await - .is_ok()); - expected_accumulated_value += value; - } - - // Remove old receipts if requested - // This shouldn't do anything since there has been no rav created yet - if remove_old_receipts { - manager.remove_obsolete_receipts().await.unwrap(); - } - - let rav_request_1_result = manager.create_rav_request(0, None).await; - assert!(rav_request_1_result.is_ok()); - - let rav_request_1 = rav_request_1_result.unwrap(); - // all receipts passing - assert_eq!( - rav_request_1.valid_receipts.len(), - stored_signed_receipts.len() - ); - // no receipts failing - assert_eq!(rav_request_1.invalid_receipts.len(), 0); - // accumulated value is correct - assert_eq!( - rav_request_1.expected_rav.valueAggregate, - expected_accumulated_value - ); - // no previous rav - assert!(rav_request_1.previous_rav.is_none()); + // Remove old receipts if requested + // This shouldn't do anything since there has been no rav created yet + if remove_old_receipts { + manager.remove_obsolete_receipts().await.unwrap(); + } - let signed_rav_1 = EIP712SignedMessage::new( - &domain_separator, - rav_request_1.expected_rav.clone(), - &keys.0, + let rav_request_1_result = manager.create_rav_request(0, None).await; + assert!(rav_request_1_result.is_ok()); + + let rav_request_1 = rav_request_1_result.unwrap(); + // all receipts passing + assert_eq!( + rav_request_1.valid_receipts.len(), + stored_signed_receipts.len() + ); + // no receipts failing + assert_eq!(rav_request_1.invalid_receipts.len(), 0); + // accumulated value is correct + assert_eq!( + rav_request_1.expected_rav.valueAggregate, + expected_accumulated_value + ); + // no previous rav + assert!(rav_request_1.previous_rav.is_none()); + + let signed_rav_1 = EIP712SignedMessage::new( + &domain_separator, + rav_request_1.expected_rav.clone(), + &keys.0, + ) + .unwrap(); + assert!(manager + .verify_and_store_rav( + rav_request_1.expected_rav, + signed_rav_1, + |address: Address| async move { Ok(keys.1 == address) } ) .await - .unwrap(); + .is_ok()); + + stored_signed_receipts.clear(); + for query_id in 10..20 { + let value = 20u128; + let mut receipt = Receipt::new(allocation_ids[0], value).unwrap(); + receipt.timestamp_ns = starting_min_timestamp + query_id + 1; + let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &keys.0).unwrap(); + stored_signed_receipts.push(signed_receipt.clone()); + query_appraisals.write().unwrap().insert(query_id, value); assert!(manager - .verify_and_store_rav(rav_request_1.expected_rav, signed_rav_1) + .verify_and_store_receipt(signed_receipt, query_id, initial_checks) .await .is_ok()); + expected_accumulated_value += value; + } - stored_signed_receipts.clear(); - for query_id in 10..20 { - let value = 20u128; - let mut receipt = Receipt::new(allocation_ids[0], value).unwrap(); - receipt.timestamp_ns = starting_min_timestamp + query_id + 1; - let signed_receipt = EIP712SignedMessage::new(&domain_separator, receipt, &keys.0) - .await - .unwrap(); - stored_signed_receipts.push(signed_receipt.clone()); - query_appraisal_storage - .write() - .await - .insert(query_id, value); - assert!(manager - .verify_and_store_receipt(signed_receipt, query_id, initial_checks.as_slice()) - .await - .is_ok()); - expected_accumulated_value += value; - } - - // Remove old receipts if requested - if remove_old_receipts { - manager.remove_obsolete_receipts().await.unwrap(); - // We expect to have 10 receipts left in receipt storage - assert_eq!( - manager - .executor - .retrieve_receipts_in_timestamp_range(.., None) - .await - .unwrap() - .len(), - 10 - ); - } - - let rav_request_2_result = manager.create_rav_request(0, None).await; - assert!(rav_request_2_result.is_ok()); - - let rav_request_2 = rav_request_2_result.unwrap(); - // all receipts passing - assert_eq!( - rav_request_2.valid_receipts.len(), - stored_signed_receipts.len() - ); - // no receipts failing - assert_eq!(rav_request_2.invalid_receipts.len(), 0); - // accumulated value is correct + // Remove old receipts if requested + if remove_old_receipts { + manager.remove_obsolete_receipts().await.unwrap(); + // We expect to have 10 receipts left in receipt storage assert_eq!( - rav_request_2.expected_rav.valueAggregate, - expected_accumulated_value + manager + .executor + .retrieve_receipts_in_timestamp_range(.., None) + .await + .unwrap() + .len(), + 10 ); - // Verify there is a previous rav - assert!(rav_request_2.previous_rav.is_some()); + } - let signed_rav_2 = EIP712SignedMessage::new( - &domain_separator, - rav_request_2.expected_rav.clone(), - &keys.0, + let rav_request_2_result = manager.create_rav_request(0, None).await; + assert!(rav_request_2_result.is_ok()); + + let rav_request_2 = rav_request_2_result.unwrap(); + // all receipts passing + assert_eq!( + rav_request_2.valid_receipts.len(), + stored_signed_receipts.len() + ); + // no receipts failing + assert_eq!(rav_request_2.invalid_receipts.len(), 0); + // accumulated value is correct + assert_eq!( + rav_request_2.expected_rav.valueAggregate, + expected_accumulated_value + ); + // Verify there is a previous rav + assert!(rav_request_2.previous_rav.is_some()); + + let signed_rav_2 = EIP712SignedMessage::new( + &domain_separator, + rav_request_2.expected_rav.clone(), + &keys.0, + ) + .unwrap(); + assert!(manager + .verify_and_store_rav( + rav_request_2.expected_rav, + signed_rav_2, + |address: Address| async move { Ok(keys.1 == address) } ) .await - .unwrap(); - assert!(manager - .verify_and_store_rav(rav_request_2.expected_rav, signed_rav_2) - .await - .is_ok()); - } + .is_ok()); } diff --git a/tap_core/src/tap_receipt/mod.rs b/tap_core/src/tap_receipt/mod.rs index 8e6912ca..4261a21b 100644 --- a/tap_core/src/tap_receipt/mod.rs +++ b/tap_core/src/tap_receipt/mod.rs @@ -15,9 +15,10 @@ pub use received_receipt::{ }; use serde::{Deserialize, Serialize}; -use strum_macros::{Display, EnumString}; use thiserror::Error; +use crate::checks::CheckingChecks; + #[derive(Error, Debug, Clone, Serialize, Deserialize)] pub enum ReceiptError { #[error("invalid allocation ID: {received_allocation_id}")] @@ -40,36 +41,4 @@ pub enum ReceiptError { } pub type ReceiptResult = Result; -pub type ReceiptCheckResults = HashMap>>; -#[derive(Hash, Eq, PartialEq, Debug, Clone, EnumString, Display, Serialize, Deserialize)] -pub enum ReceiptCheck { - CheckUnique, - CheckAllocationId, - CheckTimestamp, - CheckValue, - CheckSignature, -} - -pub fn get_full_list_of_receipt_check_results() -> ReceiptCheckResults { - let mut all_checks_list = ReceiptCheckResults::new(); - all_checks_list.insert(ReceiptCheck::CheckUnique, None); - all_checks_list.insert(ReceiptCheck::CheckAllocationId, None); - all_checks_list.insert(ReceiptCheck::CheckTimestamp, None); - all_checks_list.insert(ReceiptCheck::CheckValue, None); - all_checks_list.insert(ReceiptCheck::CheckSignature, None); - - all_checks_list -} - -pub fn get_full_list_of_checks() -> Vec { - vec![ - ReceiptCheck::CheckUnique, - ReceiptCheck::CheckAllocationId, - ReceiptCheck::CheckTimestamp, - ReceiptCheck::CheckValue, - ReceiptCheck::CheckSignature, - ] -} - -#[cfg(test)] -pub mod tests; +pub type ReceiptCheckResults = HashMap<&'static str, CheckingChecks>; diff --git a/tap_core/src/tap_receipt/receipt.rs b/tap_core/src/tap_receipt/receipt.rs index 43c1317c..630d337b 100644 --- a/tap_core/src/tap_receipt/receipt.rs +++ b/tap_core/src/tap_receipt/receipt.rs @@ -41,3 +41,64 @@ impl Receipt { }) } } + +#[cfg(test)] +mod receipt_unit_test { + use super::*; + use rstest::*; + use std::str::FromStr; + use std::time::{SystemTime, UNIX_EPOCH}; + + #[fixture] + fn allocation_ids() -> Vec
{ + vec![ + Address::from_str("0xabababababababababababababababababababab").unwrap(), + Address::from_str("0xdeaddeaddeaddeaddeaddeaddeaddeaddeaddead").unwrap(), + Address::from_str("0xbeefbeefbeefbeefbeefbeefbeefbeefbeefbeef").unwrap(), + Address::from_str("0x1234567890abcdef1234567890abcdef12345678").unwrap(), + ] + } + + #[rstest] + fn test_new_receipt(allocation_ids: Vec
) { + let value = 1234; + + let receipt = Receipt::new(allocation_ids[0], value).unwrap(); + + assert_eq!(receipt.allocation_id, allocation_ids[0]); + assert_eq!(receipt.value, value); + + // Check that the timestamp is within a reasonable range + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Current system time should be greater than `UNIX_EPOCH`") + .as_nanos() as u64; + assert!(receipt.timestamp_ns <= now); + assert!(receipt.timestamp_ns >= now - 5000000); // 5 second tolerance + } + + #[rstest] + fn test_unique_nonce_and_timestamp(allocation_ids: Vec
) { + let value = 1234; + + let receipt1 = Receipt::new(allocation_ids[0], value).unwrap(); + let receipt2 = Receipt::new(allocation_ids[0], value).unwrap(); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Current system time should be greater than `UNIX_EPOCH`") + .as_nanos() as u64; + + // Check that nonces are different + // Note: This test has an *extremely low* (~1/2^64) probability of false failure, if a failure happens + // once it is not neccessarily a sign of an issue. If this test fails more than once, especially + // in a short period of time (within a ) then there may be an issue with randomness + // of the nonce generation. + assert_ne!(receipt1.nonce, receipt2.nonce); + + assert!(receipt1.timestamp_ns <= now); + assert!(receipt1.timestamp_ns >= now - 5000000); // 5 second tolerance + + assert!(receipt2.timestamp_ns <= now); + assert!(receipt2.timestamp_ns >= now - 5000000); // 5 second tolerance + } +} diff --git a/tap_core/src/tap_receipt/receipt_auditor.rs b/tap_core/src/tap_receipt/receipt_auditor.rs index 02cbea08..34f10578 100644 --- a/tap_core/src/tap_receipt/receipt_auditor.rs +++ b/tap_core/src/tap_receipt/receipt_auditor.rs @@ -1,311 +1,49 @@ // Copyright 2023-, Semiotic AI, Inc. // SPDX-License-Identifier: Apache-2.0 -use std::collections::HashSet; - +use alloy_primitives::Address; use alloy_sol_types::Eip712Domain; -use ethers::types::Signature; -use tokio::sync::RwLock; +use futures::Future; use crate::{ - adapters::{escrow_adapter::EscrowAdapter, receipt_checks_adapter::ReceiptChecksAdapter}, - eip_712_signed_message::EIP712SignedMessage, - receipt_aggregate_voucher::ReceiptAggregateVoucher, - tap_receipt::{Receipt, ReceiptCheck, ReceiptError, ReceiptResult}, - Error, Result, + adapters::escrow_adapter::EscrowAdapter, + tap_manager::SignedRAV, + tap_receipt::{ReceiptError, ReceiptResult}, + Error, }; -use super::{received_receipt::Checking, AwaitingReserve, ReceiptWithState}; +use super::{AwaitingReserve, ReceiptWithState}; pub struct ReceiptAuditor { domain_separator: Eip712Domain, executor: E, - min_timestamp_ns: RwLock, } impl ReceiptAuditor { - pub fn new( - domain_separator: Eip712Domain, - executor: E, - starting_min_timestamp_ns: u64, - ) -> Self { + pub fn new(domain_separator: Eip712Domain, executor: E) -> Self { Self { domain_separator, executor, - min_timestamp_ns: RwLock::new(starting_min_timestamp_ns), - } - } - - /// Updates the minimum timestamp that will be accepted for a receipt (exclusive). - pub async fn update_min_timestamp_ns(&self, min_timestamp_ns: u64) { - *self.min_timestamp_ns.write().await = min_timestamp_ns; - } - - async fn check_timestamp( - &self, - signed_receipt: &EIP712SignedMessage, - ) -> ReceiptResult<()> { - let min_timestamp_ns = *self.min_timestamp_ns.read().await; - if signed_receipt.message.timestamp_ns <= min_timestamp_ns { - return Err(ReceiptError::InvalidTimestamp { - received_timestamp: signed_receipt.message.timestamp_ns, - timestamp_min: min_timestamp_ns, - }); - } - Ok(()) - } - - async fn check_timestamp_batch( - &self, - received_receipts: &mut [ReceiptWithState], - ) -> Vec> { - let mut results = Vec::new(); - - for received_receipt in received_receipts - .iter_mut() - .filter(|r| r.state.checks.contains_key(&ReceiptCheck::CheckTimestamp)) - { - if received_receipt.state.checks[&ReceiptCheck::CheckTimestamp].is_none() { - let signed_receipt = &received_receipt.signed_receipt; - results.push(self.check_timestamp(signed_receipt).await); - } - } - - results - } - - async fn check_uniqueness_batch( - &self, - received_receipts: &mut [ReceiptWithState], - ) -> Vec> { - let mut results = Vec::new(); - - // If at least one of the receipts in the batch hasn't been checked for uniqueness yet, check the whole batch. - if received_receipts - .iter() - .filter(|r| r.state.checks.contains_key(&ReceiptCheck::CheckUnique)) - .any(|r| r.state.checks[&ReceiptCheck::CheckUnique].is_none()) - { - let mut signatures: HashSet = HashSet::new(); - - for received_receipt in received_receipts { - let signature = received_receipt.signed_receipt.signature; - if signatures.insert(signature) { - results.push(Ok(())); - } else { - results.push(Err(ReceiptError::NonUniqueReceipt)); - } - } - } - - results - } -} - -impl ReceiptAuditor -where - E: EscrowAdapter + ReceiptChecksAdapter, -{ - pub async fn check( - &self, - receipt_check: &ReceiptCheck, - signed_receipt: &EIP712SignedMessage, - query_id: u64, - receipt_id: u64, - ) -> ReceiptResult<()> { - match receipt_check { - ReceiptCheck::CheckUnique => self.check_uniqueness(signed_receipt, receipt_id).await, - ReceiptCheck::CheckAllocationId => self.check_allocation_id(signed_receipt).await, - ReceiptCheck::CheckSignature => self.check_signature(signed_receipt).await, - ReceiptCheck::CheckTimestamp => self.check_timestamp(signed_receipt).await, - ReceiptCheck::CheckValue => self.check_value(signed_receipt, query_id).await, - } - } - - pub async fn check_batch( - &self, - receipt_check: &ReceiptCheck, - received_receipts: &mut [ReceiptWithState], - ) -> Vec> { - match receipt_check { - ReceiptCheck::CheckUnique => self.check_uniqueness_batch(received_receipts).await, - ReceiptCheck::CheckAllocationId => { - self.check_allocation_id_batch(received_receipts).await - } - ReceiptCheck::CheckSignature => self.check_signature_batch(received_receipts).await, - ReceiptCheck::CheckTimestamp => self.check_timestamp_batch(received_receipts).await, - ReceiptCheck::CheckValue => self.check_value_batch(received_receipts).await, - } - } -} - -impl ReceiptAuditor -where - E: ReceiptChecksAdapter, -{ - async fn check_uniqueness( - &self, - signed_receipt: &EIP712SignedMessage, - receipt_id: u64, - ) -> ReceiptResult<()> { - if !self - .executor - .is_unique(signed_receipt, receipt_id) - .await - .map_err(|e| ReceiptError::CheckFailedToComplete { - source_error_message: e.to_string(), - })? - { - return Err(ReceiptError::NonUniqueReceipt); - } - Ok(()) - } - - async fn check_allocation_id( - &self, - signed_receipt: &EIP712SignedMessage, - ) -> ReceiptResult<()> { - if !self - .executor - .is_valid_allocation_id(signed_receipt.message.allocation_id) - .await - .map_err(|e| ReceiptError::CheckFailedToComplete { - source_error_message: e.to_string(), - })? - { - return Err(ReceiptError::InvalidAllocationID { - received_allocation_id: signed_receipt.message.allocation_id, - }); } - Ok(()) } - async fn check_allocation_id_batch( + pub async fn check_rav_signature( &self, - received_receipts: &mut [ReceiptWithState], - ) -> Vec> { - let mut results = Vec::new(); - - for received_receipt in received_receipts.iter_mut().filter(|r| { - r.state - .checks - .contains_key(&ReceiptCheck::CheckAllocationId) - }) { - if received_receipt.state.checks[&ReceiptCheck::CheckAllocationId].is_none() { - let signed_receipt = &received_receipt.signed_receipt; - results.push(self.check_allocation_id(signed_receipt).await); - } + signed_rav: &SignedRAV, + verify_signer: F, + ) -> Result<(), Error> + where + F: FnOnce(Address) -> Fut, + Fut: Future>, + { + let recovered_address = signed_rav.recover_signer(&self.domain_separator)?; + if verify_signer(recovered_address).await? { + Ok(()) + } else { + Err(Error::InvalidRecoveredSigner { + address: recovered_address, + }) } - - results - } - - async fn check_value( - &self, - signed_receipt: &EIP712SignedMessage, - query_id: u64, - ) -> ReceiptResult<()> { - if !self - .executor - .is_valid_value(signed_receipt.message.value, query_id) - .await - .map_err(|e| ReceiptError::CheckFailedToComplete { - source_error_message: e.to_string(), - })? - { - return Err(ReceiptError::InvalidValue { - received_value: signed_receipt.message.value, - }); - } - Ok(()) - } - - async fn check_value_batch( - &self, - received_receipts: &mut [ReceiptWithState], - ) -> Vec> { - let mut results = Vec::new(); - - for received_receipt in received_receipts - .iter_mut() - .filter(|r| r.state.checks.contains_key(&ReceiptCheck::CheckValue)) - { - if received_receipt.state.checks[&ReceiptCheck::CheckValue].is_none() { - let signed_receipt = &received_receipt.signed_receipt; - results.push( - self.check_value(signed_receipt, received_receipt.query_id) - .await, - ); - } - } - - results - } - - async fn check_signature( - &self, - signed_receipt: &EIP712SignedMessage, - ) -> ReceiptResult<()> { - let receipt_signer_address = signed_receipt - .recover_signer(&self.domain_separator) - .map_err(|err| ReceiptError::InvalidSignature { - source_error_message: err.to_string(), - })?; - if !self - .executor - .is_valid_sender_id(receipt_signer_address) - .await - .map_err(|e| ReceiptError::CheckFailedToComplete { - source_error_message: e.to_string(), - })? - { - return Err(ReceiptError::InvalidSignature { - source_error_message: format!( - "Recovered sender id is not valid: {}", - receipt_signer_address - ), - }); - } - Ok(()) - } - - async fn check_signature_batch( - &self, - received_receipts: &mut [ReceiptWithState], - ) -> Vec> { - let mut results = Vec::new(); - - for received_receipt in received_receipts - .iter_mut() - .filter(|r| r.state.checks.contains_key(&ReceiptCheck::CheckSignature)) - { - if received_receipt.state.checks[&ReceiptCheck::CheckSignature].is_none() { - let signed_receipt = &received_receipt.signed_receipt; - results.push(self.check_signature(signed_receipt).await); - } - } - - results - } - - pub async fn check_rav_signature( - &self, - signed_rav: &EIP712SignedMessage, - ) -> Result<()> { - let rav_signer_address = signed_rav.recover_signer(&self.domain_separator)?; - if !self - .executor - .is_valid_sender_id(rav_signer_address) - .await - .map_err(|err| Error::AdapterError { - source_error: anyhow::Error::new(err), - })? - { - return Err(Error::InvalidRecoveredSigner { - address: rav_signer_address, - }); - } - Ok(()) } } diff --git a/tap_core/src/tap_receipt/received_receipt.rs b/tap_core/src/tap_receipt/received_receipt.rs index d008ef55..8a69d181 100644 --- a/tap_core/src/tap_receipt/received_receipt.rs +++ b/tap_core/src/tap_receipt/received_receipt.rs @@ -17,22 +17,22 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; -use super::{receipt_auditor::ReceiptAuditor, Receipt, ReceiptCheck, ReceiptCheckResults}; +use super::{receipt_auditor::ReceiptAuditor, Receipt, ReceiptCheckResults}; use crate::{ - adapters::{ - escrow_adapter::EscrowAdapter, receipt_checks_adapter::ReceiptChecksAdapter, - receipt_storage_adapter::StoredReceipt, - }, + adapters::{escrow_adapter::EscrowAdapter, receipt_storage_adapter::StoredReceipt}, + checks::{CheckingChecks, ReceiptCheck}, eip_712_signed_message::EIP712SignedMessage, - Error, Result, }; #[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(bound(deserialize = "'de: 'static"))] pub struct Checking { /// A list of checks to be completed for the receipt, along with their current result pub(crate) checks: ReceiptCheckResults, } + #[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(bound(deserialize = "'de: 'static"))] pub struct Failed { /// A list of checks to be completed for the receipt, along with their current result pub(crate) checks: ReceiptCheckResults, @@ -80,7 +80,7 @@ pub struct ReceiptWithId where T: ReceiptState, { - pub(crate) receipt_id: u64, + pub receipt_id: u64, pub(crate) receipt: ReceiptWithState, } @@ -223,7 +223,7 @@ impl ReceiptWithState { auditor: &ReceiptAuditor, ) -> ResultReceipt where - A: EscrowAdapter + ReceiptChecksAdapter, + A: EscrowAdapter, { match auditor.check_and_reserve_escrow(&self).await { Ok(_) => Ok(self.perform_state_changes(Reserved)), @@ -243,43 +243,14 @@ impl ReceiptWithState { /// /// Returns [`Error::InvalidCheckError] if requested error in not a required check (list of required checks provided by user on construction) /// - pub async fn perform_check( - &mut self, - check: &ReceiptCheck, - receipt_id: u64, - receipt_auditor: &ReceiptAuditor, - ) where - A: EscrowAdapter + ReceiptChecksAdapter, - { + pub async fn perform_check(&mut self, check_name: &'static str) { // Only perform check if it is incomplete // Don't check if already failed - if !self.check_is_complete(check) && !self.any_check_resulted_in_error() { - let _ = self.update_check( - check, - Some( - receipt_auditor - .check(check, &self.signed_receipt, self.query_id, receipt_id) - .await, - ), - ); - } - } - - pub async fn perform_check_batch( - batch: &mut [Self], - check: &ReceiptCheck, - receipt_auditor: &ReceiptAuditor, - ) -> Result<()> - where - A: EscrowAdapter + ReceiptChecksAdapter, - { - let results = receipt_auditor.check_batch(check, batch).await; - - for (receipt, result) in batch.iter_mut().zip(results) { - receipt.update_check(check, Some(result))?; + let check = self.state.checks.remove(check_name); + if let Some(check) = check { + let result = check.execute(self).await; + self.state.checks.insert(check_name, result); } - - Ok(()) } /// Completes a list of *incomplete* check and stores the result, if the check already has a result it is skipped @@ -292,16 +263,9 @@ impl ReceiptWithState { /// /// Returns [`Error::InvalidCheckError] if requested error in not a required check (list of required checks provided by user on construction) /// - pub async fn perform_checks( - &mut self, - checks: &[ReceiptCheck], - receipt_id: u64, - receipt_auditor: &ReceiptAuditor, - ) where - A: EscrowAdapter + ReceiptChecksAdapter, - { + pub async fn perform_checks(&mut self, checks: &[&'static str]) { for check in checks { - self.perform_check(check, receipt_id, receipt_auditor).await; + self.perform_check(check).await; } } @@ -309,18 +273,10 @@ impl ReceiptWithState { /// /// Returns `Err` only if unable to complete a check, returns `Ok` if no check failed to complete (*Important:* this is not the result of the check, just the result of _completing_ the check) /// - pub async fn finalize_receipt_checks( - mut self, - receipt_id: u64, - receipt_auditor: &ReceiptAuditor, - ) -> ResultReceipt - where - A: EscrowAdapter + ReceiptChecksAdapter, - { + pub async fn finalize_receipt_checks(mut self) -> ResultReceipt { let incomplete_checks = self.incomplete_checks(); - self.perform_checks(incomplete_checks.as_slice(), receipt_id, receipt_auditor) - .await; + self.perform_checks(incomplete_checks.as_slice()).await; if self.any_check_resulted_in_error() { let failed = self.perform_state_changes_into(); @@ -331,74 +287,33 @@ impl ReceiptWithState { } } - /// Returns all checks that completed with errors - pub fn completed_checks_with_errors(&self) -> ReceiptCheckResults { - self.state - .checks - .iter() - .filter_map(|(check, result)| { - if let Some(unwrapped_result) = result { - if unwrapped_result.is_err() { - return Some(((*check).clone(), Some((*unwrapped_result).clone()))); - } - } - None - }) - .collect() - } - /// Returns all checks that have not been completed - pub fn incomplete_checks(&self) -> Vec { - let incomplete_checks: Vec = self + pub(crate) fn incomplete_checks(&self) -> Vec<&'static str> { + let incomplete_checks = self .state .checks .iter() .filter_map(|(check, result)| { - if result.is_none() { - Some((*check).clone()) - } else { - None + if result.is_complete() { + return None; } + Some(*check) }) .collect(); incomplete_checks } - pub(crate) fn update_check( - &mut self, - check: &ReceiptCheck, - result: Option>, - ) -> Result<()> { - if !self.state.checks.contains_key(check) { - return Err(Error::InvalidCheckError { - check_string: check.to_string(), - }); - } - - self.state.checks.insert(check.clone(), result); - Ok(()) - } - - /// returns true `check` has a result, otherwise false - pub(crate) fn check_is_complete(&self, check: &ReceiptCheck) -> bool { - matches!(self.state.checks.get(check), Some(Some(_))) - } - fn any_check_resulted_in_error(&self) -> bool { - self.state.checks.iter().any(|(_, status)| match &status { - Some(result) => result.is_err(), - None => false, - }) - } - - pub fn checking_is_complete(&self) -> bool { - self.state.checks.iter().all(|(_, status)| status.is_some()) + self.state + .checks + .iter() + .any(|(_, status)| status.is_failed()) } fn get_empty_required_checks_hashmap(required_checks: &[ReceiptCheck]) -> ReceiptCheckResults { required_checks .iter() - .map(|check| (check.clone(), None)) + .map(|check| (check.typetag_name(), CheckingChecks::Pending(check.clone()))) .collect() } } @@ -438,3 +353,5 @@ where self.query_id } } +#[cfg(test)] +pub mod received_receipt_unit_test; diff --git a/tap_core/src/tap_receipt/received_receipt/received_receipt_unit_test.rs b/tap_core/src/tap_receipt/received_receipt/received_receipt_unit_test.rs new file mode 100644 index 00000000..6d1a7034 --- /dev/null +++ b/tap_core/src/tap_receipt/received_receipt/received_receipt_unit_test.rs @@ -0,0 +1,333 @@ +// Copyright 2023-, Semiotic AI, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::{ + collections::HashMap, + str::FromStr, + sync::{Arc, RwLock}, +}; + +use alloy_primitives::Address; +use alloy_sol_types::Eip712Domain; +use ethers::signers::{coins_bip39::English, LocalWallet, MnemonicBuilder, Signer}; +use rstest::*; + +use crate::{ + adapters::executor_mock::{EscrowStorage, ExecutorMock, QueryAppraisals}, + checks::{mock::get_full_list_of_checks, ReceiptCheck, TimestampCheck}, + eip_712_signed_message::EIP712SignedMessage, + tap_eip712_domain, + tap_receipt::{Receipt, ReceiptAuditor, ReceiptCheckResults, ReceivedReceipt}, +}; + +use super::{Checking, ReceiptWithState}; + +impl ReceiptWithState { + fn check_is_complete(&self, check: &str) -> bool { + self.state.checks.get(check).unwrap().is_complete() + } + + fn checking_is_complete(&self) -> bool { + self.state + .checks + .iter() + .all(|(_, status)| status.is_complete()) + } + /// Returns all checks that completed with errors + fn completed_checks_with_errors(&self) -> ReceiptCheckResults { + self.state + .checks + .iter() + .filter_map(|(check, result)| { + if result.is_failed() { + return Some((*check, result.clone())); + } + None + }) + .collect() + } +} + +#[fixture] +fn keys() -> (LocalWallet, Address) { + let wallet: LocalWallet = MnemonicBuilder::::default() + .phrase("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") + .build() + .unwrap(); + // Alloy library does not have feature parity with ethers library (yet) This workaround is needed to get the address + // to convert to an alloy Address. This will not be needed when the alloy library has wallet support. + let address: [u8; 20] = wallet.address().into(); + + (wallet, address.into()) +} + +#[fixture] +fn allocation_ids() -> Vec
{ + vec![ + Address::from_str("0xabababababababababababababababababababab").unwrap(), + Address::from_str("0xdeaddeaddeaddeaddeaddeaddeaddeaddeaddead").unwrap(), + Address::from_str("0xbeefbeefbeefbeefbeefbeefbeefbeefbeefbeef").unwrap(), + Address::from_str("0x1234567890abcdef1234567890abcdef12345678").unwrap(), + ] +} + +#[fixture] +fn sender_ids() -> Vec
{ + vec![ + Address::from_str("0xfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfb").unwrap(), + Address::from_str("0xfafafafafafafafafafafafafafafafafafafafa").unwrap(), + Address::from_str("0xadadadadadadadadadadadadadadadadadadadad").unwrap(), + keys().1, + ] +} + +#[fixture] +fn domain_separator() -> Eip712Domain { + tap_eip712_domain(1, Address::from([0x11u8; 20])) +} + +struct ExecutorFixture { + executor: ExecutorMock, + escrow_storage: EscrowStorage, + query_appraisals: QueryAppraisals, + checks: Vec, +} + +#[fixture] +fn executor_mock( + domain_separator: Eip712Domain, + allocation_ids: Vec
, + sender_ids: Vec
, +) -> ExecutorFixture { + let escrow_storage = Arc::new(RwLock::new(HashMap::new())); + let rav_storage = Arc::new(RwLock::new(None)); + let receipt_storage = Arc::new(RwLock::new(HashMap::new())); + let query_appraisals = Arc::new(RwLock::new(HashMap::new())); + + let timestamp_check = Arc::new(TimestampCheck::new(0)); + let executor = ExecutorMock::new( + rav_storage, + receipt_storage.clone(), + escrow_storage.clone(), + timestamp_check.clone(), + ); + let mut checks = get_full_list_of_checks( + domain_separator, + sender_ids.iter().cloned().collect(), + Arc::new(RwLock::new(allocation_ids.iter().cloned().collect())), + receipt_storage, + query_appraisals.clone(), + ); + checks.push(timestamp_check); + + ExecutorFixture { + executor, + escrow_storage, + query_appraisals, + checks, + } +} + +#[rstest] +#[tokio::test] +async fn initialization_valid_receipt( + keys: (LocalWallet, Address), + allocation_ids: Vec
, + domain_separator: Eip712Domain, + executor_mock: ExecutorFixture, +) { + let ExecutorFixture { checks, .. } = executor_mock; + + let signed_receipt = EIP712SignedMessage::new( + &domain_separator, + Receipt::new(allocation_ids[0], 10).unwrap(), + &keys.0, + ) + .unwrap(); + let query_id = 1; + + let received_receipt = ReceivedReceipt::new(signed_receipt, query_id, &checks); + + let received_receipt = match received_receipt { + ReceivedReceipt::Checking(checking) => checking, + _ => panic!("ReceivedReceipt should be in Checking state"), + }; + + assert!(received_receipt.completed_checks_with_errors().is_empty()); + assert!(received_receipt.incomplete_checks().len() == checks.len()); +} + +#[rstest] +#[tokio::test] +async fn partial_then_full_check_valid_receipt( + keys: (LocalWallet, Address), + domain_separator: Eip712Domain, + allocation_ids: Vec
, + executor_mock: ExecutorFixture, +) { + let ExecutorFixture { + checks, + escrow_storage, + query_appraisals, + .. + } = executor_mock; + + let query_value = 20u128; + let signed_receipt = EIP712SignedMessage::new( + &domain_separator, + Receipt::new(allocation_ids[0], query_value).unwrap(), + &keys.0, + ) + .unwrap(); + + let query_id = 1; + + // add escrow for sender + escrow_storage + .write() + .unwrap() + .insert(keys.1, query_value + 500); + // appraise query + query_appraisals + .write() + .unwrap() + .insert(query_id, query_value); + + let received_receipt = ReceivedReceipt::new(signed_receipt, query_id, &checks); + + let mut received_receipt = match received_receipt { + ReceivedReceipt::Checking(checking) => checking, + _ => panic!("ReceivedReceipt should be in Checking state"), + }; + + // perform single arbitrary check + let arbitrary_check_to_perform = checks[0].typetag_name(); + + received_receipt + .perform_check(arbitrary_check_to_perform) + .await; + assert!(received_receipt.check_is_complete(arbitrary_check_to_perform)); + + received_receipt + .perform_checks(&checks.iter().map(|c| c.typetag_name()).collect::>()) + .await; + assert!(received_receipt.checking_is_complete()); +} + +#[rstest] +#[tokio::test] +async fn partial_then_finalize_valid_receipt( + keys: (LocalWallet, Address), + allocation_ids: Vec
, + domain_separator: Eip712Domain, + executor_mock: ExecutorFixture, +) { + let ExecutorFixture { + checks, + executor, + escrow_storage, + query_appraisals, + .. + } = executor_mock; + let receipt_auditor = ReceiptAuditor::new(domain_separator.clone(), executor); + + let query_value = 20u128; + let signed_receipt = EIP712SignedMessage::new( + &domain_separator, + Receipt::new(allocation_ids[0], query_value).unwrap(), + &keys.0, + ) + .unwrap(); + + let query_id = 1; + + // add escrow for sender + escrow_storage + .write() + .unwrap() + .insert(keys.1, query_value + 500); + // appraise query + query_appraisals + .write() + .unwrap() + .insert(query_id, query_value); + + let received_receipt = ReceivedReceipt::new(signed_receipt, query_id, &checks); + + let mut received_receipt = match received_receipt { + ReceivedReceipt::Checking(checking) => checking, + _ => panic!("ReceivedReceipt should be in Checking state"), + }; + + // perform single arbitrary check + let arbitrary_check_to_perform = checks[0].typetag_name(); + + received_receipt + .perform_check(arbitrary_check_to_perform) + .await; + assert!(received_receipt.check_is_complete(arbitrary_check_to_perform)); + + let awaiting_escrow_receipt = received_receipt.finalize_receipt_checks().await; + assert!(awaiting_escrow_receipt.is_ok()); + + let awaiting_escrow_receipt = awaiting_escrow_receipt.unwrap(); + let receipt = awaiting_escrow_receipt + .check_and_reserve_escrow(&receipt_auditor) + .await; + assert!(receipt.is_ok()); +} + +#[rstest] +#[tokio::test] +async fn standard_lifetime_valid_receipt( + keys: (LocalWallet, Address), + allocation_ids: Vec
, + domain_separator: Eip712Domain, + executor_mock: ExecutorFixture, +) { + let ExecutorFixture { + checks, + executor, + escrow_storage, + query_appraisals, + .. + } = executor_mock; + let receipt_auditor = ReceiptAuditor::new(domain_separator.clone(), executor); + + let query_value = 20u128; + let signed_receipt = EIP712SignedMessage::new( + &domain_separator, + Receipt::new(allocation_ids[0], query_value).unwrap(), + &keys.0, + ) + .unwrap(); + + let query_id = 1; + + // add escrow for sender + escrow_storage + .write() + .unwrap() + .insert(keys.1, query_value + 500); + // appraise query + query_appraisals + .write() + .unwrap() + .insert(query_id, query_value); + + let received_receipt = ReceivedReceipt::new(signed_receipt, query_id, &checks); + + let received_receipt = match received_receipt { + ReceivedReceipt::Checking(checking) => checking, + _ => panic!("ReceivedReceipt should be in Checking state"), + }; + + let awaiting_escrow_receipt = received_receipt.finalize_receipt_checks().await; + assert!(awaiting_escrow_receipt.is_ok()); + + let awaiting_escrow_receipt = awaiting_escrow_receipt.unwrap(); + let receipt = awaiting_escrow_receipt + .check_and_reserve_escrow(&receipt_auditor) + .await; + assert!(receipt.is_ok()); +} diff --git a/tap_core/src/tap_receipt/tests/mod.rs b/tap_core/src/tap_receipt/tests/mod.rs deleted file mode 100644 index 73889387..00000000 --- a/tap_core/src/tap_receipt/tests/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright 2023-, Semiotic AI, Inc. -// SPDX-License-Identifier: Apache-2.0 - -pub mod receipt_tests; -pub mod received_receipt_tests; diff --git a/tap_core/src/tap_receipt/tests/receipt_tests.rs b/tap_core/src/tap_receipt/tests/receipt_tests.rs deleted file mode 100644 index 1820cdd0..00000000 --- a/tap_core/src/tap_receipt/tests/receipt_tests.rs +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2023-, Semiotic AI, Inc. -// SPDX-License-Identifier: Apache-2.0 - -#[cfg(test)] -mod receipt_unit_test { - use std::str::FromStr; - use std::time::{SystemTime, UNIX_EPOCH}; - - use alloy_primitives::Address; - use rstest::*; - - use crate::tap_receipt::Receipt; - - #[fixture] - fn allocation_ids() -> Vec
{ - vec![ - Address::from_str("0xabababababababababababababababababababab").unwrap(), - Address::from_str("0xdeaddeaddeaddeaddeaddeaddeaddeaddeaddead").unwrap(), - Address::from_str("0xbeefbeefbeefbeefbeefbeefbeefbeefbeefbeef").unwrap(), - Address::from_str("0x1234567890abcdef1234567890abcdef12345678").unwrap(), - ] - } - - #[rstest] - fn test_new_receipt(allocation_ids: Vec
) { - let value = 1234; - - let receipt = Receipt::new(allocation_ids[0], value).unwrap(); - - assert_eq!(receipt.allocation_id, allocation_ids[0]); - assert_eq!(receipt.value, value); - - // Check that the timestamp is within a reasonable range - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Current system time should be greater than `UNIX_EPOCH`") - .as_nanos() as u64; - assert!(receipt.timestamp_ns <= now); - assert!(receipt.timestamp_ns >= now - 5000000); // 5 second tolerance - } - - #[rstest] - fn test_unique_nonce_and_timestamp(allocation_ids: Vec
) { - let value = 1234; - - let receipt1 = Receipt::new(allocation_ids[0], value).unwrap(); - let receipt2 = Receipt::new(allocation_ids[0], value).unwrap(); - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Current system time should be greater than `UNIX_EPOCH`") - .as_nanos() as u64; - - // Check that nonces are different - // Note: This test has an *extremely low* (~1/2^64) probability of false failure, if a failure happens - // once it is not neccessarily a sign of an issue. If this test fails more than once, especially - // in a short period of time (within a ) then there may be an issue with randomness - // of the nonce generation. - assert_ne!(receipt1.nonce, receipt2.nonce); - - assert!(receipt1.timestamp_ns <= now); - assert!(receipt1.timestamp_ns >= now - 5000000); // 5 second tolerance - - assert!(receipt2.timestamp_ns <= now); - assert!(receipt2.timestamp_ns >= now - 5000000); // 5 second tolerance - } -} diff --git a/tap_core/src/tap_receipt/tests/received_receipt_tests.rs b/tap_core/src/tap_receipt/tests/received_receipt_tests.rs deleted file mode 100644 index 64a52029..00000000 --- a/tap_core/src/tap_receipt/tests/received_receipt_tests.rs +++ /dev/null @@ -1,328 +0,0 @@ -// Copyright 2023-, Semiotic AI, Inc. -// SPDX-License-Identifier: Apache-2.0 - -#[cfg(test)] -mod received_receipt_unit_test { - use std::{ - collections::{HashMap, HashSet}, - str::FromStr, - sync::Arc, - }; - - use alloy_primitives::Address; - use alloy_sol_types::Eip712Domain; - use ethers::signers::{coins_bip39::English, LocalWallet, MnemonicBuilder, Signer}; - use rstest::*; - use tokio::sync::RwLock; - - use crate::{ - adapters::{ - auditor_executor_mock::AuditorExecutorMock, - escrow_adapter_mock::EscrowAdapterMock, - executor_mock::{EscrowStorage, QueryAppraisals}, - receipt_checks_adapter_mock::ReceiptChecksAdapterMock, - receipt_storage_adapter_mock::ReceiptStorageAdapterMock, - }, - eip_712_signed_message::EIP712SignedMessage, - get_current_timestamp_u64_ns, tap_eip712_domain, - tap_receipt::{ - get_full_list_of_checks, Receipt, ReceiptAuditor, ReceiptCheck, ReceivedReceipt, - }, - }; - - #[fixture] - fn keys() -> (LocalWallet, Address) { - let wallet: LocalWallet = MnemonicBuilder::::default() - .phrase("abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about") - .build() - .unwrap(); - // Alloy library does not have feature parity with ethers library (yet) This workaround is needed to get the address - // to convert to an alloy Address. This will not be needed when the alloy library has wallet support. - let address: [u8; 20] = wallet.address().into(); - - (wallet, address.into()) - } - - #[fixture] - fn allocation_ids() -> Vec
{ - vec![ - Address::from_str("0xabababababababababababababababababababab").unwrap(), - Address::from_str("0xdeaddeaddeaddeaddeaddeaddeaddeaddeaddead").unwrap(), - Address::from_str("0xbeefbeefbeefbeefbeefbeefbeefbeefbeefbeef").unwrap(), - Address::from_str("0x1234567890abcdef1234567890abcdef12345678").unwrap(), - ] - } - - #[fixture] - fn sender_ids() -> Vec
{ - vec![ - Address::from_str("0xfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfb").unwrap(), - Address::from_str("0xfafafafafafafafafafafafafafafafafafafafa").unwrap(), - Address::from_str("0xadadadadadadadadadadadadadadadadadadadad").unwrap(), - keys().1, - ] - } - - #[fixture] - fn receipt_adapters() -> ( - ReceiptStorageAdapterMock, - ReceiptChecksAdapterMock, - Arc>>, - ) { - let receipt_storage = Arc::new(RwLock::new(HashMap::new())); - let receipt_storage_adapter = ReceiptStorageAdapterMock::new(Arc::clone(&receipt_storage)); - - let allocation_ids_set = Arc::new(RwLock::new(HashSet::from_iter(allocation_ids()))); - let sender_ids_set = Arc::new(RwLock::new(HashSet::from_iter(sender_ids()))); - let query_appraisal_storage = Arc::new(RwLock::new(HashMap::new())); - - let receipt_checks_adapter = ReceiptChecksAdapterMock::new( - Arc::clone(&receipt_storage), - Arc::clone(&query_appraisal_storage), - Arc::clone(&allocation_ids_set), - Arc::clone(&sender_ids_set), - ); - - ( - receipt_storage_adapter, - receipt_checks_adapter, - query_appraisal_storage, - ) - } - - #[fixture] - fn escrow_adapters() -> (EscrowAdapterMock, Arc>>) { - let sender_escrow_storage = Arc::new(RwLock::new(HashMap::new())); - let escrow_adapter = EscrowAdapterMock::new(Arc::clone(&sender_escrow_storage)); - (escrow_adapter, sender_escrow_storage) - } - - #[fixture] - fn auditor_executor() -> (AuditorExecutorMock, EscrowStorage, QueryAppraisals) { - let sender_escrow_storage = Arc::new(RwLock::new(HashMap::new())); - - let receipt_storage = Arc::new(RwLock::new(HashMap::new())); - - let allocation_ids_set = Arc::new(RwLock::new(HashSet::from_iter(allocation_ids()))); - let sender_ids_set = Arc::new(RwLock::new(HashSet::from_iter(sender_ids()))); - let query_appraisal_storage = Arc::new(RwLock::new(HashMap::new())); - ( - AuditorExecutorMock::new( - receipt_storage, - sender_escrow_storage.clone(), - query_appraisal_storage.clone(), - allocation_ids_set, - sender_ids_set, - ), - sender_escrow_storage, - query_appraisal_storage, - ) - } - - #[fixture] - fn domain_separator() -> Eip712Domain { - tap_eip712_domain(1, Address::from([0x11u8; 20])) - } - - #[rstest] - #[tokio::test] - async fn initialization_valid_receipt( - keys: (LocalWallet, Address), - allocation_ids: Vec
, - domain_separator: Eip712Domain, - ) { - let signed_receipt = EIP712SignedMessage::new( - &domain_separator, - Receipt::new(allocation_ids[0], 10).unwrap(), - &keys.0, - ) - .await - .unwrap(); - let query_id = 1; - let checks = get_full_list_of_checks(); - - let received_receipt = ReceivedReceipt::new(signed_receipt, query_id, &checks); - - let received_receipt = match received_receipt { - ReceivedReceipt::Checking(checking) => checking, - _ => panic!("ReceivedReceipt should be in Checking state"), - }; - - assert!(received_receipt.completed_checks_with_errors().is_empty()); - assert!(received_receipt.incomplete_checks().len() == checks.len()); - } - - #[rstest] - #[tokio::test] - async fn partial_then_full_check_valid_receipt( - keys: (LocalWallet, Address), - domain_separator: Eip712Domain, - allocation_ids: Vec
, - auditor_executor: (AuditorExecutorMock, EscrowStorage, QueryAppraisals), - ) { - let (executor, escrow_storage, query_appraisal_storage) = auditor_executor; - // give receipt 5 second variance for min start time - let starting_min_timestamp = get_current_timestamp_u64_ns().unwrap() - 500000000; - let receipt_auditor = - ReceiptAuditor::new(domain_separator.clone(), executor, starting_min_timestamp); - - let query_value = 20u128; - let signed_receipt = EIP712SignedMessage::new( - &domain_separator, - Receipt::new(allocation_ids[0], query_value).unwrap(), - &keys.0, - ) - .await - .unwrap(); - - let query_id = 1; - - // prepare adapters and storage to correctly validate receipt - - // add escrow for sender - escrow_storage - .write() - .await - .insert(keys.1, query_value + 500); - // appraise query - query_appraisal_storage - .write() - .await - .insert(query_id, query_value); - - let checks = get_full_list_of_checks(); - let received_receipt = ReceivedReceipt::new(signed_receipt, query_id, &checks); - let receipt_id = 0u64; - - let mut received_receipt = match received_receipt { - ReceivedReceipt::Checking(checking) => checking, - _ => panic!("ReceivedReceipt should be in Checking state"), - }; - - // perform single arbitrary check - let arbitrary_check_to_perform = ReceiptCheck::CheckUnique; - received_receipt - .perform_check(&arbitrary_check_to_perform, receipt_id, &receipt_auditor) - .await; - assert!(received_receipt.check_is_complete(&arbitrary_check_to_perform)); - - received_receipt - .perform_checks(&checks, receipt_id, &receipt_auditor) - .await; - assert!(received_receipt.checking_is_complete()); - } - - #[rstest] - #[tokio::test] - async fn partial_then_finalize_valid_receipt( - keys: (LocalWallet, Address), - allocation_ids: Vec
, - domain_separator: Eip712Domain, - auditor_executor: (AuditorExecutorMock, EscrowStorage, QueryAppraisals), - ) { - let (executor, escrow_storage, query_appraisal_storage) = auditor_executor; - // give receipt 5 second variance for min start time - let starting_min_timestamp = get_current_timestamp_u64_ns().unwrap() - 500000000; - let receipt_auditor = - ReceiptAuditor::new(domain_separator.clone(), executor, starting_min_timestamp); - - let query_value = 20u128; - let signed_receipt = EIP712SignedMessage::new( - &domain_separator, - Receipt::new(allocation_ids[0], query_value).unwrap(), - &keys.0, - ) - .await - .unwrap(); - - let query_id = 1; - - // prepare adapters and storage to correctly validate receipt - - // add escrow for sender - escrow_storage - .write() - .await - .insert(keys.1, query_value + 500); - // appraise query - query_appraisal_storage - .write() - .await - .insert(query_id, query_value); - - let checks = get_full_list_of_checks(); - let received_receipt = ReceivedReceipt::new(signed_receipt, query_id, &checks); - let receipt_id = 0u64; - - let mut received_receipt = match received_receipt { - ReceivedReceipt::Checking(checking) => checking, - _ => panic!("ReceivedReceipt should be in Checking state"), - }; - - // perform single arbitrary check - let arbitrary_check_to_perform = ReceiptCheck::CheckUnique; - - received_receipt - .perform_check(&arbitrary_check_to_perform, receipt_id, &receipt_auditor) - .await; - assert!(received_receipt.check_is_complete(&arbitrary_check_to_perform)); - - assert!(received_receipt - .finalize_receipt_checks(receipt_id, &receipt_auditor) - .await - .is_ok()); - } - - #[rstest] - #[tokio::test] - async fn standard_lifetime_valid_receipt( - keys: (LocalWallet, Address), - allocation_ids: Vec
, - domain_separator: Eip712Domain, - auditor_executor: (AuditorExecutorMock, EscrowStorage, QueryAppraisals), - ) { - let (executor, escrow_storage, query_appraisal_storage) = auditor_executor; - // give receipt 5 second variance for min start time - let starting_min_timestamp = get_current_timestamp_u64_ns().unwrap() - 500000000; - let receipt_auditor = - ReceiptAuditor::new(domain_separator.clone(), executor, starting_min_timestamp); - - let query_value = 20u128; - let signed_receipt = EIP712SignedMessage::new( - &domain_separator, - Receipt::new(allocation_ids[0], query_value).unwrap(), - &keys.0, - ) - .await - .unwrap(); - - let query_id = 1; - - // prepare adapters and storage to correctly validate receipt - - // add escrow for sender - escrow_storage - .write() - .await - .insert(keys.1, query_value + 500); - // appraise query - query_appraisal_storage - .write() - .await - .insert(query_id, query_value); - - let checks = get_full_list_of_checks(); - let received_receipt = ReceivedReceipt::new(signed_receipt, query_id, &checks); - let receipt_id = 0u64; - - let received_receipt = match received_receipt { - ReceivedReceipt::Checking(checking) => checking, - _ => panic!("ReceivedReceipt should be in Checking state"), - }; - - assert!(received_receipt - .finalize_receipt_checks(receipt_id, &receipt_auditor) - .await - .is_ok()); - } -} diff --git a/tap_integration_tests/tests/indexer_mock/mod.rs b/tap_integration_tests/tests/indexer_mock/mod.rs index 45a9c309..cc96559b 100644 --- a/tap_integration_tests/tests/indexer_mock/mod.rs +++ b/tap_integration_tests/tests/indexer_mock/mod.rs @@ -1,13 +1,11 @@ // Copyright 2023-, Semiotic AI, Inc. // SPDX-License-Identifier: Apache-2.0 -use std::{ - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, - }, - time::{SystemTime, UNIX_EPOCH}, +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, }; +use alloy_primitives::Address; use alloy_sol_types::Eip712Domain; use anyhow::{Error, Result}; use jsonrpsee::{ @@ -23,12 +21,10 @@ use tap_core::{ adapters::{ escrow_adapter::EscrowAdapter, rav_storage_adapter::{RAVRead, RAVStore}, - receipt_checks_adapter::ReceiptChecksAdapter, receipt_storage_adapter::{ReceiptRead, ReceiptStore}, }, + checks::ReceiptCheck, tap_manager::{Manager, SignedRAV, SignedReceipt}, - tap_receipt::ReceiptCheck, - Error as TapCoreError, }; /// Rpc trait represents a JSON-RPC server that has a single async method `request`. /// This method is designed to handle incoming JSON-RPC requests. @@ -56,6 +52,7 @@ pub struct RpcManager { receipt_count: Arc, // Thread-safe atomic counter for receipts threshold: u64, // The count at which a RAV request will be triggered aggregator_client: (HttpClient, String), // HTTP client for sending requests to the aggregator server + sender_id: Address, // The sender address } /// Implementation for `RpcManager`, includes the constructor and the `request` method. @@ -71,6 +68,7 @@ where initial_checks: Vec, required_checks: Vec, threshold: u64, + sender_id: Address, aggregate_server_address: String, aggregate_server_api_version: String, ) -> Result { @@ -79,11 +77,11 @@ where domain_separator, executor, required_checks, - get_current_timestamp_u64_ns()?, )), initial_checks, receipt_count: Arc::new(AtomicU64::new(0)), threshold, + sender_id, aggregator_client: ( HttpClientBuilder::default().build(aggregate_server_address)?, aggregate_server_api_version, @@ -95,15 +93,7 @@ where #[async_trait] impl RpcServer for RpcManager where - E: ReceiptStore - + ReceiptRead - + RAVStore - + RAVRead - + ReceiptChecksAdapter - + EscrowAdapter - + Send - + Sync - + 'static, + E: ReceiptStore + ReceiptRead + RAVStore + RAVRead + EscrowAdapter + Send + Sync + 'static, { async fn request( &self, @@ -135,6 +125,7 @@ where time_stamp_buffer, &self.aggregator_client, self.threshold as usize, + self.sender_id, ) .await { @@ -163,13 +154,13 @@ pub async fn run_server( threshold: u64, // The count at which a RAV request will be triggered aggregate_server_address: String, // Address of the aggregator server aggregate_server_api_version: String, // API version of the aggregator server + sender_id: Address, // The sender address ) -> Result<(ServerHandle, std::net::SocketAddr)> where E: ReceiptStore + ReceiptRead + RAVStore + RAVRead - + ReceiptChecksAdapter + EscrowAdapter + Clone + Send @@ -190,6 +181,7 @@ where initial_checks, required_checks, threshold, + sender_id, aggregate_server_address, aggregate_server_api_version, )?; @@ -204,9 +196,10 @@ async fn request_rav( time_stamp_buffer: u64, // Buffer for timestamping, see tap_core for details aggregator_client: &(HttpClient, String), // HttpClient for making requests to the tap_aggregator server threshold: usize, + expected_sender_id: Address, ) -> Result<()> where - E: ReceiptRead + RAVRead + RAVStore + EscrowAdapter + ReceiptChecksAdapter, + E: ReceiptRead + RAVRead + RAVStore + EscrowAdapter, { // Create the aggregate_receipts request params let rav_request = manager.create_rav_request(time_stamp_buffer, None).await?; @@ -224,7 +217,11 @@ where .request("aggregate_receipts", params) .await?; manager - .verify_and_store_rav(rav_request.expected_rav, remote_rav_result.data) + .verify_and_store_rav( + rav_request.expected_rav, + remote_rav_result.data, + |address| async move { Ok(address == expected_sender_id) }, + ) .await?; // For these tests, we expect every receipt to be valid, i.e. there should be no invalid receipts, nor any missing receipts (less than the expected threshold). @@ -237,16 +234,6 @@ where Ok(()) } -// get_current_timestamp_u64_ns function returns current system time since UNIX_EPOCH as a 64-bit unsigned integer. -fn get_current_timestamp_u64_ns() -> Result { - Ok(SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_err(|err| TapCoreError::InvalidSystemTime { - source_error_message: err.to_string(), - })? - .as_nanos() as u64) -} - fn to_rpc_error(e: Box, msg: &str) -> jsonrpsee::types::ErrorObjectOwned { jsonrpsee::types::ErrorObject::owned(-32000, format!("{} - {}", e, msg), None::<()>) } diff --git a/tap_integration_tests/tests/showcase.rs b/tap_integration_tests/tests/showcase.rs index 9121270a..61911669 100644 --- a/tap_integration_tests/tests/showcase.rs +++ b/tap_integration_tests/tests/showcase.rs @@ -7,10 +7,9 @@ use std::{ collections::{HashMap, HashSet}, convert::TryInto, - iter::FromIterator, net::{SocketAddr, TcpListener}, str::FromStr, - sync::Arc, + sync::{Arc, RwLock}, }; use alloy_primitives::Address; @@ -22,20 +21,15 @@ use jsonrpsee::{ }; use rand::{rngs::StdRng, Rng, SeedableRng}; use rstest::*; -use tokio::sync::RwLock; use tap_aggregator::{jsonrpsee_helpers, server as agg_server}; use tap_core::{ - adapters::{ - escrow_adapter_mock::EscrowAdapterMock, executor_mock::ExecutorMock, - rav_storage_adapter_mock::RAVStorageAdapterMock, - receipt_checks_adapter_mock::ReceiptChecksAdapterMock, - receipt_storage_adapter_mock::ReceiptStorageAdapterMock, - }, + adapters::executor_mock::{ExecutorMock, QueryAppraisals}, + checks::{mock::get_full_list_of_checks, ReceiptCheck, TimestampCheck}, eip_712_signed_message::EIP712SignedMessage, tap_eip712_domain, tap_manager::SignedRAV, - tap_receipt::{Receipt, ReceiptCheck, ReceivedReceipt}, + tap_receipt::Receipt, }; use crate::indexer_mock; @@ -123,6 +117,16 @@ fn allocation_ids() -> Vec
{ ] } +#[fixture] +fn sender_ids() -> Vec
{ + vec![ + Address::from_str("0xfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfbfb").unwrap(), + Address::from_str("0xfafafafafafafafafafafafafafafafafafafafa").unwrap(), + Address::from_str("0xadadadadadadadadadadadadadadadadadadadad").unwrap(), + keys_sender().1, + ] +} + // Domain separator is used to sign receipts/RAVs according to EIP-712 #[fixture] fn domain_separator() -> Eip712Domain { @@ -131,7 +135,8 @@ fn domain_separator() -> Eip712Domain { // Query price will typically be set by the Indexer. It's assumed to be part of the Indexer service. #[fixture] -fn query_price() -> Vec { +#[once] +fn query_price() -> &'static [u128] { let seed: Vec = (0..32u8).collect(); // A seed of your choice let mut rng: StdRng = SeedableRng::from_seed(seed.try_into().unwrap()); let mut v = Vec::new(); @@ -139,133 +144,76 @@ fn query_price() -> Vec { for _ in 0..num_queries() { v.push(rng.gen::() % 100); } - v + Box::leak(v.into_boxed_slice()) } // Available escrow is set by a Sender. It's assumed the Indexer has way of knowing this value. #[fixture] -fn available_escrow(query_price: Vec, num_batches: u64) -> u128 { - (num_batches as u128) * query_price.into_iter().sum::() +fn available_escrow(query_price: &[u128], num_batches: u64) -> u128 { + (num_batches as u128) * query_price.iter().sum::() } -// The escrow adapter, a storage struct that the Indexer uses to track the available escrow for each Sender #[fixture] -fn escrow_adapter() -> EscrowAdapterMock { - EscrowAdapterMock::new(Arc::new(RwLock::new(HashMap::new()))) +fn query_appraisals(query_price: &[u128]) -> QueryAppraisals { + Arc::new(RwLock::new( + query_price + .iter() + .enumerate() + .map(|(i, p)| (i as u64, *p)) + .collect(), + )) +} + +struct ExecutorFixture { + executor: ExecutorMock, + checks: Vec, } #[fixture] fn executor( - keys_sender: (LocalWallet, Address), - query_price: Vec, + domain_separator: Eip712Domain, allocation_ids: Vec
, - receipt_storage: Arc>>, -) -> ExecutorMock { - let (_, sender_address) = keys_sender; - let query_appraisals: HashMap<_, _> = (0u64..).zip(query_price).collect(); - let query_appraisal_storage = Arc::new(RwLock::new(query_appraisals)); - let allocation_ids: Arc>> = - Arc::new(RwLock::new(HashSet::from_iter(allocation_ids))); - let sender_ids: Arc>> = - Arc::new(RwLock::new(HashSet::from([sender_address]))); + sender_ids: Vec
, + query_appraisals: QueryAppraisals, +) -> ExecutorFixture { + let receipt_storage = Arc::new(RwLock::new(HashMap::new())); + let escrow_storage = Arc::new(RwLock::new(HashMap::new())); let rav_storage = Arc::new(RwLock::new(None)); - - let sender_escrow_storage = Arc::new(RwLock::new(HashMap::new())); - - ExecutorMock::new( + let timestamp_check = Arc::new(TimestampCheck::new(0)); + let executor = ExecutorMock::new( rav_storage, + receipt_storage.clone(), + escrow_storage.clone(), + timestamp_check.clone(), + ); + let mut checks = get_full_list_of_checks( + domain_separator, + sender_ids.iter().cloned().collect(), + Arc::new(RwLock::new(allocation_ids.iter().cloned().collect())), receipt_storage, - sender_escrow_storage, - query_appraisal_storage, - allocation_ids, - sender_ids, - ) -} - -#[fixture] -fn receipt_storage() -> Arc>> { - Arc::new(RwLock::new(HashMap::new())) -} -// A storage struct used by the Indexer to store Receipts, all recieved receipts are stored here. There are flags which indicate the validity of the receipt. -#[fixture] -fn receipt_storage_adapter( - receipt_storage: Arc>>, -) -> ReceiptStorageAdapterMock { - ReceiptStorageAdapterMock::new(receipt_storage) -} - -// This adapter is used by the Indexer to check the validity of the receipt. -#[fixture] -fn receipt_checks_adapter( - keys_sender: (LocalWallet, Address), - query_price: Vec, - allocation_ids: Vec
, - receipt_storage: Arc>>, -) -> ReceiptChecksAdapterMock { - let (_, sender_address) = keys_sender; - let query_appraisals: HashMap<_, _> = (0u64..).zip(query_price).collect(); - let query_appraisals_storage = Arc::new(RwLock::new(query_appraisals)); - let allocation_ids: Arc>> = - Arc::new(RwLock::new(HashSet::from_iter(allocation_ids))); - let sender_ids: Arc>> = - Arc::new(RwLock::new(HashSet::from([sender_address]))); - - ReceiptChecksAdapterMock::new( - receipt_storage, - query_appraisals_storage, - allocation_ids, - sender_ids, - ) -} - -// A structure for storing received RAVs. -#[fixture] -fn rav_storage_adapter() -> RAVStorageAdapterMock { - RAVStorageAdapterMock::new(Arc::new(RwLock::new(None))) -} - -// These are the checks that the Indexer will perform when requesting a RAV. -// Testing with all checks enabled. -#[fixture] -fn required_checks() -> Vec { - vec![ - ReceiptCheck::CheckAllocationId, - ReceiptCheck::CheckSignature, - ReceiptCheck::CheckTimestamp, - ReceiptCheck::CheckUnique, - ReceiptCheck::CheckValue, - ] -} + query_appraisals, + ); + checks.push(timestamp_check); -// These are the checks that the Indexer will perform for each received receipt, i.e. before requesting a RAV. -// Testing with all checks enabled. -#[fixture] -fn initial_checks() -> Vec { - vec![ - ReceiptCheck::CheckAllocationId, - ReceiptCheck::CheckSignature, - ReceiptCheck::CheckTimestamp, - ReceiptCheck::CheckUnique, - ReceiptCheck::CheckValue, - ] + ExecutorFixture { executor, checks } } #[fixture] -fn indexer_1_adapters(executor: ExecutorMock) -> ExecutorMock { +fn indexer_1_adapters(executor: ExecutorFixture) -> ExecutorFixture { executor } #[fixture] -fn indexer_2_adapters(executor: ExecutorMock) -> ExecutorMock { +fn indexer_2_adapters(executor: ExecutorFixture) -> ExecutorFixture { executor } // Helper fixture to generate a batch of receipts to be sent to the Indexer. // Messages are formatted according to TAP spec and signed according to EIP-712. #[fixture] -async fn requests_1( +fn requests_1( keys_sender: (LocalWallet, Address), - query_price: Vec, + query_price: &[u128], num_batches: u64, allocation_ids: Vec
, domain_separator: Eip712Domain, @@ -278,15 +226,14 @@ async fn requests_1( &sender_key, allocation_ids[0], &domain_separator, - ) - .await?; + )?; Ok(requests) } #[fixture] -async fn requests_2( +fn requests_2( keys_sender: (LocalWallet, Address), - query_price: Vec, + query_price: &[u128], num_batches: u64, allocation_ids: Vec
, domain_separator: Eip712Domain, @@ -299,15 +246,14 @@ async fn requests_2( &sender_key, allocation_ids[1], &domain_separator, - ) - .await?; + )?; Ok(requests) } #[fixture] -async fn repeated_timestamp_request( +fn repeated_timestamp_request( keys_sender: (LocalWallet, Address), - query_price: Vec, + query_price: &[u128], allocation_ids: Vec
, domain_separator: Eip712Domain, num_batches: u64, @@ -322,8 +268,7 @@ async fn repeated_timestamp_request( &sender_key, allocation_ids[0], &domain_separator, - ) - .await?; + )?; // Create a new receipt with the timestamp equal to the latest receipt in the first RAV request batch let repeat_timestamp = requests[receipt_threshold_1 as usize - 1] @@ -340,14 +285,14 @@ async fn repeated_timestamp_request( // Sign the new receipt and insert it in the second batch requests[receipt_threshold_1 as usize].0 = - EIP712SignedMessage::new(&domain_separator, repeat_receipt, &sender_key).await?; + EIP712SignedMessage::new(&domain_separator, repeat_receipt, &sender_key)?; Ok(requests) } #[fixture] -async fn repeated_timestamp_incremented_by_one_request( +fn repeated_timestamp_incremented_by_one_request( keys_sender: (LocalWallet, Address), - query_price: Vec, + query_price: &[u128], allocation_ids: Vec
, domain_separator: Eip712Domain, num_batches: u64, @@ -361,8 +306,7 @@ async fn repeated_timestamp_incremented_by_one_request( &sender_key, allocation_ids[0], &domain_separator, - ) - .await?; + )?; // Create a new receipt with the timestamp equal to the latest receipt timestamp+1 in the first RAV request batch let repeat_timestamp = requests[receipt_threshold_1 as usize - 1] @@ -380,14 +324,14 @@ async fn repeated_timestamp_incremented_by_one_request( // Sign the new receipt and insert it in the second batch requests[receipt_threshold_1 as usize].0 = - EIP712SignedMessage::new(&domain_separator, repeat_receipt, &sender_key).await?; + EIP712SignedMessage::new(&domain_separator, repeat_receipt, &sender_key)?; Ok(requests) } #[fixture] -async fn wrong_requests( +fn wrong_requests( wrong_keys_sender: (LocalWallet, Address), - query_price: Vec, + query_price: &[u128], num_batches: u64, allocation_ids: Vec
, domain_separator: Eip712Domain, @@ -401,8 +345,7 @@ async fn wrong_requests( &sender_key, allocation_ids[0], &domain_separator, - ) - .await?; + )?; Ok(requests) } @@ -414,10 +357,8 @@ async fn single_indexer_test_server( http_request_size_limit: u32, http_response_size_limit: u32, http_max_concurrent_connections: u32, - indexer_1_adapters: ExecutorMock, + indexer_1_adapters: ExecutorFixture, available_escrow: u128, - initial_checks: Vec, - required_checks: Vec, receipt_threshold_1: u64, ) -> Result<(ServerHandle, SocketAddr, ServerHandle, SocketAddr)> { let sender_id = keys_sender.1; @@ -429,14 +370,14 @@ async fn single_indexer_test_server( http_max_concurrent_connections, ) .await?; - let executor = indexer_1_adapters; + let ExecutorFixture { executor, checks } = indexer_1_adapters; let (indexer_handle, indexer_addr) = start_indexer_server( domain_separator.clone(), executor, sender_id, available_escrow, - initial_checks, - required_checks, + checks.clone(), + checks, receipt_threshold_1, sender_aggregator_addr, ) @@ -456,11 +397,9 @@ async fn two_indexers_test_servers( http_request_size_limit: u32, http_response_size_limit: u32, http_max_concurrent_connections: u32, - indexer_1_adapters: ExecutorMock, - indexer_2_adapters: ExecutorMock, + indexer_1_adapters: ExecutorFixture, + indexer_2_adapters: ExecutorFixture, available_escrow: u128, - initial_checks: Vec, - required_checks: Vec, receipt_threshold_1: u64, ) -> Result<( ServerHandle, @@ -479,16 +418,23 @@ async fn two_indexers_test_servers( http_max_concurrent_connections, ) .await?; - let executor_1 = indexer_1_adapters; - let executor_2 = indexer_2_adapters; + let ExecutorFixture { + executor: executor_1, + checks: checks_1, + } = indexer_1_adapters; + + let ExecutorFixture { + executor: executor_2, + checks: checks_2, + } = indexer_2_adapters; let (indexer_handle, indexer_addr) = start_indexer_server( domain_separator.clone(), executor_1, sender_id, available_escrow, - initial_checks.clone(), - required_checks.clone(), + checks_1.clone(), + checks_1, receipt_threshold_1, sender_aggregator_addr, ) @@ -499,8 +445,8 @@ async fn two_indexers_test_servers( executor_2, sender_id, available_escrow, - initial_checks, - required_checks, + checks_2.clone(), + checks_2, receipt_threshold_1, sender_aggregator_addr, ) @@ -523,10 +469,8 @@ async fn single_indexer_wrong_sender_test_server( http_request_size_limit: u32, http_response_size_limit: u32, http_max_concurrent_connections: u32, - indexer_1_adapters: ExecutorMock, + indexer_1_adapters: ExecutorFixture, available_escrow: u128, - initial_checks: Vec, - required_checks: Vec, receipt_threshold_1: u64, ) -> Result<(ServerHandle, SocketAddr, ServerHandle, SocketAddr)> { let sender_id = wrong_keys_sender.1; @@ -538,15 +482,17 @@ async fn single_indexer_wrong_sender_test_server( http_max_concurrent_connections, ) .await?; - let executor = indexer_1_adapters; + let ExecutorFixture { + executor, checks, .. + } = indexer_1_adapters; let (indexer_handle, indexer_addr) = start_indexer_server( domain_separator.clone(), executor, sender_id, available_escrow, - initial_checks, - required_checks, + checks.clone(), + checks, receipt_threshold_1, sender_aggregator_addr, ) @@ -567,13 +513,13 @@ async fn test_manager_one_indexer( (ServerHandle, SocketAddr, ServerHandle, SocketAddr), Error, >, - #[future] requests_1: Result, u64)>>, + requests_1: Result, u64)>>, ) -> Result<(), Box> { let (_server_handle, socket_addr, _sender_handle, _sender_addr) = single_indexer_test_server.await?; let indexer_1_address = "http://".to_string() + &socket_addr.to_string(); let client_1 = HttpClientBuilder::default().build(indexer_1_address)?; - let requests = requests_1.await?; + let requests = requests_1?; for (receipt_1, id) in requests { let result = client_1.request("request", (id, receipt_1)).await; @@ -601,8 +547,8 @@ async fn test_manager_two_indexers( ), Error, >, - #[future] requests_1: Result, u64)>>, - #[future] requests_2: Result, u64)>>, + requests_1: Result, u64)>>, + requests_2: Result, u64)>>, ) -> Result<()> { let ( _server_handle_1, @@ -617,8 +563,8 @@ async fn test_manager_two_indexers( let indexer_2_address = "http://".to_string() + &socket_addr_2.to_string(); let client_1 = HttpClientBuilder::default().build(indexer_1_address)?; let client_2 = HttpClientBuilder::default().build(indexer_2_address)?; - let requests_1 = requests_1.await?; - let requests_2 = requests_2.await?; + let requests_1 = requests_1?; + let requests_2 = requests_2?; for ((receipt_1, id_1), (receipt_2, id_2)) in requests_1.iter().zip(requests_2) { let future_1 = client_1.request("request", (id_1, receipt_1)); @@ -638,14 +584,14 @@ async fn test_manager_wrong_aggregator_keys( (ServerHandle, SocketAddr, ServerHandle, SocketAddr), Error, >, - #[future] requests_1: Result, u64)>>, + requests_1: Result, u64)>>, receipt_threshold_1: u64, ) -> Result<()> { let (_server_handle, socket_addr, _sender_handle, _sender_addr) = single_indexer_wrong_sender_test_server.await?; let indexer_1_address = "http://".to_string() + &socket_addr.to_string(); let client_1 = HttpClientBuilder::default().build(indexer_1_address)?; - let requests = requests_1.await?; + let requests = requests_1?; let mut counter = 1; for (receipt_1, id) in requests { @@ -681,14 +627,14 @@ async fn test_manager_wrong_requestor_keys( (ServerHandle, SocketAddr, ServerHandle, SocketAddr), Error, >, - #[future] wrong_requests: Result, u64)>>, + wrong_requests: Result, u64)>>, receipt_threshold_1: u64, ) -> Result<()> { let (_server_handle, socket_addr, _sender_handle, _sender_addr) = single_indexer_test_server.await?; let indexer_1_address = "http://".to_string() + &socket_addr.to_string(); let client_1 = HttpClientBuilder::default().build(indexer_1_address)?; - let requests = wrong_requests.await?; + let requests = wrong_requests?; let mut counter = 1; for (receipt_1, id) in requests { @@ -727,10 +673,8 @@ async fn test_tap_manager_rav_timestamp_cuttoff( ), Error, >, - #[future] repeated_timestamp_request: Result, u64)>>, - #[future] repeated_timestamp_incremented_by_one_request: Result< - Vec<(EIP712SignedMessage, u64)>, - >, + repeated_timestamp_request: Result, u64)>>, + repeated_timestamp_incremented_by_one_request: Result, u64)>>, receipt_threshold_1: u64, ) -> Result<(), Box> { // This test checks that tap_core is correctly filtering receipts by timestamp. @@ -747,7 +691,7 @@ async fn test_tap_manager_rav_timestamp_cuttoff( let indexer_2_address = "http://".to_string() + &socket_addr_2.to_string(); let client_1 = HttpClientBuilder::default().build(indexer_1_address)?; let client_2 = HttpClientBuilder::default().build(indexer_2_address)?; - let requests = repeated_timestamp_request.await?; + let requests = repeated_timestamp_request?; let mut counter = 1; for (receipt_1, id) in requests { @@ -774,7 +718,7 @@ async fn test_tap_manager_rav_timestamp_cuttoff( // Here the timestamp first receipt in the second batch is equal to timestamp + 1 of the last receipt in the first batch. // No errors are expected. - let requests = repeated_timestamp_incremented_by_one_request.await?; + let requests = repeated_timestamp_incremented_by_one_request?; for (receipt_1, id) in requests { let result = client_2.request("request", (id, receipt_1)).await; match result { @@ -793,10 +737,8 @@ async fn test_tap_aggregator_rav_timestamp_cuttoff( http_request_size_limit: u32, http_response_size_limit: u32, http_max_concurrent_connections: u32, - #[future] repeated_timestamp_request: Result, u64)>>, - #[future] repeated_timestamp_incremented_by_one_request: Result< - Vec<(EIP712SignedMessage, u64)>, - >, + repeated_timestamp_request: Result, u64)>>, + repeated_timestamp_incremented_by_one_request: Result, u64)>>, receipt_threshold_1: u64, ) -> Result<(), Box> { // This test checks that tap_aggregator is correctly rejecting receipts with invalid timestamps @@ -814,7 +756,7 @@ async fn test_tap_aggregator_rav_timestamp_cuttoff( // The second batch has one receipt with the same timestamp as the latest receipt in the first batch. // The first RAV will have the same timestamp as one receipt in the second batch. // tap_aggregator should reject the second RAV request due to the repeated timestamp. - let requests = repeated_timestamp_request.await?; + let requests = repeated_timestamp_request?; let first_batch = &requests[0..receipt_threshold_1 as usize]; let second_batch = &requests[receipt_threshold_1 as usize..2 * receipt_threshold_1 as usize]; @@ -847,7 +789,7 @@ async fn test_tap_aggregator_rav_timestamp_cuttoff( // This is the second part of the test, two batches of receipts are sent to the aggregator. // The second batch has one receipt with the timestamp = timestamp+1 of the latest receipt in the first batch. // tap_aggregator should accept the second RAV request. - let requests = repeated_timestamp_incremented_by_one_request.await?; + let requests = repeated_timestamp_incremented_by_one_request?; let first_batch = &requests[0..receipt_threshold_1 as usize]; let second_batch = &requests[receipt_threshold_1 as usize..2 * receipt_threshold_1 as usize]; @@ -882,8 +824,8 @@ async fn test_tap_aggregator_rav_timestamp_cuttoff( Ok(()) } -async fn generate_requests( - query_price: Vec, +fn generate_requests( + query_price: &[u128], num_batches: u64, sender_key: &LocalWallet, allocation_id: Address, @@ -893,14 +835,13 @@ async fn generate_requests( let mut counter = 0; for _ in 0..num_batches { - for value in &query_price { + for value in query_price { requests.push(( EIP712SignedMessage::new( domain_separator, Receipt::new(allocation_id, *value)?, sender_key, - ) - .await?, + )?, counter, )); counter += 1; @@ -927,7 +868,7 @@ async fn start_indexer_server( listener.local_addr()?.port() }; - executor.increase_escrow(sender_id, available_escrow).await; + executor.increase_escrow(sender_id, available_escrow); let aggregate_server_address = "http://".to_string() + &agg_server_addr.to_string(); let (server_handle, socket_addr) = indexer_mock::run_server( @@ -939,6 +880,7 @@ async fn start_indexer_server( receipt_threshold, aggregate_server_address, aggregate_server_api_version(), + sender_id, ) .await?;