diff --git a/src/rust-jobs/mock-test1/lib.rs b/src/rust-jobs/mock-test1/lib.rs index 4fe0fb2..816024b 100644 --- a/src/rust-jobs/mock-test1/lib.rs +++ b/src/rust-jobs/mock-test1/lib.rs @@ -12,7 +12,7 @@ pub trait Test { } fn test(mock_test: &T, x: i32, y: bool) -> i32 { - let ans = mock_test.a(x, y); + let ans = mock_test.a(x + 5, y); mock_test.b(); mock_test.c(); return ans; @@ -25,17 +25,19 @@ pub extern "C" fn entrypt() { let y: bool = verifier::any!(); verifier::assume!(x < 10); + verifier::assume!(y == true); mock .times_a(2) .times_b(2) .times_c(1) + .with_a((WithVal::Lt(15), WithVal::Eq(true))) .returning_a(|x, _y| x + 5) .returning_b(|| 4); verifier::vassert!(mock.a(x, y) < 15); verifier::vassert!(mock.b() == 4); - verifier::vassert!(test(&mock, x, y) < 15); + verifier::vassert!(test(&mock, x, y) < 20); verifier::vassert!(mock.expect_times_a(2)); verifier::vassert!(mock.expect_times_b(2)); verifier::vassert!(mock.expect_times_c(1)); diff --git a/src/seamock-lib/src/lib.rs b/src/seamock-lib/src/lib.rs index ea1fe4f..a3f715b 100644 --- a/src/seamock-lib/src/lib.rs +++ b/src/seamock-lib/src/lib.rs @@ -35,9 +35,9 @@ pub fn seamock(_args: TokenStream, input: TokenStream) -> TokenStream { generate_attr_names(method, &["times"]) }); - let mut returning_attrs = vec!{}; - let mut with_attrs = vec!{}; - let mut with_methods = vec!{}; + let mut returning_attrs = vec![]; + let mut with_attrs = vec![]; + let mut with_methods = vec![]; let ret = trait_methods.clone().flat_map(|method| { let method_output = &method.sig.output; @@ -134,8 +134,7 @@ pub fn seamock(_args: TokenStream, input: TokenStream) -> TokenStream { let method_name = &method.sig.ident; let method_output = &method.sig.output; let method_inputs = &method.sig.inputs; - let mut params = vec!{}; - // For each argument, create WithVal where T is the argument type + let mut params = vec![]; for arg in method_inputs.iter() { if let syn::FnArg::Typed(pat_type) = arg { let arg_name = match &*pat_type.pat { @@ -149,15 +148,49 @@ pub fn seamock(_args: TokenStream, input: TokenStream) -> TokenStream { let times_attr = Ident::new(&format!("times_{}", &method.sig.ident), method.sig.ident.span()); let max_times_attr = Ident::new(&format!("max_times_{}", &method.sig.ident), method.sig.ident.span()); let ret_func = Ident::new(&format!("val_returning_{}", &method.sig.ident), method.sig.ident.span()); - let error = format!("Hit times limit for {}", &method.sig.ident); + let max_attr_error = format!("Hit times limit for {}", &method.sig.ident); + let val_attr = Ident::new(&format!("val_with_{}", &method.sig.ident), method.sig.ident.span()); + let val_error = format!("Called {} with incorrect parameters", &method.sig.ident); + + let expected_val_logic = params.iter().enumerate().map(|(i, _)| { + let idx = syn::Index::from(i); + Some (quote! { + let with_val = &tuple.#idx; + let input = ¶ms.#idx; + val_match = val_match && match with_val { + WithVal::Gt(val) => input > val, + WithVal::Gte(val) => input >= val, + WithVal::Lt(val) => input < val, + WithVal::Lte(val) => input <= val, + WithVal::Eq(val) => input == val, + }; + }) + }); + + let with_matching = if params.len() > 0 { + quote! { + if let Some(tuple) = &self.#val_attr { + let params = (#(#params)*); + let mut val_match = true; + #( + #expected_val_logic + )* + if !val_match { + sea::sea_printf!(#val_error, &self.#val_attr); + verifier::vassert!(false); + } + } + } + } else { quote! {} }; Some (quote! { fn #method_name(#method_inputs) #method_output { self.#times_attr.replace_with(|&mut old| old + 1); - if (*self.#times_attr.borrow() > self.#max_times_attr) { - sea::sea_printf!(#error, self.#max_times_attr); + if *self.#times_attr.borrow() > self.#max_times_attr { + sea::sea_printf!(#max_attr_error, self.#max_times_attr); verifier::vassert!(false); } + #with_matching (self.#ret_func)(#(#params)*) } }) @@ -220,6 +253,8 @@ pub fn seamock(_args: TokenStream, input: TokenStream) -> TokenStream { // Combine the generated tokens let expanded = quote! { use core::cell::RefCell; + #[macro_use] + extern crate alloc; enum WithVal { Gt(T), Gte(T),