diff --git a/air/src/options.rs b/air/src/options.rs index a831bdad7..d97acf434 100644 --- a/air/src/options.rs +++ b/air/src/options.rs @@ -383,7 +383,11 @@ impl PartitionOptions { self.min_partition_size as usize, ); - base_elements_per_partition.div(E::EXTENSION_DEGREE) + base_elements_per_partition.div_ceil(E::EXTENSION_DEGREE) + } + + pub fn num_partitons(&self) -> usize { + self.num_partitions as usize } } diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 035d6c655..cbb8d211f 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -555,11 +555,8 @@ pub trait Prover { log_domain_size = domain_size.ilog2() ) .in_scope(|| { - let commitment = composed_evaluations.commit_to_rows::( - self.options() - .partition_options() - .partition_size::(num_constraint_composition_columns), - ); + let commitment = composed_evaluations + .commit_to_rows::(self.options().partition_options()); ConstraintCommitment::new(composed_evaluations, commitment) }); diff --git a/prover/src/matrix/row_matrix.rs b/prover/src/matrix/row_matrix.rs index 85b43122e..913250f28 100644 --- a/prover/src/matrix/row_matrix.rs +++ b/prover/src/matrix/row_matrix.rs @@ -5,6 +5,7 @@ use alloc::vec::Vec; +use air::PartitionOptions; use crypto::{ElementHasher, VectorCommitment}; use math::{fft, FieldElement, StarkField}; #[cfg(feature = "concurrent")] @@ -180,7 +181,7 @@ impl RowMatrix { /// * A vector commitment is computed for the resulting vector using the specified vector /// commitment scheme. /// * The resulting vector commitment is returned as the commitment to the entire matrix. - pub fn commit_to_rows(&self, partition_size: usize) -> V + pub fn commit_to_rows(&self, partition_options: PartitionOptions) -> V where H: ElementHasher, V: VectorCommitment, @@ -188,6 +189,9 @@ impl RowMatrix { // allocate vector to store row hashes let mut row_hashes = unsafe { uninit_vector::(self.num_rows()) }; + let partition_size = partition_options.partition_size::(self.num_cols()); + let num_partitions = partition_options.num_partitons(); + if partition_size == self.num_cols() * E::EXTENSION_DEGREE { // iterate though matrix rows, hashing each row batch_iter_mut!( @@ -205,7 +209,7 @@ impl RowMatrix { &mut row_hashes, 128, // min batch size |batch: &mut [H::Digest], batch_offset: usize| { - let mut buffer = vec![H::Digest::default(); partition_size]; + let mut buffer = vec![H::Digest::default(); num_partitions]; for (i, row_hash) in batch.iter_mut().enumerate() { self.row(batch_offset + i) .chunks(partition_size) diff --git a/prover/src/trace/trace_lde/default/mod.rs b/prover/src/trace/trace_lde/default/mod.rs index 26b5e3916..a26a67152 100644 --- a/prover/src/trace/trace_lde/default/mod.rs +++ b/prover/src/trace/trace_lde/default/mod.rs @@ -43,7 +43,7 @@ pub struct DefaultTraceLde< aux_segment_oracles: Option, blowup: usize, trace_info: TraceInfo, - partition_option: PartitionOptions, + partition_options: PartitionOptions, _h: PhantomData, } @@ -64,14 +64,14 @@ where trace_info: &TraceInfo, main_trace: &ColMatrix, domain: &StarkDomain, - partition_option: PartitionOptions, + partition_options: PartitionOptions, ) -> (Self, TracePolyTable) { // extend the main execution trace and build a commitment to the extended trace let (main_segment_lde, main_segment_vector_com, main_segment_polys) = build_trace_commitment::( main_trace, domain, - partition_option.partition_size::(main_trace.num_cols()), + partition_options, ); let trace_poly_table = TracePolyTable::new(main_segment_polys); @@ -82,7 +82,7 @@ where aux_segment_oracles: None, blowup: domain.trace_to_lde_blowup(), trace_info: trace_info.clone(), - partition_option, + partition_options, _h: PhantomData, }; @@ -148,11 +148,7 @@ where ) -> (ColMatrix, H::Digest) { // extend the auxiliary trace segment and build a commitment to the extended trace let (aux_segment_lde, aux_segment_oracles, aux_segment_polys) = - build_trace_commitment::( - aux_trace, - domain, - self.partition_option.partition_size::(aux_trace.num_cols()), - ); + build_trace_commitment::(aux_trace, domain, self.partition_options); // check errors assert!( @@ -276,7 +272,7 @@ where fn build_trace_commitment( trace: &ColMatrix, domain: &StarkDomain, - partition_size: usize, + partition_options: PartitionOptions, ) -> (RowMatrix, V, ColMatrix) where E: FieldElement, @@ -306,7 +302,7 @@ where // build trace commitment let commitment_domain_size = trace_lde.num_rows(); let trace_vector_com = info_span!("compute_execution_trace_commitment", commitment_domain_size) - .in_scope(|| trace_lde.commit_to_rows::(partition_size)); + .in_scope(|| trace_lde.commit_to_rows::(partition_options)); assert_eq!(trace_vector_com.domain_len(), commitment_domain_size); (trace_lde, trace_vector_com, trace_polys) diff --git a/verifier/src/channel.rs b/verifier/src/channel.rs index 9d7dbc426..b33673e16 100644 --- a/verifier/src/channel.rs +++ b/verifier/src/channel.rs @@ -36,6 +36,7 @@ pub struct VerifierChannel< constraint_commitment: H::Digest, constraint_queries: Option>, // partition sizes for the rows of main, auxiliary and constraint traces rows + num_partitions: usize, partition_size_main: usize, partition_size_aux: usize, partition_size_constraint: usize, @@ -120,6 +121,7 @@ where .map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))?; // --- compute the partition size for each trace ------------------------------------------ + let num_partitions = partition_options.num_partitons(); let partition_size_main = partition_options .partition_size::(air.context().trace_info().main_trace_width()); let partition_size_aux = @@ -135,6 +137,7 @@ where constraint_commitment, constraint_queries: Some(constraint_queries), // num partitions used in commitment + num_partitions, partition_size_main, partition_size_aux, partition_size_constraint, @@ -211,7 +214,9 @@ where let items: Vec = queries .main_states .rows() - .map(|row| hash_row::(row, self.partition_size_main)) + .map(|row| { + hash_row::(row, self.partition_size_main, self.num_partitions) + }) .collect(); >::verify_many( @@ -225,7 +230,7 @@ where if let Some(ref aux_states) = queries.aux_states { let items: Vec = aux_states .rows() - .map(|row| hash_row::(row, self.partition_size_aux)) + .map(|row| hash_row::(row, self.partition_size_aux, self.num_partitions)) .collect(); >::verify_many( @@ -252,7 +257,7 @@ where let items: Vec = queries .evaluations .rows() - .map(|row| hash_row::(row, self.partition_size_constraint)) + .map(|row| hash_row::(row, self.partition_size_constraint, self.num_partitions)) .collect(); >::verify_many( @@ -437,7 +442,7 @@ where // ================================================================================================ /// Hashes a row of a trace in batches where each batch is of size at most `partition_size`. -fn hash_row(row: &[E], partition_size: usize) -> H::Digest +fn hash_row(row: &[E], partition_size: usize, num_partitions: usize) -> H::Digest where E: FieldElement, H: ElementHasher, @@ -445,7 +450,7 @@ where if partition_size == row.len() * E::EXTENSION_DEGREE { H::hash_elements(row) } else { - let mut buffer = vec![H::Digest::default(); partition_size]; + let mut buffer = vec![H::Digest::default(); num_partitions]; row.chunks(partition_size) .zip(buffer.iter_mut())