Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shout d=1 #568

Merged
merged 11 commits into from
Jan 29, 2025
47 changes: 47 additions & 0 deletions jolt-core/src/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ use crate::jolt::vm::Jolt;
use crate::poly::commitment::commitment_scheme::CommitmentScheme;
use crate::poly::commitment::hyperkzg::HyperKZG;
use crate::poly::commitment::zeromorph::Zeromorph;
use crate::subprotocols::shout::ShoutProof;
use crate::utils::math::Math;
use crate::utils::transcript::{KeccakTranscript, Transcript};
use ark_bn254::{Bn254, Fr};
use ark_std::test_rng;
use rand_core::RngCore;
use serde::Serialize;

#[derive(Debug, Copy, Clone, clap::ValueEnum)]
Expand All @@ -21,6 +25,7 @@ pub enum BenchType {
Sha2,
Sha3,
Sha2Chain,
Shout,
}

#[allow(unreachable_patterns)] // good errors on new BenchTypes
Expand All @@ -41,6 +46,7 @@ pub fn benchmarks(
BenchType::Fibonacci => {
fibonacci::<Fr, Zeromorph<Bn254, KeccakTranscript>, KeccakTranscript>()
}
BenchType::Shout => shout::<Fr, KeccakTranscript>(),
_ => panic!("BenchType does not have a mapping"),
},
PCSType::HyperKZG => match bench_type {
Expand All @@ -52,12 +58,53 @@ pub fn benchmarks(
BenchType::Fibonacci => {
fibonacci::<Fr, HyperKZG<Bn254, KeccakTranscript>, KeccakTranscript>()
}
BenchType::Shout => shout::<Fr, KeccakTranscript>(),
_ => panic!("BenchType does not have a mapping"),
},
_ => panic!("PCS Type does not have a mapping"),
}
}

fn shout<F, ProofTranscript>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
ProofTranscript: Transcript,
{
let small_value_lookup_tables = F::compute_lookup_tables();
F::initialize_lookup_tables(small_value_lookup_tables);

let mut tasks = Vec::new();

const TABLE_SIZE: usize = 1 << 16;
const NUM_LOOKUPS: usize = 1 << 20;

let mut rng = test_rng();

let lookup_table: Vec<F> = (0..TABLE_SIZE).map(|_| F::random(&mut rng)).collect();
let read_addresses: Vec<usize> = (0..NUM_LOOKUPS)
.map(|_| rng.next_u32() as usize % TABLE_SIZE)
.collect();

let mut prover_transcript = ProofTranscript::new(b"test_transcript");
let r_cycle: Vec<F> = prover_transcript.challenge_vector(NUM_LOOKUPS.log_2());

let task = move || {
let _proof = ShoutProof::prove(
lookup_table,
read_addresses,
&r_cycle,
&mut prover_transcript,
);
};

tasks.push((
tracing::info_span!("Shout d=1"),
Box::new(task) as Box<dyn FnOnce()>,
));

tasks
}

fn fibonacci<F, PCS, ProofTranscript>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/jolt/vm/read_write_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::lasso::memory_checking::{
ExogenousOpenings, Initializable, StructuredPolynomialData, VerifierComputedOpening,
};
use crate::poly::compact_polynomial::{CompactPolynomial, SmallScalar};
use crate::poly::multilinear_polynomial::MultilinearPolynomial;
use crate::poly::multilinear_polynomial::{MultilinearPolynomial, PolynomialEvaluation};
use crate::poly::opening_proof::{ProverOpeningAccumulator, VerifierOpeningAccumulator};
use crate::utils::thread::unsafe_allocate_zero_vec;
use rayon::prelude::*;
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/lasso/surge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
poly::{
commitment::commitment_scheme::BatchType,
compact_polynomial::{CompactPolynomial, SmallScalar},
multilinear_polynomial::MultilinearPolynomial,
multilinear_polynomial::{MultilinearPolynomial, PolynomialEvaluation},
opening_proof::{ProverOpeningAccumulator, VerifierOpeningAccumulator},
},
};
Expand Down
67 changes: 67 additions & 0 deletions jolt-core/src/poly/compact_polynomial.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::ops::Index;

use crate::utils::math::Math;
use crate::utils::thread::unsafe_allocate_zero_vec;
use crate::{field::JoltField, utils};
use ark_serialize::{
CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate,
};
use num_integer::Integer;
use rayon::prelude::*;

use super::multilinear_polynomial::{BindingOrder, PolynomialBinding};

Expand Down Expand Up @@ -211,6 +213,71 @@ impl<T: SmallScalar, F: JoltField> PolynomialBinding<F> for CompactPolynomial<T,
self.len = n;
}

