Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: trying out plonky2 prover over gpu #1332

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 153 additions & 63 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,7 @@ lto = "fat"
lto = "thin"

[patch.crates-io]
plonky2 = { git = "https://github.com/0xmozak/plonky2.git" }
starky = { git = "https://github.com/0xmozak/plonky2.git" }
plonky2 = { git = "https://github.com/0xmozak/plonky2.git", branch = "vivek/cuda" }
starky = { git = "https://github.com/0xmozak/plonky2.git", branch = "vivek/cuda" }
# plonky2 = { path = "../plonky2/plonky2" }
# starky = { path = "../plonky2/starky" }
5 changes: 4 additions & 1 deletion circuits/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ serde = { version = "1.0", features = ["derive"] }
starky = { version = "0", default-features = false, features = ["std"] }
thiserror = "1.0"
tt-call = "1.0"
rustacuda = "0.1.3"
rustacuda_core = "0.1.2"
env_logger = { version = "0.10" }

[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
env_logger = { version = "0.10" }
hex = "0.4"
im = "15.1"
mozak-examples = { path = "../examples-builder", features = ["fibonacci", "fibonacci-input-new-api"] }
Expand All @@ -41,6 +43,7 @@ enable_poseidon_starks = []
enable_register_starks = []
test = []
timing = ["plonky2/timing", "starky/timing"]
cuda = ["plonky2/cuda"]

[[test]]
name = "riscv_tests"
Expand Down
5 changes: 5 additions & 0 deletions circuits/src/cpu/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ mod tests {
Stark::prove_and_verify(&program, &record).unwrap();
}

#[test]
fn prove_add_cuda() {
prove_add::<MozakStark<F, D>>(90, 90, 5);
}

use proptest::prelude::ProptestConfig;
use proptest::proptest;
proptest! {
Expand Down
1 change: 1 addition & 0 deletions circuits/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#![allow(clippy::missing_errors_doc)]
// FIXME: Remove this, when proptest's macro is updated not to trigger clippy.
#![allow(clippy::ignored_unit_patterns)]
#![feature(allocator_api)]

pub mod bitshift;
pub mod columns_view;
Expand Down
181 changes: 160 additions & 21 deletions circuits/src/stark/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use plonky2::field::extension::Extendable;
use plonky2::field::packable::Packable;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Field;
use plonky2::fri::oracle::PolynomialBatch;
use plonky2::fri::oracle::{CudaInvContext, PolynomialBatch};
use plonky2::hash::hash_types::RichField;
use plonky2::iop::challenger::Challenger;
use plonky2::plonk::config::GenericConfig;
Expand Down Expand Up @@ -48,6 +48,7 @@ pub fn prove<F, C, const D: usize>(
config: &StarkConfig,
public_inputs: PublicInputs<F>,
timing: &mut TimingTree,
ctx: &mut Option<&mut CudaInvContext<F, D>>,
) -> Result<AllProof<F, C, D>>
where
F: RichField + Extendable<D>,
Expand All @@ -63,6 +64,7 @@ where
public_inputs,
&traces_poly_values,
timing,
ctx,
)
}

