From 595751ae9b3ce7cf163a416967bd7e85966e028f Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:22:09 +0200 Subject: [PATCH 1/9] wip --- prover/benches/logup_gkr.rs | 33 ++++- sumcheck/src/prover/mod.rs | 3 + sumcheck/src/prover/plain.rs | 227 ++++++++++++++++++++--------------- 3 files changed, 163 insertions(+), 100 deletions(-) diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs index e86c84aef..3bc4ac397 100644 --- a/prover/benches/logup_gkr.rs +++ b/prover/benches/logup_gkr.rs @@ -239,11 +239,11 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { } fn get_num_fractions(&self) -> usize { - 4 + 16 } fn max_degree(&self) -> usize { - 3 + 10 } fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) @@ -264,18 +264,43 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 4); - assert_eq!(denominator.len(), 4); + assert_eq!(numerator.len(), 16); + assert_eq!(denominator.len(), 16); assert_eq!(query.len(), 5); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; numerator[2] = E::ONE; numerator[3] = E::ONE; + numerator[4] = E::from(query[1]); + numerator[5] = E::ONE; + numerator[6] = E::ONE; + numerator[7] = E::ONE; + numerator[8] = E::from(query[1]); + numerator[9] = E::ONE; + numerator[10] = E::ONE; + numerator[11] = E::ONE; + numerator[12] = E::from(query[1]); + numerator[13] = E::ONE; + numerator[14] = E::ONE; + numerator[15] = E::ONE; denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); + denominator[4] = rand_values[0] - E::from(query[0]); + denominator[5] = -(rand_values[0] - E::from(query[2])); + denominator[6] = -(rand_values[0] - E::from(query[3])); + denominator[7] = -(rand_values[0] - E::from(query[4])); + denominator[8] = rand_values[0] - E::from(query[0]); + denominator[9] = -(rand_values[0] - E::from(query[2])); + denominator[10] = -(rand_values[0] - E::from(query[3])); + denominator[11] = -(rand_values[0] - E::from(query[4])); + denominator[12] = rand_values[0] - E::from(query[0]); + denominator[13] = -(rand_values[0] - E::from(query[2])); + denominator[14] = -(rand_values[0] - E::from(query[3])); + denominator[15] = -(rand_values[0] - E::from(query[4])); + } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index 13d35e551..705a67918 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -11,3 +11,6 @@ pub use plain::sumcheck_prove_plain; mod error; pub use error::SumCheckProverError; + +#[cfg(feature = "concurrent")] +const MINIMAL_MLE_SIZE: usize = 1 << 4; \ No newline at end of file diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index e0092cf10..f0343f65a 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -10,6 +10,8 @@ pub use rayon::prelude::*; use smallvec::smallvec; use super::SumCheckProverError; +#[cfg(feature = "concurrent")] +use super::MINIMAL_MLE_SIZE; use crate::{ comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof, @@ -67,9 +69,130 @@ pub fn sumcheck_prove_plain( + p0: &MultiLinearPoly, + p1: &MultiLinearPoly, + q0: &MultiLinearPoly, + q1: &MultiLinearPoly, + eq: &MultiLinearPoly, + len: usize, + r_batch: E, +) -> (E, E, E) { + #[cfg(feature = "concurrent")] + let res = if p0.num_evaluations() >= MINIMAL_MLE_SIZE { + parallel(p0, p1, q0, q1, eq, len, r_batch) + } else { + serial(p0, p1, q0, q1, eq, len, r_batch) + }; + + #[cfg(not(feature = "concurrent"))] + let res = serial(p0, p1, q0, q1, eq, len, r_batch); + + res +} + +fn serial( + p0: &MultiLinearPoly, + p1: &MultiLinearPoly, + q0: &MultiLinearPoly, + q1: &MultiLinearPoly, + eq: &MultiLinearPoly, + len: usize, + r_batch: E, +) -> (E, E, E) { + (0..len).fold((E::ZERO, E::ZERO, E::ZERO), |(acc_point_1, acc_point_2, acc_point_3), i| { + let round_poly_eval_at_1 = comb_func( + p0[2 * i + 1], + p1[2 * i + 1], + q0[2 * i + 1], + q1[2 * i + 1], + eq[2 * i + 1], + r_batch, + ); + + let p0_delta = p0[2 * i + 1] - p0[2 * i]; + let p1_delta = p1[2 * i + 1] - p1[2 * i]; + let q0_delta = q0[2 * i + 1] - q0[2 * i]; + let q1_delta = q1[2 * i + 1] - q1[2 * i]; + let eq_delta = eq[2 * i + 1] - eq[2 * i]; + + let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; + let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; + let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; + let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; + let mut eq_evx = eq[2 * i + 1] + eq_delta; + let round_poly_eval_at_2 = + comb_func(p0_eval_at_x, p1_eval_at_x, q0_eval_at_x, q1_eval_at_x, eq_evx, r_batch); + + p0_eval_at_x += p0_delta; + p1_eval_at_x += p1_delta; + q0_eval_at_x += q0_delta; + q1_eval_at_x += q1_delta; + eq_evx += eq_delta; + let round_poly_eval_at_3 = + comb_func(p0_eval_at_x, p1_eval_at_x, q0_eval_at_x, q1_eval_at_x, eq_evx, r_batch); + + ( + round_poly_eval_at_1 + acc_point_1, + round_poly_eval_at_2 + acc_point_2, + round_poly_eval_at_3 + acc_point_3, + ) + }) +} + +#[cfg(feature = "concurrent")] +fn parallel( + p0: &MultiLinearPoly, + p1: &MultiLinearPoly, + q0: &MultiLinearPoly, + q1: &MultiLinearPoly, + eq: &MultiLinearPoly, + len: usize, + r_batch: E, +) -> (E, E, E) { + (0..len) + .into_par_iter() + .fold( + || (E::ZERO, E::ZERO, E::ZERO), |(acc_point_1, acc_point_2, acc_point_3), i| { let round_poly_eval_at_1 = comb_func( p0[2 * i + 1], @@ -120,97 +243,9 @@ pub fn sumcheck_prove_plain Date: Tue, 24 Sep 2024 21:44:50 +0200 Subject: [PATCH 2/9] restore benchmark --- prover/benches/logup_gkr.rs | 34 ++++++---------------------------- sumcheck/src/prover/mod.rs | 2 +- 2 files changed, 7 insertions(+), 29 deletions(-) diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs index 3bc4ac397..4b1a86976 100644 --- a/prover/benches/logup_gkr.rs +++ b/prover/benches/logup_gkr.rs @@ -239,11 +239,11 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { } fn get_num_fractions(&self) -> usize { - 16 + 4 } fn max_degree(&self) -> usize { - 10 + 3 } fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) @@ -264,42 +264,20 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 16); - assert_eq!(denominator.len(), 16); + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); assert_eq!(query.len(), 5); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; numerator[2] = E::ONE; numerator[3] = E::ONE; - numerator[4] = E::from(query[1]); - numerator[5] = E::ONE; - numerator[6] = E::ONE; - numerator[7] = E::ONE; - numerator[8] = E::from(query[1]); - numerator[9] = E::ONE; - numerator[10] = E::ONE; - numerator[11] = E::ONE; - numerator[12] = E::from(query[1]); - numerator[13] = E::ONE; - numerator[14] = E::ONE; - numerator[15] = E::ONE; + denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); - denominator[4] = rand_values[0] - E::from(query[0]); - denominator[5] = -(rand_values[0] - E::from(query[2])); - denominator[6] = -(rand_values[0] - E::from(query[3])); - denominator[7] = -(rand_values[0] - E::from(query[4])); - denominator[8] = rand_values[0] - E::from(query[0]); - denominator[9] = -(rand_values[0] - E::from(query[2])); - denominator[10] = -(rand_values[0] - E::from(query[3])); - denominator[11] = -(rand_values[0] - E::from(query[4])); - denominator[12] = rand_values[0] - E::from(query[0]); - denominator[13] = -(rand_values[0] - E::from(query[2])); - denominator[14] = -(rand_values[0] - E::from(query[3])); - denominator[15] = -(rand_values[0] - E::from(query[4])); + } diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index 705a67918..5af31e0c6 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -13,4 +13,4 @@ mod error; pub use error::SumCheckProverError; #[cfg(feature = "concurrent")] -const MINIMAL_MLE_SIZE: usize = 1 << 4; \ No newline at end of file +const MINIMAL_MLE_SIZE: usize = 1 << 6; \ No newline at end of file From 1e576802d1971a678aabf1600ffa525005d660bb Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 24 Sep 2024 22:15:33 +0200 Subject: [PATCH 3/9] move threshold decision to GKR prover --- prover/src/logup_gkr/mod.rs | 3 + prover/src/logup_gkr/prover.rs | 18 ++- sumcheck/src/prover/mod.rs | 7 +- sumcheck/src/prover/plain.rs | 209 +++++++++++++++++++-------------- 4 files changed, 139 insertions(+), 98 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 4d2ee975a..c9ca5a16f 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -20,6 +20,9 @@ pub use utils::{ {chunks_mut, iter, iter_mut}, }; +#[cfg(feature = "concurrent")] +const MINIMAL_MLE_SIZE: usize = 1 << 4; + // EVALUATED CIRCUIT // ================================================================================================ diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 899482f22..fe55b77df 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -3,15 +3,19 @@ use alloc::vec::Vec; use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; +#[cfg(feature = "concurrent")] +use sumcheck::sumcheck_prove_plain_parallel; use sumcheck::{ - sum_check_prove_higher_degree, sumcheck_prove_plain, BeforeFinalLayerProof, CircuitOutput, - EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, + sum_check_prove_higher_degree, sumcheck_prove_plain_serial, BeforeFinalLayerProof, + CircuitOutput, EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; use tracing::instrument; use utils::iter; #[cfg(feature = "concurrent")] pub use utils::rayon::prelude::*; +#[cfg(feature = "concurrent")] +use super::MINIMAL_MLE_SIZE; use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; use crate::{matrix::ColMatrix, Trace}; @@ -265,7 +269,15 @@ fn sum_check_prove_num_rounds_degree_3< let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; let claim = claim.0 + claim.1 * r_batch; - let proof = sumcheck_prove_plain(claim, r_batch, p, q, eq, transcript)?; + #[cfg(feature = "concurrent")] + let proof = if p.num_evaluations() >= MINIMAL_MLE_SIZE { + sumcheck_prove_plain_parallel(claim, r_batch, p, q, eq, transcript)? + } else { + sumcheck_prove_plain_serial(claim, r_batch, p, q, eq, transcript)? + }; + + #[cfg(not(feature = "concurrent"))] + let proof = sumcheck_prove_plain_serial(claim, r_batch, p, q, eq, transcript)?; Ok(proof) } diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index 5af31e0c6..adaa63298 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -7,10 +7,9 @@ mod high_degree; pub use high_degree::sum_check_prove_higher_degree; mod plain; -pub use plain::sumcheck_prove_plain; +#[cfg(feature = "concurrent")] +pub use plain::sumcheck_prove_plain_parallel; +pub use plain::sumcheck_prove_plain_serial; mod error; pub use error::SumCheckProverError; - -#[cfg(feature = "concurrent")] -const MINIMAL_MLE_SIZE: usize = 1 << 6; \ No newline at end of file diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index f0343f65a..532458757 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -10,8 +10,6 @@ pub use rayon::prelude::*; use smallvec::smallvec; use super::SumCheckProverError; -#[cfg(feature = "concurrent")] -use super::MINIMAL_MLE_SIZE; use crate::{ comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof, @@ -49,8 +47,12 @@ use crate::{ /// Note that the degree of the non-linear composition polynomial is 3. /// /// [1]: https://eprint.iacr.org/2023/1284 +#[cfg(feature = "concurrent")] #[allow(clippy::too_many_arguments)] -pub fn sumcheck_prove_plain>( +pub fn sumcheck_prove_plain_parallel< + E: FieldElement, + H: ElementHasher, +>( mut claim: E, r_batch: E, p: MultiLinearPoly, @@ -69,8 +71,65 @@ pub fn sumcheck_prove_plain( - p0: &MultiLinearPoly, - p1: &MultiLinearPoly, - q0: &MultiLinearPoly, - q1: &MultiLinearPoly, - eq: &MultiLinearPoly, - len: usize, +#[allow(clippy::too_many_arguments)] +pub fn sumcheck_prove_plain_serial>( + mut claim: E, r_batch: E, -) -> (E, E, E) { - #[cfg(feature = "concurrent")] - let res = if p0.num_evaluations() >= MINIMAL_MLE_SIZE { - parallel(p0, p1, q0, q1, eq, len, r_batch) - } else { - serial(p0, p1, q0, q1, eq, len, r_batch) - }; - - #[cfg(not(feature = "concurrent"))] - let res = serial(p0, p1, q0, q1, eq, len, r_batch); - - res -} + p: MultiLinearPoly, + q: MultiLinearPoly, + eq: &mut MultiLinearPoly, + transcript: &mut impl RandomCoin, +) -> Result, SumCheckProverError> { + let mut round_proofs = vec![]; -fn serial( - p0: &MultiLinearPoly, - p1: &MultiLinearPoly, - q0: &MultiLinearPoly, - q1: &MultiLinearPoly, - eq: &MultiLinearPoly, - len: usize, - r_batch: E, -) -> (E, E, E) { - (0..len).fold((E::ZERO, E::ZERO, E::ZERO), |(acc_point_1, acc_point_2, acc_point_3), i| { - let round_poly_eval_at_1 = comb_func( - p0[2 * i + 1], - p1[2 * i + 1], - q0[2 * i + 1], - q1[2 * i + 1], - eq[2 * i + 1], - r_batch, - ); + let mut challenges = vec![]; - let p0_delta = p0[2 * i + 1] - p0[2 * i]; - let p1_delta = p1[2 * i + 1] - p1[2 * i]; - let q0_delta = q0[2 * i + 1] - q0[2 * i]; - let q1_delta = q1[2 * i + 1] - q1[2 * i]; - let eq_delta = eq[2 * i + 1] - eq[2 * i]; - - let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; - let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; - let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; - let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; - let mut eq_evx = eq[2 * i + 1] + eq_delta; - let round_poly_eval_at_2 = - comb_func(p0_eval_at_x, p1_eval_at_x, q0_eval_at_x, q1_eval_at_x, eq_evx, r_batch); - - p0_eval_at_x += p0_delta; - p1_eval_at_x += p1_delta; - q0_eval_at_x += q0_delta; - q1_eval_at_x += q1_delta; - eq_evx += eq_delta; - let round_poly_eval_at_3 = - comb_func(p0_eval_at_x, p1_eval_at_x, q0_eval_at_x, q1_eval_at_x, eq_evx, r_batch); - - ( - round_poly_eval_at_1 + acc_point_1, - round_poly_eval_at_2 + acc_point_2, - round_poly_eval_at_3 + acc_point_3, - ) - }) -} + // construct the vector of multi-linear polynomials + let (mut p0, mut p1) = p.project_least_significant_variable(); + let (mut q0, mut q1) = q.project_least_significant_variable(); -#[cfg(feature = "concurrent")] -fn parallel( - p0: &MultiLinearPoly, - p1: &MultiLinearPoly, - q0: &MultiLinearPoly, - q1: &MultiLinearPoly, - eq: &MultiLinearPoly, - len: usize, - r_batch: E, -) -> (E, E, E) { - (0..len) - .into_par_iter() - .fold( - || (E::ZERO, E::ZERO, E::ZERO), + for _ in 0..p0.num_variables() { + let len = p0.num_evaluations() / 2; + + let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len).fold( + (E::ZERO, E::ZERO, E::ZERO), |(acc_point_1, acc_point_2, acc_point_3), i| { let round_poly_eval_at_1 = comb_func( p0[2 * i + 1], @@ -243,9 +239,40 @@ fn parallel( round_poly_eval_at_3 + acc_point_3, ) }, - ) - .reduce( - || (E::ZERO, E::ZERO, E::ZERO), - |(a0, b0, c0), (a1, b1, c1)| (a0 + a1, b0 + b1, c0 + c1), - ) + ); + + let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; + let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); + let compressed_round_poly = compressed_round_poly_evals.to_poly(claim); + + // reseed with the s_i polynomial + transcript.reseed(H::hash_elements(&compressed_round_poly.0)); + let round_proof = RoundProof { + round_poly_coefs: compressed_round_poly.clone(), + }; + + let round_challenge = + transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + // fold each multi-linear using the round challenge + p0.bind_least_significant_variable(round_challenge); + p1.bind_least_significant_variable(round_challenge); + q0.bind_least_significant_variable(round_challenge); + q1.bind_least_significant_variable(round_challenge); + eq.bind_least_significant_variable(round_challenge); + + // compute the new reduced round claim + claim = compressed_round_poly.evaluate_using_claim(&claim, &round_challenge); + + round_proofs.push(round_proof); + challenges.push(round_challenge); + } + + Ok(SumCheckProof { + openings_claim: FinalOpeningClaim { + eval_point: challenges, + openings: vec![p0[0], p1[0], q0[0], q1[0]], + }, + round_proofs, + }) } From b32f915dac2f275141fe3fe194cc432252c02a77 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 09:20:33 +0200 Subject: [PATCH 4/9] feat: add targeted benchmark and handling of edge case in sum-check --- prover/benches/logup_gkr.rs | 3 - prover/src/logup_gkr/mod.rs | 4 +- sumcheck/src/lib.rs | 2 + sumcheck/src/prover/mod.rs | 3 + sumcheck/src/prover/plain.rs | 238 +++++++++++++++++++++-------------- 5 files changed, 151 insertions(+), 99 deletions(-) diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs index 4b1a86976..e86c84aef 100644 --- a/prover/benches/logup_gkr.rs +++ b/prover/benches/logup_gkr.rs @@ -272,13 +272,10 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { numerator[2] = E::ONE; numerator[3] = E::ONE; - denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); - - } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index c9ca5a16f..fc4c4cc53 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -21,7 +21,9 @@ pub use utils::{ }; #[cfg(feature = "concurrent")] -const MINIMAL_MLE_SIZE: usize = 1 << 4; +use sumcheck::LOG_MIN_MLE_SIZE; +#[cfg(feature = "concurrent")] +const MINIMAL_MLE_SIZE: usize = 1 << (LOG_MIN_MLE_SIZE + 2); // EVALUATED CIRCUIT // ================================================================================================ diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index b7f670a9d..b6b9d6aad 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -15,6 +15,8 @@ extern crate alloc; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; +#[cfg(feature = "concurrent")] +pub use prover::LOG_MIN_MLE_SIZE; mod prover; pub use prover::*; diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index adaa63298..f2467ecf1 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -13,3 +13,6 @@ pub use plain::sumcheck_prove_plain_serial; mod error; pub use error::SumCheckProverError; + +//#[cfg(feature = "concurrent")] +pub const LOG_MIN_MLE_SIZE: usize = 2; \ No newline at end of file diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 532458757..192a97f6f 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -10,6 +10,8 @@ pub use rayon::prelude::*; use smallvec::smallvec; use super::SumCheckProverError; +#[cfg(feature = "concurrent")] +use super::LOG_MIN_MLE_SIZE; use crate::{ comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof, @@ -68,68 +70,12 @@ pub fn sumcheck_prove_plain_parallel< let (mut p0, mut p1) = p.project_least_significant_variable(); let (mut q0, mut q1) = q.project_least_significant_variable(); + //for _ in 0..p0.num_variables() - LOG_MIN_MLE_SIZE { for _ in 0..p0.num_variables() { let len = p0.num_evaluations() / 2; - let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len) - .into_par_iter() - .fold( - || (E::ZERO, E::ZERO, E::ZERO), - |(acc_point_1, acc_point_2, acc_point_3), i| { - let round_poly_eval_at_1 = comb_func( - p0[2 * i + 1], - p1[2 * i + 1], - q0[2 * i + 1], - q1[2 * i + 1], - eq[2 * i + 1], - r_batch, - ); - - let p0_delta = p0[2 * i + 1] - p0[2 * i]; - let p1_delta = p1[2 * i + 1] - p1[2 * i]; - let q0_delta = q0[2 * i + 1] - q0[2 * i]; - let q1_delta = q1[2 * i + 1] - q1[2 * i]; - let eq_delta = eq[2 * i + 1] - eq[2 * i]; - - let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; - let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; - let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; - let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; - let mut eq_evx = eq[2 * i + 1] + eq_delta; - let round_poly_eval_at_2 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, - r_batch, - ); - - p0_eval_at_x += p0_delta; - p1_eval_at_x += p1_delta; - q0_eval_at_x += q0_delta; - q1_eval_at_x += q1_delta; - eq_evx += eq_delta; - let round_poly_eval_at_3 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, - r_batch, - ); - - ( - round_poly_eval_at_1 + acc_point_1, - round_poly_eval_at_2 + acc_point_2, - round_poly_eval_at_3 + acc_point_3, - ) - }, - ) - .reduce( - || (E::ZERO, E::ZERO, E::ZERO), - |(a0, b0, c0), (a1, b1, c1)| (a0 + a1, b0 + b1, c0 + c1), - ); + let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = + core_parallel(&p0, &p1, &q0, &q1, eq, r_batch, len); let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); @@ -157,7 +103,40 @@ pub fn sumcheck_prove_plain_parallel< round_proofs.push(round_proof); challenges.push(round_challenge); } - + /* + for _ in p0.num_variables() - LOG_MIN_MLE_SIZE..p0.num_variables() { + let len = p0.num_evaluations() / 2; + + let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = + core_serial(&p0, &p1, &q0, &q1, eq, r_batch, len); + + let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; + let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); + let compressed_round_poly = compressed_round_poly_evals.to_poly(claim); + + // reseed with the s_i polynomial + transcript.reseed(H::hash_elements(&compressed_round_poly.0)); + let round_proof = RoundProof { + round_poly_coefs: compressed_round_poly.clone(), + }; + + let round_challenge = + transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + // fold each multi-linear using the round challenge + p0.bind_least_significant_variable(round_challenge); + p1.bind_least_significant_variable(round_challenge); + q0.bind_least_significant_variable(round_challenge); + q1.bind_least_significant_variable(round_challenge); + eq.bind_least_significant_variable(round_challenge); + + // compute the new reduced round claim + claim = compressed_round_poly.evaluate_using_claim(&claim, &round_challenge); + + round_proofs.push(round_proof); + challenges.push(round_challenge); + } + */ Ok(SumCheckProof { openings_claim: FinalOpeningClaim { eval_point: challenges, @@ -187,8 +166,59 @@ pub fn sumcheck_prove_plain_serial( + p0: &MultiLinearPoly, + p1: &MultiLinearPoly, + q0: &MultiLinearPoly, + q1: &MultiLinearPoly, + eq: &MultiLinearPoly, + r_batch: E, + len: usize, +) -> (E, E, E) { + (0..len) + .into_par_iter() + .fold( + || (E::ZERO, E::ZERO, E::ZERO), |(acc_point_1, acc_point_2, acc_point_3), i| { let round_poly_eval_at_1 = comb_func( p0[2 * i + 1], @@ -239,40 +269,58 @@ pub fn sumcheck_prove_plain_serial( + p0: &MultiLinearPoly, + p1: &MultiLinearPoly, + q0: &MultiLinearPoly, + q1: &MultiLinearPoly, + eq: &MultiLinearPoly, + r_batch: E, + len: usize, +) -> (E, E, E) { + (0..len).fold((E::ZERO, E::ZERO, E::ZERO), |(acc_point_1, acc_point_2, acc_point_3), i| { + let round_poly_eval_at_1 = comb_func( + p0[2 * i + 1], + p1[2 * i + 1], + q0[2 * i + 1], + q1[2 * i + 1], + eq[2 * i + 1], + r_batch, + ); - Ok(SumCheckProof { - openings_claim: FinalOpeningClaim { - eval_point: challenges, - openings: vec![p0[0], p1[0], q0[0], q1[0]], - }, - round_proofs, + let p0_delta = p0[2 * i + 1] - p0[2 * i]; + let p1_delta = p1[2 * i + 1] - p1[2 * i]; + let q0_delta = q0[2 * i + 1] - q0[2 * i]; + let q1_delta = q1[2 * i + 1] - q1[2 * i]; + let eq_delta = eq[2 * i + 1] - eq[2 * i]; + + let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; + let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; + let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; + let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; + let mut eq_evx = eq[2 * i + 1] + eq_delta; + let round_poly_eval_at_2 = + comb_func(p0_eval_at_x, p1_eval_at_x, q0_eval_at_x, q1_eval_at_x, eq_evx, r_batch); + + p0_eval_at_x += p0_delta; + p1_eval_at_x += p1_delta; + q0_eval_at_x += q0_delta; + q1_eval_at_x += q1_delta; + eq_evx += eq_delta; + let round_poly_eval_at_3 = + comb_func(p0_eval_at_x, p1_eval_at_x, q0_eval_at_x, q1_eval_at_x, eq_evx, r_batch); + + ( + round_poly_eval_at_1 + acc_point_1, + round_poly_eval_at_2 + acc_point_2, + round_poly_eval_at_3 + acc_point_3, + ) }) } From 8179d44cc4751af3cebbba429010176be5f0a92f Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 09:31:07 +0200 Subject: [PATCH 5/9] chore: restore previous plain sum-check implementation --- prover/src/logup_gkr/mod.rs | 5 - prover/src/logup_gkr/prover.rs | 18 +- sumcheck/src/lib.rs | 2 - sumcheck/src/prover/mod.rs | 7 +- sumcheck/src/prover/plain.rs | 292 ++++++++++----------------------- 5 files changed, 95 insertions(+), 229 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index fc4c4cc53..4d2ee975a 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -20,11 +20,6 @@ pub use utils::{ {chunks_mut, iter, iter_mut}, }; -#[cfg(feature = "concurrent")] -use sumcheck::LOG_MIN_MLE_SIZE; -#[cfg(feature = "concurrent")] -const MINIMAL_MLE_SIZE: usize = 1 << (LOG_MIN_MLE_SIZE + 2); - // EVALUATED CIRCUIT // ================================================================================================ diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index fe55b77df..899482f22 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -3,19 +3,15 @@ use alloc::vec::Vec; use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; -#[cfg(feature = "concurrent")] -use sumcheck::sumcheck_prove_plain_parallel; use sumcheck::{ - sum_check_prove_higher_degree, sumcheck_prove_plain_serial, BeforeFinalLayerProof, - CircuitOutput, EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, + sum_check_prove_higher_degree, sumcheck_prove_plain, BeforeFinalLayerProof, CircuitOutput, + EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; use tracing::instrument; use utils::iter; #[cfg(feature = "concurrent")] pub use utils::rayon::prelude::*; -#[cfg(feature = "concurrent")] -use super::MINIMAL_MLE_SIZE; use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; use crate::{matrix::ColMatrix, Trace}; @@ -269,15 +265,7 @@ fn sum_check_prove_num_rounds_degree_3< let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; let claim = claim.0 + claim.1 * r_batch; - #[cfg(feature = "concurrent")] - let proof = if p.num_evaluations() >= MINIMAL_MLE_SIZE { - sumcheck_prove_plain_parallel(claim, r_batch, p, q, eq, transcript)? - } else { - sumcheck_prove_plain_serial(claim, r_batch, p, q, eq, transcript)? - }; - - #[cfg(not(feature = "concurrent"))] - let proof = sumcheck_prove_plain_serial(claim, r_batch, p, q, eq, transcript)?; + let proof = sumcheck_prove_plain(claim, r_batch, p, q, eq, transcript)?; Ok(proof) } diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index b6b9d6aad..b7f670a9d 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -15,8 +15,6 @@ extern crate alloc; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; -#[cfg(feature = "concurrent")] -pub use prover::LOG_MIN_MLE_SIZE; mod prover; pub use prover::*; diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index f2467ecf1..13d35e551 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -7,12 +7,7 @@ mod high_degree; pub use high_degree::sum_check_prove_higher_degree; mod plain; -#[cfg(feature = "concurrent")] -pub use plain::sumcheck_prove_plain_parallel; -pub use plain::sumcheck_prove_plain_serial; +pub use plain::sumcheck_prove_plain; mod error; pub use error::SumCheckProverError; - -//#[cfg(feature = "concurrent")] -pub const LOG_MIN_MLE_SIZE: usize = 2; \ No newline at end of file diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 192a97f6f..e0092cf10 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -10,8 +10,6 @@ pub use rayon::prelude::*; use smallvec::smallvec; use super::SumCheckProverError; -#[cfg(feature = "concurrent")] -use super::LOG_MIN_MLE_SIZE; use crate::{ comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof, @@ -49,12 +47,8 @@ use crate::{ /// Note that the degree of the non-linear composition polynomial is 3. /// /// [1]: https://eprint.iacr.org/2023/1284 -#[cfg(feature = "concurrent")] #[allow(clippy::too_many_arguments)] -pub fn sumcheck_prove_plain_parallel< - E: FieldElement, - H: ElementHasher, ->( +pub fn sumcheck_prove_plain>( mut claim: E, r_batch: E, p: MultiLinearPoly, @@ -70,155 +64,12 @@ pub fn sumcheck_prove_plain_parallel< let (mut p0, mut p1) = p.project_least_significant_variable(); let (mut q0, mut q1) = q.project_least_significant_variable(); - //for _ in 0..p0.num_variables() - LOG_MIN_MLE_SIZE { for _ in 0..p0.num_variables() { let len = p0.num_evaluations() / 2; - let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = - core_parallel(&p0, &p1, &q0, &q1, eq, r_batch, len); - - let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; - let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); - let compressed_round_poly = compressed_round_poly_evals.to_poly(claim); - - // reseed with the s_i polynomial - transcript.reseed(H::hash_elements(&compressed_round_poly.0)); - let round_proof = RoundProof { - round_poly_coefs: compressed_round_poly.clone(), - }; - - let round_challenge = - transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; - - // fold each multi-linear using the round challenge - p0.bind_least_significant_variable(round_challenge); - p1.bind_least_significant_variable(round_challenge); - q0.bind_least_significant_variable(round_challenge); - q1.bind_least_significant_variable(round_challenge); - eq.bind_least_significant_variable(round_challenge); - - // compute the new reduced round claim - claim = compressed_round_poly.evaluate_using_claim(&claim, &round_challenge); - - round_proofs.push(round_proof); - challenges.push(round_challenge); - } - /* - for _ in p0.num_variables() - LOG_MIN_MLE_SIZE..p0.num_variables() { - let len = p0.num_evaluations() / 2; - - let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = - core_serial(&p0, &p1, &q0, &q1, eq, r_batch, len); - - let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; - let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); - let compressed_round_poly = compressed_round_poly_evals.to_poly(claim); - - // reseed with the s_i polynomial - transcript.reseed(H::hash_elements(&compressed_round_poly.0)); - let round_proof = RoundProof { - round_poly_coefs: compressed_round_poly.clone(), - }; - - let round_challenge = - transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; - - // fold each multi-linear using the round challenge - p0.bind_least_significant_variable(round_challenge); - p1.bind_least_significant_variable(round_challenge); - q0.bind_least_significant_variable(round_challenge); - q1.bind_least_significant_variable(round_challenge); - eq.bind_least_significant_variable(round_challenge); - - // compute the new reduced round claim - claim = compressed_round_poly.evaluate_using_claim(&claim, &round_challenge); - - round_proofs.push(round_proof); - challenges.push(round_challenge); - } - */ - Ok(SumCheckProof { - openings_claim: FinalOpeningClaim { - eval_point: challenges, - openings: vec![p0[0], p1[0], q0[0], q1[0]], - }, - round_proofs, - }) -} - -#[allow(clippy::too_many_arguments)] -pub fn sumcheck_prove_plain_serial>( - mut claim: E, - r_batch: E, - p: MultiLinearPoly, - q: MultiLinearPoly, - eq: &mut MultiLinearPoly, - transcript: &mut impl RandomCoin, -) -> Result, SumCheckProverError> { - let mut round_proofs = vec![]; - - let mut challenges = vec![]; - - // construct the vector of multi-linear polynomials - let (mut p0, mut p1) = p.project_least_significant_variable(); - let (mut q0, mut q1) = q.project_least_significant_variable(); - - for _ in 0..p0.num_variables() { - let len = p0.num_evaluations() / 2; - - let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = - core_serial(&p0, &p1, &q0, &q1, eq, r_batch, len); - - let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; - let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); - let compressed_round_poly = compressed_round_poly_evals.to_poly(claim); - - // reseed with the s_i polynomial - transcript.reseed(H::hash_elements(&compressed_round_poly.0)); - let round_proof = RoundProof { - round_poly_coefs: compressed_round_poly.clone(), - }; - - let round_challenge = - transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; - - // fold each multi-linear using the round challenge - p0.bind_least_significant_variable(round_challenge); - p1.bind_least_significant_variable(round_challenge); - q0.bind_least_significant_variable(round_challenge); - q1.bind_least_significant_variable(round_challenge); - eq.bind_least_significant_variable(round_challenge); - - // compute the new reduced round claim - claim = compressed_round_poly.evaluate_using_claim(&claim, &round_challenge); - - round_proofs.push(round_proof); - challenges.push(round_challenge); - } - - Ok(SumCheckProof { - openings_claim: FinalOpeningClaim { - eval_point: challenges, - openings: vec![p0[0], p1[0], q0[0], q1[0]], - }, - round_proofs, - }) -} - -#[cfg(feature = "concurrent")] -fn core_parallel( - p0: &MultiLinearPoly, - p1: &MultiLinearPoly, - q0: &MultiLinearPoly, - q1: &MultiLinearPoly, - eq: &MultiLinearPoly, - r_batch: E, - len: usize, -) -> (E, E, E) { - (0..len) - .into_par_iter() - .fold( - || (E::ZERO, E::ZERO, E::ZERO), + #[cfg(not(feature = "concurrent"))] + let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len).fold( + (E::ZERO, E::ZERO, E::ZERO), |(acc_point_1, acc_point_2, acc_point_3), i| { let round_poly_eval_at_1 = comb_func( p0[2 * i + 1], @@ -269,58 +120,97 @@ fn core_parallel( round_poly_eval_at_3 + acc_point_3, ) }, - ) - .reduce( - || (E::ZERO, E::ZERO, E::ZERO), - |(a0, b0, c0), (a1, b1, c1)| (a0 + a1, b0 + b1, c0 + c1), - ) -} - -fn core_serial( - p0: &MultiLinearPoly, - p1: &MultiLinearPoly, - q0: &MultiLinearPoly, - q1: &MultiLinearPoly, - eq: &MultiLinearPoly, - r_batch: E, - len: usize, -) -> (E, E, E) { - (0..len).fold((E::ZERO, E::ZERO, E::ZERO), |(acc_point_1, acc_point_2, acc_point_3), i| { - let round_poly_eval_at_1 = comb_func( - p0[2 * i + 1], - p1[2 * i + 1], - q0[2 * i + 1], - q1[2 * i + 1], - eq[2 * i + 1], - r_batch, ); - let p0_delta = p0[2 * i + 1] - p0[2 * i]; - let p1_delta = p1[2 * i + 1] - p1[2 * i]; - let q0_delta = q0[2 * i + 1] - q0[2 * i]; - let q1_delta = q1[2 * i + 1] - q1[2 * i]; - let eq_delta = eq[2 * i + 1] - eq[2 * i]; + #[cfg(feature = "concurrent")] + let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len) + .into_par_iter() + .fold( + || (E::ZERO, E::ZERO, E::ZERO), + |(a, b, c), i| { + let round_poly_eval_at_1 = comb_func( + p0[2 * i + 1], + p1[2 * i + 1], + q0[2 * i + 1], + q1[2 * i + 1], + eq[2 * i + 1], + r_batch, + ); + + let p0_delta = p0[2 * i + 1] - p0[2 * i]; + let p1_delta = p1[2 * i + 1] - p1[2 * i]; + let q0_delta = q0[2 * i + 1] - q0[2 * i]; + let q1_delta = q1[2 * i + 1] - q1[2 * i]; + let eq_delta = eq[2 * i + 1] - eq[2 * i]; + + let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; + let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; + let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; + let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; + let mut eq_evx = eq[2 * i + 1] + eq_delta; + let round_poly_eval_at_2 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + p0_eval_at_x += p0_delta; + p1_eval_at_x += p1_delta; + q0_eval_at_x += q0_delta; + q1_eval_at_x += q1_delta; + eq_evx += eq_delta; + let round_poly_eval_at_3 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + (round_poly_eval_at_1 + a, round_poly_eval_at_2 + b, round_poly_eval_at_3 + c) + }, + ) + .reduce( + || (E::ZERO, E::ZERO, E::ZERO), + |(a0, b0, c0), (a1, b1, c1)| (a0 + a1, b0 + b1, c0 + c1), + ); + + let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; + let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); + let compressed_round_poly = compressed_round_poly_evals.to_poly(claim); + + // reseed with the s_i polynomial + transcript.reseed(H::hash_elements(&compressed_round_poly.0)); + let round_proof = RoundProof { + round_poly_coefs: compressed_round_poly.clone(), + }; + + let round_challenge = + transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + // fold each multi-linear using the round challenge + p0.bind_least_significant_variable(round_challenge); + p1.bind_least_significant_variable(round_challenge); + q0.bind_least_significant_variable(round_challenge); + q1.bind_least_significant_variable(round_challenge); + eq.bind_least_significant_variable(round_challenge); - let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; - let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; - let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; - let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; - let mut eq_evx = eq[2 * i + 1] + eq_delta; - let round_poly_eval_at_2 = - comb_func(p0_eval_at_x, p1_eval_at_x, q0_eval_at_x, q1_eval_at_x, eq_evx, r_batch); + // compute the new reduced round claim + claim = compressed_round_poly.evaluate_using_claim(&claim, &round_challenge); - p0_eval_at_x += p0_delta; - p1_eval_at_x += p1_delta; - q0_eval_at_x += q0_delta; - q1_eval_at_x += q1_delta; - eq_evx += eq_delta; - let round_poly_eval_at_3 = - comb_func(p0_eval_at_x, p1_eval_at_x, q0_eval_at_x, q1_eval_at_x, eq_evx, r_batch); + round_proofs.push(round_proof); + challenges.push(round_challenge); + } - ( - round_poly_eval_at_1 + acc_point_1, - round_poly_eval_at_2 + acc_point_2, - round_poly_eval_at_3 + acc_point_3, - ) + Ok(SumCheckProof { + openings_claim: FinalOpeningClaim { + eval_point: challenges, + openings: vec![p0[0], p1[0], q0[0], q1[0]], + }, + round_proofs, }) } From 812a97bdb3543db4d04f0a88e14091dc0914fde6 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Fri, 20 Sep 2024 19:11:15 +0200 Subject: [PATCH 6/9] feat: improve way construct mles --- prover/src/logup_gkr/prover.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 899482f22..67ccf8713 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -8,9 +8,9 @@ use sumcheck::{ EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; use tracing::instrument; -use utils::iter; #[cfg(feature = "concurrent")] -pub use utils::rayon::prelude::*; +use utils::rayon::prelude::*; +use utils::{iter, iter_mut, uninit_vector}; use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; use crate::{matrix::ColMatrix, Trace}; @@ -144,7 +144,7 @@ fn build_mls_from_main_trace_segment( oracles: &[LogUpGkrOracle], main_trace: &ColMatrix<::BaseField>, ) -> Result>, GkrProverError> { - let mut mls = vec![]; + let mut mls = Vec::with_capacity(oracles.len()); for oracle in oracles { match oracle { @@ -156,8 +156,12 @@ fn build_mls_from_main_trace_segment( }, LogUpGkrOracle::NextRow(index) => { let col = main_trace.get_column(*index); - let mut values: Vec = col.iter().map(|value| E::from(*value)).collect(); - values.rotate_left(1); + + let mut values: Vec = unsafe { uninit_vector(col.len()) }; + values[col.len() - 1] = E::from(col[0]); + iter_mut!(&mut values[..col.len() - 1]) + .enumerate() + .for_each(|(i, value)| *value = E::from(col[i + 1])); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, From dee3c6240f1a34de94ed1b91cdd4c697c39e6a38 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Sun, 22 Sep 2024 15:44:19 +0200 Subject: [PATCH 7/9] wip: bench within vs across --- prover/Cargo.toml | 8 ++++++-- prover/src/logup_gkr/prover.rs | 30 ++++++++++++++++++++++++++---- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/prover/Cargo.toml b/prover/Cargo.toml index f75acdc8e..00bf549a9 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -16,7 +16,7 @@ rust-version = "1.78" bench = false [[bench]] -name = "logup_gkr" +name = "build_mle" harness = false [[bench]] @@ -28,7 +28,11 @@ name = "row_matrix" harness = false [[bench]] -name = "lagrange_kernel" +name = "logup_gkr" +harness = false + +[[bench]] +name = "row_matrix" harness = false [features] diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 67ccf8713..81126c697 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -146,11 +146,11 @@ fn build_mls_from_main_trace_segment( ) -> Result>, GkrProverError> { let mut mls = Vec::with_capacity(oracles.len()); - for oracle in oracles { + iter!(oracles).for_each(|oracle| { match oracle { LogUpGkrOracle::CurrentRow(index) => { let col = main_trace.get_column(*index); - let values: Vec = iter!(col).map(|value| E::from(*value)).collect(); + let values: Vec = col.iter().map(|value| E::from(*value)).collect(); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, @@ -159,14 +159,36 @@ fn build_mls_from_main_trace_segment( let mut values: Vec = unsafe { uninit_vector(col.len()) }; values[col.len() - 1] = E::from(col[0]); - iter_mut!(&mut values[..col.len() - 1]) + (&mut values[..col.len() - 1]) + .iter_mut() .enumerate() .for_each(|(i, value)| *value = E::from(col[i + 1])); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, }; - } + }); + //for oracle in oracles { + //match oracle { + //LogUpGkrOracle::CurrentRow(index) => { + //let col = main_trace.get_column(*index); + //let values: Vec = iter!(col).map(|value| E::from(*value)).collect(); + //let ml = MultiLinearPoly::from_evaluations(values); + //mls.push(ml) + //}, + //LogUpGkrOracle::NextRow(index) => { + //let col = main_trace.get_column(*index); + + //let mut values: Vec = unsafe { uninit_vector(col.len()) }; + //values[col.len() - 1] = E::from(col[0]); + //iter_mut!(&mut values[..col.len() - 1]) + //.enumerate() + //.for_each(|(i, value)| *value = E::from(col[i + 1])); + //let ml = MultiLinearPoly::from_evaluations(values); + //mls.push(ml) + //}, + //}; + //} Ok(mls) } From 654b7bcaf889df2a32c36a4d4d8d4bd0d09ee777 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:31:00 +0200 Subject: [PATCH 8/9] feat: go with the within method for building MLEs from main trace segment --- prover/Cargo.toml | 4 ---- prover/src/logup_gkr/prover.rs | 35 +++++++--------------------------- 2 files changed, 7 insertions(+), 32 deletions(-) diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 00bf549a9..32d65a158 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -23,10 +23,6 @@ harness = false name = "logup_gkr_e2e" harness = false -[[bench]] -name = "row_matrix" -harness = false - [[bench]] name = "logup_gkr" harness = false diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 81126c697..0da61fbd8 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -81,7 +81,7 @@ pub fn prove_gkr( // build the MLEs of the relevant main trace columns let main_trace_mls = - build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; + build_mle_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; // build the periodic table representing periodic columns as multi-linear extensions let periodic_table = evaluator.build_periodic_values(); @@ -140,17 +140,17 @@ fn prove_input_layer< /// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for /// LogUp-GKR. #[instrument(skip_all)] -fn build_mls_from_main_trace_segment( +fn build_mle_from_main_trace_segment( oracles: &[LogUpGkrOracle], main_trace: &ColMatrix<::BaseField>, ) -> Result>, GkrProverError> { let mut mls = Vec::with_capacity(oracles.len()); - iter!(oracles).for_each(|oracle| { + for oracle in oracles { match oracle { LogUpGkrOracle::CurrentRow(index) => { let col = main_trace.get_column(*index); - let values: Vec = col.iter().map(|value| E::from(*value)).collect(); + let values: Vec = iter!(col).map(|value| E::from(*value)).collect(); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, @@ -159,36 +159,15 @@ fn build_mls_from_main_trace_segment( let mut values: Vec = unsafe { uninit_vector(col.len()) }; values[col.len() - 1] = E::from(col[0]); - (&mut values[..col.len() - 1]) - .iter_mut() + iter_mut!(&mut values[..col.len() - 1]) .enumerate() .for_each(|(i, value)| *value = E::from(col[i + 1])); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, }; - }); - //for oracle in oracles { - //match oracle { - //LogUpGkrOracle::CurrentRow(index) => { - //let col = main_trace.get_column(*index); - //let values: Vec = iter!(col).map(|value| E::from(*value)).collect(); - //let ml = MultiLinearPoly::from_evaluations(values); - //mls.push(ml) - //}, - //LogUpGkrOracle::NextRow(index) => { - //let col = main_trace.get_column(*index); - - //let mut values: Vec = unsafe { uninit_vector(col.len()) }; - //values[col.len() - 1] = E::from(col[0]); - //iter_mut!(&mut values[..col.len() - 1]) - //.enumerate() - //.for_each(|(i, value)| *value = E::from(col[i + 1])); - //let ml = MultiLinearPoly::from_evaluations(values); - //mls.push(ml) - //}, - //}; - //} + } + Ok(mls) } From 0fc8fd8cf585b8bf407579ce6f79e3b825f231eb Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:35:37 +0200 Subject: [PATCH 9/9] chore: remove unused bench --- prover/Cargo.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 32d65a158..125ab2a4f 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -15,10 +15,6 @@ rust-version = "1.78" [lib] bench = false -[[bench]] -name = "build_mle" -harness = false - [[bench]] name = "logup_gkr_e2e" harness = false