diff --git a/air/src/options.rs b/air/src/options.rs index 01599489c..657a6381d 100644 --- a/air/src/options.rs +++ b/air/src/options.rs @@ -4,7 +4,7 @@ // LICENSE file in the root directory of this source tree. use alloc::vec::Vec; -use core::{cmp, ops::Div}; +use core::cmp; use fri::FriOptions; use math::{FieldElement, StarkField, ToElements}; @@ -76,16 +76,13 @@ pub enum FieldExtension { /// collision resistance of the hash function used by the protocol. For example, if a hash function /// with 128-bit collision resistance is used, soundness of a STARK proof cannot exceed 128 bits. /// -/// In addition to the above, the parameter `num_partitions` is used in order to specify the number -/// of partitions each of the traces committed to during proof generation is split into, and -/// the parameter `min_partition_size` gives a lower bound on the size of each such partition. -/// More precisely, and taking the main segment trace as an example, the prover will split the main -/// segment trace into `num_partitions` parts each of size at least `min_partition_size`. The prover -/// will then proceed to hash each part row-wise resulting in `num_partitions` digests per row of -/// the trace. The prover finally combines the `num_partitions` digest (per row) into one digest -/// (per row) and at this point the vector commitment scheme can be called. -/// In the case when `num_partitions` is equal to `1` the prover will just hash each row in one go -/// producing one digest per row of the trace. +/// In addition, partition options (see [PartitionOptions]) can be provided to split traces during +/// proving and distribute work across multiple devices. Taking the main segment trace as an example, +/// the prover will split the main segment trace into `num_partitions` parts, and then proceed to hash +/// each part row-wise resulting in `num_partitions` digests per row of the trace. Finally, +/// `num_partitions` digests (per row) are combined into one digest (per row) and at this point +/// a vector commitment scheme can be called. In the case when `num_partitions` is equal to `1` (default) +/// the prover will hash each row in one go producing one digest per row of the trace. #[derive(Debug, Clone, Eq, PartialEq)] pub struct ProofOptions { num_queries: u8, @@ -177,13 +174,13 @@ impl ProofOptions { /// # Panics /// Panics if: /// - `num_partitions` is zero or greater than 16. - /// - `min_partition_size` is zero or greater than 256. + /// - `hash_rate` is zero or greater than 256. pub const fn with_partitions( mut self, num_partitions: usize, - min_partition_size: usize, + hash_rate: usize, ) -> ProofOptions { - self.partition_options = PartitionOptions::new(num_partitions, min_partition_size); + self.partition_options = PartitionOptions::new(num_partitions, hash_rate); self } @@ -277,7 +274,7 @@ impl Serializable for ProofOptions { target.write_u8(self.fri_folding_factor); target.write_u8(self.fri_remainder_max_degree); target.write_u8(self.partition_options.num_partitions); - target.write_u8(self.partition_options.min_partition_size); + target.write_u8(self.partition_options.hash_rate); } } @@ -347,31 +344,43 @@ impl Deserializable for FieldExtension { // PARTITION OPTION IMPLEMENTATION // ================================================================================================ -/// Defines the parameters used when committing to the traces generated during the protocol. +/// Defines the parameters used to calculate partition size when committing to the traces +/// generated during the protocol. +/// +/// Using multiple partitions will change how vector commitments are calculated: +/// - Input matrix columns are split into at most num_partitions partitions +/// - For each matrix row, a hash is calculated for each partition separately +/// - The results are merged together by one more hash iteration +/// +/// This is especially useful when proving with multiple GPU cards where each device holds +/// a subset of data and allows less data reshuffling when generating commitments. +/// +/// Hash_rate parameter is used to find the optimal partition size to minimize the number +/// of hash iterations. It specifies how many field elements are consumed by each hash iteration. #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct PartitionOptions { num_partitions: u8, - min_partition_size: u8, + hash_rate: u8, } impl PartitionOptions { /// Returns a new instance of `[PartitionOptions]`. - pub const fn new(num_partitions: usize, min_partition_size: usize) -> Self { + pub const fn new(num_partitions: usize, hash_rate: usize) -> Self { assert!(num_partitions >= 1, "number of partitions must be greater than or eqaul to 1"); assert!(num_partitions <= 16, "number of partitions must be smaller than or equal to 16"); assert!( - min_partition_size >= 1, - "smallest partition size must be greater than or equal to 1" + hash_rate >= 1, + "hash rate must be greater than or equal to 1" ); assert!( - min_partition_size <= 256, - "smallest partition size must be smaller than or equal to 256" + hash_rate <= 256, + "hash rate must be smaller than or equal to 256" ); Self { num_partitions: num_partitions as u8, - min_partition_size: min_partition_size as u8, + hash_rate: hash_rate as u8, } } @@ -379,21 +388,30 @@ impl PartitionOptions { /// well as the constraint evaluation trace. /// The returned size is given in terms of number of columns in the field `E`. pub fn partition_size(&self, num_columns: usize) -> usize { - if self.num_partitions == 1 && self.min_partition_size == 1 { + if self.num_partitions == 1 { return num_columns; } - let base_elements_per_partition = cmp::max( - (num_columns * E::EXTENSION_DEGREE).div_ceil(self.num_partitions as usize), - self.min_partition_size as usize, - ); - base_elements_per_partition.div(E::EXTENSION_DEGREE) + // Don't separate columns that would fit inside one hash iteration. min_partition_size is + // the number of `E` elements that can be consumed in one hash iteration. + let min_partition_size = self.hash_rate as usize / E::EXTENSION_DEGREE; + + cmp::max( + num_columns.div_ceil(self.num_partitions as usize), + min_partition_size, + ) + } + + /// The actual number of partitions, after the min partition size implied + /// by the hash rate is taken into account. + pub fn num_partitions(&self, num_columns: usize) -> usize { + num_columns.div_ceil(self.partition_size::(num_columns)) } } impl Default for PartitionOptions { fn default() -> Self { - Self { num_partitions: 1, min_partition_size: 1 } + Self { num_partitions: 1, hash_rate: 1 } } } @@ -402,9 +420,9 @@ impl Default for PartitionOptions { #[cfg(test)] mod tests { - use math::fields::f64::BaseElement; + use math::fields::{f64::BaseElement, CubeExtension}; - use super::{FieldExtension, ProofOptions, ToElements}; + use super::{FieldExtension, PartitionOptions, ProofOptions, ToElements}; #[test] fn proof_options_to_elements() { @@ -438,4 +456,37 @@ mod tests { ); assert_eq!(expected, options.to_elements()); } + + #[test] + fn correct_partition_sizes() { + type E1 = BaseElement; + type E3 = CubeExtension; + + let options = PartitionOptions::new(4, 8); + let columns = 7; + assert_eq!(8, options.partition_size::(columns)); + assert_eq!(1, options.num_partitions::(columns)); + + let options = PartitionOptions::new(4, 8); + let columns = 70; + assert_eq!(18, options.partition_size::(columns)); + assert_eq!(4, options.num_partitions::(columns)); + + let options = PartitionOptions::new(2, 8); + let columns = 7; + assert_eq!(4, options.partition_size::(columns)); + assert_eq!(2, options.num_partitions::(columns)); + + let options: PartitionOptions = PartitionOptions::new(4, 8); + let columns = 7; + assert_eq!(2, options.partition_size::(columns)); + assert_eq!(4, options.num_partitions::(columns)); + + // don't use all partitions if it would result in sizes smaller than + // a single hash iteration can handle + let options: PartitionOptions = PartitionOptions::new(4, 8); + let columns = 3; + assert_eq!(2, options.partition_size::(columns)); + assert_eq!(2, options.num_partitions::(columns)); + } } diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 1a2e157ea..c72c0c766 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -558,8 +558,7 @@ pub trait Prover { .in_scope(|| { let commitment = composed_evaluations.commit_to_rows::( self.options() - .partition_options() - .partition_size::(num_constraint_composition_columns), + .partition_options(), ); ConstraintCommitment::new(composed_evaluations, commitment) }); diff --git a/prover/src/matrix/row_matrix.rs b/prover/src/matrix/row_matrix.rs index 6cb9ef60c..ef146643e 100644 --- a/prover/src/matrix/row_matrix.rs +++ b/prover/src/matrix/row_matrix.rs @@ -3,6 +3,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +use air::PartitionOptions; use alloc::vec::Vec; use crypto::{ElementHasher, VectorCommitment}; @@ -180,13 +181,14 @@ impl RowMatrix { /// * A vector commitment is computed for the resulting vector using the specified vector /// commitment scheme. /// * The resulting vector commitment is returned as the commitment to the entire matrix. - pub fn commit_to_rows(&self, partition_size: usize) -> V + pub fn commit_to_rows(&self, partition_options: PartitionOptions) -> V where H: ElementHasher, V: VectorCommitment, { // allocate vector to store row hashes let mut row_hashes = unsafe { uninit_vector::(self.num_rows()) }; + let partition_size = partition_options.partition_size::(self.num_cols()); if partition_size == self.num_cols() { // iterate though matrix rows, hashing each row @@ -200,17 +202,21 @@ impl RowMatrix { } ); } else { + let num_partitions = partition_options.num_partitions::(self.num_cols()); + // iterate though matrix rows, hashing each row batch_iter_mut!( &mut row_hashes, 128, // min batch size |batch: &mut [H::Digest], batch_offset: usize| { - let mut buffer = vec![H::Digest::default(); partition_size]; + let mut buffer = vec![H::Digest::default(); num_partitions]; for (i, row_hash) in batch.iter_mut().enumerate() { self.row(batch_offset + i) .chunks(partition_size) .zip(buffer.iter_mut()) - .for_each(|(chunk, buf)| *buf = H::hash_elements(chunk)); + .for_each(|(chunk, buf)| { + *buf = H::hash_elements(chunk); + }); *row_hash = H::merge_many(&buffer); } } diff --git a/prover/src/trace/trace_lde/default/mod.rs b/prover/src/trace/trace_lde/default/mod.rs index 26b5e3916..850ce0d90 100644 --- a/prover/src/trace/trace_lde/default/mod.rs +++ b/prover/src/trace/trace_lde/default/mod.rs @@ -43,7 +43,7 @@ pub struct DefaultTraceLde< aux_segment_oracles: Option, blowup: usize, trace_info: TraceInfo, - partition_option: PartitionOptions, + partition_options: PartitionOptions, _h: PhantomData, } @@ -64,16 +64,16 @@ where trace_info: &TraceInfo, main_trace: &ColMatrix, domain: &StarkDomain, - partition_option: PartitionOptions, + partition_options: PartitionOptions, ) -> (Self, TracePolyTable) { // extend the main execution trace and build a commitment to the extended trace let (main_segment_lde, main_segment_vector_com, main_segment_polys) = build_trace_commitment::( main_trace, domain, - partition_option.partition_size::(main_trace.num_cols()), + partition_options, ); - + let trace_poly_table = TracePolyTable::new(main_segment_polys); let trace_lde = DefaultTraceLde { main_segment_lde, @@ -82,7 +82,7 @@ where aux_segment_oracles: None, blowup: domain.trace_to_lde_blowup(), trace_info: trace_info.clone(), - partition_option, + partition_options, _h: PhantomData, }; @@ -151,9 +151,9 @@ where build_trace_commitment::( aux_trace, domain, - self.partition_option.partition_size::(aux_trace.num_cols()), + self.partition_options, ); - + // check errors assert!( usize::from(self.aux_segment_lde.is_some()) < self.trace_info.num_aux_segments(), @@ -276,7 +276,7 @@ where fn build_trace_commitment( trace: &ColMatrix, domain: &StarkDomain, - partition_size: usize, + partition_options: PartitionOptions, ) -> (RowMatrix, V, ColMatrix) where E: FieldElement, @@ -306,7 +306,7 @@ where // build trace commitment let commitment_domain_size = trace_lde.num_rows(); let trace_vector_com = info_span!("compute_execution_trace_commitment", commitment_domain_size) - .in_scope(|| trace_lde.commit_to_rows::(partition_size)); + .in_scope(|| trace_lde.commit_to_rows::(partition_options)); assert_eq!(trace_vector_com.domain_len(), commitment_domain_size); (trace_lde, trace_vector_com, trace_polys) diff --git a/verifier/src/channel.rs b/verifier/src/channel.rs index e6c511cd8..1425d86aa 100644 --- a/verifier/src/channel.rs +++ b/verifier/src/channel.rs @@ -445,7 +445,9 @@ where if partition_size == row.len() { H::hash_elements(row) } else { - let mut buffer = vec![H::Digest::default(); partition_size]; + let num_partitions = row.len().div_ceil(partition_size); + + let mut buffer = vec![H::Digest::default(); num_partitions]; row.chunks(partition_size) .zip(buffer.iter_mut())