Skip to content

Commit

Permalink
feat: Add maybe_async_trait (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
igamigo authored Oct 4, 2024
1 parent 3345055 commit e92e541
Show file tree
Hide file tree
Showing 21 changed files with 230 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

## 0.10.0 (2024-06-11) - `utils/maybe-async` crate only
- Added `maybe-async-trait` procedural macro.
- [BREAKING] Refactored `maybe-async` macro into simpler `maybe-async` and `maybe-await` macros.

## 0.9.1 (2024-06-24) - `utils/core` crate only
Expand Down
3 changes: 1 addition & 2 deletions air/src/air/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,7 @@ impl<B: StarkField> AirContext<B> {

// we use the identity: ceil(a/b) = (a + b - 1)/b
let num_constraint_col =
(highest_constraint_degree - transition_divisior_degree + trace_length - 1)
/ trace_length;
(highest_constraint_degree - transition_divisior_degree).div_ceil(trace_length);

cmp::max(num_constraint_col, 1)
}
Expand Down
2 changes: 2 additions & 0 deletions air/src/proof/ood_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ impl Deserializable for OodFrame {
// OOD FRAME TRACE STATES
// ================================================================================================

/// Trace evaluation frame at the out-of-domain point.
///
/// Stores the trace evaluations at `z` and `gz`, where `z` is a random Field element in
/// `current_row` and `next_row`, respectively. If the Air contains a Lagrange kernel auxiliary
/// column, then that column interpolated polynomial will be evaluated at `z`, `gz`, `g^2 z`, ...
Expand Down
4 changes: 2 additions & 2 deletions air/src/proof/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ impl<'a, E: FieldElement> Iterator for RowIterator<'a, E> {
}
}

impl<'a, E: FieldElement> ExactSizeIterator for RowIterator<'a, E> {
impl<E: FieldElement> ExactSizeIterator for RowIterator<'_, E> {
fn len(&self) -> usize {
self.table.num_rows()
}
}

impl<'a, E: FieldElement> FusedIterator for RowIterator<'a, E> {}
impl<E: FieldElement> FusedIterator for RowIterator<'_, E> {}
28 changes: 14 additions & 14 deletions crypto/src/hash/mds/mds_f64_12x12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

//! This module contains helper functions as well as constants used to perform a 12x12 vector-matrix
//! multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce
//! the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain".
//! This follows from the simple fact that every circulant matrix has the columns of the discrete
//! Fourier transform matrix as orthogonal eigenvectors.
//! The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that
//! with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain,
//! divisions by 2 and repeated modular reductions. This is because of our explicit choice of
//! an MDS matrix that has small powers of 2 entries in frequency domain.
//! The following implementation has benefited greatly from the discussions and insights of
//! Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is based on Nabaglo's implementation
//! in [Plonky2](https://github.com/mir-protocol/plonky2).
//! The circulant matrix is identified by its first row: [7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8].
// FFT-BASED MDS MULTIPLICATION HELPER FUNCTIONS
// ================================================================================================

Expand All @@ -12,20 +26,6 @@ use math::{
FieldElement,
};

/// This module contains helper functions as well as constants used to perform a 12x12 vector-matrix
/// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce
/// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain".
/// This follows from the simple fact that every circulant matrix has the columns of the discrete
/// Fourier transform matrix as orthogonal eigenvectors.
/// The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that
/// with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain,
/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of
/// an MDS matrix that has small powers of 2 entries in frequency domain.
/// The following implementation has benefited greatly from the discussions and insights of
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is based on Nabaglo's implementation
/// in [Plonky2](https://github.com/mir-protocol/plonky2).
/// The circulant matrix is identified by its first row: [7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8].
// MDS matrix in frequency domain.
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of
// the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors
Expand Down
26 changes: 13 additions & 13 deletions crypto/src/hash/mds/mds_f64_8x8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

//! This module contains helper functions as well as constants used to perform a 8x8 vector-matrix
//! multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce
//! the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain".
//! This follows from the simple fact that every circulant matrix has the columns of the discrete
//! Fourier transform matrix as orthogonal eigenvectors.
//! The implementation also avoids the use of internal 2-point FFTs, and 2-point iFFTs, and substitutes
//! them with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain,
//! divisions by 2 and repeated modular reductions. This is because of our explicit choice of
//! an MDS matrix that has small powers of 2 entries in frequency domain.
//! The following implementation has benefited greatly from the discussions and insights of
//! Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero is based on Nabaglo's implementation
//! in [Plonky2](https://github.com/mir-protocol/plonky2).
// FFT-BASED MDS MULTIPLICATION HELPER FUNCTIONS
// ================================================================================================

Expand All @@ -12,20 +25,7 @@ use math::{
FieldElement,
};

/// This module contains helper functions as well as constants used to perform a 8x8 vector-matrix
/// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce
/// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain".
/// This follows from the simple fact that every circulant matrix has the columns of the discrete
/// Fourier transform matrix as orthogonal eigenvectors.
/// The implementation also avoids the use of internal 2-point FFTs, and 2-point iFFTs, and substitutes
/// them with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain,
/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of
/// an MDS matrix that has small powers of 2 entries in frequency domain.
/// The following implementation has benefited greatly from the discussions and insights of
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero is based on Nabaglo's implementation
/// in [Plonky2](https://github.com/mir-protocol/plonky2).
/// The circulant matrix is identified by its first row: [23, 8, 13, 10, 7, 6, 21, 8].
// MDS matrix in frequency domain.
// More precisely, this is the output of the two 4-point (real) FFTs of the first column of
// the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors
Expand Down
2 changes: 2 additions & 0 deletions crypto/src/merkle/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub const MIN_CONCURRENT_LEAVES: usize = 1024;
// PUBLIC FUNCTIONS
// ================================================================================================

/// Returns internal nodes of a Merkle tree constructed from the provided leaves.
///
/// Builds all internal nodes of the Merkle using all available threads and stores the
/// results in a single vector such that root of the tree is at position 1, nodes immediately
/// under the root is at positions 2 and 3 etc.
Expand Down
2 changes: 2 additions & 0 deletions examples/src/utils/rescue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ pub const RATE_WIDTH: usize = 4;
/// Two elements (32-bytes) are returned as digest.
const DIGEST_SIZE: usize = 2;

/// Number of rounds in a single permutation of the hash function.
///
/// The number of rounds is set to 7 to provide 128-bit security level with 40% security margin;
/// computed using algorithm 7 from <https://eprint.iacr.org/2020/1143.pdf>
/// security margin here differs from Rescue Prime specification which suggests 50% security
Expand Down
2 changes: 1 addition & 1 deletion math/src/field/extensions/cubic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ impl<B: ExtensibleField<3>> TryFrom<u128> for CubeExtension<B> {
}
}