#[tracing::instrument(skip_all, name = "CompactPolynomial::bind")]
fn bind_parallel(&mut self, r: F, order: BindingOrder) {
let n = self.len() / 2;
if self.is_bound() {
match order {
BindingOrder::LowToHigh => {
// TODO(moodlezoup): Use `binding_scratch_space` trick
let mut new_coeffs = unsafe_allocate_zero_vec(n);
for i in 0..n {
if self.bound_coeffs[2 * i + 1] == self.bound_coeffs[2 * i] {
new_coeffs[i] = self.bound_coeffs[2 * i];
} else {
new_coeffs[i] = self.bound_coeffs[2 * i]
+ r * (self.bound_coeffs[2 * i + 1] - self.bound_coeffs[2 * i]);
}
}
self.bound_coeffs = new_coeffs;
}
BindingOrder::HighToLow => {
let (left, right) = self.bound_coeffs.split_at_mut(n);
left.par_iter_mut()
.zip(right.par_iter())
.filter(|(a, b)| a != b)
.for_each(|(a, b)| {
*a += r * (*b - *a);
});
}
}
} else {
let r_r2 = r * F::montgomery_r2().unwrap_or(F::one());
let one_minus_r_r2 = (F::one() - r) * F::montgomery_r2().unwrap_or(F::one());
match order {
BindingOrder::LowToHigh => {
self.bound_coeffs = (0..n)
.into_par_iter()
.map(|i| {
if self.coeffs[2 * i] == self.coeffs[2 * i + 1] {
self.coeffs[2 * i].to_field()
} else {
self.coeffs[2 * i].field_mul(one_minus_r_r2)
+ self.coeffs[2 * i + 1].field_mul(r_r2)
}
})
.collect();
}
BindingOrder::HighToLow => {
let (left, right) = self.coeffs.split_at(n);
self.bound_coeffs = left
.par_iter()
.zip(right.par_iter())
.map(|(&a, &b)| {
if a == b {
a.to_field()
} else {
a.field_mul(one_minus_r_r2) + b.field_mul(r_r2)
}
})
.collect();
}
}
}
self.num_vars -= 1;
self.len = n;
}

