Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi GPU partition fixes #340

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions air/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ impl ProofOptions {
pub const fn with_partitions(
mut self,
num_partitions: usize,
min_partition_size: usize,
hash_rate: usize,
) -> ProofOptions {
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
self.partition_options = PartitionOptions::new(num_partitions, min_partition_size);
self.partition_options = PartitionOptions::new(num_partitions, hash_rate);

self
}
Expand Down Expand Up @@ -277,7 +277,7 @@ impl Serializable for ProofOptions {
target.write_u8(self.fri_folding_factor);
target.write_u8(self.fri_remainder_max_degree);
target.write_u8(self.partition_options.num_partitions);
target.write_u8(self.partition_options.min_partition_size);
target.write_u8(self.partition_options.hash_rate);
}
}

Expand Down Expand Up @@ -351,49 +351,56 @@ impl Deserializable for FieldExtension {
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct PartitionOptions {
num_partitions: u8,
min_partition_size: u8,
hash_rate: u8,
}
irakliyk marked this conversation as resolved.
Show resolved Hide resolved

impl PartitionOptions {
/// Returns a new instance of `[PartitionOptions]`.
pub const fn new(num_partitions: usize, min_partition_size: usize) -> Self {
pub const fn new(num_partitions: usize, hash_rate: usize) -> Self {
assert!(num_partitions >= 1, "number of partitions must be greater than or eqaul to 1");
assert!(num_partitions <= 16, "number of partitions must be smaller than or equal to 16");

assert!(
min_partition_size >= 1,
"smallest partition size must be greater than or equal to 1"
hash_rate >= 1,
"hash rate must be greater than or equal to 1"
);
assert!(
min_partition_size <= 256,
"smallest partition size must be smaller than or equal to 256"
hash_rate <= 256,
"hash rate must be smaller than or equal to 256"
);

Self {
num_partitions: num_partitions as u8,
min_partition_size: min_partition_size as u8,
hash_rate: hash_rate as u8,
}
}

/// Returns the size of each partition used when committing to the main and auxiliary traces as
/// well as the constraint evaluation trace.
/// The returned size is given in terms of number of columns in the field `E`.
pub fn partition_size<E: FieldElement>(&self, num_columns: usize) -> usize {
if self.num_partitions == 1 && self.min_partition_size == 1 {
if self.num_partitions == 1 {
return num_columns;
}
let min_partition_size = self.hash_rate as usize / E::EXTENSION_DEGREE;
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
let base_elements_per_partition = cmp::max(
(num_columns * E::EXTENSION_DEGREE).div_ceil(self.num_partitions as usize),
self.min_partition_size as usize,
min_partition_size,
);

base_elements_per_partition.div(E::EXTENSION_DEGREE)
irakliyk marked this conversation as resolved.
Show resolved Hide resolved
}

/// The actual number of partitions, after the min partition size implied
/// by the hash rate is taken into account.
pub fn num_partitions<E: FieldElement>(&self, num_columns: usize) -> usize {
num_columns.div_ceil(self.partition_size::<E>(num_columns))
}
}

impl Default for PartitionOptions {
fn default() -> Self {
Self { num_partitions: 1, min_partition_size: 1 }
Self { num_partitions: 1, hash_rate: 1 }
}
}

Expand Down
3 changes: 1 addition & 2 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,7 @@ pub trait Prover {
.in_scope(|| {
let commitment = composed_evaluations.commit_to_rows::<Self::HashFn, Self::VC>(
self.options()
.partition_options()
.partition_size::<E>(num_constraint_composition_columns),
.partition_options(),
);
ConstraintCommitment::new(composed_evaluations, commitment)
});
Expand Down
12 changes: 9 additions & 3 deletions prover/src/matrix/row_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

use air::PartitionOptions;
use alloc::vec::Vec;

use crypto::{ElementHasher, VectorCommitment};
Expand Down Expand Up @@ -180,13 +181,14 @@ impl<E: FieldElement> RowMatrix<E> {
/// * A vector commitment is computed for the resulting vector using the specified vector
/// commitment scheme.
/// * The resulting vector commitment is returned as the commitment to the entire matrix.
pub fn commit_to_rows<H, V>(&self, partition_size: usize) -> V
pub fn commit_to_rows<H, V>(&self, partition_options: PartitionOptions) -> V
where
H: ElementHasher<BaseField = E::BaseField>,
V: VectorCommitment<H>,
{
// allocate vector to store row hashes
let mut row_hashes = unsafe { uninit_vector::<H::Digest>(self.num_rows()) };
let partition_size = partition_options.partition_size::<E>(self.num_cols());

if partition_size == self.num_cols() {
// iterate though matrix rows, hashing each row
Expand All @@ -200,17 +202,21 @@ impl<E: FieldElement> RowMatrix<E> {
}
);
} else {
let num_partitions = partition_options.num_partitions::<E>(self.num_cols());

// iterate though matrix rows, hashing each row
batch_iter_mut!(
&mut row_hashes,
128, // min batch size
|batch: &mut [H::Digest], batch_offset: usize| {
let mut buffer = vec![H::Digest::default(); partition_size];
let mut buffer = vec![H::Digest::default(); num_partitions];
for (i, row_hash) in batch.iter_mut().enumerate() {
self.row(batch_offset + i)
.chunks(partition_size)
.zip(buffer.iter_mut())
.for_each(|(chunk, buf)| *buf = H::hash_elements(chunk));
.for_each(|(chunk, buf)| {
*buf = H::hash_elements(chunk);
});
*row_hash = H::merge_many(&buffer);
}
}
Expand Down
18 changes: 9 additions & 9 deletions prover/src/trace/trace_lde/default/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub struct DefaultTraceLde<
aux_segment_oracles: Option<V>,
blowup: usize,
trace_info: TraceInfo,
partition_option: PartitionOptions,
partition_options: PartitionOptions,
_h: PhantomData<H>,
}

Expand All @@ -64,16 +64,16 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<E::BaseField>,
domain: &StarkDomain<E::BaseField>,
partition_option: PartitionOptions,
partition_options: PartitionOptions,
) -> (Self, TracePolyTable<E>) {
// extend the main execution trace and build a commitment to the extended trace
let (main_segment_lde, main_segment_vector_com, main_segment_polys) =
build_trace_commitment::<E, E::BaseField, H, V>(
main_trace,
domain,
partition_option.partition_size::<E::BaseField>(main_trace.num_cols()),
partition_options,
);

let trace_poly_table = TracePolyTable::new(main_segment_polys);
let trace_lde = DefaultTraceLde {
main_segment_lde,
Expand All @@ -82,7 +82,7 @@ where
aux_segment_oracles: None,
blowup: domain.trace_to_lde_blowup(),
trace_info: trace_info.clone(),
partition_option,
partition_options,
_h: PhantomData,
};

Expand Down Expand Up @@ -151,9 +151,9 @@ where
build_trace_commitment::<E, E, H, Self::VC>(
aux_trace,
domain,
self.partition_option.partition_size::<E>(aux_trace.num_cols()),
self.partition_options,
);

// check errors
assert!(
usize::from(self.aux_segment_lde.is_some()) < self.trace_info.num_aux_segments(),
Expand Down Expand Up @@ -276,7 +276,7 @@ where
fn build_trace_commitment<E, F, H, V>(
trace: &ColMatrix<F>,
domain: &StarkDomain<E::BaseField>,
partition_size: usize,
partition_options: PartitionOptions,
) -> (RowMatrix<F>, V, ColMatrix<F>)
where
E: FieldElement,
Expand Down Expand Up @@ -306,7 +306,7 @@ where
// build trace commitment
let commitment_domain_size = trace_lde.num_rows();
let trace_vector_com = info_span!("compute_execution_trace_commitment", commitment_domain_size)
.in_scope(|| trace_lde.commit_to_rows::<H, V>(partition_size));
.in_scope(|| trace_lde.commit_to_rows::<H, V>(partition_options));
assert_eq!(trace_vector_com.domain_len(), commitment_domain_size);

(trace_lde, trace_vector_com, trace_polys)
Expand Down
4 changes: 3 additions & 1 deletion verifier/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,9 @@ where
if partition_size == row.len() {
H::hash_elements(row)
} else {
let mut buffer = vec![H::Digest::default(); partition_size];
let num_partitions = row.len().div_ceil(partition_size);

let mut buffer = vec![H::Digest::default(); num_partitions];

row.chunks(partition_size)
.zip(buffer.iter_mut())
Expand Down
Loading