diff --git a/components/salsa-2022-macros/src/input.rs b/components/salsa-2022-macros/src/input.rs index 5b83efe6c..3fa593435 100644 --- a/components/salsa-2022-macros/src/input.rs +++ b/components/salsa-2022-macros/src/input.rs @@ -1,3 +1,5 @@ +use std::fmt::Formatter; + use crate::salsa_struct::{SalsaField, SalsaStruct, SalsaStructKind}; use proc_macro2::{Literal, TokenStream}; @@ -52,6 +54,7 @@ impl crate::options::AllowedOptions for InputStruct { impl InputStruct { fn generate_input(&self) -> syn::Result { let id_struct = self.id_struct(); + let (builder, builder_impl) = self.input_builder(); let inherent_impl = self.input_inherent_impl(); let ingredients_for_impl = self.input_ingredients(); let as_id_impl = self.as_id_impl(); @@ -60,6 +63,8 @@ impl InputStruct { Ok(quote! { #id_struct + #builder + #builder_impl #inherent_impl #ingredients_for_impl #as_id_impl @@ -159,6 +164,18 @@ impl InputStruct { } }; + let builder_fn: syn::ImplItemMethod = { + let builder_name = self.builder_name(); + parse_quote! { + fn new_builder() -> #builder_name { + #builder_name{ + #(#field_names: None,)* + durability: None, + } + } + } + }; + if singleton { let get: syn::ImplItemMethod = parse_quote! { #[track_caller] @@ -189,6 +206,8 @@ impl InputStruct { #(#field_getters)* #(#field_setters)* + + #builder_fn } } } else { @@ -199,11 +218,86 @@ impl InputStruct { #(#field_getters)* #(#field_setters)* + + #builder_fn } } } + } + + /// generate builder struct + fn input_builder(&self) -> (syn::ItemStruct, syn::ItemImpl) { + let ident = self.id_ident(); + let struct_name = self.builder_name(); + let field_names = self.all_field_names(); + let field_tys: Vec<_> = self.all_field_tys(); + let db_dyn_ty = self.db_dyn_ty(); + let jar_ty = self.jar_ty(); + let input_index = self.input_index(); + let field_indices = self.all_field_indices(); + + let buidler: syn::ItemStruct = parse_quote! { + struct #struct_name{ + #(#field_names: Option<#field_tys>,)* + durability: Option, + } + }; + + let build_fn: syn::ImplItemMethod = if self.0.is_isingleton() { + parse_quote! { + fn build(&mut self, __db: &#db_dyn_ty) -> std::result::Result<#ident, std::boxed::Box> { + let durability = self.durability.ok_or_else(|| "durability not provided")?; - // } + let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); + let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); + let __id = __ingredients.#input_index.new_singleton_input(__runtime); + + #( + __ingredients.#field_indices.store_new(__runtime, __id, self.#field_names.clone().ok_or_else(||"field not provided")?, durability); + )* + std::result::Result::Ok(__id) + } + } + } else { + parse_quote! { + fn build(&mut self, __db: &#db_dyn_ty) -> std::result::Result<#ident, std::boxed::Box> { + let durability = self.durability.ok_or_else(|| "durability not provided")?; + + let (__jar, __runtime) = <_ as salsa::storage::HasJar<#jar_ty>>::jar(__db); + let __ingredients = <#jar_ty as salsa::storage::HasIngredientsFor< #ident >>::ingredient(__jar); + let __id =__ingredients.#input_index.new_input(__runtime); + #( + __ingredients.#field_indices.store_new(__runtime, __id, self.#field_names.clone().ok_or_else(||"field not provided")?, durability); + )* + std::result::Result::Ok(__id) + } + } + }; + + let impls: syn::ItemImpl = { + parse_quote! { + impl #struct_name { + fn with_durability(&mut self, durability: salsa::durability::Durability) -> &mut Self { + self.durability = Some(durability); + self + } + + fn with_fields(&mut self, #(#field_names: #field_tys,)*) -> &mut Self { + #(self.#field_names = Some(#field_names);)* + self + } + + #build_fn + } + } + }; + + (buidler, impls) + } + + fn builder_name(&self) -> syn::Ident { + let ident = self.id_ident(); + syn::Ident::new(&format!("{}Builder", ident), ident.span()) } /// Generate the `IngredientsFor` impl for this entity. @@ -319,3 +413,14 @@ impl InputStruct { } } } + +#[derive(Debug)] +pub struct InputBuilderError(String); + +impl std::fmt::Display for InputBuilderError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + write!(f, "{:?}", self.0) + } +} + +impl std::error::Error for InputBuilderError {} diff --git a/salsa-2022-tests/tests/builder_api.rs b/salsa-2022-tests/tests/builder_api.rs new file mode 100644 index 000000000..f8cb7faeb --- /dev/null +++ b/salsa-2022-tests/tests/builder_api.rs @@ -0,0 +1,65 @@ +#![allow(warnings)] + +use expect_test::expect; +use salsa::Durability; +#[salsa::jar(db = Db)] +struct Jar(MyInput, MySingletonInput); + +trait Db: salsa::DbWithJar {} + +#[salsa::input(jar = Jar)] +struct MyInput { + field1: u32, + field2: u32, +} + +#[salsa::input(singleton)] +struct MySingletonInput { + field1: u32, +} + +#[salsa::db(Jar)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database {} + +impl Db for Database {} + +#[test] +fn test_builder() { + let mut db = Database::default(); + // durability set with new is always low + let mut input = MyInput::new(&db, 12, 13); + input + .set_field1(&mut db) + .with_durability(Durability::HIGH) + .to(20); + input + .set_field2(&mut db) + .with_durability(Durability::HIGH) + .to(40); + let input_from_builder = MyInput::new_builder() + .with_durability(Durability::HIGH) + .with_fields(20, 40) + .build(&db) + .unwrap(); + + assert_eq!(input.field1(&db), input_from_builder.field1(&db)); + assert_eq!(input.field2(&db), input_from_builder.field2(&db)); +} + +#[test] +#[should_panic] +// should panic because were creating he same input twice +fn test_sg_builder_panic() { + let mut db = Database::default(); + let input1 = MySingletonInput::new(&db, 5); + let input_from_builder = MySingletonInput::new_builder() + .with_durability(Durability::LOW) + .with_fields(5) + .build(&db) + .unwrap(); +}