diff --git a/components/salsa-macro-rules/src/setup_accumulator_impl.rs b/components/salsa-macro-rules/src/setup_accumulator_impl.rs index f10318a2a..e8d4da49b 100644 --- a/components/salsa-macro-rules/src/setup_accumulator_impl.rs +++ b/components/salsa-macro-rules/src/setup_accumulator_impl.rs @@ -24,7 +24,8 @@ macro_rules! setup_accumulator_impl { fn $ingredient(db: &dyn $zalsa::Database) -> &$zalsa_struct::IngredientImpl<$Struct> { $CACHE.get_or_create(db, || { - db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Struct>>::default()) + db.zalsa() + .add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Struct>>() }) } diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index b506b28b0..7db4a09a5 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -67,7 +67,7 @@ macro_rules! setup_input_struct { use salsa::plumbing as $zalsa; use $zalsa::input as $zalsa_struct; - struct $Configuration; + $vis struct $Configuration; impl $zalsa_struct::Configuration for $Configuration { const DEBUG_NAME: &'static str = stringify!($Struct); @@ -89,13 +89,13 @@ macro_rules! setup_input_struct { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); CACHE.get_or_create(db, || { - db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + db.zalsa().add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) } pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { let zalsa_mut = db.zalsa_mut(); - let index = zalsa_mut.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()); + let index = zalsa_mut.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); let current_revision = zalsa_mut.current_revision(); let (ingredient, runtime) = zalsa_mut.lookup_ingredient_mut(index); let ingredient = ingredient.assert_type_mut::<$zalsa_struct::IngredientImpl>(); @@ -124,8 +124,17 @@ macro_rules! setup_input_struct { } impl $zalsa::SalsaStructInDb for $Struct { - fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { - aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + aux.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() + } + + #[inline] + fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option { + if type_id == $zalsa::TypeId::of::<$Struct>() { + $zalsa::Some($Struct(id)) + } else { + $zalsa::None + } } } @@ -187,7 +196,7 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database, { - $Configuration::ingredient(db.as_dyn_database()).get_singleton_input() + $Configuration::ingredient(db.as_dyn_database()).get_singleton_input(db) } #[track_caller] diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index f3eeb83c3..82ab8637e 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -119,7 +119,7 @@ macro_rules! setup_interned_struct { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); CACHE.get_or_create(db.as_dyn_database(), || { - db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + db.zalsa().add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) } } @@ -149,8 +149,17 @@ macro_rules! setup_interned_struct { } impl $zalsa::SalsaStructInDb for $Struct<'_> { - fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { - aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + aux.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() + } + + #[inline] + fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option { + if type_id == $zalsa::TypeId::of::<$Struct>() { + $zalsa::Some(<$Struct as $zalsa::FromId>::from_id(id)) + } else { + $zalsa::None + } } } diff --git a/components/salsa-macro-rules/src/setup_interned_struct_sans_lifetime.rs b/components/salsa-macro-rules/src/setup_interned_struct_sans_lifetime.rs index d1ae12cac..c28a7ca92 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct_sans_lifetime.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct_sans_lifetime.rs @@ -127,7 +127,7 @@ macro_rules! setup_interned_struct_sans_lifetime { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); CACHE.get_or_create(db.as_dyn_database(), || { - db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + db.zalsa().add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) } } @@ -157,8 +157,17 @@ macro_rules! setup_interned_struct_sans_lifetime { } impl $zalsa::SalsaStructInDb for $Struct { - fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { - aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + aux.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() + } + + #[inline] + fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option { + if type_id == $zalsa::TypeId::of::<$Struct>() { + $zalsa::Some(<$Struct as $zalsa::FromId>::from_id(id)) + } else { + $zalsa::None + } } } diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index accffbf44..01abbe971 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -99,8 +99,17 @@ macro_rules! setup_tracked_fn { $zalsa::IngredientCache::new(); impl $zalsa::SalsaStructInDb for $InternedData<'_> { - fn lookup_ingredient_index(_aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { - None + fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + $zalsa::IngredientIndices::uninitialized() + } + + #[inline] + fn cast(id: $zalsa::Id, type_id: ::core::any::TypeId) -> Option { + if type_id == ::core::any::TypeId::of::<$InternedData>() { + Some($InternedData(id, ::core::marker::PhantomData)) + } else { + None + } } } @@ -130,7 +139,7 @@ macro_rules! setup_tracked_fn { fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { $FN_CACHE.get_or_create(db.as_dyn_database(), || { ::zalsa_db(db); - db.zalsa().add_or_lookup_jar_by_type(&$Configuration) + db.zalsa().add_or_lookup_jar_by_type::<$Configuration>() }) } @@ -139,7 +148,7 @@ macro_rules! setup_tracked_fn { db: &dyn $Db, ) -> &$zalsa::interned::IngredientImpl<$Configuration> { $INTERN_CACHE.get_or_create(db.as_dyn_database(), || { - db.zalsa().add_or_lookup_jar_by_type(&$Configuration).successor(0) + db.zalsa().add_or_lookup_jar_by_type::<$Configuration>().successor(0) }) } } @@ -190,33 +199,43 @@ macro_rules! setup_tracked_fn { if $needs_interner { $Configuration::intern_ingredient(db).data(db.as_dyn_database(), key).clone() } else { - $zalsa::FromId::from_id(key) + $zalsa::FromIdWithDb::from_id(key, db) } } } } impl $zalsa::Jar for $Configuration { + fn create_dependencies(zalsa: &$zalsa::Zalsa) -> $zalsa::IngredientIndices + where + Self: Sized + { + $zalsa::macro_if! { + if $needs_interner { + $zalsa::IngredientIndices::uninitialized() + } else { + <$InternedData as $zalsa::SalsaStructInDb>::lookup_or_create_ingredient_index(zalsa) + } + } + } + fn create_ingredients( - &self, - aux: &dyn $zalsa::JarAux, + zalsa: &$zalsa::Zalsa, first_index: $zalsa::IngredientIndex, + struct_index: $zalsa::IngredientIndices, ) -> Vec> { let struct_index = $zalsa::macro_if! { if $needs_interner { - first_index.successor(0) + first_index.successor(0).into() } else { - <$InternedData as $zalsa::SalsaStructInDb>::lookup_ingredient_index(aux) - .expect( - "Salsa struct is passed as an argument of a tracked function, but its ingredient hasn't been added!" - ) + struct_index } }; let fn_ingredient = <$zalsa::function::IngredientImpl<$Configuration>>::new( struct_index, first_index, - aux, + zalsa, ); fn_ingredient.set_capacity($lru); $zalsa::macro_if! { @@ -235,8 +254,8 @@ macro_rules! setup_tracked_fn { } } - fn salsa_struct_type_id(&self) -> Option { - None + fn id_struct_type_id() -> $zalsa::TypeId { + $zalsa::TypeId::of::<$InternedData<'static>>() } } diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index a783e3762..06cba1675 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -134,7 +134,7 @@ macro_rules! setup_tracked_struct { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); CACHE.get_or_create(db, || { - db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl::<$Configuration>>::default()) + db.zalsa().add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) } } @@ -152,8 +152,17 @@ macro_rules! setup_tracked_struct { } impl $zalsa::SalsaStructInDb for $Struct<'_> { - fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { - aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + aux.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() + } + + #[inline] + fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option { + if type_id == $zalsa::TypeId::of::<$Struct>() { + $zalsa::Some(<$Struct as $zalsa::FromId>::from_id(id)) + } else { + $zalsa::None + } } } diff --git a/components/salsa-macros/src/enum_.rs b/components/salsa-macros/src/enum_.rs new file mode 100644 index 000000000..a6b9c8abd --- /dev/null +++ b/components/salsa-macros/src/enum_.rs @@ -0,0 +1,142 @@ +use crate::token_stream_with_error; +use proc_macro2::TokenStream; + +/// For an entity struct `Foo` with fields `f1: T1, ..., fN: TN`, we generate... +/// +/// * the "id struct" `struct Foo(salsa::Id)` +/// * the entity ingredient, which maps the id fields to the `Id` +/// * for each value field, a function ingredient +pub(crate) fn enum_(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let enum_item = parse_macro_input!(input as syn::ItemEnum); + match enum_impl(enum_item) { + Ok(v) => v.into(), + Err(e) => token_stream_with_error(input, e), + } +} + +fn enum_impl(enum_item: syn::ItemEnum) -> syn::Result { + let enum_name = enum_item.ident.clone(); + let mut variant_names = Vec::new(); + let mut variant_types = Vec::new(); + if enum_item.variants.is_empty() { + return Err(syn::Error::new( + enum_item.enum_token.span, + "empty enums are not permitted", + )); + } + for variant in &enum_item.variants { + let valid = match &variant.fields { + syn::Fields::Unnamed(fields) => { + variant_names.push(variant.ident.clone()); + variant_types.push(fields.unnamed[0].ty.clone()); + fields.unnamed.len() == 1 + } + syn::Fields::Unit | syn::Fields::Named(_) => false, + }; + if !valid { + return Err(syn::Error::new( + variant.ident.span(), + "the only form allowed is `Variant(SalsaStruct)`", + )); + } + } + + let (impl_generics, type_generics, where_clause) = enum_item.generics.split_for_impl(); + + let as_id = quote! { + impl #impl_generics zalsa::AsId for #enum_name #type_generics + #where_clause { + #[inline] + fn as_id(&self) -> zalsa::Id { + match self { + #( Self::#variant_names(__v) => zalsa::AsId::as_id(__v), )* + } + } + } + }; + + let from_id = quote! { + impl #impl_generics zalsa::FromIdWithDb for #enum_name #type_generics + #where_clause { + #[inline] + fn from_id(__id: zalsa::Id, __db: &(impl ?Sized + zalsa::Database)) -> Self { + let __zalsa = __db.zalsa(); + let __type_id = __zalsa.lookup_page_type_id(__id); + ::cast(__id, __type_id).expect("invalid enum variant") + } + } + }; + + let salsa_struct_in_db = quote! { + impl #impl_generics zalsa::SalsaStructInDb for #enum_name #type_generics + #where_clause { + #[inline] + fn lookup_or_create_ingredient_index(__zalsa: &zalsa::Zalsa) -> zalsa::IngredientIndices { + let mut __result = zalsa::IngredientIndices::uninitialized(); + #( + __result.merge( + &<#variant_types as zalsa::SalsaStructInDb>::lookup_or_create_ingredient_index(__zalsa) + ); + )* + __result + } + + #[inline] + fn cast(id: zalsa::Id, type_id: ::core::any::TypeId) -> Option { + #( + // Subtle: the ingredient can be missing, but in this case the id cannot come + // from it - because it wasn't initialized yet. + if let Some(result) = <#variant_types as zalsa::SalsaStructInDb>::cast(id, type_id) { + Some(Self::#variant_names(result)) + } else + )* + { + None + } + } + } + }; + + let std_traits = quote! { + impl #impl_generics ::core::marker::Copy for #enum_name #type_generics + #where_clause {} + + impl #impl_generics ::core::clone::Clone for #enum_name #type_generics + #where_clause { + #[inline] + fn clone(&self) -> Self { *self } + } + + impl #impl_generics ::core::cmp::Eq for #enum_name #type_generics + #where_clause {} + + impl #impl_generics ::core::cmp::PartialEq for #enum_name #type_generics + #where_clause { + #[inline] + fn eq(&self, __other: &Self) -> bool { + zalsa::AsId::as_id(self) == zalsa::AsId::as_id(__other) + } + } + + impl #impl_generics ::core::hash::Hash for #enum_name #type_generics + #where_clause { + #[inline] + fn hash<__H: ::core::hash::Hasher>(&self, __state: &mut __H) { + ::core::hash::Hash::hash(&zalsa::AsId::as_id(self), __state); + } + } + }; + + let all_impls = quote! { + const _: () = { + use salsa::plumbing as zalsa; + + #as_id + #from_id + #salsa_struct_in_db + + #std_traits + }; + }; + Ok(all_impls) +} diff --git a/components/salsa-macros/src/lib.rs b/components/salsa-macros/src/lib.rs index d7abe6643..4fda558ec 100644 --- a/components/salsa-macros/src/lib.rs +++ b/components/salsa-macros/src/lib.rs @@ -38,6 +38,7 @@ mod accumulator; mod db; mod db_lifetime; mod debug; +mod enum_; mod fn_util; mod hygiene; mod input; @@ -67,6 +68,11 @@ pub fn interned(args: TokenStream, input: TokenStream) -> TokenStream { interned::interned(args, input) } +#[proc_macro_derive(Enum)] +pub fn enum_(input: TokenStream) -> TokenStream { + enum_::enum_(input) +} + /// A discouraged variant of `#[salsa::interned]`. /// /// `#[salsa::interned_sans_lifetime]` is intended to be used in codebases that are migrating from diff --git a/src/accumulator.rs b/src/accumulator.rs index 8cb38e2e1..886613c58 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -1,7 +1,7 @@ //! Basic test of accumulator functionality. use std::{ - any::Any, + any::{Any, TypeId}, fmt::{self, Debug}, marker::PhantomData, }; @@ -13,8 +13,8 @@ use accumulated_map::AccumulatedMap; use crate::{ cycle::CycleRecoveryStrategy, ingredient::{fmt_index, Ingredient, Jar}, - plumbing::JarAux, - zalsa::IngredientIndex, + plumbing::IngredientIndices, + zalsa::{IngredientIndex, Zalsa}, zalsa_local::QueryOrigin, Database, DatabaseKeyIndex, Id, Revision, }; @@ -47,15 +47,15 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - &self, - _aux: &dyn JarAux, + _zalsa: &Zalsa, first_index: IngredientIndex, + _dependencies: IngredientIndices, ) -> Vec> { vec![Box::new(>::new(first_index))] } - fn salsa_struct_type_id(&self) -> Option { - None + fn id_struct_type_id() -> TypeId { + TypeId::of::() } } @@ -70,9 +70,8 @@ impl IngredientImpl { where Db: ?Sized + Database, { - let jar: JarImpl = Default::default(); let zalsa = db.zalsa(); - let index = zalsa.add_or_lookup_jar_by_type(&jar); + let index = zalsa.add_or_lookup_jar_by_type::>(); let ingredient = zalsa.lookup_ingredient(index).assert_type::(); Some(ingredient) } diff --git a/src/function.rs b/src/function.rs index b06be486c..9577fddfc 100644 --- a/src/function.rs +++ b/src/function.rs @@ -5,7 +5,7 @@ use crate::{ cycle::CycleRecoveryStrategy, ingredient::fmt_index, key::DatabaseKeyIndex, - plumbing::JarAux, + memo_ingredient_indices::{IngredientIndices, MemoIngredientIndices}, salsa_struct::SalsaStructInDb, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, zalsa_local::QueryOrigin, @@ -95,7 +95,7 @@ pub struct IngredientImpl { index: IngredientIndex, /// The index for the memo/sync tables - memo_ingredient_index: MemoIngredientIndex, + memo_ingredient_indices: MemoIngredientIndices, /// Used to find memos to throw out when we have too many memoized values. lru: lru::Lru, @@ -126,10 +126,12 @@ impl IngredientImpl where C: Configuration, { - pub fn new(struct_index: IngredientIndex, index: IngredientIndex, aux: &dyn JarAux) -> Self { + pub fn new(struct_indices: IngredientIndices, index: IngredientIndex, zalsa: &Zalsa) -> Self { + let memo_ingredient_indices = struct_indices + .memo_indices(|struct_index| zalsa.next_memo_ingredient_index(struct_index, index)); Self { index, - memo_ingredient_index: aux.next_memo_ingredient_index(struct_index, index), + memo_ingredient_indices, lru: Default::default(), deleted_entries: Default::default(), } @@ -165,6 +167,7 @@ where zalsa: &'db Zalsa, id: Id, memo: memo::Memo>, + memo_ingredient_index: MemoIngredientIndex, ) -> &'db memo::Memo> { let memo = Arc::new(memo); let db_memo = unsafe { @@ -172,13 +175,21 @@ where // value is returned) and anything removed from map is added to deleted entries (ensured elsewhere). self.extend_memo_lifetime(&memo) }; - if let Some(old_value) = self.insert_memo_into_table_for(zalsa, id, memo) { + if let Some(old_value) = + self.insert_memo_into_table_for(zalsa, id, memo, memo_ingredient_index) + { // In case there is a reference to the old memo out there, we have to store it // in the deleted entries. This will get cleared when a new revision starts. self.deleted_entries.push(old_value); } db_memo } + + #[inline] + fn memo_ingredient_index(&self, zalsa: &Zalsa, id: Id) -> MemoIngredientIndex { + self.memo_ingredient_indices + .find(zalsa.ingredient_index(id)) + } } impl Ingredient for IngredientImpl diff --git a/src/function/execute.rs b/src/function/execute.rs index 4171fe6d4..21ba3c21a 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -85,6 +85,12 @@ where tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}"); - self.insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions)) + let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); + self.insert_memo( + zalsa, + id, + Memo::new(Some(value), revision_now, revisions), + memo_ingredient_index, + ) } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index f6d495dff..78c7f74ef 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,5 +1,6 @@ use super::{memo::Memo, Configuration, IngredientImpl}; use crate::accumulator::accumulated_map::InputAccumulatedValues; +use crate::zalsa::MemoIngredientIndex; use crate::{runtime::StampedValue, zalsa::ZalsaDatabase, AsDynDatabase as _, Id}; impl IngredientImpl @@ -37,17 +38,26 @@ where db: &'db C::DbView, id: Id, ) -> &'db Memo> { + let memo_ingredient_index = self.memo_ingredient_index(db.zalsa(), id); loop { - if let Some(memo) = self.fetch_hot(db, id).or_else(|| self.fetch_cold(db, id)) { + if let Some(memo) = self + .fetch_hot(db, id, memo_ingredient_index) + .or_else(|| self.fetch_cold(db, id, memo_ingredient_index)) + { return memo; } } } #[inline] - fn fetch_hot<'db>(&'db self, db: &'db C::DbView, id: Id) -> Option<&'db Memo>> { + fn fetch_hot<'db>( + &'db self, + db: &'db C::DbView, + id: Id, + memo_ingredient_index: MemoIngredientIndex, + ) -> Option<&'db Memo>> { let zalsa = db.zalsa(); - let memo_guard = self.get_memo_from_table_for(zalsa, id); + let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(memo) = &memo_guard { if memo.value.is_some() && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo) @@ -61,7 +71,12 @@ where None } - fn fetch_cold<'db>(&'db self, db: &'db C::DbView, id: Id) -> Option<&'db Memo>> { + fn fetch_cold<'db>( + &'db self, + db: &'db C::DbView, + id: Id, + memo_ingredient_index: MemoIngredientIndex, + ) -> Option<&'db Memo>> { let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(id); @@ -70,7 +85,7 @@ where db.as_dyn_database(), zalsa_local, database_key_index, - self.memo_ingredient_index, + memo_ingredient_index, )?; // Push the query on the stack. @@ -78,7 +93,7 @@ where // Now that we've claimed the item, check again to see if there's a "hot" value. let zalsa = db.zalsa(); - let opt_old_memo = self.get_memo_from_table_for(zalsa, id); + let opt_old_memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(old_memo) = &opt_old_memo { if old_memo.value.is_some() && self.deep_verify_memo(db, old_memo, &active_query) { // Unsafety invariant: memo is present in memo_map. diff --git a/src/function/inputs.rs b/src/function/inputs.rs index 8dce73da0..40060dddf 100644 --- a/src/function/inputs.rs +++ b/src/function/inputs.rs @@ -7,7 +7,8 @@ where C: Configuration, { pub(super) fn origin(&self, zalsa: &Zalsa, key: Id) -> Option { - self.get_memo_from_table_for(zalsa, key) + let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); + self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) .map(|m| m.revisions.origin.clone()) } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index b1d671a36..eb79289ad 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,6 +1,6 @@ use crate::{ key::DatabaseKeyIndex, - zalsa::{Zalsa, ZalsaDatabase}, + zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}, zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin}, AsDynDatabase as _, Id, Revision, }; @@ -18,6 +18,7 @@ where revision: Revision, ) -> bool { let (zalsa, zalsa_local) = db.zalsas(); + let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); loop { @@ -26,13 +27,15 @@ where tracing::debug!("{database_key_index:?}: maybe_changed_after(revision = {revision:?})"); // Check if we have a verified version: this is the hot path. - let memo_guard = self.get_memo_from_table_for(zalsa, id); + let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(memo) = &memo_guard { if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { return memo.revisions.changed_at > revision; } drop(memo_guard); // release the arc-swap guard before cold path - if let Some(mcs) = self.maybe_changed_after_cold(db, id, revision) { + if let Some(mcs) = + self.maybe_changed_after_cold(db, id, revision, memo_ingredient_index) + { return mcs; } else { // We failed to claim, have to retry. @@ -49,6 +52,7 @@ where db: &'db C::DbView, key_index: Id, revision: Revision, + memo_ingredient_index: MemoIngredientIndex, ) -> Option { let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(key_index); @@ -57,12 +61,13 @@ where db.as_dyn_database(), zalsa_local, database_key_index, - self.memo_ingredient_index, + memo_ingredient_index, )?; let active_query = zalsa_local.push_query(database_key_index); // Load the current memo, if any. - let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index) else { + let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index) + else { return Some(true); }; diff --git a/src/function/memo.rs b/src/function/memo.rs index 304c6c314..68f04b20b 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use crossbeam::atomic::AtomicCell; +use crate::zalsa::MemoIngredientIndex; use crate::zalsa_local::QueryOrigin; use crate::{ key::DatabaseKeyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, Event, EventKind, Id, @@ -37,11 +38,12 @@ impl IngredientImpl { zalsa: &'db Zalsa, id: Id, memo: ArcMemo<'db, C>, + memo_ingredient_index: MemoIngredientIndex, ) -> Option> { let static_memo = unsafe { self.to_static(memo) }; let old_static_memo = zalsa .memo_table_for(id) - .insert(self.memo_ingredient_index, static_memo)?; + .insert(memo_ingredient_index, static_memo)?; unsafe { Some(self.to_self(old_static_memo)) } } @@ -52,8 +54,9 @@ impl IngredientImpl { &'db self, zalsa: &'db Zalsa, id: Id, + memo_ingredient_index: MemoIngredientIndex, ) -> Option> { - let static_memo = zalsa.memo_table_for(id).get(self.memo_ingredient_index)?; + let static_memo = zalsa.memo_table_for(id).get(memo_ingredient_index)?; unsafe { Some(self.to_self(static_memo)) } } @@ -62,7 +65,7 @@ impl IngredientImpl { /// or has values assigned as output of another query, this has no effect. pub(super) fn evict_value_from_memo_for<'db>(&'db self, zalsa: &'db Zalsa, id: Id) { zalsa.memo_table_for(id).map_memo::>>( - self.memo_ingredient_index, + self.memo_ingredient_index(zalsa, id), |memo| { match memo.revisions.origin { QueryOrigin::Assigned(_) diff --git a/src/function/specify.rs b/src/function/specify.rs index 9eccad65b..42774fd98 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -72,7 +72,8 @@ where accumulated: Default::default(), }; - if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) { + let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); + if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { self.backdate_if_appropriate(&old_memo, &mut revisions, &value); self.diff_outputs(db, database_key_index, &old_memo, &mut revisions); } @@ -88,7 +89,7 @@ where memo.tracing_debug(), key ); - self.insert_memo(zalsa, key, memo); + self.insert_memo(zalsa, key, memo, memo_ingredient_index); // Record that the current query *specified* a value for this cell. let database_key_index = self.database_key_index(key); @@ -106,8 +107,9 @@ where key: Id, ) { let zalsa = db.zalsa(); + let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); - let memo = match self.get_memo_from_table_for(zalsa, key) { + let memo = match self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { Some(m) => m, None => return, }; diff --git a/src/id.rs b/src/id.rs index 06b54b88b..c1c57b1be 100644 --- a/src/id.rs +++ b/src/id.rs @@ -2,6 +2,8 @@ use std::fmt::Debug; use std::hash::Hash; use std::num::NonZeroU32; +use crate::Database; + /// The `Id` of a salsa struct in the database [`Table`](`crate::table::Table`). /// /// The higher-order bits of an `Id` identify a [`Page`](`crate::table::Page`) @@ -72,3 +74,16 @@ impl FromId for Id { id } } + +/// Enums cannot use [`FromId`] because they need access to the DB to tell the `TypeId` of the variant, +/// so they use this trait instead, that has a blanket implementation for `FromId`. +pub trait FromIdWithDb: AsId + Copy + Eq + Hash + Debug { + fn from_id(id: Id, db: &(impl ?Sized + Database)) -> Self; +} + +impl FromIdWithDb for T { + #[inline] + fn from_id(id: Id, _db: &(impl ?Sized + Database)) -> Self { + FromId::from_id(id) + } +} diff --git a/src/ingredient.rs b/src/ingredient.rs index 8a46205d2..59658e0ad 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -6,7 +6,8 @@ use std::{ use crate::{ accumulator::accumulated_map::AccumulatedMap, cycle::CycleRecoveryStrategy, - zalsa::{IngredientIndex, MemoIngredientIndex}, + plumbing::IngredientIndices, + zalsa::{IngredientIndex, Zalsa}, zalsa_local::QueryOrigin, Database, DatabaseKeyIndex, Id, }; @@ -16,40 +17,33 @@ use super::Revision; /// A "jar" is a group of ingredients that are added atomically. /// Each type implementing jar can be added to the database at most once. pub trait Jar: Any { + /// This creates the ingredient dependencies of this jar. We need to split this from `create_ingredients()` + /// because while `create_ingredients()` is called, a lock on the ingredient map is held (to guarantee + /// atomicity), so other ingredients could not be created. + /// + /// Only tracked fns use this. + fn create_dependencies(_zalsa: &Zalsa) -> IngredientIndices + where + Self: Sized, + { + IngredientIndices::uninitialized() + } + /// Create the ingredients given the index of the first one. /// All subsequent ingredients will be assigned contiguous indices. fn create_ingredients( - &self, - aux: &dyn JarAux, + zalsa: &Zalsa, first_index: IngredientIndex, - ) -> Vec>; - - /// If this jar's first ingredient is a salsa struct, return its `TypeId` - fn salsa_struct_type_id(&self) -> Option; -} - -/// Methods on the Salsa database available to jars while they are creating their ingredients. -pub trait JarAux { - /// Return index of first ingredient from `jar` (based on the dynamic type of `jar`). - /// Returns `None` if the jar has not yet been added. - /// Used by tracked functions to lookup the ingredient index for the salsa struct they take as argument. - fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option; - - /// Returns the memo ingredient index that should be used to attach data from the given tracked function - /// to the given salsa struct (which the fn accepts as argument). - /// - /// The memo ingredient indices for a given function must be distinct from the memo indices - /// of all other functions that take the same salsa struct. - /// - /// # Parameters - /// - /// * `struct_ingredient_index`, the index of the salsa struct the memo will be attached to - /// * `ingredient_index`, the index of the tracked function whose data is stored in the memo - fn next_memo_ingredient_index( - &self, - struct_ingredient_index: IngredientIndex, - ingredient_index: IngredientIndex, - ) -> MemoIngredientIndex; + dependencies: IngredientIndices, + ) -> Vec> + where + Self: Sized; + + /// This returns the [`TypeId`] of the ID struct, that is, the struct that wraps `salsa::Id` + /// and carry the name of the jar. + fn id_struct_type_id() -> TypeId + where + Self: Sized; } pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { diff --git a/src/input.rs b/src/input.rs index e4bbffd85..822919917 100644 --- a/src/input.rs +++ b/src/input.rs @@ -14,10 +14,10 @@ use parking_lot::Mutex; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, cycle::CycleRecoveryStrategy, - id::{AsId, FromId}, + id::{AsId, FromIdWithDb}, ingredient::{fmt_index, Ingredient}, key::{DatabaseKeyIndex, DependencyIndex}, - plumbing::{Jar, JarAux, Stamp}, + plumbing::{Jar, Stamp}, table::{memo::MemoTable, sync::SyncTable, Slot, Table}, zalsa::{IngredientIndex, Zalsa}, zalsa_local::QueryOrigin, @@ -30,7 +30,7 @@ pub trait Configuration: Any { const IS_SINGLETON: bool; /// The input struct (which wraps an `Id`) - type Struct: FromId + 'static + Send + Sync; + type Struct: FromIdWithDb + 'static + Send + Sync; /// A (possibly empty) tuple of the fields for this struct. type Fields: Send + Sync; @@ -53,9 +53,9 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - &self, - _aux: &dyn JarAux, + _zalsa: &Zalsa, struct_index: crate::zalsa::IngredientIndex, + _dependencies: crate::memo_ingredient_indices::IngredientIndices, ) -> Vec> { let struct_ingredient: IngredientImpl = IngredientImpl::new(struct_index); @@ -66,8 +66,8 @@ impl Jar for JarImpl { .collect() } - fn salsa_struct_type_id(&self) -> Option { - Some(TypeId::of::<::Struct>()) + fn id_struct_type_id() -> TypeId { + TypeId::of::() } } @@ -129,7 +129,7 @@ impl IngredientImpl { drop(guard); } - FromId::from_id(id) + FromIdWithDb::from_id(id, db) } /// Change the value of the field `field_index` to a new value. @@ -168,12 +168,14 @@ impl IngredientImpl { } /// Get the singleton input previously created (if any). - pub fn get_singleton_input(&self) -> Option { + pub fn get_singleton_input(&self, db: &(impl ?Sized + Database)) -> Option { assert!( C::IS_SINGLETON, "get_singleton_input invoked on a non-singleton" ); - self.singleton_index.load().map(FromId::from_id) + self.singleton_index + .load() + .map(|id| FromIdWithDb::from_id(id, db)) } /// Access field of an input. diff --git a/src/interned.rs b/src/interned.rs index 76b1dda9b..7683a0659 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -4,11 +4,11 @@ use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::durability::Durability; use crate::ingredient::fmt_index; use crate::key::DependencyIndex; -use crate::plumbing::{Jar, JarAux}; +use crate::plumbing::{IngredientIndices, Jar}; use crate::table::memo::MemoTable; use crate::table::sync::SyncTable; use crate::table::Slot; -use crate::zalsa::IngredientIndex; +use crate::zalsa::{IngredientIndex, Zalsa}; use crate::zalsa_local::QueryOrigin; use crate::{Database, DatabaseKeyIndex, Id}; use std::any::TypeId; @@ -89,15 +89,15 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - &self, - _aux: &dyn JarAux, + _zalsa: &Zalsa, first_index: IngredientIndex, + _dependencies: IngredientIndices, ) -> Vec> { vec![Box::new(IngredientImpl::::new(first_index)) as _] } - fn salsa_struct_type_id(&self) -> Option { - Some(TypeId::of::<::Struct<'static>>()) + fn id_struct_type_id() -> TypeId { + TypeId::of::>() } } diff --git a/src/lib.rs b/src/lib.rs index 64e6cb301..3540047bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,7 @@ mod ingredient; mod input; mod interned; mod key; +mod memo_ingredient_indices; mod nonce; mod par_map; mod revision; @@ -53,6 +54,7 @@ pub use salsa_macros::input; pub use salsa_macros::interned; pub use salsa_macros::interned_sans_lifetime; pub use salsa_macros::tracked; +pub use salsa_macros::Enum; pub use salsa_macros::Update; pub mod prelude { @@ -67,6 +69,9 @@ pub mod prelude { /// /// The contents of this module are NOT subject to semver. pub mod plumbing { + pub use std::any::TypeId; + pub use std::option::Option::{self, None, Some}; + pub use crate::accumulator::Accumulator; pub use crate::array::Array; pub use crate::attach::attach; @@ -78,11 +83,12 @@ pub mod plumbing { pub use crate::function::should_backdate_value; pub use crate::id::AsId; pub use crate::id::FromId; + pub use crate::id::FromIdWithDb; pub use crate::id::Id; pub use crate::ingredient::Ingredient; pub use crate::ingredient::Jar; - pub use crate::ingredient::JarAux; pub use crate::key::DatabaseKeyIndex; + pub use crate::memo_ingredient_indices::IngredientIndices; pub use crate::revision::Revision; pub use crate::runtime::stamp; pub use crate::runtime::Runtime; diff --git a/src/memo_ingredient_indices.rs b/src/memo_ingredient_indices.rs new file mode 100644 index 000000000..0c7e3bf8d --- /dev/null +++ b/src/memo_ingredient_indices.rs @@ -0,0 +1,114 @@ +use std::fmt; + +use crate::zalsa::MemoIngredientIndex; +use crate::IngredientIndex; + +/// The maximum number of memo ingredient indices we can hold. This affects the +/// maximum number of variants possible in `#[derive(salsa::Enum)]`. We use a const +/// so that we don't allocate and to perhaps allow the compiler to vectorize the search. +pub const MAX_MEMO_INGREDIENT_INDICES: usize = 20; + +/// An ingredient has an [ingredient index][IngredientIndex]. However, Salsa also supports +/// enums of salsa structs, and those don't have a constant ingredient index, because they +/// are not ingredients by themselves but rather composed of them. However, an enum can be +/// viewed as a *set* of [`IngredientIndex`], where each instance of the enum can belong +/// to one, potentially different, index. This is what this type represents: a set of +/// `IngredientIndex`. +/// +/// This type is represented as an array, for efficiency, and supports up to 20 indices. +/// That means that Salsa enums can have at most 20 variants. Alternatively, they can also +/// contain Salsa enums as variants, but then the total number of variants is counter - because +/// what matters is the number of unique `IngredientIndex`s. +#[derive(Clone)] +pub struct IngredientIndices { + indices: [IngredientIndex; MAX_MEMO_INGREDIENT_INDICES], + len: u8, +} + +impl From for IngredientIndices { + #[inline] + fn from(value: IngredientIndex) -> Self { + let mut result = Self::uninitialized(); + result.indices[0] = value; + result.len = 1; + result + } +} + +impl fmt::Debug for IngredientIndices { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entries(&self.indices[..self.len.into()]) + .finish() + } +} + +impl IngredientIndices { + #[inline] + pub(crate) fn memo_indices( + &self, + mut memo_index: impl FnMut(IngredientIndex) -> MemoIngredientIndex, + ) -> MemoIngredientIndices { + let mut memo_ingredient_indices = [( + IngredientIndex::from((u32::MAX - 1) as usize), + MemoIngredientIndex::from_usize((u32::MAX - 1) as usize), + ); MAX_MEMO_INGREDIENT_INDICES]; + for i in 0..usize::from(self.len) { + let memo_ingredient_index = memo_index(self.indices[i]); + memo_ingredient_indices[i] = (self.indices[i], memo_ingredient_index); + } + MemoIngredientIndices { + indices: memo_ingredient_indices, + len: self.len, + } + } + + #[inline] + pub fn uninitialized() -> Self { + Self { + indices: [IngredientIndex::from((u32::MAX - 1) as usize); MAX_MEMO_INGREDIENT_INDICES], + len: 0, + } + } + + #[track_caller] + #[inline] + pub fn merge(&mut self, other: &Self) { + if usize::from(self.len) + usize::from(other.len) > MAX_MEMO_INGREDIENT_INDICES { + panic!("too many variants in the salsa enum"); + } + self.indices[usize::from(self.len)..][..usize::from(other.len)] + .copy_from_slice(&other.indices[..usize::from(other.len)]); + self.len += other.len; + } +} + +/// This type is to [`MemoIngredientIndex`] what [`IngredientIndices`] is to [`IngredientIndex`]: +/// since enums can contain different ingredient indices, they can also have different memo indices, +/// so we need to keep track of them. +#[derive(Clone)] +pub struct MemoIngredientIndices { + indices: [(IngredientIndex, MemoIngredientIndex); MAX_MEMO_INGREDIENT_INDICES], + len: u8, +} + +impl fmt::Debug for MemoIngredientIndices { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entries(&self.indices[..self.len.into()]) + .finish() + } +} + +impl MemoIngredientIndices { + #[inline] + pub(crate) fn find(&self, ingredient_index: IngredientIndex) -> MemoIngredientIndex { + for &(ingredient, memo_ingredient_index) in &self.indices[..(self.len - 1).into()] { + if ingredient == ingredient_index { + return memo_ingredient_index; + } + } + // It must be the last. + self.indices[usize::from(self.len - 1)].1 + } +} diff --git a/src/salsa_struct.rs b/src/salsa_struct.rs index 8674dc125..bba9fd34b 100644 --- a/src/salsa_struct.rs +++ b/src/salsa_struct.rs @@ -1,5 +1,52 @@ -use crate::{plumbing::JarAux, IngredientIndex}; +use std::any::TypeId; -pub trait SalsaStructInDb { - fn lookup_ingredient_index(aux: &dyn JarAux) -> Option; +use crate::memo_ingredient_indices::IngredientIndices; +use crate::zalsa::Zalsa; +use crate::Id; + +pub trait SalsaStructInDb: Sized { + /// This method is used to create ingredient indices. Note, it does *not* create the ingredients + /// themselves, that is the job of [`Zalsa::add_or_lookup_jar_by_type()`]. This method only creates + /// or lookup the indices. Naturally, implementors may call `add_or_lookup_jar_by_type()` to + /// create the ingredient, but they do not must, e.g. enums recursively call + /// `lookup_or_create_ingredient_index()` for their variants and combine them. + fn lookup_or_create_ingredient_index(zalsa: &Zalsa) -> IngredientIndices; + + /// This method is used to support nested Salsa enums, a.k.a.: + /// ```ignore + /// #[salsa::input] + /// struct Input {} + /// + /// #[salsa::interned] + /// struct Interned1 {} + /// + /// #[salsa::interned] + /// struct Interned2 {} + /// + /// #[derive(Debug, salsa::Enum)] + /// enum InnerEnum { + /// Input(Input), + /// Interned1(Interned1), + /// } + /// + /// #[derive(Debug, salsa::Enum)] + /// enum OuterEnum { + /// InnerEnum(InnerEnum), + /// Interned2(Interned2), + /// } + /// ``` + /// Imagine `OuterEnum` got a [`salsa::Id`][Id] and it wants to know which variant it belongs to. + /// It cannot ask each variant "what is your ingredient index?" and compare, because `InnerEnum` + /// has multiple possible ingredient indices. + /// + /// It could ask each variant "is this value yours?" and then invoke [`FromId`][crate::id::FromId] + /// with the correct variant, but that will duplicate the work: now `InnerEnum` will have to do + /// the same thing for its variants. + /// + /// Instead, we keep track of the [`TypeId`] of the ID struct, and ask each variant to "cast" it. If + /// it succeeds, we return that value; if not, we go to the next variant. + /// + /// Why `TypeId` and not `IngredientIndex`? Because it's cheaper and easier. The `TypeId` is readily + /// available at compile time, while the `IngredientIndex` requires a runtime lookup. + fn cast(id: Id, type_id: TypeId) -> Option; } diff --git a/src/table.rs b/src/table.rs index ce0747038..00529cc9d 100644 --- a/src/table.rs +++ b/src/table.rs @@ -30,6 +30,8 @@ pub(crate) struct Table { pub(crate) trait TablePage: Any + Send + Sync { fn hidden_type_name(&self) -> &'static str; + fn ingredient_index(&self) -> IngredientIndex; + /// Access the memos attached to `slot`. /// /// # Safety condition @@ -47,7 +49,6 @@ pub(crate) trait TablePage: Any + Send + Sync { pub(crate) struct Page { /// The ingredient for elements on this page. - #[allow(dead_code)] // pretty sure we'll need this ingredient: IngredientIndex, /// Number of elements of `data` that are initialized. @@ -118,6 +119,13 @@ impl Default for Table { } impl Table { + /// Returns the [`IngredientIndex`] for an [`Id`]. + #[inline] + pub fn ingredient_index(&self, id: Id) -> IngredientIndex { + let (page_idx, _) = split_id(id); + self.pages[page_idx.0].ingredient_index() + } + /// Get a reference to the data for `id`, which must have been allocated from this table with type `T`. /// /// # Panics @@ -254,6 +262,10 @@ impl TablePage for Page { std::any::type_name::() } + fn ingredient_index(&self) -> IngredientIndex { + self.ingredient + } + unsafe fn memos(&self, slot: SlotIndex, current_revision: Revision) -> &MemoTable { self.get(slot).memos(current_revision) } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 197687082..2b21a22a4 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -6,7 +6,7 @@ use tracked_field::FieldIngredientImpl; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, cycle::CycleRecoveryStrategy, - ingredient::{fmt_index, Ingredient, Jar, JarAux}, + ingredient::{fmt_index, Ingredient, Jar}, key::{DatabaseKeyIndex, DependencyIndex}, plumbing::ZalsaLocal, runtime::StampedValue, @@ -101,9 +101,9 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - &self, - _aux: &dyn JarAux, + _zalsa: &Zalsa, struct_index: crate::zalsa::IngredientIndex, + _dependencies: crate::memo_ingredient_indices::IngredientIndices, ) -> Vec> { let struct_ingredient = >::new(struct_index); @@ -114,8 +114,8 @@ impl Jar for JarImpl { .collect() } - fn salsa_struct_type_id(&self) -> Option { - Some(TypeId::of::<::Struct<'static>>()) + fn id_struct_type_id() -> TypeId { + TypeId::of::>() } } diff --git a/src/zalsa.rs b/src/zalsa.rs index 65684e488..53090cc2f 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -2,11 +2,13 @@ use append_only_vec::AppendOnlyVec; use parking_lot::{Mutex, RwLock}; use rustc_hash::FxHashMap; use std::any::{Any, TypeId}; +use std::collections::hash_map; use std::marker::PhantomData; use std::thread::ThreadId; use crate::cycle::CycleRecoveryStrategy; -use crate::ingredient::{Ingredient, Jar, JarAux}; +use crate::hash::FxDashMap; +use crate::ingredient::{Ingredient, Jar}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::{Runtime, WaitResult}; use crate::table::memo::MemoTable; @@ -136,6 +138,9 @@ pub struct Zalsa { /// adding new kinds of ingredients. jar_map: Mutex>, + /// A map from the `IngredientIndex` to the `TypeId` of its ID struct. + ingredient_to_id_struct_type_id_map: FxDashMap, + /// Vector of ingredients. /// /// Immutable unless the mutex on `ingredients_map` is held. @@ -155,6 +160,7 @@ impl Zalsa { views_of: Views::new::(), nonce: NONCE.nonce(), jar_map: Default::default(), + ingredient_to_id_struct_type_id_map: Default::default(), ingredients_vec: AppendOnlyVec::new(), ingredients_requiring_reset: AppendOnlyVec::new(), runtime: Runtime::default(), @@ -187,34 +193,62 @@ impl Zalsa { unsafe { self.table().syncs(id, self.current_revision()) } } + pub(crate) fn next_memo_ingredient_index( + &self, + struct_ingredient_index: IngredientIndex, + ingredient_index: IngredientIndex, + ) -> MemoIngredientIndex { + let mut memo_ingredients = self.memo_ingredient_indices.write(); + let idx = struct_ingredient_index.as_usize(); + let memo_ingredients = if let Some(memo_ingredients) = memo_ingredients.get_mut(idx) { + memo_ingredients + } else { + memo_ingredients.resize_with(idx + 1, Vec::new); + memo_ingredients.get_mut(idx).unwrap() + }; + let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap()); + memo_ingredients.push(ingredient_index); + mi + } + + #[inline] + pub fn lookup_page_type_id(&self, id: Id) -> TypeId { + let ingredient_index = self.ingredient_index(id); + *self + .ingredient_to_id_struct_type_id_map + .get(&ingredient_index) + .expect("should have the ingredient index available") + } + /// **NOT SEMVER STABLE** - pub fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex { - { - let jar_type_id = jar.type_id(); - let mut jar_map = self.jar_map.lock(); - let mut should_create = false; - // First record the index we will use into the map and then go and create the ingredients. - // Those ingredients may invoke methods on the `JarAux` trait that read from this map - // to lookup ingredient indices for already created jars. - // - // Note that we still hold the lock above so only one jar is being created at a time and hence - // ingredient indices cannot overlap. - let index = *jar_map.entry(jar_type_id).or_insert_with(|| { - should_create = true; - IngredientIndex::from(self.ingredients_vec.len()) - }); - if should_create { - let aux = JarAuxImpl(self, &jar_map); - let ingredients = jar.create_ingredients(&aux, index); - for ingredient in ingredients { - let expected_index = ingredient.ingredient_index(); - - if ingredient.requires_reset_for_new_revision() { - self.ingredients_requiring_reset.push(expected_index); - } - - let actual_index = self.ingredients_vec.push(ingredient); - assert_eq!( + pub fn add_or_lookup_jar_by_type(&self) -> IngredientIndex { + let jar_type_id = TypeId::of::(); + let mut jar_map = self.jar_map.lock(); + if let Some(index) = jar_map.get(&jar_type_id) { + return *index; + }; + drop(jar_map); + let dependencies = J::create_dependencies(self); + + jar_map = self.jar_map.lock(); + let index = IngredientIndex::from(self.ingredients_vec.len()); + match jar_map.entry(jar_type_id) { + hash_map::Entry::Occupied(entry) => { + // Someone made it earlier than us. + return *entry.get(); + } + hash_map::Entry::Vacant(entry) => entry.insert(index), + }; + let ingredients = J::create_ingredients(self, index, dependencies); + for ingredient in ingredients { + let expected_index = ingredient.ingredient_index(); + + if ingredient.requires_reset_for_new_revision() { + self.ingredients_requiring_reset.push(expected_index); + } + + let actual_index = self.ingredients_vec.push(ingredient); + assert_eq!( expected_index.as_usize(), actual_index, "ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`", @@ -222,11 +256,13 @@ impl Zalsa { expected_index, actual_index, ); - } - } - - index } + + drop(jar_map); + self.ingredient_to_id_struct_type_id_map + .insert(index, J::id_struct_type_id()); + + index } pub(crate) fn lookup_ingredient(&self, index: IngredientIndex) -> &dyn Ingredient { @@ -309,31 +345,10 @@ impl Zalsa { self.memo_ingredient_indices.read()[struct_ingredient_index.as_usize()] [memo_ingredient_index.as_usize()] } -} - -struct JarAuxImpl<'a>(&'a Zalsa, &'a FxHashMap); - -impl JarAux for JarAuxImpl<'_> { - fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option { - self.1.get(&jar.type_id()).map(ToOwned::to_owned) - } - fn next_memo_ingredient_index( - &self, - struct_ingredient_index: IngredientIndex, - ingredient_index: IngredientIndex, - ) -> MemoIngredientIndex { - let mut memo_ingredients = self.0.memo_ingredient_indices.write(); - let idx = struct_ingredient_index.as_usize(); - let memo_ingredients = if let Some(memo_ingredients) = memo_ingredients.get_mut(idx) { - memo_ingredients - } else { - memo_ingredients.resize_with(idx + 1, Vec::new); - memo_ingredients.get_mut(idx).unwrap() - }; - let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap()); - memo_ingredients.push(ingredient_index); - mi + #[inline] + pub fn ingredient_index(&self, id: Id) -> IngredientIndex { + self.table().ingredient_index(id) } } diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs index 8bb104023..103f6f058 100644 --- a/tests/interned-structs_self_ref.rs +++ b/tests/interned-structs_self_ref.rs @@ -1,8 +1,10 @@ //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. +use std::any::TypeId; use std::convert::identity; +use salsa::plumbing::Zalsa; use test_log::test; #[test] @@ -86,7 +88,7 @@ const _: () = { zalsa_::IngredientCache::new(); CACHE.get_or_create(db.as_dyn_database(), || { db.zalsa() - .add_or_lookup_jar_by_type(&>::default()) + .add_or_lookup_jar_by_type::>() }) } } @@ -110,10 +112,18 @@ const _: () = { } } impl zalsa_::SalsaStructInDb for InternedString<'_> { - fn lookup_ingredient_index( - aux: &dyn salsa::plumbing::JarAux, - ) -> Option { - aux.lookup_jar_by_type(&zalsa_struct_::JarImpl::::default()) + fn lookup_or_create_ingredient_index(aux: &Zalsa) -> salsa::plumbing::IngredientIndices { + aux.add_or_lookup_jar_by_type::>() + .into() + } + + #[inline] + fn cast(id: zalsa_::Id, type_id: TypeId) -> Option { + if type_id == TypeId::of::() { + Some(::from_id(id)) + } else { + None + } } } diff --git a/tests/tracked_fn_on_interned_enum.rs b/tests/tracked_fn_on_interned_enum.rs new file mode 100644 index 000000000..b1a025585 --- /dev/null +++ b/tests/tracked_fn_on_interned_enum.rs @@ -0,0 +1,93 @@ +//! Test that a `tracked` fn on a `salsa::interned` +//! compiles and executes successfully. + +#[salsa::interned_sans_lifetime] +struct Name { + name: String, +} + +#[salsa::interned] +struct NameAndAge<'db> { + name_and_age: String, +} + +#[salsa::interned_sans_lifetime] +struct Age { + age: u32, +} + +#[derive(Debug, salsa::Enum)] +enum Enum<'db> { + Name(Name), + NameAndAge(NameAndAge<'db>), + Age(Age), +} + +#[salsa::input] +struct Input { + value: String, +} + +#[derive(Debug, salsa::Enum)] +enum EnumOfEnum<'db> { + Enum(Enum<'db>), + Input(Input), +} + +#[salsa::tracked] +fn tracked_fn<'db>(db: &'db dyn salsa::Database, enum_: Enum<'db>) -> String { + match enum_ { + Enum::Name(name) => name.name(db), + Enum::NameAndAge(name_and_age) => name_and_age.name_and_age(db), + Enum::Age(age) => age.age(db).to_string(), + } +} + +#[salsa::tracked] +fn tracked_fn2<'db>(db: &'db dyn salsa::Database, enum_: EnumOfEnum<'db>) -> String { + match enum_ { + EnumOfEnum::Enum(enum_) => tracked_fn(db, enum_), + EnumOfEnum::Input(input) => input.value(db), + } +} + +#[test] +fn execute() { + let db = salsa::DatabaseImpl::new(); + let name = Name::new(&db, "Salsa".to_string()); + let name_and_age = NameAndAge::new(&db, "Salsa 3".to_string()); + let age = Age::new(&db, 123); + + assert_eq!(tracked_fn(&db, Enum::Name(name)), "Salsa"); + assert_eq!(tracked_fn(&db, Enum::NameAndAge(name_and_age)), "Salsa 3"); + assert_eq!(tracked_fn(&db, Enum::Age(age)), "123"); + assert_eq!(tracked_fn(&db, Enum::Name(name)), "Salsa"); + assert_eq!(tracked_fn(&db, Enum::NameAndAge(name_and_age)), "Salsa 3"); + assert_eq!(tracked_fn(&db, Enum::Age(age)), "123"); + + assert_eq!( + tracked_fn2(&db, EnumOfEnum::Enum(Enum::Name(name))), + "Salsa" + ); + assert_eq!( + tracked_fn2(&db, EnumOfEnum::Enum(Enum::NameAndAge(name_and_age))), + "Salsa 3" + ); + assert_eq!(tracked_fn2(&db, EnumOfEnum::Enum(Enum::Age(age))), "123"); + assert_eq!( + tracked_fn2(&db, EnumOfEnum::Enum(Enum::Name(name))), + "Salsa" + ); + assert_eq!( + tracked_fn2(&db, EnumOfEnum::Enum(Enum::NameAndAge(name_and_age))), + "Salsa 3" + ); + assert_eq!(tracked_fn2(&db, EnumOfEnum::Enum(Enum::Age(age))), "123"); + assert_eq!( + tracked_fn2( + &db, + EnumOfEnum::Input(Input::new(&db, "Hello world!".to_string())) + ), + "Hello world!" + ); +}