Expand All @@ -76,34 +78,68 @@ pub fn prove_with_traces<F, C, const D: usize>(
public_inputs: PublicInputs<F>,
traces_poly_values: &TableKindArray<Vec<PolynomialValues<F>>>,
timing: &mut TimingTree,
ctx: &mut Option<&mut CudaInvContext<F, D>>,
) -> Result<AllProof<F, C, D>>
where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>, {
let rate_bits = config.fri_config.rate_bits;
let cap_height = config.fri_config.cap_height;
let trace_commitments;

let trace_commitments = timed!(
timing,
"Compute trace commitments for each table",
traces_poly_values
.clone()
.with_kind()
.map(|(trace, table)| {
timed!(
timing,
&format!("compute trace commitment for {table:?}"),
PolynomialBatch::<F, C, D>::from_values(
trace.clone(),
rate_bits,
false,
cap_height,
#[cfg(feature = "cuda")]
{
trace_commitments = timed!(
timing,
"Compute trace commitments for each table",
traces_poly_values
.clone()
.with_kind()
.map(|(trace, table)| {
timed!(
timing,
None,
&format!("compute trace commitment for {table:?}"),
// creates merkle tree out of trace polynomials over gpu
PolynomialBatch::<F, C, D>::from_values_cuda(
trace.clone(),
rate_bits,
false,
cap_height,
timing,
trace.len(),
trace.first().expect("Not a single polynomial").len(),
ctx.as_mut().unwrap(),
)
)
)
})
);
})
);
}

#[cfg(not(feature = "cuda"))]
{
trace_commitments = timed!(
timing,
"Compute trace commitments for each table",
traces_poly_values
.clone()
.with_kind()
.map(|(trace, table)| {
timed!(
timing,
&format!("compute trace commitment for {table:?}"),
PolynomialBatch::<F, C, D>::from_values(
trace.clone(),
rate_bits,
false,
cap_height,
timing,
None,
)
)
})
);
}
// log::info!("trace_commitments {:?}", trace_commitments);

let trace_caps = trace_commitments
.each_ref()
Expand All @@ -124,6 +160,7 @@ where
&ctl_challenges
)
);
#[cfg(feature = "cuda")]
let proofs_with_metadata = timed!(
timing,
"compute all proofs given commitments",
Expand All @@ -135,7 +172,24 @@ where
&trace_commitments,
&ctl_data_per_table,
&mut challenger,
timing
timing,
ctx,
)?
);
#[cfg(not(feature = "cuda"))]
let proofs_with_metadata = timed!(
timing,
"compute all proofs given commitments",
prove_with_commitments(
mozak_stark,
config,
&public_inputs,
traces_poly_values,
&trace_commitments,
&ctl_data_per_table,
&mut challenger,
timing,
&mut None
)?
);

Expand Down Expand Up @@ -171,6 +225,7 @@ pub(crate) fn prove_single_table<F, C, S, const D: usize>(
ctl_data: &CtlData<F>,
challenger: &mut Challenger<F, C::Hasher>,
timing: &mut TimingTree,
ctx: &mut Option<&mut CudaInvContext<F, D>>,
) -> Result<StarkProofWithMetadata<F, C, D>>
where
F: RichField + Extendable<D>,
Expand Down Expand Up @@ -301,6 +356,7 @@ where
challenger,
&fri_params,
timing,
ctx,
)
);

Expand Down Expand Up @@ -332,6 +388,7 @@ pub fn prove_with_commitments<F, C, const D: usize>(
ctl_data_per_table: &TableKindArray<CtlData<F>>,
challenger: &mut Challenger<F, C::Hasher>,
timing: &mut TimingTree,
ctx: &mut Option<&mut CudaInvContext<F, D>>,
) -> Result<TableKindArray<StarkProofWithMetadata<F, C, D>>>
where
F: RichField + Extendable<D>,
Expand All @@ -353,6 +410,7 @@ where
&ctl_data_per_table[kind],
challenger,
timing,
ctx,
)?
}))
}
Expand Down Expand Up @@ -480,4 +538,85 @@ mod tests {
},
]);
}