impl<'a, B: ExtensibleField<3>> TryFrom<&'a [u8]> for CubeExtension<B> {
impl<B: ExtensibleField<3>> TryFrom<&'_ [u8]> for CubeExtension<B> {
type Error = DeserializationError;

/// Converts a slice of bytes into a field element; returns error if the value encoded in bytes
Expand Down
2 changes: 1 addition & 1 deletion math/src/field/extensions/quadratic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ impl<B: ExtensibleField<2>> TryFrom<u128> for QuadExtension<B> {
}
}

impl<'a, B: ExtensibleField<2>> TryFrom<&'a [u8]> for QuadExtension<B> {
impl<B: ExtensibleField<2>> TryFrom<&'_ [u8]> for QuadExtension<B> {
type Error = DeserializationError;

/// Converts a slice of bytes into a field element; returns error if the value encoded in bytes
Expand Down
2 changes: 1 addition & 1 deletion math/src/field/f128/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ impl TryFrom<u128> for BaseElement {
}
}

impl<'a> TryFrom<&'a [u8]> for BaseElement {
impl TryFrom<&'_ [u8]> for BaseElement {
type Error = String;

/// Converts a slice of bytes into a field element; returns error if the value encoded in bytes
Expand Down
2 changes: 1 addition & 1 deletion math/src/field/f62/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ impl TryFrom<[u8; 8]> for BaseElement {
}
}

impl<'a> TryFrom<&'a [u8]> for BaseElement {
impl TryFrom<&'_ [u8]> for BaseElement {
type Error = DeserializationError;

/// Converts a slice of bytes into a field element; returns error if the value encoded in bytes
Expand Down
3 changes: 2 additions & 1 deletion math/src/field/f64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

//! An implementation of a 64-bit STARK-friendly prime field with modulus $2^{64} - 2^{32} + 1$
//! using Montgomery representation.
//!
//! Our implementation follows <https://eprint.iacr.org/2022/274.pdf> and is constant-time.
//!
//! This field supports very fast modular arithmetic and has a number of other attractive
Expand Down Expand Up @@ -571,7 +572,7 @@ impl TryFrom<[u8; 8]> for BaseElement {
}
}

impl<'a> TryFrom<&'a [u8]> for BaseElement {
impl TryFrom<&'_ [u8]> for BaseElement {
type Error = DeserializationError;

/// Converts a slice of bytes into a field element; returns error if the value encoded in bytes
Expand Down
2 changes: 1 addition & 1 deletion prover/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ where
// FRI PROVER CHANNEL IMPLEMENTATION
// ================================================================================================

impl<'a, A, E, H, R, V> fri::ProverChannel<E> for ProverChannel<'a, A, E, H, R, V>
impl<A, E, H, R, V> fri::ProverChannel<E> for ProverChannel<'_, A, E, H, R, V>
where
A: Air,
E: FieldElement<BaseField = A::BaseField>,
Expand Down
2 changes: 1 addition & 1 deletion prover/src/constraints/evaluation_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ pub struct EvaluationTableFragment<'a, E: FieldElement> {
ta_evaluations: Vec<&'a mut [E]>,
}

