Skip to content

Commit

Permalink
Optimized linear combination of points (#380)
Browse files Browse the repository at this point in the history
Add `lincomb()` as an alias for a 2-point linear combination
  • Loading branch information
fjarri authored Jul 18, 2021
1 parent 2c12fc5 commit 937a924
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 51 deletions.
11 changes: 10 additions & 1 deletion k256/bench/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use criterion::{
use hex_literal::hex;
use k256::{
elliptic_curve::{generic_array::arr, group::ff::PrimeField},
ProjectivePoint, Scalar,
lincomb, ProjectivePoint, Scalar,
};

fn test_scalar_x() -> Scalar {
Expand Down Expand Up @@ -34,9 +34,18 @@ fn bench_point_mul<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) {
group.bench_function("point-scalar mul", |b| b.iter(|| &p * &s));
}

fn bench_point_lincomb<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) {
let p = ProjectivePoint::generator();
let m = hex!("AA5E28D6A97A2479A65527F7290311A3624D4CC0FA1578598EE3C2613BF99522");
let s = Scalar::from_repr(m.into()).unwrap();
group.bench_function("lincomb via mul+add", |b| b.iter(|| &p * &s + &p * &s));
group.bench_function("lincomb()", |b| b.iter(|| lincomb(&p, &s, &p, &s)));
}

fn bench_high_level(c: &mut Criterion) {
let mut group = c.benchmark_group("high-level operations");
bench_point_mul(&mut group);
bench_point_lincomb(&mut group);
group.finish();
}

Expand Down
1 change: 1 addition & 0 deletions k256/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) mod scalar;
mod util;

pub use field::FieldElement;
pub use mul::lincomb;

use affine::AffinePoint;
use projective::ProjectivePoint;
Expand Down
211 changes: 168 additions & 43 deletions k256/src/arithmetic/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ use core::ops::{Mul, MulAssign};
use elliptic_curve::subtle::{Choice, ConditionallySelectable, ConstantTimeEq};

/// Lookup table containing precomputed values `[p, 2p, 3p, ..., 8p]`
#[derive(Copy, Clone, Default)]
struct LookupTable([ProjectivePoint; 8]);

