Skip to content

Commit

Permalink
Multi GPU partition fixes (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
gswirski authored Nov 19, 2024
1 parent f8e1216 commit 8640715
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 47 deletions.
115 changes: 83 additions & 32 deletions air/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// LICENSE file in the root directory of this source tree.

use alloc::vec::Vec;
use core::{cmp, ops::Div};
use core::cmp;

use fri::FriOptions;
use math::{FieldElement, StarkField, ToElements};
Expand Down Expand Up @@ -76,16 +76,13 @@ pub enum FieldExtension {
/// collision resistance of the hash function used by the protocol. For example, if a hash function
/// with 128-bit collision resistance is used, soundness of a STARK proof cannot exceed 128 bits.
///
/// In addition to the above, the parameter `num_partitions` is used in order to specify the number
/// of partitions each of the traces committed to during proof generation is split into, and
/// the parameter `min_partition_size` gives a lower bound on the size of each such partition.
/// More precisely, and taking the main segment trace as an example, the prover will split the main
/// segment trace into `num_partitions` parts each of size at least `min_partition_size`. The prover
/// will then proceed to hash each part row-wise resulting in `num_partitions` digests per row of
/// the trace. The prover finally combines the `num_partitions` digest (per row) into one digest
/// (per row) and at this point the vector commitment scheme can be called.
/// In the case when `num_partitions` is equal to `1` the prover will just hash each row in one go
/// producing one digest per row of the trace.
/// In addition, partition options (see [PartitionOptions]) can be provided to split traces during
/// proving and distribute work across multiple devices. Taking the main segment trace as an example,
/// the prover will split the main segment trace into `num_partitions` parts, and then proceed to hash
/// each part row-wise resulting in `num_partitions` digests per row of the trace. Finally,
/// `num_partitions` digests (per row) are combined into one digest (per row) and at this point
/// a vector commitment scheme can be called. In the case when `num_partitions` is equal to `1` (default)
/// the prover will hash each row in one go producing one digest per row of the trace.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct ProofOptions {
num_queries: u8,
Expand Down Expand Up @@ -177,13 +174,13 @@ impl ProofOptions {
/// # Panics
/// Panics if:
/// - `num_partitions` is zero or greater than 16.
/// - `min_partition_size` is zero or greater than 256.
/// - `hash_rate` is zero or greater than 256.
pub const fn with_partitions(
mut self,
num_partitions: usize,
min_partition_size: usize,
hash_rate: usize,
) -> ProofOptions {
self.partition_options = PartitionOptions::new(num_partitions, min_partition_size);
self.partition_options = PartitionOptions::new(num_partitions, hash_rate);

self
}
Expand Down Expand Up @@ -277,7 +274,7 @@ impl Serializable for ProofOptions {
target.write_u8(self.fri_folding_factor);
target.write_u8(self.fri_remainder_max_degree);
target.write_u8(self.partition_options.num_partitions);
target.write_u8(self.partition_options.min_partition_size);
target.write_u8(self.partition_options.hash_rate);
}
}

Expand Down Expand Up @@ -347,53 +344,74 @@ impl Deserializable for FieldExtension {
// PARTITION OPTION IMPLEMENTATION
// ================================================================================================

/// Defines the parameters used when committing to the traces generated during the protocol.
/// Defines the parameters used to calculate partition size when committing to the traces
/// generated during the protocol.
///
/// Using multiple partitions will change how vector commitments are calculated:
/// - Input matrix columns are split into at most num_partitions partitions
/// - For each matrix row, a hash is calculated for each partition separately
/// - The results are merged together by one more hash iteration
///
/// This is especially useful when proving with multiple GPU cards where each device holds
/// a subset of data and allows less data reshuffling when generating commitments.
///
/// Hash_rate parameter is used to find the optimal partition size to minimize the number
/// of hash iterations. It specifies how many field elements are consumed by each hash iteration.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct PartitionOptions {
num_partitions: u8,
min_partition_size: u8,
hash_rate: u8,
}

impl PartitionOptions {
/// Returns a new instance of `[PartitionOptions]`.
pub const fn new(num_partitions: usize, min_partition_size: usize) -> Self {
pub const fn new(num_partitions: usize, hash_rate: usize) -> Self {
assert!(num_partitions >= 1, "number of partitions must be greater than or eqaul to 1");
assert!(num_partitions <= 16, "number of partitions must be smaller than or equal to 16");

assert!(
min_partition_size >= 1,
"smallest partition size must be greater than or equal to 1"
hash_rate >= 1,
"hash rate must be greater than or equal to 1"
);
assert!(
min_partition_size <= 256,
"smallest partition size must be smaller than or equal to 256"
hash_rate <= 256,
"hash rate must be smaller than or equal to 256"
);

Self {
num_partitions: num_partitions as u8,
min_partition_size: min_partition_size as u8,
hash_rate: hash_rate as u8,
}
}

/// Returns the size of each partition used when committing to the main and auxiliary traces as
/// well as the constraint evaluation trace.
/// The returned size is given in terms of number of columns in the field `E`.
pub fn partition_size<E: FieldElement>(&self, num_columns: usize) -> usize {
if self.num_partitions == 1 && self.min_partition_size == 1 {
if self.num_partitions == 1 {
return num_columns;
}
let base_elements_per_partition = cmp::max(
(num_columns * E::EXTENSION_DEGREE).div_ceil(self.num_partitions as usize),
self.min_partition_size as usize,
);

base_elements_per_partition.div(E::EXTENSION_DEGREE)
// Don't separate columns that would fit inside one hash iteration. min_partition_size is
// the number of `E` elements that can be consumed in one hash iteration.
let min_partition_size = self.hash_rate as usize / E::EXTENSION_DEGREE;

cmp::max(
num_columns.div_ceil(self.num_partitions as usize),
min_partition_size,
)
}

/// The actual number of partitions, after the min partition size implied
/// by the hash rate is taken into account.
pub fn num_partitions<E: FieldElement>(&self, num_columns: usize) -> usize {
num_columns.div_ceil(self.partition_size::<E>(num_columns))
}
}

impl Default for PartitionOptions {
fn default() -> Self {
Self { num_partitions: 1, min_partition_size: 1 }
Self { num_partitions: 1, hash_rate: 1 }
}
}

Expand All @@ -402,9 +420,9 @@ impl Default for PartitionOptions {

#[cfg(test)]
mod tests {
use math::fields::f64::BaseElement;
use math::fields::{f64::BaseElement, CubeExtension};

use super::{FieldExtension, ProofOptions, ToElements};
use super::{FieldExtension, PartitionOptions, ProofOptions, ToElements};

#[test]
fn proof_options_to_elements() {
Expand Down Expand Up @@ -438,4 +456,37 @@ mod tests {
);
assert_eq!(expected, options.to_elements());
}

#[test]
fn correct_partition_sizes() {
type E1 = BaseElement;
type E3 = CubeExtension<BaseElement>;

let options = PartitionOptions::new(4, 8);
let columns = 7;
assert_eq!(8, options.partition_size::<E1>(columns));
assert_eq!(1, options.num_partitions::<E1>(columns));

let options = PartitionOptions::new(4, 8);
let columns = 70;
assert_eq!(18, options.partition_size::<E1>(columns));
assert_eq!(4, options.num_partitions::<E1>(columns));

let options = PartitionOptions::new(2, 8);
let columns = 7;
assert_eq!(4, options.partition_size::<E3>(columns));
assert_eq!(2, options.num_partitions::<E3>(columns));

let options: PartitionOptions = PartitionOptions::new(4, 8);
let columns = 7;
assert_eq!(2, options.partition_size::<E3>(columns));
assert_eq!(4, options.num_partitions::<E3>(columns));

// don't use all partitions if it would result in sizes smaller than
// a single hash iteration can handle
let options: PartitionOptions = PartitionOptions::new(4, 8);
let columns = 3;
assert_eq!(2, options.partition_size::<E3>(columns));
assert_eq!(2, options.num_partitions::<E3>(columns));
}
}
3 changes: 1 addition & 2 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,7 @@ pub trait Prover {
.in_scope(|| {
let commitment = composed_evaluations.commit_to_rows::<Self::HashFn, Self::VC>(
self.options()
.partition_options()
.partition_size::<E>(num_constraint_composition_columns),
.partition_options(),
);
ConstraintCommitment::new(composed_evaluations, commitment)
});
Expand Down
12 changes: 9 additions & 3 deletions prover/src/matrix/row_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

use air::PartitionOptions;
use alloc::vec::Vec;

use crypto::{ElementHasher, VectorCommitment};
Expand Down Expand Up @@ -180,13 +181,14 @@ impl<E: FieldElement> RowMatrix<E> {
/// * A vector commitment is computed for the resulting vector using the specified vector
/// commitment scheme.
/// * The resulting vector commitment is returned as the commitment to the entire matrix.
pub fn commit_to_rows<H, V>(&self, partition_size: usize) -> V
pub fn commit_to_rows<H, V>(&self, partition_options: PartitionOptions) -> V
where
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
// allocate vector to store row hashes
let mut row_hashes = unsafe { uninit_vector::<H::Digest>(self.num_rows()) };
let partition_size = partition_options.partition_size::<E>(self.num_cols());

if partition_size == self.num_cols() {
// iterate though matrix rows, hashing each row
Expand All @@ -200,17 +202,21 @@ impl<E: FieldElement> RowMatrix<E> {
}
);
} else {
let num_partitions = partition_options.num_partitions::<E>(self.num_cols());

// iterate though matrix rows, hashing each row
batch_iter_mut!(
&mut row_hashes,
128, // min batch size
|batch: &mut [H::Digest], batch_offset: usize| {
let mut buffer = vec![H::Digest::default(); partition_size];
let mut buffer = vec![H::Digest::default(); num_partitions];
for (i, row_hash) in batch.iter_mut().enumerate() {
self.row(batch_offset + i)
.chunks(partition_size)
.zip(buffer.iter_mut())
.for_each(|(chunk, buf)| *buf = H::hash_elements(chunk));
.for_each(|(chunk, buf)| {
*buf = H::hash_elements(chunk);
});
*row_hash = H::merge_many(&buffer);
}
}
Expand Down
18 changes: 9 additions & 9 deletions prover/src/trace/trace_lde/default/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub struct DefaultTraceLde<
aux_segment_oracles: Option<V>,
blowup: usize,
trace_info: TraceInfo,
partition_option: PartitionOptions,
partition_options: PartitionOptions,
_h: PhantomData<H>,
}

Expand All @@ -64,16 +64,16 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<E::BaseField>,
domain: &StarkDomain<E::BaseField>,
partition_option: PartitionOptions,
partition_options: PartitionOptions,
) -> (Self, TracePolyTable<E>) {
// extend the main execution trace and build a commitment to the extended trace
let (main_segment_lde, main_segment_vector_com, main_segment_polys) =
build_trace_commitment::<E, E::BaseField, H, V>(
main_trace,
domain,
partition_option.partition_size::<E::BaseField>(main_trace.num_cols()),
partition_options,
);

let trace_poly_table = TracePolyTable::new(main_segment_polys);
let trace_lde = DefaultTraceLde {
main_segment_lde,
Expand All @@ -82,7 +82,7 @@ where
aux_segment_oracles: None,
blowup: domain.trace_to_lde_blowup(),
trace_info: trace_info.clone(),
partition_option,
partition_options,
_h: PhantomData,
};

Expand Down Expand Up @@ -151,9 +151,9 @@ where
build_trace_commitment::<E, E, H, Self::VC>(
aux_trace,
domain,
self.partition_option.partition_size::<E>(aux_trace.num_cols()),
self.partition_options,
);

// check errors
assert!(
usize::from(self.aux_segment_lde.is_some()) < self.trace_info.num_aux_segments(),
Expand Down Expand Up @@ -276,7 +276,7 @@ where
fn build_trace_commitment<E, F, H, V>(
trace: &ColMatrix<F>,
domain: &StarkDomain<E::BaseField>,
partition_size: usize,
partition_options: PartitionOptions,
) -> (RowMatrix<F>, V, ColMatrix<F>)
where
E: FieldElement,
Expand Down Expand Up @@ -306,7 +306,7 @@ where
// build trace commitment
let commitment_domain_size = trace_lde.num_rows();
let trace_vector_com = info_span!("compute_execution_trace_commitment", commitment_domain_size)
.in_scope(|| trace_lde.commit_to_rows::<H, V>(partition_size));
.in_scope(|| trace_lde.commit_to_rows::<H, V>(partition_options));
assert_eq!(trace_vector_com.domain_len(), commitment_domain_size);

(trace_lde, trace_vector_com, trace_polys)
Expand Down
4 changes: 3 additions & 1 deletion verifier/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,9 @@ where
if partition_size == row.len() {
H::hash_elements(row)
} else {
let mut buffer = vec![H::Digest::default(); partition_size];
let num_partitions = row.len().div_ceil(partition_size);

let mut buffer = vec![H::Digest::default(); num_partitions];

row.chunks(partition_size)
.zip(buffer.iter_mut())
Expand Down

0 comments on commit 8640715

Please sign in to comment.