diff --git a/prover/Cargo.toml b/prover/Cargo.toml index f75acdc8e..125ab2a4f 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -15,20 +15,16 @@ rust-version = "1.78" [lib] bench = false -[[bench]] -name = "logup_gkr" -harness = false - [[bench]] name = "logup_gkr_e2e" harness = false [[bench]] -name = "row_matrix" +name = "logup_gkr" harness = false [[bench]] -name = "lagrange_kernel" +name = "row_matrix" harness = false [features] diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 899482f22..0da61fbd8 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -8,9 +8,9 @@ use sumcheck::{ EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; use tracing::instrument; -use utils::iter; #[cfg(feature = "concurrent")] -pub use utils::rayon::prelude::*; +use utils::rayon::prelude::*; +use utils::{iter, iter_mut, uninit_vector}; use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; use crate::{matrix::ColMatrix, Trace}; @@ -81,7 +81,7 @@ pub fn prove_gkr( // build the MLEs of the relevant main trace columns let main_trace_mls = - build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; + build_mle_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; // build the periodic table representing periodic columns as multi-linear extensions let periodic_table = evaluator.build_periodic_values(); @@ -140,11 +140,11 @@ fn prove_input_layer< /// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for /// LogUp-GKR. #[instrument(skip_all)] -fn build_mls_from_main_trace_segment( +fn build_mle_from_main_trace_segment( oracles: &[LogUpGkrOracle], main_trace: &ColMatrix<::BaseField>, ) -> Result>, GkrProverError> { - let mut mls = vec![]; + let mut mls = Vec::with_capacity(oracles.len()); for oracle in oracles { match oracle { @@ -156,13 +156,18 @@ fn build_mls_from_main_trace_segment( }, LogUpGkrOracle::NextRow(index) => { let col = main_trace.get_column(*index); - let mut values: Vec = col.iter().map(|value| E::from(*value)).collect(); - values.rotate_left(1); + + let mut values: Vec = unsafe { uninit_vector(col.len()) }; + values[col.len() - 1] = E::from(col[0]); + iter_mut!(&mut values[..col.len() - 1]) + .enumerate() + .for_each(|(i, value)| *value = E::from(col[i + 1])); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, }; } + Ok(mls) }