Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi GPU partition fixes #340

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 83 additions & 32 deletions air/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// LICENSE file in the root directory of this source tree.

use alloc::vec::Vec;
use core::{cmp, ops::Div};
use core::cmp;

use fri::FriOptions;
use math::{FieldElement, StarkField, ToElements};
Expand Down Expand Up @@ -76,16 +76,13 @@ pub enum FieldExtension {
/// collision resistance of the hash function used by the protocol. For example, if a hash function
/// with 128-bit collision resistance is used, soundness of a STARK proof cannot exceed 128 bits.
///
/// In addition to the above, the parameter `num_partitions` is used in order to specify the number
/// of partitions each of the traces committed to during proof generation is split into, and
/// the parameter `min_partition_size` gives a lower bound on the size of each such partition.
/// More precisely, and taking the main segment trace as an example, the prover will split the main
/// segment trace into `num_partitions` parts each of size at least `min_partition_size`. The prover
/// will then proceed to hash each part row-wise resulting in `num_partitions` digests per row of
/// the trace. The prover finally combines the `num_partitions` digest (per row) into one digest
/// (per row) and at this point the vector commitment scheme can be called.
/// In the case when `num_partitions` is equal to `1` the prover will just hash each row in one go
/// producing one digest per row of the trace.
/// In addition, partition options (see [PartitionOptions]) can be provided to split traces during
/// proving and distribute work across multiple devices. Taking the main segment trace as an example,
/// the prover will split the main segment trace into `num_partitions` parts, and then proceed to hash
/// each part row-wise resulting in `num_partitions` digests per row of the trace. Finally,
/// `num_partitions` digests (per row) are combined into one digest (per row) and at this point
/// a vector commitment scheme can be called. In the case when `num_partitions` is equal to `1` (default)
/// the prover will hash each row in one go producing one digest per row of the trace.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct ProofOptions {
num_queries: u8,
Expand Down Expand Up @@ -177,13 +174,13 @@ impl ProofOptions {
/// # Panics
/// Panics if:
/// - `num_partitions` is zero or greater than 16.
/// - `min_partition_size` is zero or greater than 256.
/// - `hash_rate` is zero or greater than 256.
pub const fn with_partitions(
mut self,
num_partitions: usize,
min_partition_size: usize,
hash_rate: usize,
) -> ProofOptions {
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
self.partition_options = PartitionOptions::new(num_partitions, min_partition_size);
self.partition_options = PartitionOptions::new(num_partitions, hash_rate);

self
}
Expand Down Expand Up @@ -277,7 +274,7 @@ impl Serializable for ProofOptions {
target.write_u8(self.fri_folding_factor);
target.write_u8(self.fri_remainder_max_degree);
target.write_u8(self.partition_options.num_partitions);
target.write_u8(self.partition_options.min_partition_size);
target.write_u8(self.partition_options.hash_rate);
}
}

Expand Down Expand Up @@ -347,53 +344,74 @@ impl Deserializable for FieldExtension {
// PARTITION OPTION IMPLEMENTATION
// ================================================================================================

/// Defines the parameters used when committing to the traces generated during the protocol.
/// Defines the parameters used to calculate partition size when committing to the traces
/// generated during the protocol.
///
/// Using multiple partitions will change how vector commitments are calculated:
/// - Input matrix columns are split into at most num_partitions partitions
/// - For each matrix row, a hash is calculated for each partition separately
/// - The results are merged together by one more hash iteration
///
/// This is especially useful when proving with multiple GPU cards where each device holds
/// a subset of data and allows less data reshuffling when generating commitments.
///
/// Hash_rate parameter is used to find the optimal partition size to minimize the number
/// of hash iterations. It specifies how many field elements are consumed by each hash iteration.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct PartitionOptions {
num_partitions: u8,
min_partition_size: u8,
hash_rate: u8,
}
irakliyk marked this conversation as resolved.
Show resolved Hide resolved

impl PartitionOptions {
/// Returns a new instance of `[PartitionOptions]`.
pub const fn new(num_partitions: usize, min_partition_size: usize) -> Self {
pub const fn new(num_partitions: usize, hash_rate: usize) -> Self {
assert!(num_partitions >= 1, "number of partitions must be greater than or eqaul to 1");
assert!(num_partitions <= 16, "number of partitions must be smaller than or equal to 16");

assert!(
min_partition_size >= 1,
"smallest partition size must be greater than or equal to 1"
hash_rate >= 1,
"hash rate must be greater than or equal to 1"
);
assert!(
min_partition_size <= 256,
"smallest partition size must be smaller than or equal to 256"
hash_rate <= 256,
"hash rate must be smaller than or equal to 256"
);

Self {
num_partitions: num_partitions as u8,
min_partition_size: min_partition_size as u8,
hash_rate: hash_rate as u8,
}
}

/// Returns the size of each partition used when committing to the main and auxiliary traces as
/// well as the constraint evaluation trace.
/// The returned size is given in terms of number of columns in the field `E`.
pub fn partition_size<E: FieldElement>(&self, num_columns: usize) -> usize {
if self.num_partitions == 1 && self.min_partition_size == 1 {
if self.num_partitions == 1 {
return num_columns;
}
let base_elements_per_partition = cmp::max(
(num_columns * E::EXTENSION_DEGREE).div_ceil(self.num_partitions as usize),
self.min_partition_size as usize,
);

base_elements_per_partition.div(E::EXTENSION_DEGREE)
// Don't separate columns that would fit inside one hash iteration. min_partition_size is
// the number of `E` elements that can be consumed in one hash iteration.
let min_partition_size = self.hash_rate as usize / E::EXTENSION_DEGREE;
irakliyk marked this conversation as resolved.
Show resolved Hide resolved

cmp::max(
num_columns.div_ceil(self.num_partitions as usize),
min_partition_size,
)
}

/// The actual number of partitions, after the min partition size implied
/// by the hash rate is taken into account.
pub fn num_partitions<E: FieldElement>(&self, num_columns: usize) -> usize {
num_columns.div_ceil(self.partition_size::<E>(num_columns))
}
}

impl Default for PartitionOptions {
fn default() -> Self {
Self { num_partitions: 1, min_partition_size: 1 }
Self { num_partitions: 1, hash_rate: 1 }
}
}

Expand All @@ -402,9 +420,9 @@ impl Default for PartitionOptions {

#[cfg(test)]
mod tests {
use math::fields::f64::BaseElement;
use math::fields::{f64::BaseElement, CubeExtension};

use super::{FieldExtension, ProofOptions, ToElements};
use super::{FieldExtension, PartitionOptions, ProofOptions, ToElements};

#[test]
fn proof_options_to_elements() {
Expand Down Expand Up @@ -438,4 +456,37 @@ mod tests {
);
assert_eq!(expected, options.to_elements());
}

#[test]
fn correct_partition_sizes() {
type E1 = BaseElement;
type E3 = CubeExtension<BaseElement>;

let options = PartitionOptions::new(4, 8);
let columns = 7;
assert_eq!(8, options.partition_size::<E1>(columns));
assert_eq!(1, options.num_partitions::<E1>(columns));

let options = PartitionOptions::new(4, 8);
let columns = 70;
assert_eq!(18, options.partition_size::<E1>(columns));
assert_eq!(4, options.num_partitions::<E1>(columns));

let options = PartitionOptions::new(2, 8);
let columns = 7;
assert_eq!(4, options.partition_size::<E3>(columns));
assert_eq!(2, options.num_partitions::<E3>(columns));

let options: PartitionOptions = PartitionOptions::new(4, 8);
let columns = 7;
assert_eq!(2, options.partition_size::<E3>(columns));
assert_eq!(4, options.num_partitions::<E3>(columns));

// don't use all partitions if it would result in sizes smaller than
// a single hash iteration can handle
let options: PartitionOptions = PartitionOptions::new(4, 8);
let columns = 3;
assert_eq!(2, options.partition_size::<E3>(columns));
assert_eq!(2, options.num_partitions::<E3>(columns));
}
}
3 changes: 1 addition & 2 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,7 @@ pub trait Prover {
.in_scope(|| {
let commitment = composed_evaluations.commit_to_rows::<Self::HashFn, Self::VC>(
self.options()
.partition_options()
.partition_size::<E>(num_constraint_composition_columns),
.partition_options(),
);
ConstraintCommitment::new(composed_evaluations, commitment)
});
Expand Down
12 changes: 9 additions & 3 deletions prover/src/matrix/row_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

use air::PartitionOptions;
use alloc::vec::Vec;

use crypto::{ElementHasher, VectorCommitment};
Expand Down Expand Up @@ -180,13 +181,14 @@ impl<E: FieldElement> RowMatrix<E> {
/// * A vector commitment is computed for the resulting vector using the specified vector
/// commitment scheme.
/// * The resulting vector commitment is returned as the commitment to the entire matrix.
pub fn commit_to_rows<H, V>(&self, partition_size: usize) -> V
pub fn commit_to_rows<H, V>(&self, partition_options: PartitionOptions) -> V
where
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
// allocate vector to store row hashes
let mut row_hashes = unsafe { uninit_vector::<H::Digest>(self.num_rows()) };
let partition_size = partition_options.partition_size::<E>(self.num_cols());

if partition_size == self.num_cols() {
// iterate though matrix rows, hashing each row
Expand All @@ -200,17 +202,21 @@ impl<E: FieldElement> RowMatrix<E> {
}
);
} else {
let num_partitions = partition_options.num_partitions::<E>(self.num_cols());

// iterate though matrix rows, hashing each row
batch_iter_mut!(
&mut row_hashes,
128, // min batch size
|batch: &mut [H::Digest], batch_offset: usize| {
let mut buffer = vec![H::Digest::default(); partition_size];
let mut buffer = vec![H::Digest::default(); num_partitions];
for (i, row_hash) in batch.iter_mut().enumerate() {
self.row(batch_offset + i)
.chunks(partition_size)
.zip(buffer.iter_mut())
.for_each(|(chunk, buf)| *buf = H::hash_elements(chunk));
.for_each(|(chunk, buf)| {
*buf = H::hash_elements(chunk);
});
*row_hash = H::merge_many(&buffer);
}
}
Expand Down
18 changes: 9 additions & 9 deletions prover/src/trace/trace_lde/default/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub struct DefaultTraceLde<
aux_segment_oracles: Option<V>,
blowup: usize,
trace_info: TraceInfo,
partition_option: PartitionOptions,
partition_options: PartitionOptions,
_h: PhantomData<H>,
}

Expand All @@ -64,16 +64,16 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<E::BaseField>,
domain: &StarkDomain<E::BaseField>,
partition_option: PartitionOptions,
partition_options: PartitionOptions,
) -> (Self, TracePolyTable<E>) {
// extend the main execution trace and build a commitment to the extended trace
let (main_segment_lde, main_segment_vector_com, main_segment_polys) =
build_trace_commitment::<E, E::BaseField, H, V>(
main_trace,
domain,
partition_option.partition_size::<E::BaseField>(main_trace.num_cols()),
partition_options,
);

let trace_poly_table = TracePolyTable::new(main_segment_polys);
let trace_lde = DefaultTraceLde {
main_segment_lde,
Expand All @@ -82,7 +82,7 @@ where
aux_segment_oracles: None,
blowup: domain.trace_to_lde_blowup(),
trace_info: trace_info.clone(),
partition_option,
partition_options,
_h: PhantomData,
};

Expand Down Expand Up @@ -151,9 +151,9 @@ where
build_trace_commitment::<E, E, H, Self::VC>(
aux_trace,
domain,
self.partition_option.partition_size::<E>(aux_trace.num_cols()),
self.partition_options,
);

// check errors
assert!(
usize::from(self.aux_segment_lde.is_some()) < self.trace_info.num_aux_segments(),
Expand Down Expand Up @@ -276,7 +276,7 @@ where
fn build_trace_commitment<E, F, H, V>(
trace: &ColMatrix<F>,
domain: &StarkDomain<E::BaseField>,
partition_size: usize,
partition_options: PartitionOptions,
) -> (RowMatrix<F>, V, ColMatrix<F>)
where
E: FieldElement,
Expand Down Expand Up @@ -306,7 +306,7 @@ where
// build trace commitment
let commitment_domain_size = trace_lde.num_rows();
let trace_vector_com = info_span!("compute_execution_trace_commitment", commitment_domain_size)
.in_scope(|| trace_lde.commit_to_rows::<H, V>(partition_size));
.in_scope(|| trace_lde.commit_to_rows::<H, V>(partition_options));
assert_eq!(trace_vector_com.domain_len(), commitment_domain_size);

(trace_lde, trace_vector_com, trace_polys)
Expand Down
4 changes: 3 additions & 1 deletion verifier/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,9 @@ where
if partition_size == row.len() {
H::hash_elements(row)
} else {
let mut buffer = vec![H::Digest::default(); partition_size];
let num_partitions = row.len().div_ceil(partition_size);

let mut buffer = vec![H::Digest::default(); num_partitions];

row.chunks(partition_size)
.zip(buffer.iter_mut())
Expand Down
Loading