impl<'a, E: FieldElement> EvaluationTableFragment<'a, E> {
impl<E: FieldElement> EvaluationTableFragment<'_, E> {
/// Returns the row at which the fragment starts.
pub fn offset(&self) -> usize {
self.offset
Expand Down
2 changes: 1 addition & 1 deletion prover/src/constraints/evaluator/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub struct DefaultConstraintEvaluator<'a, A: Air, E: FieldElement<BaseField = A:
periodic_values: PeriodicValueTable<E::BaseField>,
}

impl<'a, A, E> ConstraintEvaluator<E> for DefaultConstraintEvaluator<'a, A, E>
impl<A, E> ConstraintEvaluator<E> for DefaultConstraintEvaluator<'_, A, E>
where
A: Air,
E: FieldElement<BaseField = A::BaseField>,
Expand Down
10 changes: 5 additions & 5 deletions prover/src/matrix/col_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,15 +333,15 @@ impl<'a, E: FieldElement> Iterator for ColumnIter<'a, E> {
}
}

impl<'a, E: FieldElement> ExactSizeIterator for ColumnIter<'a, E> {
impl<E: FieldElement> ExactSizeIterator for ColumnIter<'_, E> {
fn len(&self) -> usize {
self.matrix.map(|matrix| matrix.num_cols()).unwrap_or_default()
}
}

impl<'a, E: FieldElement> FusedIterator for ColumnIter<'a, E> {}
impl<E: FieldElement> FusedIterator for ColumnIter<'_, E> {}

impl<'a, E: FieldElement> Default for ColumnIter<'a, E> {
impl<E: FieldElement> Default for ColumnIter<'_, E> {
fn default() -> Self {
Self::empty()
}
Expand Down Expand Up @@ -382,10 +382,10 @@ impl<'a, E: FieldElement> Iterator for ColumnIterMut<'a, E> {
}
}

impl<'a, E: FieldElement> ExactSizeIterator for ColumnIterMut<'a, E> {
impl<E: FieldElement> ExactSizeIterator for ColumnIterMut<'_, E> {
fn len(&self) -> usize {
self.matrix.num_cols()
}
}

impl<'a, E: FieldElement> FusedIterator for ColumnIterMut<'a, E> {}
impl<E: FieldElement> FusedIterator for ColumnIterMut<'_, E> {}
2 changes: 1 addition & 1 deletion prover/src/trace/trace_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ pub struct TraceTableFragment<'a, B: StarkField> {
data: Vec<&'a mut [B]>,
}

impl<'a, B: StarkField> TraceTableFragment<'a, B> {
impl<B: StarkField> TraceTableFragment<'_, B> {
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------

Expand Down
4 changes: 2 additions & 2 deletions utils/core/src/serde/byte_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ impl<'a> ReadAdapter<'a> {
}

#[cfg(feature = "std")]
impl<'a> ByteReader for ReadAdapter<'a> {
impl ByteReader for ReadAdapter<'_> {
#[inline(always)]
fn read_u8(&mut self) -> Result<u8, DeserializationError> {
self.pop()
Expand Down Expand Up @@ -638,7 +638,7 @@ impl<'a> SliceReader<'a> {
}
}

impl<'a> ByteReader for SliceReader<'a> {
impl ByteReader for SliceReader<'_> {
fn read_u8(&mut self) -> Result<u8, DeserializationError> {
self.check_eor(1)?;
let result = self.source[self.pos];
Expand Down
52 changes: 52 additions & 0 deletions utils/maybe_async/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,58 @@ async fn world() -> String {
}
```

## maybe_async_trait

The `maybe_async_trait` macro can be applied to traits, and it will conditionally add the `async` keyword to trait methods annotated with `#[maybe_async]`, depending on the async feature being enabled. It also applies `#[async_trait::async_trait(?Send)]` to the trait or impl block when the async feature is on.

For example:

```rust
// Adding `maybe_async_trait` to a trait definition
#[maybe_async_trait]
trait ExampleTrait {
#[maybe_async]
fn hello_world(&self);

fn get_hello(&self) -> String;
}

// Adding `maybe_async_trait` to an implementation of the trait
#[maybe_async_trait]
impl ExampleTrait for MyStruct {
#[maybe_async]
fn hello_world(&self) {
// ...
}

fn get_hello(&self) -> String {
// ...
}
}
```

When `async` is set, it gets transformed into:

```rust
#[async_trait::async_trait(?Send)]
trait ExampleTrait {
async fn hello_world(&self);

fn get_hello(&self) -> String;
}

#[async_trait::async_trait(?Send)]
impl ExampleTrait for MyStruct {
async fn hello_world(&self) {
// ...
}

fn get_hello(&self) -> String {
// ...
}
}
```

## License

This project is [MIT licensed](../../LICENSE).
Loading

0 comments on commit e92e541

Please sign in to comment.