Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hpke: Update HPKE code to use ecdh stdlib package. #530

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions hpke/algs.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/elliptic"
"crypto/ecdh"
_ "crypto/sha256" // Linking sha256.
_ "crypto/sha512" // Linking sha512.
"fmt"
Expand All @@ -13,7 +13,6 @@ import (

"github.com/cloudflare/circl/dh/x25519"
"github.com/cloudflare/circl/dh/x448"
"github.com/cloudflare/circl/ecc/p384"
"github.com/cloudflare/circl/kem"
"github.com/cloudflare/circl/kem/kyber/kyber768"
"github.com/cloudflare/circl/kem/xwing"
Expand Down Expand Up @@ -247,19 +246,19 @@ var (
)

func init() {
dhkemp256hkdfsha256.Curve = elliptic.P256()
dhkemp256hkdfsha256.Curve = ecdh.P256()
dhkemp256hkdfsha256.dhKemBase.id = KEM_P256_HKDF_SHA256
dhkemp256hkdfsha256.dhKemBase.name = "HPKE_KEM_P256_HKDF_SHA256"
dhkemp256hkdfsha256.dhKemBase.Hash = crypto.SHA256
dhkemp256hkdfsha256.dhKemBase.dhKEM = dhkemp256hkdfsha256

dhkemp384hkdfsha384.Curve = p384.P384()
dhkemp384hkdfsha384.Curve = ecdh.P384()
dhkemp384hkdfsha384.dhKemBase.id = KEM_P384_HKDF_SHA384
dhkemp384hkdfsha384.dhKemBase.name = "HPKE_KEM_P384_HKDF_SHA384"
dhkemp384hkdfsha384.dhKemBase.Hash = crypto.SHA384
dhkemp384hkdfsha384.dhKemBase.dhKEM = dhkemp384hkdfsha384

dhkemp521hkdfsha512.Curve = elliptic.P521()
dhkemp521hkdfsha512.Curve = ecdh.P521()
dhkemp521hkdfsha512.dhKemBase.id = KEM_P521_HKDF_SHA512
dhkemp521hkdfsha512.dhKemBase.name = "HPKE_KEM_P521_HKDF_SHA512"
dhkemp521hkdfsha512.dhKemBase.Hash = crypto.SHA512
Expand Down
1 change: 1 addition & 0 deletions hpke/hpke.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,5 +278,6 @@ var (
ErrInvalidKEMPublicKey = errors.New("hpke: invalid KEM public key")
ErrInvalidKEMPrivateKey = errors.New("hpke: invalid KEM private key")
ErrInvalidKEMSharedSecret = errors.New("hpke: invalid KEM shared secret")
ErrInvalidKEMDeriveKey = errors.New("hpke: too many tries to derive KEM key")
ErrAEADSeqOverflows = errors.New("hpke: AEAD sequence number overflows")
)
3 changes: 3 additions & 0 deletions hpke/hpke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ func BenchmarkHpkeRoundTrip(b *testing.B) {
kdf hpke.KDF
aead hpke.AEAD
}{
{hpke.KEM_P256_HKDF_SHA256, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM},
{hpke.KEM_P384_HKDF_SHA384, hpke.KDF_HKDF_SHA384, hpke.AEAD_AES256GCM},
{hpke.KEM_P521_HKDF_SHA512, hpke.KDF_HKDF_SHA512, hpke.AEAD_AES256GCM},
{hpke.KEM_X25519_HKDF_SHA256, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM},
{hpke.KEM_X25519_KYBER768_DRAFT00, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM},
{hpke.KEM_XWING, hpke.KDF_HKDF_SHA256, hpke.AEAD_AES128GCM},
Expand Down
155 changes: 68 additions & 87 deletions hpke/shortkem.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
package hpke

import (
"crypto/elliptic"
"crypto/ecdh"
"crypto/rand"
"crypto/subtle"
"fmt"
"math/big"

"github.com/cloudflare/circl/kem"
)

type shortKEM struct {
dhKemBase
elliptic.Curve
ecdh.Curve
}

func (s shortKEM) PrivateKeySize() int { return s.byteSize() }
Expand All @@ -21,19 +19,40 @@ func (s shortKEM) CiphertextSize() int { return 1 + 2*s.byteSize() }
func (s shortKEM) PublicKeySize() int { return 1 + 2*s.byteSize() }
func (s shortKEM) EncapsulationSeedSize() int { return s.byteSize() }

func (s shortKEM) byteSize() int { return (s.Params().BitSize + 7) / 8 }
func (s shortKEM) byteSize() int {
var bits int
switch s.Curve {
case ecdh.P256():
bits = 256
case ecdh.P384():
bits = 384
case ecdh.P521():
bits = 521
default:
panic(ErrInvalidKEM)
}

return (bits + 7) / 8
}

func (s shortKEM) sizeDH() int { return s.byteSize() }
func (s shortKEM) calcDH(dh []byte, sk kem.PrivateKey, pk kem.PublicKey) error {
PK := pk.(*shortKEMPubKey)
SK := sk.(*shortKEMPrivKey)
l := len(dh)
x, _ := s.ScalarMult(PK.x, PK.y, SK.priv) // only x-coordinate is used.
if x.Sign() == 0 {
return ErrInvalidKEMSharedSecret
PK, ok := pk.(*shortKEMPubKey)
if !ok {
return ErrInvalidKEMPublicKey
}

SK, ok := sk.(*shortKEMPrivKey)
if !ok {
return ErrInvalidKEMPrivateKey
}

x, err := SK.priv.ECDH(&PK.pub)
if err != nil {
return err
}
b := x.Bytes()
copy(dh[l-len(b):l], b)

copy(dh, x)
return nil
}

Expand All @@ -49,122 +68,84 @@ func (s shortKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
}

bitmask := byte(0xFF)
if s.Params().BitSize == 521 {
if s.Curve == ecdh.P521() {
bitmask = 0x01
}

dkpPrk := s.labeledExtract([]byte(""), []byte("dkp_prk"), seed)
var bytes []byte
ctr := 0
for skBig := new(big.Int); skBig.Sign() == 0 || skBig.Cmp(s.Params().N) >= 0; ctr++ {
if ctr > 255 {
panic("derive key error")
}
bytes = s.labeledExpand(
for ctr := 0; ctr <= 255; ctr++ {
bytes := s.labeledExpand(
dkpPrk,
[]byte("candidate"),
[]byte{byte(ctr)},
uint16(s.byteSize()),
)
bytes[0] &= bitmask
skBig.SetBytes(bytes)
sk, err := s.UnmarshalBinaryPrivateKey(bytes)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirmed that Curve.NewPrivateKey (called by UnmarshalBinaryPrivateKey) checks that the value is lower than the order of the curve, so is equivalent to the existing check.

if err == nil {
return sk.Public(), sk
}
}
l := s.PrivateKeySize()
sk := &shortKEMPrivKey{s, make([]byte, l), nil}
copy(sk.priv[l-len(bytes):], bytes)
return sk.Public(), sk

panic(ErrInvalidKEMDeriveKey)
}

func (s shortKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
sk, x, y, err := elliptic.GenerateKey(s, rand.Reader)
pub := &shortKEMPubKey{s, x, y}
return pub, &shortKEMPrivKey{s, sk, pub}, err
key, err := s.Curve.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, err
}

sk := &shortKEMPrivKey{s, key}
return sk.Public(), sk, err
}

func (s shortKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
l := s.PrivateKeySize()
if len(data) < l {
return nil, ErrInvalidKEMPrivateKey
}
sk := &shortKEMPrivKey{s, make([]byte, l), nil}
copy(sk.priv[l-len(data):l], data[:l])
if !sk.validate() {
return nil, ErrInvalidKEMPrivateKey
key, err := s.Curve.NewPrivateKey(data)
if err != nil {
return nil, err
}

return sk, nil
return &shortKEMPrivKey{s, key}, nil
}

func (s shortKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
x, y := elliptic.Unmarshal(s, data)
if x == nil {
return nil, ErrInvalidKEMPublicKey
}
key := &shortKEMPubKey{s, x, y}
if !key.validate() {
return nil, ErrInvalidKEMPublicKey
key, err := s.Curve.NewPublicKey(data)
if err != nil {
return nil, err
}
return key, nil

return &shortKEMPubKey{s, *key}, nil
}

type shortKEMPubKey struct {
scheme shortKEM
x, y *big.Int
pub ecdh.PublicKey
}

func (k *shortKEMPubKey) String() string {
return fmt.Sprintf("x: %v\ny: %v", k.x.Text(16), k.y.Text(16))
}
func (k *shortKEMPubKey) Scheme() kem.Scheme { return k.scheme }
func (k *shortKEMPubKey) MarshalBinary() ([]byte, error) {
return elliptic.Marshal(k.scheme, k.x, k.y), nil
}
func (k *shortKEMPubKey) String() string { return fmt.Sprintf("%x", k.pub.Bytes()) }
func (k *shortKEMPubKey) Scheme() kem.Scheme { return k.scheme }
func (k *shortKEMPubKey) MarshalBinary() ([]byte, error) { return k.pub.Bytes(), nil }

func (k *shortKEMPubKey) Equal(pk kem.PublicKey) bool {
k1, ok := pk.(*shortKEMPubKey)
return ok &&
k.scheme.Params().Name == k1.scheme.Params().Name &&
k.x.Cmp(k1.x) == 0 &&
k.y.Cmp(k1.y) == 0
}

func (k *shortKEMPubKey) validate() bool {
p := k.scheme.Params().P
notAtInfinity := k.x.Sign() > 0 && k.y.Sign() > 0
lessThanP := k.x.Cmp(p) < 0 && k.y.Cmp(p) < 0
onCurve := k.scheme.IsOnCurve(k.x, k.y)
return notAtInfinity && lessThanP && onCurve
return ok && k.scheme == k1.scheme && k.pub.Equal(&k1.pub)
}

type shortKEMPrivKey struct {
scheme shortKEM
priv []byte
pub *shortKEMPubKey
priv *ecdh.PrivateKey
}

func (k *shortKEMPrivKey) String() string { return fmt.Sprintf("%x", k.priv) }
func (k *shortKEMPrivKey) Scheme() kem.Scheme { return k.scheme }
func (k *shortKEMPrivKey) MarshalBinary() ([]byte, error) {
return append(make([]byte, 0, k.scheme.PrivateKeySize()), k.priv...), nil
}
func (k *shortKEMPrivKey) String() string { return fmt.Sprintf("%x", k.priv.Bytes()) }
func (k *shortKEMPrivKey) Scheme() kem.Scheme { return k.scheme }
func (k *shortKEMPrivKey) MarshalBinary() ([]byte, error) { return k.priv.Bytes(), nil }

func (k *shortKEMPrivKey) Equal(pk kem.PrivateKey) bool {
k1, ok := pk.(*shortKEMPrivKey)
return ok &&
k.scheme.Params().Name == k1.scheme.Params().Name &&
subtle.ConstantTimeCompare(k.priv, k1.priv) == 1
return ok && k.scheme == k1.scheme && k.priv.Equal(k1.priv)
}

func (k *shortKEMPrivKey) Public() kem.PublicKey {
if k.pub == nil {
x, y := k.scheme.ScalarBaseMult(k.priv)
k.pub = &shortKEMPubKey{k.scheme, x, y}
}
return k.pub
}

func (k *shortKEMPrivKey) validate() bool {
n := new(big.Int).SetBytes(k.priv)
order := k.scheme.Curve.Params().N
return len(k.priv) == k.scheme.PrivateKeySize() && n.Cmp(order) < 0
return &shortKEMPubKey{k.scheme, *k.priv.PublicKey()}
}
Loading