Skip to content

Commit

Permalink
Fixed to prevent stack overflow on self-referencing types
Browse files Browse the repository at this point in the history
  • Loading branch information
Eagle941 committed Dec 23, 2024
1 parent 63414c8 commit 52fbdbd
Show file tree
Hide file tree
Showing 10 changed files with 713 additions and 54 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ jobs:
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
- name: Install libraries
run: |
if [ "$(uname)" != "Darwin" ]; then
sudo apt-get install -y libfontconfig1-dev
fi
- name: Build
run: cargo build --locked
- name: Run tests
Expand All @@ -47,6 +52,8 @@ jobs:
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
- name: Install libraries
run: sudo apt-get install -y libfontconfig1-dev
- name: Cargo clippy
run: cargo clippy --locked --all-targets --all-features -- -D warnings

Expand All @@ -73,6 +80,8 @@ jobs:
profile: minimal
override: true
components: rustfmt
- name: Install libraries
run: sudo apt-get install -y libfontconfig1-dev
- name: Cargo fmt
run: cargo +nightly fmt --all -- --check

Expand All @@ -92,5 +101,7 @@ jobs:
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
- name: Install libraries
run: sudo apt-get install -y libfontconfig1-dev
- name: Cargo doc
run: cargo doc --no-deps --document-private-items --locked
8 changes: 4 additions & 4 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ fn main() {
///
/// - `path`: The file to write.
/// - `overwrite`: If `true`, the file can be overwritten.
fn check_file(path: &Option<PathBuf>, overwrite: bool) -> anyhow::Result<()> {
fn check_file(path: Option<&PathBuf>, overwrite: bool) -> anyhow::Result<()> {
if let Some(filename) = path {
if filename.exists() {
if !overwrite {
Expand Down Expand Up @@ -107,9 +107,9 @@ fn run(args: Args) -> anyhow::Result<()> {
let overwrite = args.overwrite;
let serial_replay = args.serial_replay;

check_file(&svg_path, overwrite)?;
check_file(&txt_out, overwrite)?;
check_file(&trace_out, overwrite)?;
check_file(svg_path.as_ref(), overwrite)?;
check_file(txt_out.as_ref(), overwrite)?;
check_file(trace_out.as_ref(), overwrite)?;

let storage = RpcStorage::new(rpc_url, serial_replay);

Expand Down
29 changes: 15 additions & 14 deletions starknet-replay/src/profiler/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ pub fn extract_libfuncs_weight(
visited_pcs: &VisitedPcs,
storage: &impl Storage,
) -> Result<ReplayStatistics, ProfilerError> {
let mut local_cumulative_libfuncs_weight: ReplayStatistics = ReplayStatistics::new();
let mut local_cumulative_libfuncs_weight = ReplayStatistics::new();

for (replay_class_hash, all_pcs) in visited_pcs {
tracing::info!("Processing pcs from {replay_class_hash:?}.");
let Ok(contract_class) = storage.get_contract_class_at_block(replay_class_hash) else {
continue;
};
Expand All @@ -100,13 +101,12 @@ pub fn extract_libfuncs_weight(
continue;
};

let runner = SierraProfiler::new(sierra_program.clone())?;
let runner = SierraProfiler::new(sierra_program)?;

for pcs in all_pcs {
let concrete_libfunc_weights = internal_extract_libfuncs_weight(&runner, pcs);
local_cumulative_libfuncs_weight =
local_cumulative_libfuncs_weight.add_statistics(&concrete_libfunc_weights);
}
let concrete_libfunc_weights = internal_extract_libfuncs_weight(&runner, all_pcs);

local_cumulative_libfuncs_weight =
local_cumulative_libfuncs_weight.add_statistics(&concrete_libfunc_weights);
}

for (concrete_name, weight) in local_cumulative_libfuncs_weight
Expand Down Expand Up @@ -191,14 +191,14 @@ mod tests {
}
}

fn compile_cairo_program(filename: &str) -> Program {
fn compile_cairo_program(filename: &str, replace_ids: bool) -> Program {
let absolute_path = env::var("CARGO_MANIFEST_DIR").unwrap();
let filename = [absolute_path.as_str(), filename].iter().join("");
let file_path = Path::new(&filename);
compile_cairo_project_at_path(
file_path,
CompilerConfig {
replace_ids: true,
replace_ids,
..CompilerConfig::default()
},
)
Expand All @@ -207,12 +207,13 @@ mod tests {

fn compile_cairo_contract(
filename: &str,
replace_ids: bool,
) -> cairo_lang_starknet_classes::contract_class::ContractClass {
let absolute_path = env::var("CARGO_MANIFEST_DIR").unwrap();
let filename = [absolute_path.as_str(), filename].iter().join("");
let file_path = Path::new(&filename);
let config = CompilerConfig {
replace_ids: true,
replace_ids,
..CompilerConfig::default()
};
let contract_path = None;
Expand Down Expand Up @@ -241,7 +242,7 @@ mod tests {
entrypoint_offset: usize,
args: &[MaybeRelocatable],
) -> Vec<usize> {
let contract_class = compile_cairo_contract(filename);
let contract_class = compile_cairo_contract(filename, true);

let add_pythonic_hints = false;
let max_bytecode_size = 180_000;
Expand Down Expand Up @@ -388,7 +389,7 @@ mod tests {
let visited_pcs: Vec<usize> = vec![1, 4, 6, 8, 3];

let cairo_file = "/test_data/sierra_add_program.cairo";
let sierra_program = compile_cairo_program(cairo_file);
let sierra_program = compile_cairo_program(cairo_file, true);

let sierra_profiler = SierraProfiler::new(sierra_program.clone()).unwrap();

Expand Down Expand Up @@ -438,7 +439,7 @@ mod tests {
// }

let cairo_file = "/test_data/sierra_add_contract.cairo";
let sierra_program = compile_cairo_contract(cairo_file)
let sierra_program = compile_cairo_contract(cairo_file, true)
.extract_sierra_program()
.unwrap();
let visited_pcs = visited_pcs_from_entrypoint(cairo_file, 0, &[]);
Expand Down Expand Up @@ -524,7 +525,7 @@ mod tests {
// }

let cairo_file = "/test_data/sierra_dict.cairo";
let sierra_program = compile_cairo_contract(cairo_file)
let sierra_program = compile_cairo_contract(cairo_file, true)
.extract_sierra_program()
.unwrap();
let visited_pcs = visited_pcs_from_entrypoint(cairo_file, 0, &[]);
Expand Down
7 changes: 4 additions & 3 deletions starknet-replay/src/profiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use cairo_lang_sierra_to_casm::metadata::{
MetadataComputationConfig,
MetadataError,
};
use itertools::Itertools;
use tracing::trace;

use crate::error::ProfilerError;
Expand Down Expand Up @@ -190,12 +191,12 @@ impl SierraProfiler {
#[must_use]
pub fn collect_profiling_info(&self, pcs: &[usize]) -> HashMap<StatementIdx, usize> {
let mut sierra_statement_weights = HashMap::default();
for pc in pcs {
for (pc, frequency) in pcs.iter().counts() {
let statements: Vec<&CompiledStatement> =
self.commands.iter().filter(|c| c.pc == *pc).collect();
for statement in statements {
let statement_idx = StatementIdx(statement.statement_idx);
*sierra_statement_weights.entry(statement_idx).or_insert(0) += 1;
*sierra_statement_weights.entry(statement_idx).or_insert(0) += frequency;
}
}

Expand Down Expand Up @@ -225,7 +226,7 @@ impl SierraProfiler {
&self,
statements: &HashMap<StatementIdx, usize>,
) -> HashMap<String, usize> {
let mut libfunc_weights = HashMap::<String, usize>::default();
let mut libfunc_weights = HashMap::default();
for (statement_idx, frequency) in statements {
if let Some(GenStatement::Invocation(invocation)) =
self.statement_idx_to_gen_statement(*statement_idx)
Expand Down
96 changes: 76 additions & 20 deletions starknet-replay/src/profiler/replace_ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
//! data. Without debug information, the [`cairo_lang_sierra::program::Program`]
//! contains only numeric ids to indicate libfuncs and types.
use std::collections::HashSet;
use std::sync::Arc;

use cairo_lang_sierra::ids::{ConcreteLibfuncId, ConcreteTypeId, FunctionId};
use cairo_lang_sierra::program::{self, ConcreteLibfuncLongId, Program};
use cairo_lang_sierra::program::{self, ConcreteLibfuncLongId, Program, TypeDeclaration};
use cairo_lang_sierra_generator::db::SierraGeneratorTypeLongId;
use cairo_lang_sierra_generator::replace_ids::SierraIdReplacer;
use cairo_lang_utils::extract_matches;
Expand Down Expand Up @@ -45,11 +46,11 @@ use cairo_lang_utils::extract_matches;
/// [`cairo_lang_sierra_generator::replace_ids::SierraIdReplacer`] to be able to
/// perform the replacement from id to string.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct DebugReplacer {
pub struct DebugReplacer<'a> {
/// The Sierra program to replace ids from.
program: Program,
program: &'a Program,
}
impl DebugReplacer {
impl DebugReplacer<'_> {
/// Get the long debug name for the libfunc with id equivalent to `id`.
fn lookup_intern_concrete_lib_func(&self, id: &ConcreteLibfuncId) -> ConcreteLibfuncLongId {
self.program
Expand All @@ -61,19 +62,79 @@ impl DebugReplacer {
.long_id
}

/// Get the long debug name for the type with id equivalent to `id`.
fn lookup_intern_concrete_type(&self, id: &ConcreteTypeId) -> SierraGeneratorTypeLongId {
let concrete_type = self
.program
/// Get the type declaration for a given `type_id`.
fn get_type_declaration(&self, type_id: &ConcreteTypeId) -> TypeDeclaration {
self.program
.type_declarations
.iter()
.find(|f| f.id.id == id.id)
.find(|f| f.id.id == type_id.id)
.expect("ConcreteTypeId should be found in type_declarations.")
.clone();
SierraGeneratorTypeLongId::Regular(Arc::new(concrete_type.long_id))
.clone()
}

/// This function builds the `HashSet` of type dependencies for `type_id`.
/// The argument `visited_types` is used to keep track of previously
/// visited dependencies to break cycles and avoid infinite recursion.
fn type_dependencies(
&self,
visited_types: &mut HashSet<ConcreteTypeId>,
type_id: &ConcreteTypeId,
) -> HashSet<ConcreteTypeId> {
let mut dependencies = HashSet::new();

if visited_types.contains(type_id) {
return dependencies;
}
visited_types.insert(type_id.clone());

let concrete_type = self.get_type_declaration(type_id);

concrete_type.long_id.generic_args.iter().for_each(|t| {
if let program::GenericArg::Type(concrete_type_id) = t {
dependencies.insert(concrete_type_id.clone());
if visited_types.contains(concrete_type_id) {
return;
}
dependencies.extend(self.type_dependencies(visited_types, concrete_type_id));
}
});

dependencies
}

/// Returns true if `type_id` depends on `needle`. False otherwise.
fn has_in_deps(&self, type_id: &ConcreteTypeId, needle: &ConcreteTypeId) -> bool {
let mut visited_types = HashSet::new();
let deps = self.type_dependencies(&mut visited_types, type_id);
if deps.contains(needle) {
return true;
}
false
}

/// Get the long debug name for the type with id equivalent to `id`.
///
/// If `id` is a self-referencing type (i.e. it depends on itself), then the
/// function returns `None` as an alternative to
/// [`SierraGeneratorTypeLongId::CycleBreaker`]. It's not possible to
/// construct a [`SierraGeneratorTypeLongId::CycleBreaker`] object because
/// it requires having access to the `SalsaDB` of the program.
fn lookup_intern_concrete_type(
&self,
id: &ConcreteTypeId,
) -> Option<SierraGeneratorTypeLongId> {
let concrete_type = self.get_type_declaration(id);
if self.has_in_deps(id, id) {
None
} else {
Some(SierraGeneratorTypeLongId::Regular(Arc::new(
concrete_type.long_id,
)))
}
}
}
impl SierraIdReplacer for DebugReplacer {

impl SierraIdReplacer for DebugReplacer<'_> {
fn replace_libfunc_id(&self, id: &ConcreteLibfuncId) -> ConcreteLibfuncId {
let mut long_id = self.lookup_intern_concrete_lib_func(id);
self.replace_generic_args(&mut long_id.generic_args);
Expand All @@ -91,10 +152,7 @@ impl SierraIdReplacer for DebugReplacer {
// It's not possible to recover the `debug_name` of `Phantom` and `CycleBreaker` because
// it relies on access to the Salsa db which is available only during
// contract compilation.
SierraGeneratorTypeLongId::Phantom(_) | SierraGeneratorTypeLongId::CycleBreaker(_) => {
id.clone()
}
SierraGeneratorTypeLongId::Regular(long_id) => {
Some(SierraGeneratorTypeLongId::Regular(long_id)) => {
let mut long_id = long_id.as_ref().clone();
self.replace_generic_args(&mut long_id.generic_args);
if long_id.generic_id == "Enum".into() || long_id.generic_id == "Struct".into() {
Expand All @@ -116,6 +174,7 @@ impl SierraIdReplacer for DebugReplacer {
debug_name: Some(long_id.to_string().into()),
}
}
_ => id.clone(),
}
}

Expand Down Expand Up @@ -149,10 +208,7 @@ impl SierraIdReplacer for DebugReplacer {
/// [`cairo_lang_sierra_generator::db::SierraGenGroup`] trait object.
#[must_use]
pub fn replace_sierra_ids_in_program(program: &Program) -> Program {
DebugReplacer {
program: program.clone(),
}
.apply(program)
DebugReplacer { program }.apply(program)
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion starknet-replay/src/runner/replay_class_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct ReplayClassHash {

/// The type [`VisitedPcs`] is a hashmap to store the visited program counters
/// for each contract invocation during replay.
pub type VisitedPcs = HashMap<ReplayClassHash, Vec<Vec<usize>>>;
pub type VisitedPcs = HashMap<ReplayClassHash, Vec<usize>>;

/// The type [`TransactionOutput`] contains the combination of transaction
/// receipt and list of visited program counters.
Expand Down
4 changes: 2 additions & 2 deletions starknet-replay/src/storage/rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,15 @@ impl ReplayStorage for RpcStorage {

let visited_pcs: VisitedPcs = state
.visited_pcs
.clone()
.0
.clone()
.into_iter()
.map(|(class_hash, pcs)| {
let replay_class_hash = ReplayClassHash {
block_number,
class_hash,
};
(replay_class_hash, pcs.into_iter().collect())
(replay_class_hash, pcs.clone())
})
.collect();
if let Some(filename) = trace_out {
Expand Down
Loading

0 comments on commit 52fbdbd

Please sign in to comment.