fn final_sumcheck_claim(&self) -> F {
assert_eq!(self.len, 1);
self.bound_coeffs[0]
Expand Down
119 changes: 112 additions & 7 deletions jolt-core/src/poly/identity_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,125 @@ use crate::field::JoltField;

use crate::utils::math::Math;

pub struct IdentityPolynomial {
size_point: usize,
use super::multilinear_polynomial::{BindingOrder, PolynomialBinding, PolynomialEvaluation};

pub struct IdentityPolynomial<F: JoltField> {
num_vars: usize,
num_bound_vars: usize,
bound_value: F,
}

impl<F: JoltField> IdentityPolynomial<F> {
pub fn new(num_vars: usize) -> Self {
IdentityPolynomial {
num_vars,
num_bound_vars: 0,
bound_value: F::zero(),
}
}
}

impl IdentityPolynomial {
pub fn new(size_point: usize) -> Self {
IdentityPolynomial { size_point }
impl<F: JoltField> PolynomialBinding<F> for IdentityPolynomial<F> {
fn is_bound(&self) -> bool {
self.num_bound_vars != 0
}

fn bind(&mut self, r: F, order: BindingOrder) {
debug_assert!(self.num_bound_vars < self.num_vars);
debug_assert_eq!(
order,
BindingOrder::LowToHigh,
"IdentityPolynomial only supports low-to-high binding"
);

self.bound_value += F::from_u32(1u32 << self.num_bound_vars) * r;
self.num_bound_vars += 1;
}

pub fn evaluate<F: JoltField>(&self, r: &[F]) -> F {
fn bind_parallel(&mut self, r: F, order: BindingOrder) {
// Binding is constant time, no parallelism necessary
self.bind(r, order);
}

fn final_sumcheck_claim(&self) -> F {
debug_assert_eq!(self.num_vars, self.num_bound_vars);
self.bound_value
}
}

impl<F: JoltField> PolynomialEvaluation<F> for IdentityPolynomial<F> {
fn evaluate(&self, r: &[F]) -> F {
let len = r.len();
assert_eq!(len, self.size_point);
assert_eq!(len, self.num_vars);
(0..len)
.map(|i| F::from_u64((len - i - 1).pow2() as u64) * r[i])
.sum()
}

fn batch_evaluate(_polys: &[&Self], _r: &[F]) -> (Vec<F>, Vec<F>) {
unimplemented!("Currently unused")
}

fn sumcheck_evals(&self, index: usize, degree: usize, order: BindingOrder) -> Vec<F> {
debug_assert!(degree > 0);
debug_assert!(index < self.num_vars.pow2() / 2);
debug_assert_eq!(
order,
BindingOrder::LowToHigh,
"IdentityPolynomial only supports low-to-high binding"
);

let mut evals = vec![F::zero(); degree];
evals[0] = self.bound_value + F::from_u64((index as u64) << (1 + self.num_bound_vars));
let m = F::from_u32(1 << self.num_bound_vars);
let mut eval = evals[0] + m;
for i in 1..degree {
eval += m;
evals[i] = eval;
}
evals
}
}

#[cfg(test)]
mod tests {
use crate::poly::multilinear_polynomial::MultilinearPolynomial;

use super::*;
use ark_bn254::Fr;
use ark_std::test_rng;

#[test]
fn identity_poly() {
const NUM_VARS: usize = 10;

let mut rng = test_rng();
let mut identity_poly: IdentityPolynomial<Fr> = IdentityPolynomial::new(NUM_VARS);
let mut reference_poly: MultilinearPolynomial<Fr> =
MultilinearPolynomial::from((0..(1 << NUM_VARS)).map(|i| i as u32).collect::<Vec<_>>());

for j in 0..reference_poly.len() / 2 {
let identity_poly_evals = identity_poly.sumcheck_evals(j, 3, BindingOrder::LowToHigh);
let reference_poly_evals = reference_poly.sumcheck_evals(j, 3, BindingOrder::LowToHigh);
assert_eq!(identity_poly_evals, reference_poly_evals);
}

for _ in 0..NUM_VARS {
let r = Fr::random(&mut rng);
identity_poly.bind(r, BindingOrder::LowToHigh);
reference_poly.bind(r, BindingOrder::LowToHigh);
for j in 0..reference_poly.len() / 2 {
let identity_poly_evals =
identity_poly.sumcheck_evals(j, 3, BindingOrder::LowToHigh);
let reference_poly_evals =
reference_poly.sumcheck_evals(j, 3, BindingOrder::LowToHigh);
assert_eq!(identity_poly_evals, reference_poly_evals);
}
}

assert_eq!(
identity_poly.final_sumcheck_claim(),
reference_poly.final_sumcheck_claim()
);
}
}
20 changes: 20 additions & 0 deletions jolt-core/src/poly/multilinear_polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub enum MultilinearPolynomial<F: JoltField> {
}

/// The order in which polynomial variables are bound in sumcheck
#[derive(Debug, PartialEq)]
pub enum BindingOrder {
LowToHigh,
HighToLow,
Expand Down Expand Up @@ -105,6 +106,7 @@ impl<F: JoltField> MultilinearPolynomial<F> {
}
}

#[tracing::instrument(skip_all)]
pub fn linear_combination(polynomials: &[&Self], coefficients: &[F]) -> Self {
debug_assert_eq!(polynomials.len(), coefficients.len());

Expand Down Expand Up @@ -412,6 +414,9 @@ pub trait PolynomialBinding<F: JoltField> {
fn is_bound(&self) -> bool;
/// Binds the polynomial to a random field element `r`.
fn bind(&mut self, r: F, order: BindingOrder);
/// Binds the polynomial to a random field element `r`, parallelizing
/// by coefficient.
fn bind_parallel(&mut self, r: F, order: BindingOrder);
/// Returns the final sumcheck claim about the polynomial.
fn final_sumcheck_claim(&self) -> F;
}
Expand Down Expand Up @@ -456,6 +461,21 @@ impl<F: JoltField> PolynomialBinding<F> for MultilinearPolynomial<F> {
}
}

#[tracing::instrument(skip_all, name = "MultilinearPolynomial::bind_parallel")]
fn bind_parallel(&mut self, r: F, order: BindingOrder) {
match self {
MultilinearPolynomial::LargeScalars(poly) => match order {
BindingOrder::LowToHigh => poly.bound_poly_var_bot_01_optimized(&r),
BindingOrder::HighToLow => poly.bound_poly_var_top_zero_optimized(&r),
},
MultilinearPolynomial::U8Scalars(poly) => poly.bind_parallel(r, order),
MultilinearPolynomial::U16Scalars(poly) => poly.bind_parallel(r, order),
MultilinearPolynomial::U32Scalars(poly) => poly.bind_parallel(r, order),
MultilinearPolynomial::U64Scalars(poly) => poly.bind_parallel(r, order),
MultilinearPolynomial::I64Scalars(poly) => poly.bind_parallel(r, order),
}
}

fn final_sumcheck_claim(&self) -> F {
match self {
MultilinearPolynomial::LargeScalars(poly) => {
Expand Down
1 change: 1 addition & 0 deletions jolt-core/src/subprotocols/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

pub mod grand_product;
pub mod grand_product_quarks;
pub mod shout;
pub mod sparse_grand_product;
pub mod sumcheck;

Expand Down
Loading