/// Test for ensuring Polynomial batch computed over gpu is same as
/// that computed over cpu
#[test]
#[cfg(feature = "cuda")]
fn test_cuda_poly_batch() {
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Sample;
use plonky2::fri::oracle::PolynomialBatch;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::util::timing::TimingTree;

const D: usize = 2;
// the cuda code only supports Poseidon for now (Not Poseidon2!)
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;

// number of rows of trace table
let values_num_per_poly = 1 << 6;
// number of columns of trace table
let poly_num = 8;
// generate random trace
let mut polys = vec![];
for _i in 0..poly_num {
let poly: Vec<F> = (0..values_num_per_poly).map(|_| F::rand()).collect();
let poly_as_value = PolynomialValues::new(poly);
polys.push(poly_as_value);
}
// rate_bits = log2ceil(constraint_degree)
let rate_bits = 2;
let cap_height = 4;
let len_cap = 1 << cap_height;
// flattened len of all lde polynomials
let _all_len = poly_num * values_num_per_poly * (1 << rate_bits);
let num_digests = 2 * (values_num_per_poly * (1 << rate_bits) - len_cap);
let _num_digests_and_caps = num_digests + len_cap;
let blinding = false;
let timing = &mut TimingTree::default();
// merkle tree over over cpu
let batch: PolynomialBatch<F, C, D> = PolynomialBatch::from_values(
polys.clone(),
rate_bits,
blinding,
cap_height,
timing,
None,
);
let ctx = &mut crate::test_utils::cuda_ctx();
// merkle tree over gpu
let cuda_batch: PolynomialBatch<F, C, D> = PolynomialBatch::from_values_cuda(
polys,
rate_bits,
blinding,
cap_height,
timing,
poly_num,
values_num_per_poly,
ctx,
);

// check that polynomials were computed in coefficient form
// are same for cpu and gpu.
assert_eq!(batch.polynomials, cuda_batch.polynomials);
let leaves = batch
.merkle_tree
.leaves
.into_iter()
.flatten()
.collect::<Vec<F>>();
// check that merkle tree computed over cpu is same as gpu
assert_eq!(leaves, *cuda_batch.merkle_tree.my_leaves);
assert_eq!(batch.merkle_tree.cap, cuda_batch.merkle_tree.cap);
assert_eq!(
batch.merkle_tree.digests.len(),
cuda_batch.merkle_tree.my_digests[..num_digests].len()
);
assert_eq!(
batch.merkle_tree.digests,
cuda_batch.merkle_tree.my_digests[..num_digests]
);
}
}
10 changes: 9 additions & 1 deletion circuits/src/stark/recursive_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ where
}

#[cfg(test)]
#[allow(unused_imports)]
mod tests {
use std::panic;
use std::panic::AssertUnwindSafe;
Expand Down Expand Up @@ -639,6 +640,9 @@ mod tests {
#[test]
#[ignore]
fn recursive_verify_mozak_starks() -> Result<()> {
#[cfg(not(feature = "cuda"))]
{
type S = MozakStark<F, D>;
let stark = S::default();
let mut config = StarkConfig::standard_fast_config();
config.fri_config.cap_height = 1;
Expand Down Expand Up @@ -678,7 +682,9 @@ mod tests {
);

let recursive_proof = mozak_stark_circuit.prove(&mozak_proof)?;
mozak_stark_circuit.circuit.verify(recursive_proof)
mozak_stark_circuit.circuit.verify(recursive_proof)?;
}
Ok(())
}

#[test]
Expand Down Expand Up @@ -708,6 +714,7 @@ mod tests {
&stark_config0,
public_inputs,
&mut TimingTree::default(),
&mut None,
)?;

let (program1, record1) = execute_code(vec![inst; 128], &[], &[(6, 100), (7, 200)]);
Expand All @@ -722,6 +729,7 @@ mod tests {
&stark_config1,
public_inputs,
&mut TimingTree::default(),
&mut None,
)?;

// The degree bits should be different for the two proofs.
Expand Down
6 changes: 5 additions & 1 deletion circuits/src/stark/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize> A
}

#[cfg(test)]
#[allow(unused_imports)]
mod tests {
use mozak_runner::util::execute_code;
use plonky2::util::timing::TimingTree;
Expand All @@ -45,13 +46,15 @@ mod tests {

#[test]
fn test_serialization_deserialization() {

#[cfg(not(feature = "cuda"))]
{
let (program, record) = execute_code([], &[], &[]);
let stark = S::default();
let config = fast_test_config();
let public_inputs = PublicInputs {
entry_point: from_u32(program.entry_point),
};

let all_proof = prove::<F, C, D>(
&program,
&record,
Expand All @@ -68,5 +71,6 @@ mod tests {
AllProof::<F, C, D>::deserialize_proof_from_flexbuffer(s.view())
.expect("deserialization failed");
verify_proof(&stark, all_proof_deserialized, &config).unwrap();
}
}
}
Loading