impl From<&ProjectivePoint> for LookupTable {
Expand Down Expand Up @@ -147,94 +148,218 @@ fn decompose_scalar(k: &Scalar) -> (Scalar, Scalar) {
(r1, r2)
}

/// Returns `[a_0, ..., a_32]` such that `sum(a_j * 2^(j * 4)) == x`,
/// and `-8 <= a_j <= 7`.
/// Assumes `x < 2^128`.
fn to_radix_16_half(x: &Scalar) -> [i8; 33] {
// `x` can have up to 256 bits, so we need an additional byte to store the carry.
let mut output = [0i8; 33];

// Step 1: change radix.
// Convert from radix 256 (bytes) to radix 16 (nibbles)
let bytes = x.to_bytes();
for i in 0..16 {
output[2 * i] = (bytes[31 - i] & 0xf) as i8;
output[2 * i + 1] = ((bytes[31 - i] >> 4) & 0xf) as i8;
}
// This needs to be an object to have Default implemented for it
// (required because it's used in static_map later)
// Otherwise we could just have a function returning an array.
#[derive(Copy, Clone)]
struct Radix16Decomposition([i8; 33]);

impl Radix16Decomposition {
/// Returns an object containing a decomposition
/// `[a_0, ..., a_32]` such that `sum(a_j * 2^(j * 4)) == x`,
/// and `-8 <= a_j <= 7`.
/// Assumes `x < 2^128`.
fn new(x: &Scalar) -> Self {
debug_assert!((x >> 128).is_zero().unwrap_u8() == 1);

// The resulting decomposition can be negative, so, despite the limit on `x`,
// it can have up to 256 bits, and we need an additional byte to store the carry.
let mut output = [0i8; 33];

// Step 1: change radix.
// Convert from radix 256 (bytes) to radix 16 (nibbles)
let bytes = x.to_bytes();
for i in 0..16 {
output[2 * i] = (bytes[31 - i] & 0xf) as i8;
output[2 * i + 1] = ((bytes[31 - i] >> 4) & 0xf) as i8;
}

debug_assert!((x >> 128).is_zero().unwrap_u8() == 1);
// Step 2: recenter coefficients from [0,16) to [-8,8)
for i in 0..32 {
let carry = (output[i] + 8) >> 4;
output[i] -= carry << 4;
output[i + 1] += carry;
}

// Step 2: recenter coefficients from [0,16) to [-8,8)
for i in 0..32 {
let carry = (output[i] + 8) >> 4;
output[i] -= carry << 4;
output[i + 1] += carry;
Self(output)
}

output
}

fn mul_windowed(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint {
let (r1, r2) = decompose_scalar(k);
let x_beta = x.endomorphism();
impl Default for Radix16Decomposition {
fn default() -> Self {
Self([0i8; 33])
}
}

let r1_sign = r1.is_high();
let r1_c = Scalar::conditional_select(&r1, &-r1, r1_sign);
let r2_sign = r2.is_high();
let r2_c = Scalar::conditional_select(&r2, &-r2, r2_sign);
/// Maps an array `x` to an array using the predicate `f`.
/// We can't use the standard `map()` because as of Rust 1.51 we cannot collect into arrays.
/// Consequently, since we cannot have an uninitialized array (without `unsafe`),
/// a default value needs to be provided.
fn static_map<T: Copy, V: Copy, const N: usize>(
f: impl Fn(T) -> V,
x: &[T; N],
default: V,
) -> [V; N] {
let mut res = [default; N];
for i in 0..N {
res[i] = f(x[i]);
}
res
}

let table1 = LookupTable::from(&ProjectivePoint::conditional_select(x, &-x, r1_sign));
let table2 = LookupTable::from(&ProjectivePoint::conditional_select(
&x_beta, &-x_beta, r2_sign,
));
/// Maps two arrays `x` and `y` into an array using a predicate `f` that takes two arguments.
fn static_zip_map<T: Copy, S: Copy, V: Copy, const N: usize>(
f: impl Fn(T, S) -> V,
x: &[T; N],
y: &[S; N],
default: V,
) -> [V; N] {
let mut res = [default; N];
for i in 0..N {
res[i] = f(x[i], y[i]);
}
res
}

let digits1 = to_radix_16_half(&r1_c);
let digits2 = to_radix_16_half(&r2_c);
/// Calculates a linear combination `sum(x[i] * k[i])`, `i = 0..N`
#[inline(always)]
fn lincomb_generic<const N: usize>(xs: &[ProjectivePoint; N], ks: &[Scalar; N]) -> ProjectivePoint {
let rs = static_map(
|k| decompose_scalar(&k),
ks,
(Scalar::default(), Scalar::default()),
);
let r1s = static_map(|(r1, _r2)| r1, &rs, Scalar::default());
let r2s = static_map(|(_r1, r2)| r2, &rs, Scalar::default());

let xs_beta = static_map(|x| x.endomorphism(), xs, ProjectivePoint::default());

let r1_signs = static_map(|r| r.is_high(), &r1s, Choice::from(0u8));
let r2_signs = static_map(|r| r.is_high(), &r2s, Choice::from(0u8));

let r1s_c = static_zip_map(
|r, r_sign| Scalar::conditional_select(&r, &-r, r_sign),
&r1s,
&r1_signs,
Scalar::default(),
);
let r2s_c = static_zip_map(
|r, r_sign| Scalar::conditional_select(&r, &-r, r_sign),
&r2s,
&r2_signs,
Scalar::default(),
);

let tables1 = static_zip_map(
|x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)),
&xs,
&r1_signs,
LookupTable::default(),
);
let tables2 = static_zip_map(
|x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)),
&xs_beta,
&r2_signs,
LookupTable::default(),
);

let digits1 = static_map(
|r| Radix16Decomposition::new(&r),
&r1s_c,
Radix16Decomposition::default(),
);
let digits2 = static_map(
|r| Radix16Decomposition::new(&r),
&r2s_c,
Radix16Decomposition::default(),
);

let mut acc = ProjectivePoint::identity();
for component in 0..N {
acc += &tables1[component].select(digits1[component].0[32]);
acc += &tables2[component].select(digits2[component].0[32]);
}

let mut acc = table1.select(digits1[32]) + table2.select(digits2[32]);
for i in (0..32).rev() {
for _j in 0..4 {
acc = acc.double();
}

acc += &table1.select(digits1[i]);
acc += &table2.select(digits2[i]);
for component in 0..N {
acc += &tables1[component].select(digits1[component].0[i]);
acc += &tables2[component].select(digits2[component].0[i]);
}
}
acc
}

#[inline(always)]
fn mul(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint {
lincomb_generic(&[*x], &[*k])
}

/// Calculates `x * k + y * l`.
pub fn lincomb(
x: &ProjectivePoint,
k: &Scalar,
y: &ProjectivePoint,
l: &Scalar,
) -> ProjectivePoint {
lincomb_generic(&[*x, *y], &[*k, *l])
}

impl Mul<Scalar> for ProjectivePoint {
type Output = ProjectivePoint;

fn mul(self, other: Scalar) -> ProjectivePoint {
mul_windowed(&self, &other)
mul(&self, &other)
}
}

impl Mul<&Scalar> for &ProjectivePoint {
type Output = ProjectivePoint;

fn mul(self, other: &Scalar) -> ProjectivePoint {
mul_windowed(self, other)
mul(self, other)
}
}

impl Mul<&Scalar> for ProjectivePoint {
type Output = ProjectivePoint;

fn mul(self, other: &Scalar) -> ProjectivePoint {
mul_windowed(&self, other)
mul(&self, other)
}
}

impl MulAssign<Scalar> for ProjectivePoint {
fn mul_assign(&mut self, rhs: Scalar) {
*self = mul_windowed(self, &rhs);
*self = mul(self, &rhs);
}
}

impl MulAssign<&Scalar> for ProjectivePoint {
fn mul_assign(&mut self, rhs: &Scalar) {
*self = mul_windowed(self, rhs);
*self = mul(self, rhs);
}
}

#[cfg(test)]
mod tests {
use super::lincomb;
use crate::arithmetic::{ProjectivePoint, Scalar};
use elliptic_curve::rand_core::OsRng;
use elliptic_curve::{Field, Group};

#[test]
fn test_lincomb() {
let x = ProjectivePoint::random(&mut OsRng);
let y = ProjectivePoint::random(&mut OsRng);
let k = Scalar::random(&mut OsRng);
let l = Scalar::random(&mut OsRng);

let reference = &x * &k + &y * &l;
let test = lincomb(&x, &k, &y, &l);
assert_eq!(reference, test);
}
}
4 changes: 2 additions & 2 deletions k256/src/ecdsa/recoverable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use crate::{
consts::U32, generic_array::GenericArray, ops::Invert, subtle::Choice,
weierstrass::DecompressPoint,
},
AffinePoint, FieldBytes, NonZeroScalar, ProjectivePoint, Scalar,
lincomb, AffinePoint, FieldBytes, NonZeroScalar, ProjectivePoint, Scalar,
};

#[cfg(feature = "keccak256")]
Expand Down Expand Up @@ -185,7 +185,7 @@ impl Signature {
let r_inv = r.invert().unwrap();
let u1 = -(r_inv * z);
let u2 = r_inv * *s;
let pk = ((ProjectivePoint::generator() * u1) + (R * u2)).to_affine();
let pk = lincomb(&ProjectivePoint::generator(), &u1, &R, &u2).to_affine();

// TODO(tarcieri): ensure the signature verifies?
Ok(VerifyingKey::from(&pk))
Expand Down
14 changes: 10 additions & 4 deletions k256/src/ecdsa/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
use super::{recoverable, Error, Signature};
use crate::{
AffinePoint, CompressedPoint, EncodedPoint, ProjectivePoint, PublicKey, Scalar, Secp256k1,
lincomb, AffinePoint, CompressedPoint, EncodedPoint, ProjectivePoint, PublicKey, Scalar,
Secp256k1,
};
use core::convert::TryFrom;
use ecdsa_core::{hazmat::VerifyPrimitive, signature};
Expand Down Expand Up @@ -90,9 +91,14 @@ impl VerifyPrimitive<Secp256k1> for AffinePoint {
let u1 = z * &s_inv;
let u2 = *r * s_inv;

let x = ((ProjectivePoint::generator() * u1) + (ProjectivePoint::from(*self) * u2))
.to_affine()
.x;
let x = lincomb(
&ProjectivePoint::generator(),
&u1,
&ProjectivePoint::from(*self),
&u2,
)
.to_affine()
.x;

if Scalar::from_bytes_reduced(&x.to_bytes()).eq(&r) {
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion k256/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub mod test_vectors;
pub use elliptic_curve::{self, bigint::U256};

#[cfg(feature = "arithmetic")]
pub use arithmetic::{affine::AffinePoint, projective::ProjectivePoint, scalar::Scalar};
pub use arithmetic::{affine::AffinePoint, lincomb, projective::ProjectivePoint, scalar::Scalar};

#[cfg(feature = "expose-field")]
pub use arithmetic::FieldElement;
Expand Down

0 comments on commit 937a924

Please sign in to comment.