Skip to content
This repository has been archived by the owner on Oct 3, 2024. It is now read-only.

Commit

Permalink
Merge pull request #23 from bytemare/add-scalar-ops
Browse files Browse the repository at this point in the history
Add some scalar operations
  • Loading branch information
bytemare authored Dec 29, 2022
2 parents 95d00ba + 9b7be40 commit 39c00b2
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 13 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ jobs:

analyze:
name: Analyze
if: github.event_name == 'push'
runs-on: ubuntu-latest
steps:
- name: Checkout repo
Expand Down
61 changes: 49 additions & 12 deletions internal/nist/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package nist

import (
"crypto/subtle"
"errors"
"fmt"
"math/big"
Expand Down Expand Up @@ -50,19 +51,19 @@ func (s *Scalar) assert(scalar internal.Scalar) *Scalar {
return _sc
}

// Zero sets the scalar to 0, and returns it.
// Zero sets s to 0, and returns it.
func (s *Scalar) Zero() internal.Scalar {
s.s.Set(zero)
return s
}

// One sets the scalar to 1, and returns it.
// One sets s to 1, and returns it.
func (s *Scalar) One() internal.Scalar {
s.s.Set(one)
return s
}

// Random sets the current scalar to a new random scalar and returns it.
// Random sets s to a new random scalar and returns it.
// The random source is crypto/rand, and this functions is guaranteed to return a non-zero scalar.
func (s *Scalar) Random() internal.Scalar {
for {
Expand All @@ -74,7 +75,7 @@ func (s *Scalar) Random() internal.Scalar {
}
}

// Add returns the sum of the scalars, and does not change the receiver.
// Add returns s+scalar, and returns s.
func (s *Scalar) Add(scalar internal.Scalar) internal.Scalar {
if scalar == nil {
return s
Expand All @@ -86,7 +87,7 @@ func (s *Scalar) Add(scalar internal.Scalar) internal.Scalar {
return s
}

// Subtract returns the difference between the scalars, and does not change the receiver.
// Subtract returns s-scalar, and returns s.
func (s *Scalar) Subtract(scalar internal.Scalar) internal.Scalar {
if scalar == nil {
return s
Expand All @@ -98,7 +99,7 @@ func (s *Scalar) Subtract(scalar internal.Scalar) internal.Scalar {
return s
}

// Multiply returns the multiplication of the scalars, and does not change the receiver.
// Multiply sets s to s*scalar, and returns s.
func (s *Scalar) Multiply(scalar internal.Scalar) internal.Scalar {
if scalar == nil {
return s.Zero()
Expand All @@ -110,26 +111,62 @@ func (s *Scalar) Multiply(scalar internal.Scalar) internal.Scalar {
return s
}

// Invert returns the scalar's modular inverse ( 1 / scalar ), and does not change the receiver.
// Pow sets s to s**scalar modulo the group order, and returns s. If scalar is nil, it returns 1.
func (s *Scalar) Pow(scalar internal.Scalar) internal.Scalar {
if scalar == nil || scalar.IsZero() {
return s.One()
}

if scalar.Equal(scalar.Copy().One()) == 1 {
return s
}

sc := s.assert(scalar)
s.field.exponent(&s.s, &s.s, &sc.s)

return s
}

// Invert sets s to its modular inverse ( 1 / s ).
func (s *Scalar) Invert() internal.Scalar {
s.field.inv(&s.s, &s.s)
return s
}

// Equal returns 1 if the scalars are equal, and 0 otherwise.
// Equal returns 1 if the s == scalar are equal, and 0 otherwise.
func (s *Scalar) Equal(scalar internal.Scalar) int {
if scalar == nil {
return 0
}

sc := s.assert(scalar)

switch s.s.Cmp(&sc.s) {
case 0:
return 1
default:
return subtle.ConstantTimeCompare(s.s.Bytes(), sc.s.Bytes())
}

// LessOrEqual returns 1 if s <= scalar, and 0 otherwise.
func (s *Scalar) LessOrEqual(scalar internal.Scalar) int {
sc := s.assert(scalar)

ienc := s.Encode()
jenc := sc.Encode()

leni := len(ienc)
if leni != len(jenc) {
panic(internal.ErrParamScalarLength)
}

var res bool

for i := 0; i < leni; i++ {
res = res || (ienc[i] > jenc[i])
}

if res {
return 0
}

return 1
}

// IsZero returns whether the scalar is 0.
Expand Down
46 changes: 46 additions & 0 deletions internal/ristretto/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ func (s *Scalar) Multiply(scalar internal.Scalar) internal.Scalar {
return s
}

// Pow sets s to s**scalar modulo the group order, and returns s. If scalar is nil, it returns 1.
func (s *Scalar) Pow(scalar internal.Scalar) internal.Scalar {
if scalar == nil || scalar.IsZero() {
return s.One()
}

if scalar.Equal(scalar.Copy().One()) == 1 {
return s
}

sc := assert(scalar)
sc.Subtract(&scOne)

for !sc.IsZero() {
s.Multiply(s)
sc.Subtract(&scOne)
}

return s
}

// Invert sets the receiver to the scalar's modular inverse ( 1 / scalar ), and returns it.
func (s *Scalar) Invert() internal.Scalar {
s.scalar.Invert(&s.scalar)
Expand All @@ -123,6 +144,31 @@ func (s *Scalar) Equal(scalar internal.Scalar) int {
return s.scalar.Equal(&sc.scalar)
}

// LessOrEqual returns 1 if s <= scalar and 0 otherwise.
func (s *Scalar) LessOrEqual(scalar internal.Scalar) int {
sc := assert(scalar)

ienc := s.Encode()
jenc := sc.Encode()

i := len(ienc)
if i != len(jenc) {
panic(internal.ErrParamScalarLength)
}

var res bool

for i--; i >= 0; i-- {
res = res || (ienc[i] > jenc[i])
}

if res {
return 0
}

return 1
}

// IsZero returns whether the scalar is 0.
func (s *Scalar) IsZero() bool {
return s.scalar.Equal(ristretto255.NewScalar().Zero()) == 1
Expand Down
6 changes: 6 additions & 0 deletions internal/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ type Scalar interface {
// Multiply multiplies the receiver with the input, and returns the receiver.
Multiply(Scalar) Scalar

// Pow sets s to s**scalar modulo the group order, and returns s. If scalar is nil, it returns 1.
Pow(scalar Scalar) Scalar

// Invert sets the receiver to the scalar's modular inverse ( 1 / scalar ), and returns it.
Invert() Scalar

// Equal returns 1 if the scalars are equal, and 0 otherwise.
Equal(Scalar) int

// LessOrEqual returns 1 if s <= scalar, and 0 otherwise.
LessOrEqual(scalar Scalar) int

// IsZero returns whether the scalar is 0.
IsZero() bool

Expand Down
20 changes: 20 additions & 0 deletions scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ func (s *Scalar) Multiply(scalar *Scalar) *Scalar {
return s
}

// Pow sets s to s**scalar modulo the group order, and returns s. If scalar is nil, it returns 1.
func (s *Scalar) Pow(scalar *Scalar) *Scalar {
if scalar == nil {
return s.One()
}

s.Scalar.Pow(scalar.Scalar)

return s
}

// Invert sets the receiver to the scalar's modular inverse ( 1 / scalar ), and returns it.
func (s *Scalar) Invert() *Scalar {
s.Scalar.Invert()
Expand All @@ -91,6 +102,15 @@ func (s *Scalar) Equal(scalar *Scalar) int {
return s.Scalar.Equal(scalar.Scalar)
}

// LessOrEqual returns 1 if s <= scalar, and 0 otherwise.
func (s *Scalar) LessOrEqual(scalar *Scalar) int {
if scalar == nil {
return 0
}

return s.Scalar.LessOrEqual(scalar.Scalar)
}

// IsZero returns whether the scalar is 0.
func (s *Scalar) IsZero() bool {
return s.Scalar.IsZero()
Expand Down
59 changes: 59 additions & 0 deletions tests/scalar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ func TestScalar_Arithmetic(t *testing.T) {
scalarTestZero(t, group.id)
scalarTestOne(t, group.id)
scalarTestEqual(t, group.id)
scalarTestLessOrEqual(t, group.id)
scalarTestRandom(t, group.id)
scalarTestAdd(t, group.id)
scalarTestSubtract(t, group.id)
scalarTestMultiply(t, group.id)
scalarTestPow(t, group.id)
scalarTestInvert(t, group.id)
})
}
Expand Down Expand Up @@ -183,6 +185,32 @@ func scalarTestEqual(t *testing.T, g crypto.Group) {
}
}

func scalarTestLessOrEqual(t *testing.T, g crypto.Group) {
zero := g.NewScalar().Zero()
one := g.NewScalar().One()
two := g.NewScalar().One().Add(one)

if zero.LessOrEqual(one) != 1 {
t.Fatal("expected 0 < 1")
}

if one.LessOrEqual(two) != 1 {
t.Fatal("expected 1 < 2")
}

if one.LessOrEqual(zero) == 1 {
t.Fatal("expected 1 > 0")
}

if two.LessOrEqual(one) == 1 {
t.Fatal("expected 2 > 1")
}

if two.LessOrEqual(two) != 1 {
t.Fatal("expected 2 == 2")
}
}

func scalarTestAdd(t *testing.T, g crypto.Group) {
r := g.NewScalar().Random()
cpy := r.Copy()
Expand All @@ -206,6 +234,37 @@ func scalarTestMultiply(t *testing.T, g crypto.Group) {
}
}

func scalarTestPow(t *testing.T, g crypto.Group) {
// s**nil = 1
s := g.NewScalar().Random()
if s.Pow(nil).Equal(g.NewScalar().One()) != 1 {
t.Fatal("expected s**nil = 1")
}

// s**0 = 1
s = g.NewScalar().Random()
zero := g.NewScalar().Zero()
if s.Pow(zero).Equal(g.NewScalar().One()) != 1 {
t.Fatal("expected s**0 = 1")
}

// s**1 = s
s = g.NewScalar().Random()
one := g.NewScalar().One()
if s.Copy().Pow(one).Equal(s) != 1 {
t.Fatal("expected s**1 = s")
}

// s**2 = s*s
s = g.NewScalar().Random()
s2 := s.Copy().Multiply(s)
two := g.NewScalar().One().Add(g.NewScalar().One())

if s.Pow(two).Equal(s2) != 1 {
t.Fatal("expected s**2 = s*s")
}
}

func scalarTestInvert(t *testing.T, g crypto.Group) {
s := g.NewScalar().Random()
sqr := s.Copy().Multiply(s)
Expand Down

0 comments on commit 39c00b2

Please sign in to comment.