diff --git a/starknet-replay/src/profiler/analysis.rs b/starknet-replay/src/profiler/analysis.rs index 0fd5568..549acdb 100644 --- a/starknet-replay/src/profiler/analysis.rs +++ b/starknet-replay/src/profiler/analysis.rs @@ -89,9 +89,10 @@ pub fn extract_libfuncs_weight( visited_pcs: &VisitedPcs, storage: &impl Storage, ) -> Result { - let mut local_cumulative_libfuncs_weight: ReplayStatistics = ReplayStatistics::new(); + let mut local_cumulative_libfuncs_weight = ReplayStatistics::new(); for (replay_class_hash, all_pcs) in visited_pcs { + tracing::info!("Processing pcs from {replay_class_hash:?}."); let Ok(contract_class) = storage.get_contract_class_at_block(replay_class_hash) else { continue; }; @@ -100,13 +101,12 @@ pub fn extract_libfuncs_weight( continue; }; - let runner = SierraProfiler::new(sierra_program.clone())?; + let runner = SierraProfiler::new(sierra_program)?; - for pcs in all_pcs { - let concrete_libfunc_weights = internal_extract_libfuncs_weight(&runner, pcs); - local_cumulative_libfuncs_weight = - local_cumulative_libfuncs_weight.add_statistics(&concrete_libfunc_weights); - } + let concrete_libfunc_weights = internal_extract_libfuncs_weight(&runner, all_pcs); + + local_cumulative_libfuncs_weight = + local_cumulative_libfuncs_weight.add_statistics(&concrete_libfunc_weights); } for (concrete_name, weight) in local_cumulative_libfuncs_weight @@ -191,14 +191,14 @@ mod tests { } } - fn compile_cairo_program(filename: &str) -> Program { + fn compile_cairo_program(filename: &str, replace_ids: bool) -> Program { let absolute_path = env::var("CARGO_MANIFEST_DIR").unwrap(); let filename = [absolute_path.as_str(), filename].iter().join(""); let file_path = Path::new(&filename); compile_cairo_project_at_path( file_path, CompilerConfig { - replace_ids: true, + replace_ids, ..CompilerConfig::default() }, ) @@ -207,12 +207,13 @@ mod tests { fn compile_cairo_contract( filename: &str, + replace_ids: bool, ) -> cairo_lang_starknet_classes::contract_class::ContractClass { let absolute_path = env::var("CARGO_MANIFEST_DIR").unwrap(); let filename = [absolute_path.as_str(), filename].iter().join(""); let file_path = Path::new(&filename); let config = CompilerConfig { - replace_ids: true, + replace_ids, ..CompilerConfig::default() }; let contract_path = None; @@ -241,7 +242,7 @@ mod tests { entrypoint_offset: usize, args: &[MaybeRelocatable], ) -> Vec { - let contract_class = compile_cairo_contract(filename); + let contract_class = compile_cairo_contract(filename, true); let add_pythonic_hints = false; let max_bytecode_size = 180_000; @@ -388,7 +389,7 @@ mod tests { let visited_pcs: Vec = vec![1, 4, 6, 8, 3]; let cairo_file = "/test_data/sierra_add_program.cairo"; - let sierra_program = compile_cairo_program(cairo_file); + let sierra_program = compile_cairo_program(cairo_file, true); let sierra_profiler = SierraProfiler::new(sierra_program.clone()).unwrap(); @@ -438,7 +439,7 @@ mod tests { // } let cairo_file = "/test_data/sierra_add_contract.cairo"; - let sierra_program = compile_cairo_contract(cairo_file) + let sierra_program = compile_cairo_contract(cairo_file, true) .extract_sierra_program() .unwrap(); let visited_pcs = visited_pcs_from_entrypoint(cairo_file, 0, &[]); @@ -524,7 +525,7 @@ mod tests { // } let cairo_file = "/test_data/sierra_dict.cairo"; - let sierra_program = compile_cairo_contract(cairo_file) + let sierra_program = compile_cairo_contract(cairo_file, true) .extract_sierra_program() .unwrap(); let visited_pcs = visited_pcs_from_entrypoint(cairo_file, 0, &[]); diff --git a/starknet-replay/src/profiler/mod.rs b/starknet-replay/src/profiler/mod.rs index 1e4a9cf..231ecea 100644 --- a/starknet-replay/src/profiler/mod.rs +++ b/starknet-replay/src/profiler/mod.rs @@ -20,6 +20,7 @@ use cairo_lang_sierra_to_casm::metadata::{ MetadataComputationConfig, MetadataError, }; +use itertools::Itertools; use tracing::trace; use crate::error::ProfilerError; @@ -190,12 +191,12 @@ impl SierraProfiler { #[must_use] pub fn collect_profiling_info(&self, pcs: &[usize]) -> HashMap { let mut sierra_statement_weights = HashMap::default(); - for pc in pcs { + for (pc, frequency) in pcs.iter().counts() { let statements: Vec<&CompiledStatement> = self.commands.iter().filter(|c| c.pc == *pc).collect(); for statement in statements { let statement_idx = StatementIdx(statement.statement_idx); - *sierra_statement_weights.entry(statement_idx).or_insert(0) += 1; + *sierra_statement_weights.entry(statement_idx).or_insert(0) += frequency; } } @@ -225,7 +226,7 @@ impl SierraProfiler { &self, statements: &HashMap, ) -> HashMap { - let mut libfunc_weights = HashMap::::default(); + let mut libfunc_weights = HashMap::default(); for (statement_idx, frequency) in statements { if let Some(GenStatement::Invocation(invocation)) = self.statement_idx_to_gen_statement(*statement_idx) diff --git a/starknet-replay/src/profiler/replace_ids.rs b/starknet-replay/src/profiler/replace_ids.rs index 115f850..be75714 100644 --- a/starknet-replay/src/profiler/replace_ids.rs +++ b/starknet-replay/src/profiler/replace_ids.rs @@ -4,10 +4,11 @@ //! data. Without debug information, the [`cairo_lang_sierra::program::Program`] //! contains only numeric ids to indicate libfuncs and types. +use std::collections::HashSet; use std::sync::Arc; use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId}; -use cairo_lang_sierra::program::{self, ConcreteLibfuncLongId, Program}; +use cairo_lang_sierra::program::{self, ConcreteLibfuncLongId, Program, TypeDeclaration}; use cairo_lang_sierra_generator::db::SierraGeneratorTypeLongId; use cairo_lang_sierra_generator::replace_ids::SierraIdReplacer; use cairo_lang_utils::extract_matches; @@ -45,11 +46,11 @@ use cairo_lang_utils::extract_matches; /// [`cairo_lang_sierra_generator::replace_ids::SierraIdReplacer`] to be able to /// perform the replacement from id to string. #[derive(Debug, Clone, Eq, PartialEq)] -pub struct DebugReplacer { +pub struct DebugReplacer<'a> { /// The Sierra program to replace ids from. - program: Program, + program: &'a Program, } -impl DebugReplacer { +impl DebugReplacer<'_> { /// Get the long debug name for the libfunc with id equivalent to `id`. fn lookup_intern_concrete_lib_func(&self, id: &ConcreteLibfuncId) -> ConcreteLibfuncLongId { self.program @@ -61,19 +62,85 @@ impl DebugReplacer { .long_id } - /// Get the long debug name for the type with id equivalent to `id`. - fn lookup_intern_concrete_type(&self, id: &ConcreteTypeId) -> SierraGeneratorTypeLongId { - let concrete_type = self - .program + /// Get the type declaration for a given `type_id`. + fn get_type_declaration(&self, type_id: &ConcreteTypeId) -> TypeDeclaration { + self.program .type_declarations .iter() - .find(|f| f.id.id == id.id) + .find(|f| f.id.id == type_id.id) .expect("ConcreteTypeId should be found in type_declarations.") - .clone(); - SierraGeneratorTypeLongId::Regular(Arc::new(concrete_type.long_id)) + .clone() + } + + /// This function builds the HashSet of type dependencies for `type_id`. The + /// argument `visited_types` is used to keep track of previously visited + /// dependencies to break cycles and avoid infinite recursion. + fn type_dependencies( + &self, + visited_types: &mut HashSet, + type_id: &ConcreteTypeId, + ) -> HashSet { + let mut dependencies = HashSet::new(); + + if visited_types.contains(type_id) { + return dependencies; + } + visited_types.insert(type_id.clone()); + + let concrete_type = self.get_type_declaration(type_id); + + concrete_type + .long_id + .generic_args + .iter() + .for_each(|t| match t { + program::GenericArg::Type(concrete_type_id) => { + dependencies.insert(concrete_type_id.clone()); + if visited_types.contains(concrete_type_id) { + return; + } + dependencies.extend(self.type_dependencies(visited_types, concrete_type_id)); + return; + } + _ => return, + }); + + dependencies + } + + /// Returns true if `type_id` depends on `needle`. False otherwise. + fn has_in_deps(&self, type_id: &ConcreteTypeId, needle: &ConcreteTypeId) -> bool { + let mut visited_types = HashSet::new(); + let deps = self.type_dependencies(&mut visited_types, type_id); + if deps.contains(&needle) { + return true; + } + return false; + } + + /// Get the long debug name for the type with id equivalent to `id`. + /// + /// If `id` is a self-referencing type (i.e. it depends on itself), then the + /// function returns `None` as an alternative to + /// [`SierraGeneratorTypeLongId::CircuitBreaker`]. It's not possible to + /// construct a [`SierraGeneratorTypeLongId::CircuitBreaker`] object because + /// it requires having access to the SalsaDB of the program. + fn lookup_intern_concrete_type( + &self, + id: &ConcreteTypeId, + ) -> Option { + let concrete_type = self.get_type_declaration(id); + if self.has_in_deps(id, id) { + None + } else { + Some(SierraGeneratorTypeLongId::Regular(Arc::new( + concrete_type.long_id, + ))) + } } } -impl SierraIdReplacer for DebugReplacer { + +impl SierraIdReplacer for DebugReplacer<'_> { fn replace_libfunc_id(&self, id: &ConcreteLibfuncId) -> ConcreteLibfuncId { let mut long_id = self.lookup_intern_concrete_lib_func(id); self.replace_generic_args(&mut long_id.generic_args); @@ -91,10 +158,7 @@ impl SierraIdReplacer for DebugReplacer { // It's not possible to recover the `debug_name` of `Phantom` and `CycleBreaker` because // it relies on access to the Salsa db which is available only during // contract compilation. - SierraGeneratorTypeLongId::Phantom(_) | SierraGeneratorTypeLongId::CycleBreaker(_) => { - id.clone() - } - SierraGeneratorTypeLongId::Regular(long_id) => { + Some(SierraGeneratorTypeLongId::Regular(long_id)) => { let mut long_id = long_id.as_ref().clone(); self.replace_generic_args(&mut long_id.generic_args); if long_id.generic_id == "Enum".into() || long_id.generic_id == "Struct".into() { @@ -116,6 +180,7 @@ impl SierraIdReplacer for DebugReplacer { debug_name: Some(long_id.to_string().into()), } } + _ => id.clone(), } } @@ -149,10 +214,7 @@ impl SierraIdReplacer for DebugReplacer { /// [`cairo_lang_sierra_generator::db::SierraGenGroup`] trait object. #[must_use] pub fn replace_sierra_ids_in_program(program: &Program) -> Program { - DebugReplacer { - program: program.clone(), - } - .apply(program) + DebugReplacer { program }.apply(program) } #[cfg(test)] diff --git a/starknet-replay/src/runner/replay_class_hash.rs b/starknet-replay/src/runner/replay_class_hash.rs index 63824f8..ed022e6 100644 --- a/starknet-replay/src/runner/replay_class_hash.rs +++ b/starknet-replay/src/runner/replay_class_hash.rs @@ -24,7 +24,7 @@ pub struct ReplayClassHash { /// The type [`VisitedPcs`] is a hashmap to store the visited program counters /// for each contract invocation during replay. -pub type VisitedPcs = HashMap>>; +pub type VisitedPcs = HashMap>; /// The type [`TransactionOutput`] contains the combination of transaction /// receipt and list of visited program counters. diff --git a/starknet-replay/src/storage/rpc/mod.rs b/starknet-replay/src/storage/rpc/mod.rs index f7ee0eb..e507a8a 100644 --- a/starknet-replay/src/storage/rpc/mod.rs +++ b/starknet-replay/src/storage/rpc/mod.rs @@ -489,15 +489,15 @@ impl ReplayStorage for RpcStorage { let visited_pcs: VisitedPcs = state .visited_pcs - .clone() .0 + .to_owned() .into_iter() .map(|(class_hash, pcs)| { let replay_class_hash = ReplayClassHash { block_number, class_hash, }; - (replay_class_hash, pcs.into_iter().collect()) + (replay_class_hash, pcs.clone()) }) .collect(); if let Some(filename) = trace_out { diff --git a/starknet-replay/src/storage/rpc/visited_pcs.rs b/starknet-replay/src/storage/rpc/visited_pcs.rs index 010427f..b4d21ed 100644 --- a/starknet-replay/src/storage/rpc/visited_pcs.rs +++ b/starknet-replay/src/storage/rpc/visited_pcs.rs @@ -12,19 +12,17 @@ use starknet_api::core::ClassHash; /// The hashmap of [`VisitedPcsRaw`] is a map from a /// [`starknet_api::core::ClassHash`] to a vector of visited program counters. -/// The vector returned from each call to [`starknet_api::core::ClassHash`] is -/// added to the nested vector. #[derive(Clone, Debug, Default, PartialEq, Eq)] -pub struct VisitedPcsRaw(pub HashMap>>); +pub struct VisitedPcsRaw(pub HashMap>); impl VisitedPcs for VisitedPcsRaw { - type Pcs = Vec>; + type Pcs = Vec; fn new() -> Self { VisitedPcsRaw(HashMap::default()) } fn insert(&mut self, class_hash: &ClassHash, pcs: &[usize]) { - self.0.entry(*class_hash).or_default().push(pcs.to_vec()); + self.0.entry(*class_hash).or_default().extend(pcs.iter()); } fn iter(&self) -> impl Iterator { @@ -41,15 +39,13 @@ impl VisitedPcs for VisitedPcsRaw { fn to_set(pcs: Self::Pcs) -> HashSet { let mut set = HashSet::new(); - pcs.into_iter().flatten().for_each(|p| { + pcs.into_iter().for_each(|p| { set.insert(p); }); set } fn add_visited_pcs(state: &mut dyn State, class_hash: &ClassHash, pcs: Self::Pcs) { - for pc in pcs { - state.add_visited_pcs(*class_hash, &pc); - } + state.add_visited_pcs(*class_hash, &pcs); } }