Skip to content

Commit

Permalink
Parallelize input layer generation (#324)
Browse files Browse the repository at this point in the history
  • Loading branch information
Al-Kindi-0 authored Sep 20, 2024
1 parent 3d095e5 commit 09ed09b
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 36 deletions.
88 changes: 53 additions & 35 deletions prover/src/logup_gkr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator};
use math::FieldElement;
use sumcheck::{EqFunction, MultiLinearPoly, SumCheckProverError};
use tracing::instrument;
use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use utils::{
batch_iter_mut, chunks, uninit_vector, ByteReader, ByteWriter, Deserializable,
DeserializationError, Serializable,
};

use crate::Trace;

mod prover;
pub use prover::prove_gkr;
#[cfg(feature = "concurrent")]
pub use utils::rayon::{current_num_threads as rayon_num_threads, prelude::*};

// EVALUATED CIRCUIT
// ================================================================================================
Expand Down Expand Up @@ -106,53 +111,66 @@ impl<E: FieldElement> EvaluatedCircuit<E> {
/// Generates the input layer of the circuit from the main trace columns and some randomness
/// provided by the verifier.
fn generate_input_layer(
main_trace: &impl Trace<BaseField = E::BaseField>,
trace: &impl Trace<BaseField = E::BaseField>,
evaluator: &impl LogUpGkrEvaluator<BaseField = E::BaseField>,
log_up_randomness: &[E],
) -> CircuitLayer<E> {
let num_fractions = evaluator.get_num_fractions();
let periodic_values = evaluator.build_periodic_values();

let mut input_layer_wires =
Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions);
let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols());

let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()];
let mut periodic_values_row = vec![E::BaseField::ZERO; periodic_values.num_columns()];
let mut numerators = vec![E::ZERO; num_fractions];
let mut denominators = vec![E::ZERO; num_fractions];
for i in 0..main_trace.main_segment().num_rows() {
let wires_from_trace_row = {
main_trace.read_main_frame(i, &mut main_frame);
periodic_values.fill_periodic_values_at(i, &mut periodic_values_row);
evaluator.build_query(&main_frame, &mut query);

evaluator.evaluate_query(
&query,
&periodic_values_row,
log_up_randomness,
&mut numerators,
&mut denominators,
);
let input_gates_values: Vec<CircuitWire<E>> = numerators
.iter()
.zip(denominators.iter())
.map(|(numerator, denominator)| CircuitWire::new(*numerator, *denominator))
.collect();
input_gates_values
};

input_layer_wires.extend(wires_from_trace_row);
}
unsafe { uninit_vector(trace.main_segment().num_rows() * num_fractions) };
let num_cols = trace.main_segment().num_cols();
let num_oracles = evaluator.get_oracles().len();
let num_periodic_cols = periodic_values.num_columns();

batch_iter_mut!(
&mut input_layer_wires,
1024,
|batch: &mut [CircuitWire<E>], batch_offset: usize| {
let mut main_frame = EvaluationFrame::new(num_cols);
let mut query = vec![E::BaseField::ZERO; num_oracles];
let mut periodic_values_row = vec![E::BaseField::ZERO; num_periodic_cols];
let mut numerators = vec![E::ZERO; num_fractions];
let mut denominators = vec![E::ZERO; num_fractions];

let row_offset = batch_offset / num_fractions;
let batch_size = batch.len();
let num_rows_per_batch = batch_size / num_fractions;

for i in
(0..trace.main_segment().num_rows()).skip(row_offset).take(num_rows_per_batch)
{
trace.read_main_frame(i, &mut main_frame);
periodic_values.fill_periodic_values_at(i, &mut periodic_values_row);
evaluator.build_query(&main_frame, &mut query);

evaluator.evaluate_query(
&query,
&periodic_values_row,
log_up_randomness,
&mut numerators,
&mut denominators,
);

let n = (i - row_offset) * num_fractions;
for ((wire, numerator), denominator) in batch[n..n + num_fractions]
.iter_mut()
.zip(numerators.iter())
.zip(denominators.iter())
{
*wire = CircuitWire::new(*numerator, *denominator);
}
}
}
);

CircuitLayer::new(input_layer_wires)
}

/// Computes the subsequent layer of the circuit from a given layer.
fn compute_next_layer(prev_layer: &CircuitLayer<E>) -> CircuitLayer<E> {
let next_layer_wires = prev_layer
.wires()
.chunks_exact(2)
let next_layer_wires = chunks!(prev_layer.wires(), 2)
.map(|input_wires| {
let left_input_wire = input_wires[0];
let right_input_wire = input_wires[1];
Expand Down
18 changes: 18 additions & 0 deletions utils/core/src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,21 @@ macro_rules! batch_iter_mut {
$c($e, 0);
};
}

/// Returns either a regular or a parallel iterator over at most `chunk_size` elements depending
/// on whether `concurrent` feature is enabled.
///
/// When `concurrent` feature is enabled, creates a parallel iterator; otherwise, creates a
/// regular iterator.
#[macro_export]
macro_rules! chunks {
($e: expr, $chunk_size: expr) => {{
#[cfg(feature = "concurrent")]
let result = $e.par_chunks($chunk_size);

#[cfg(not(feature = "concurrent"))]
let result = $e.chunks($chunk_size);

result
}};
}
2 changes: 1 addition & 1 deletion winterfell/src/tests/logup_gkr_periodic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
#[test]
fn test_logup_gkr_periodic() {
let aux_trace_width = 1;
let trace = LogUpGkrPeriodic::new(2_usize.pow(7), aux_trace_width);
let trace = LogUpGkrPeriodic::new(2_usize.pow(12), aux_trace_width);
let prover = LogUpGkrPeriodicProver::new(aux_trace_width);

let proof = prover.prove(trace).unwrap();
Expand Down

0 comments on commit 09ed09b

Please sign in to comment.