Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option for partitioned trace committment #336

Merged
merged 14 commits into from
Oct 24, 2024
2 changes: 1 addition & 1 deletion air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mod errors;
pub use errors::AssertionError;

mod options;
pub use options::{FieldExtension, ProofOptions};
pub use options::{FieldExtension, PartitionOptions, ProofOptions};

mod air;
pub use air::{
Expand Down
114 changes: 103 additions & 11 deletions air/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
// LICENSE file in the root directory of this source tree.

use alloc::vec::Vec;
use core::{cmp, ops::Div};

use fri::FriOptions;
use math::{StarkField, ToElements};
use math::{FieldElement, StarkField, ToElements};
use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

// CONSTANTS
Expand Down Expand Up @@ -74,6 +75,17 @@ pub enum FieldExtension {
/// is the hash function used in the protocol. The soundness of a STARK proof is limited by the
/// collision resistance of the hash function used by the protocol. For example, if a hash function
/// with 128-bit collision resistance is used, soundness of a STARK proof cannot exceed 128 bits.
///
/// In addition to the above, the parameter `num_partitions` is used in order to specify the number
/// of partitions each of the traces committed to during proof generation is split into, and
/// the parameter `min_partition_size` gives a lower bound on the size of each such partition.
/// More precisely, and taking the main segment trace as an example, the prover will split the main
/// segment trace into `num_partitions` parts each of size at least `min_partition_size`. The prover
/// will then proceed to hash each part row-wise resulting in `num_partitions` digests per row of
/// the trace. The prover finally combines the `num_partitions` digest (per row) into one digest
/// (per row) and at this point the vector commitment scheme can be called.
/// In the case when `num_partitions` is equal to `1` the prover will just hash each row in one go
/// producing one digest per row of the trace.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct ProofOptions {
num_queries: u8,
Expand All @@ -82,6 +94,7 @@ pub struct ProofOptions {
field_extension: FieldExtension,
fri_folding_factor: u8,
fri_remainder_max_degree: u8,
partition_options: PartitionOptions,
}

// PROOF OPTIONS IMPLEMENTATION
Expand All @@ -108,7 +121,6 @@ impl ProofOptions {
/// - `grinding_factor` is greater than 32.
/// - `fri_folding_factor` is not 2, 4, 8, or 16.
/// - `fri_remainder_max_degree` is greater than 255 or is not a power of two minus 1.
#[rustfmt::skip]
pub const fn new(
num_queries: usize,
blowup_factor: usize,
Expand All @@ -125,11 +137,20 @@ impl ProofOptions {
assert!(blowup_factor >= MIN_BLOWUP_FACTOR, "blowup factor cannot be smaller than 2");
assert!(blowup_factor <= MAX_BLOWUP_FACTOR, "blowup factor cannot be greater than 128");

assert!(grinding_factor <= MAX_GRINDING_FACTOR, "grinding factor cannot be greater than 32");
assert!(
grinding_factor <= MAX_GRINDING_FACTOR,
"grinding factor cannot be greater than 32"
);

assert!(fri_folding_factor.is_power_of_two(), "FRI folding factor must be a power of 2");
assert!(fri_folding_factor >= FRI_MIN_FOLDING_FACTOR, "FRI folding factor cannot be smaller than 2");
assert!(fri_folding_factor <= FRI_MAX_FOLDING_FACTOR, "FRI folding factor cannot be greater than 16");
assert!(
fri_folding_factor >= FRI_MIN_FOLDING_FACTOR,
"FRI folding factor cannot be smaller than 2"
);
assert!(
fri_folding_factor <= FRI_MAX_FOLDING_FACTOR,
"FRI folding factor cannot be greater than 16"
);

assert!(
(fri_remainder_max_degree + 1).is_power_of_two(),
Expand All @@ -140,16 +161,33 @@ impl ProofOptions {
"FRI polynomial remainder degree cannot be greater than 255"
);

ProofOptions {
Self {
num_queries: num_queries as u8,
blowup_factor: blowup_factor as u8,
grinding_factor: grinding_factor as u8,
field_extension,
fri_folding_factor: fri_folding_factor as u8,
fri_remainder_max_degree: fri_remainder_max_degree as u8,
partition_options: PartitionOptions::new(1, 1),
}
}

/// Updates the provided [ProofOptions] instance with the specified partition parameters.
///
/// # Panics
/// Panics if:
/// - `num_partitions` is zero or greater than 16.
/// - `min_partition_size` is zero or greater than 256.
pub const fn with_partitions(
mut self,
num_partitions: usize,
min_partition_size: usize,
) -> ProofOptions {
self.partition_options = PartitionOptions::new(num_partitions, min_partition_size);

self
}

// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -206,6 +244,11 @@ impl ProofOptions {
let remainder_max_degree = self.fri_remainder_max_degree as usize;
FriOptions::new(self.blowup_factor(), folding_factor, remainder_max_degree)
}

/// Returns the `[PartitionOptions]` used in this instance of proof options.
pub fn partition_options(&self) -> PartitionOptions {
self.partition_options
}
}

impl<E: StarkField> ToElements<E> for ProofOptions {
Expand Down Expand Up @@ -233,6 +276,8 @@ impl Serializable for ProofOptions {
target.write(self.field_extension);
target.write_u8(self.fri_folding_factor);
target.write_u8(self.fri_remainder_max_degree);
target.write_u8(self.partition_options.num_partitions);
target.write_u8(self.partition_options.min_partition_size);
}
}

Expand All @@ -242,14 +287,15 @@ impl Deserializable for ProofOptions {
/// # Errors
/// Returns an error of a valid proof options could not be read from the specified `source`.
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
Ok(ProofOptions::new(
let result = ProofOptions::new(
source.read_u8()? as usize,
source.read_u8()? as usize,
source.read_u8()? as u32,
FieldExtension::read_from(source)?,
source.read_u8()? as usize,
source.read_u8()? as usize,
))
);
Ok(result.with_partitions(source.read_u8()? as usize, source.read_u8()? as usize))
}
}

Expand All @@ -272,9 +318,6 @@ impl FieldExtension {
}
}

// SERIALIZATION
// ================================================================================================

impl Serializable for FieldExtension {
/// Serializes `self` and writes the resulting bytes into the `target`.
fn write_into<W: ByteWriter>(&self, target: &mut W) {
Expand All @@ -301,6 +344,55 @@ impl Deserializable for FieldExtension {
}
}

// PARTITION OPTION IMPLEMENTATION
// ================================================================================================

/// Defines the parameters used when committing to the traces generated during the protocol.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct PartitionOptions {
num_partitions: u8,
min_partition_size: u8,
}

impl PartitionOptions {
/// Returns a new instance of `[PartitionOptions]`.
pub const fn new(num_partitions: usize, min_partition_size: usize) -> Self {
assert!(num_partitions >= 1, "number of partitions must be greater than or eqaul to 1");
assert!(num_partitions <= 16, "number of partitions must be smaller than or equal to 16");

assert!(
min_partition_size >= 1,
"smallest partition size must be greater than or equal to 1"
);
assert!(
min_partition_size <= 256,
"smallest partition size must be smaller than or equal to 256"
);

Self {
num_partitions: num_partitions as u8,
min_partition_size: min_partition_size as u8,
}
}

/// Returns the size of each partition used when committing to the main and auxiliary traces as
/// well as the constraint evaluation trace.
pub fn partition_size<E: FieldElement>(&self, num_columns: usize) -> usize {
let base_elements_per_partition = cmp::max(
(num_columns * E::EXTENSION_DEGREE).div_ceil(self.num_partitions as usize),
self.min_partition_size as usize,
);

base_elements_per_partition.div(E::EXTENSION_DEGREE)
}
}

impl Default for PartitionOptions {
fn default() -> Self {
Self { num_partitions: 1, min_partition_size: 1 }
}
}

// TESTS
// ================================================================================================

Expand Down
9 changes: 9 additions & 0 deletions crypto/src/hash/blake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ impl<B: StarkField> Hasher for Blake3_256<B> {
ByteDigest(blake3::hash(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
ByteDigest(blake3::hash(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut data = [0; 40];
data[..32].copy_from_slice(&seed.0);
Expand Down Expand Up @@ -84,6 +88,11 @@ impl<B: StarkField> Hasher for Blake3_192<B> {
ByteDigest(result.as_bytes()[..24].try_into().unwrap())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
let result = blake3::hash(ByteDigest::digests_as_bytes(values));
ByteDigest(result.as_bytes()[..24].try_into().unwrap())
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut data = [0; 32];
data[..24].copy_from_slice(&seed.0);
Expand Down
24 changes: 24 additions & 0 deletions crypto/src/hash/blake/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

use math::{fields::f62::BaseElement, FieldElement};
use rand_utils::rand_array;
use utils::Deserializable;

use super::{Blake3_256, ElementHasher, Hasher};
use crate::hash::{Blake3_192, ByteDigest};

#[test]
fn hash_padding() {
Expand All @@ -29,3 +31,25 @@ fn hash_elements_padding() {
let r2 = Blake3_256::hash_elements(&e2);
assert_ne!(r1, r2);
}

#[test]
fn merge_vs_merge_many_256() {
let digest_0 = ByteDigest::read_from_bytes(&[1_u8; 32]).unwrap();
let digest_1 = ByteDigest::read_from_bytes(&[2_u8; 32]).unwrap();

let r1 = Blake3_256::<BaseElement>::merge(&[digest_0, digest_1]);
let r2 = Blake3_256::<BaseElement>::merge_many(&[digest_0, digest_1]);

assert_eq!(r1, r2)
}

#[test]
fn merge_vs_merge_many_192() {
let digest_0 = ByteDigest::read_from_bytes(&[1_u8; 24]).unwrap();
let digest_1 = ByteDigest::read_from_bytes(&[2_u8; 24]).unwrap();

let r1 = Blake3_192::<BaseElement>::merge(&[digest_0, digest_1]);
let r2 = Blake3_192::<BaseElement>::merge_many(&[digest_0, digest_1]);

assert_eq!(r1, r2)
}
3 changes: 3 additions & 0 deletions crypto/src/hash/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ pub trait Hasher {
/// Merkle trees.
fn merge(values: &[Self::Digest; 2]) -> Self::Digest;

/// Returns a hash of many digests.
fn merge_many(values: &[Self::Digest]) -> Self::Digest;

/// Returns hash(`seed` || `value`). This method is intended for use in PRNG and PoW contexts.
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest;
}
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/rescue/rp62_248/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ impl Hasher for Rp62_248 {
ElementDigest::new(state[..DIGEST_SIZE].try_into().unwrap())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(ElementDigest::digests_as_elements(values))
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the state.
Expand Down
14 changes: 14 additions & 0 deletions crypto/src/hash/rescue/rp62_248/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ fn hash_elements_vs_merge() {
assert_eq!(m_result, h_result);
}

#[test]
fn merge_vs_merge_many() {
let elements: [BaseElement; 8] = rand_array();

let digests: [ElementDigest; 2] = [
ElementDigest::new(elements[..4].try_into().unwrap()),
ElementDigest::new(elements[4..].try_into().unwrap()),
];

let m_result = Rp62_248::merge(&digests);
let h_result = Rp62_248::merge_many(&digests);
assert_eq!(m_result, h_result);
}

#[test]
fn hash_elements_vs_merge_with_int() {
let seed = ElementDigest::new(rand_array());
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/rescue/rp64_256/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ impl Hasher for Rp64_256 {
ElementDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(ElementDigest::digests_as_elements(values))
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state.
Expand Down
14 changes: 14 additions & 0 deletions crypto/src/hash/rescue/rp64_256/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ fn hash_elements_vs_merge() {
assert_eq!(m_result, h_result);
}

#[test]
fn merge_vs_merge_many() {
let elements: [BaseElement; 8] = rand_array();

let digests: [ElementDigest; 2] = [
ElementDigest::new(elements[..4].try_into().unwrap()),
ElementDigest::new(elements[4..].try_into().unwrap()),
];

let m_result = Rp64_256::merge(&digests);
let h_result = Rp64_256::merge_many(&digests);
assert_eq!(m_result, h_result);
}

#[test]
fn hash_elements_vs_merge_with_int() {
let seed = ElementDigest::new(rand_array());
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/rescue/rp64_256_jive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ impl Hasher for RpJive64_256 {
Self::apply_jive_summation(&initial_state, &state)
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(ElementDigest::digests_as_elements(values))
}

// We do not rely on the sponge construction to build our compression function. Instead, we use
// the Jive compression mode designed in https://eprint.iacr.org/2022/840.pdf.
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
Expand Down
4 changes: 4 additions & 0 deletions crypto/src/hash/sha/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ impl<B: StarkField> Hasher for Sha3_256<B> {
ByteDigest(sha3::Sha3_256::digest(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_many(values: &[Self::Digest]) -> Self::Digest {
ByteDigest(sha3::Sha3_256::digest(ByteDigest::digests_as_bytes(values)).into())
}

fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut data = [0; 40];
data[..32].copy_from_slice(&seed.0);
Expand Down
7 changes: 4 additions & 3 deletions examples/src/fibonacci/fib2/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

use winterfell::{
crypto::MerkleTree, matrix::ColMatrix, AuxRandElements, ConstraintCompositionCoefficients,
DefaultConstraintEvaluator, DefaultTraceLde, StarkDomain, Trace, TraceInfo, TracePolyTable,
TraceTable,
DefaultConstraintEvaluator, DefaultTraceLde, PartitionOptions, StarkDomain, Trace, TraceInfo,
TracePolyTable, TraceTable,
};

use super::{
Expand Down Expand Up @@ -77,8 +77,9 @@ where
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
partition_option: PartitionOptions,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain)
DefaultTraceLde::new(trace_info, main_trace, domain, partition_option)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
Expand Down
Loading