From 0a3d662f20155f84fd831a4637e512609943bcd5 Mon Sep 17 00:00:00 2001 From: Eric Lagergren Date: Sat, 16 Apr 2022 13:45:46 -0700 Subject: [PATCH] arm64: implement AES-CTR in assembly Signed-off-by: Eric Lagergren --- README.md | 14 +- asm_amd64.s => aes_amd64.s | 0 asm_arm64.s => aes_arm64.s | 0 ctr_arm64.s | 392 +++++++++++++++++++++++++++++++++++++ fuzz_test.go | 11 +- siv_asm.go => siv_amd64.go | 20 +- siv_arm64.go | 180 +++++++++++++++++ siv_test.go | 111 ++++++++++- stub_arm64.go | 11 ++ 9 files changed, 709 insertions(+), 30 deletions(-) rename asm_amd64.s => aes_amd64.s (100%) rename asm_arm64.s => aes_arm64.s (100%) create mode 100644 ctr_arm64.s rename siv_asm.go => siv_amd64.go (93%) create mode 100644 siv_arm64.go diff --git a/README.md b/README.md index 29610d4..d91d900 100644 --- a/README.md +++ b/README.md @@ -16,19 +16,7 @@ go get github.com/ericlagergren/siv@latest ## Performance -The performance of HCTR2 is determined by two things: AES-CTR and -POLYVAL. This module provides ARMv8 and x86-64 assembly AES-CTR -implementations and uses a hardware-accelerated POLYVAL -implementation (see [github.com/ericlagergren/polyval](https://pkg.go.dev/github.com/ericlagergren/polyval)). - -The ARMv8 assembly implementation of AES-CTR-256 with -hardware-accelerated POLYVAL runs at about X cycle per byte. - -The x86-64 assembly implementation of AES-CTR-256 with -hardware-accelerated POLYVAL runs at about X cycles per byte. - -The `crypto/aes` implementation of AES-CTR-256 with -hardware-accelerated POLYVAL runs at about X cycles per byte. +TBD ## Security diff --git a/asm_amd64.s b/aes_amd64.s similarity index 100% rename from asm_amd64.s rename to aes_amd64.s diff --git a/asm_arm64.s b/aes_arm64.s similarity index 100% rename from asm_arm64.s rename to aes_arm64.s diff --git a/ctr_arm64.s b/ctr_arm64.s new file mode 100644 index 0000000..9dbb68f --- /dev/null +++ b/ctr_arm64.s @@ -0,0 +1,392 @@ +//go:build gc && !purego + +#include "textflag.h" +#include "go_asm.h" + +// AESE_AESMC performs AESE and AESMC. +// +// The instructions are paried to take advantage of instruction +// fusion. +#define AESE_AESMC(rk, v) \ + AESE rk.B16, v.B16 \ + AESMC v.B16, v.B16 + +#define ENCRYPT256x1(v0, rk1, rk2, rk3, rk4) \ + AESE_AESMC(rk1, v0) \ + AESE_AESMC(rk2, v0) \ + AESE_AESMC(rk3, v0) \ + AESE_AESMC(rk4, v0) + +#define ENCRYPT128x1(v0, rk5, rk6, rk7, rk8, rk9, rk10, rk11, rk12, rk13, rk14, rk15) \ + AESE_AESMC(rk5, v0) \ + AESE_AESMC(rk6, v0) \ + AESE_AESMC(rk7, v0) \ + AESE_AESMC(rk8, v0) \ + AESE_AESMC(rk9, v0) \ + AESE_AESMC(rk10, v0) \ + AESE_AESMC(rk11, v0) \ + AESE_AESMC(rk12, v0) \ + AESE_AESMC(rk13, v0) \ + AESE rk14.B16, v0.B16 \ + VEOR v0.B16, rk15.B16, v0.B16 + +#define ENCRYPT256x8(v0, v1, v2, v3, v4, v5, v6, v7, rk1, rk2, rk3, rk4) \ + AESE_AESMC(rk1, v0) \ + AESE_AESMC(rk1, v1) \ + AESE_AESMC(rk1, v2) \ + AESE_AESMC(rk1, v3) \ + AESE_AESMC(rk1, v4) \ + AESE_AESMC(rk1, v5) \ + AESE_AESMC(rk1, v6) \ + AESE_AESMC(rk1, v7) \ + \ + AESE_AESMC(rk2, v0) \ + AESE_AESMC(rk2, v1) \ + AESE_AESMC(rk2, v2) \ + AESE_AESMC(rk2, v3) \ + AESE_AESMC(rk2, v4) \ + AESE_AESMC(rk2, v5) \ + AESE_AESMC(rk2, v6) \ + AESE_AESMC(rk2, v7) \ + \ + AESE_AESMC(rk3, v0) \ + AESE_AESMC(rk3, v1) \ + AESE_AESMC(rk3, v2) \ + AESE_AESMC(rk3, v3) \ + AESE_AESMC(rk3, v4) \ + AESE_AESMC(rk3, v5) \ + AESE_AESMC(rk3, v6) \ + AESE_AESMC(rk3, v7) \ + \ + AESE_AESMC(rk4, v0) \ + AESE_AESMC(rk4, v1) \ + AESE_AESMC(rk4, v2) \ + AESE_AESMC(rk4, v3) \ + AESE_AESMC(rk4, v4) \ + AESE_AESMC(rk4, v5) \ + AESE_AESMC(rk4, v6) \ + AESE_AESMC(rk4, v7) + +#define ENCRYPT128x8(v0, v1, v2, v3, v4, v5, v6, v7, rk5, rk6, rk7, rk8, rk9, rk10, rk11, rk12, rk13, rk14, rk15) \ + AESE_AESMC(rk5, v0) \ + AESE_AESMC(rk5, v1) \ + AESE_AESMC(rk5, v2) \ + AESE_AESMC(rk5, v3) \ + AESE_AESMC(rk5, v4) \ + AESE_AESMC(rk5, v5) \ + AESE_AESMC(rk5, v6) \ + AESE_AESMC(rk5, v7) \ + \ + AESE_AESMC(rk6, v0) \ + AESE_AESMC(rk6, v1) \ + AESE_AESMC(rk6, v2) \ + AESE_AESMC(rk6, v3) \ + AESE_AESMC(rk6, v4) \ + AESE_AESMC(rk6, v5) \ + AESE_AESMC(rk6, v6) \ + AESE_AESMC(rk6, v7) \ + \ + AESE_AESMC(rk7, v0) \ + AESE_AESMC(rk7, v1) \ + AESE_AESMC(rk7, v2) \ + AESE_AESMC(rk7, v3) \ + AESE_AESMC(rk7, v4) \ + AESE_AESMC(rk7, v5) \ + AESE_AESMC(rk7, v6) \ + AESE_AESMC(rk7, v7) \ + \ + AESE_AESMC(rk8, v0) \ + AESE_AESMC(rk8, v1) \ + AESE_AESMC(rk8, v2) \ + AESE_AESMC(rk8, v3) \ + AESE_AESMC(rk8, v4) \ + AESE_AESMC(rk8, v5) \ + AESE_AESMC(rk8, v6) \ + AESE_AESMC(rk8, v7) \ + \ + AESE_AESMC(rk9, v0) \ + AESE_AESMC(rk9, v1) \ + AESE_AESMC(rk9, v2) \ + AESE_AESMC(rk9, v3) \ + AESE_AESMC(rk9, v4) \ + AESE_AESMC(rk9, v5) \ + AESE_AESMC(rk9, v6) \ + AESE_AESMC(rk9, v7) \ + \ + AESE_AESMC(rk10, v0) \ + AESE_AESMC(rk10, v1) \ + AESE_AESMC(rk10, v2) \ + AESE_AESMC(rk10, v3) \ + AESE_AESMC(rk10, v4) \ + AESE_AESMC(rk10, v5) \ + AESE_AESMC(rk10, v6) \ + AESE_AESMC(rk10, v7) \ + \ + AESE_AESMC(rk11, v0) \ + AESE_AESMC(rk11, v1) \ + AESE_AESMC(rk11, v2) \ + AESE_AESMC(rk11, v3) \ + AESE_AESMC(rk11, v4) \ + AESE_AESMC(rk11, v5) \ + AESE_AESMC(rk11, v6) \ + AESE_AESMC(rk11, v7) \ + \ + AESE_AESMC(rk12, v0) \ + AESE_AESMC(rk12, v1) \ + AESE_AESMC(rk12, v2) \ + AESE_AESMC(rk12, v3) \ + AESE_AESMC(rk12, v4) \ + AESE_AESMC(rk12, v5) \ + AESE_AESMC(rk12, v6) \ + AESE_AESMC(rk12, v7) \ + \ + AESE_AESMC(rk13, v0) \ + AESE_AESMC(rk13, v1) \ + AESE_AESMC(rk13, v2) \ + AESE_AESMC(rk13, v3) \ + AESE_AESMC(rk13, v4) \ + AESE_AESMC(rk13, v5) \ + AESE_AESMC(rk13, v6) \ + AESE_AESMC(rk13, v7) \ + \ + AESE rk14.B16, v0.B16 \ + AESE rk14.B16, v1.B16 \ + AESE rk14.B16, v2.B16 \ + AESE rk14.B16, v3.B16 \ + AESE rk14.B16, v4.B16 \ + AESE rk14.B16, v5.B16 \ + AESE rk14.B16, v6.B16 \ + AESE rk14.B16, v7.B16 \ + \ + VEOR v0.B16, rk15.B16, v0.B16 \ + VEOR v1.B16, rk15.B16, v1.B16 \ + VEOR v2.B16, rk15.B16, v2.B16 \ + VEOR v3.B16, rk15.B16, v3.B16 \ + VEOR v4.B16, rk15.B16, v4.B16 \ + VEOR v5.B16, rk15.B16, v5.B16 \ + VEOR v6.B16, rk15.B16, v6.B16 \ + VEOR v7.B16, rk15.B16, v7.B16 + +// func aesctrAsm(nr int, enc *uint32, iv *[blockSize]byte, dst, src *byte, nblocks int) +TEXT ·aesctrAsm(SB), NOSPLIT, $0-48 +#define nrounds R0 +#define enc_ptr R1 +#define dst_ptr R2 +#define src_ptr R3 +#define remain R4 +#define block_ptr R5 +#define nwide R6 +#define nsingle R7 + +#define idx0 R8 +#define idx1 R9 +#define idx2 R10 +#define idx3 R11 +#define idx4 R12 +#define idx5 R13 +#define idx6 R14 +#define idx7 R15 + +#define block V0 + +#define rk1 V1 +#define rk2 V2 +#define rk3 V3 +#define rk4 V4 +#define rk5 V5 +#define rk6 V6 +#define rk7 V7 +#define rk8 V8 +#define rk9 V9 +#define rk10 V10 +#define rk11 V11 +#define rk12 V12 +#define rk13 V13 +#define rk14 V14 +#define rk15 V15 + +#define src0 V16 +#define src1 V17 +#define src2 V18 +#define src3 V19 +#define src4 V20 +#define src5 V21 +#define src6 V22 +#define src7 V23 + +#define ks0 V24 +#define ks1 V25 +#define ks2 V26 +#define ks3 V27 +#define ks4 V28 +#define ks5 V29 +#define ks6 V30 +#define ks7 V31 + + MOVD nr+0(FP), nrounds + MOVD enc+8(FP), enc_ptr + MOVD nblocks+40(FP), remain + MOVD iv+16(FP), block_ptr + MOVD dst+24(FP), dst_ptr + MOVD src+32(FP), src_ptr + + VLD1 (block_ptr), [block.B16] + +loadKeys: + CMP $12, nrounds + BLT load128 + +load256: + VLD1.P 64(enc_ptr), [rk1.B16, rk2.B16, rk3.B16, rk4.B16] + +load128: + VLD1.P 64(enc_ptr), [rk5.B16, rk6.B16, rk7.B16, rk8.B16] + VLD1.P 64(enc_ptr), [rk9.B16, rk10.B16, rk11.B16, rk12.B16] + VLD1.P 48(enc_ptr), [rk13.B16, rk14.B16, rk15.B16] + +initLoops: + MOVD ZR, idx0 + VMOV block.S[0], idx0 + +initSingleLoop: +#ifdef const_useMultiBlock + ANDS $7, remain, nsingle + BEQ initWideLoop + +#else + MOVD remain, nsingle + +#endif // const_useMultiBlock + +// Handle any blocks in excess of the stride. +singleLoop: + VLD1.P 16(src_ptr), [src0.B16] + + VMOV block.B16, ks0.B16 + + CMP $12, nrounds + BLT enc128x1 + +enc256x1: + ENCRYPT256x1(ks0, rk1, rk2, rk3, rk4) + +enc128x1: + ENCRYPT128x1(ks0, rk5, rk6, rk7, rk8, rk9, rk10, rk11, rk12, rk13, rk14, rk15) + + ADD $1, idx0 + VMOV idx0, block.S[0] + + VEOR ks0.B16, src0.B16, src0.B16 + VST1.P [src0.B16], 16(dst_ptr) + + SUBS $1, nsingle + BNE singleLoop + +#ifndef const_useMultiBlock + B done + +#endif // const_useMultiBlock + +initWideLoop: + ASR $3, remain, nwide + CBZ nwide, done + + // Now handle the full stride. +wideLoop: + ADD $1, idx0, idx1 + ADD $2, idx0, idx2 + ADD $3, idx0, idx3 + ADD $4, idx0, idx4 + ADD $5, idx0, idx5 + ADD $6, idx0, idx6 + ADD $7, idx0, idx7 + + VMOV block.B16, ks0.B16 + VMOV idx0, ks0.S[0] + VMOV block.B16, ks1.B16 + VMOV idx1, ks1.S[0] + VMOV block.B16, ks2.B16 + VMOV idx2, ks2.S[0] + VMOV block.B16, ks3.B16 + VMOV idx3, ks3.S[0] + VMOV block.B16, ks4.B16 + VMOV idx4, ks4.S[0] + VMOV block.B16, ks5.B16 + VMOV idx5, ks5.S[0] + VMOV block.B16, ks6.B16 + VMOV idx6, ks6.S[0] + VMOV block.B16, ks7.B16 + VMOV idx7, ks7.S[0] + + VLD1.P 64(src_ptr), [src0.B16, src1.B16, src2.B16, src3.B16] + VLD1.P 64(src_ptr), [src4.B16, src5.B16, src6.B16, src7.B16] + + CMP $12, nrounds + BLT enc128x8 + +enc256x8: + ENCRYPT256x8(ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, rk1, rk2, rk3, rk4) + +enc128x8: + ENCRYPT128x8(ks0, ks1, ks2, ks3, ks4, ks5, ks6, ks7, rk5, rk6, rk7, rk8, rk9, rk10, rk11, rk12, rk13, rk14, rk15) + + VEOR ks0.B16, src0.B16, src0.B16 + VEOR ks1.B16, src1.B16, src1.B16 + VEOR ks2.B16, src2.B16, src2.B16 + VEOR ks3.B16, src3.B16, src3.B16 + VEOR ks4.B16, src4.B16, src4.B16 + VEOR ks5.B16, src5.B16, src5.B16 + VEOR ks6.B16, src6.B16, src6.B16 + VEOR ks7.B16, src7.B16, src7.B16 + + VST1.P [src0.B16, src1.B16, src2.B16, src3.B16], 64(dst_ptr) + VST1.P [src4.B16, src5.B16, src6.B16, src7.B16], 64(dst_ptr) + + ADD $8, idx0 + SUBS $1, nwide + BNE wideLoop + +done: + // Clear the registers. + VEOR block.B16, block.B16, block.B16 + + VEOR src0.B16, src0.B16, src0.B16 + VEOR src1.B16, src1.B16, src1.B16 + VEOR src2.B16, src2.B16, src2.B16 + VEOR src3.B16, src3.B16, src3.B16 + + VEOR ks0.B16, ks0.B16, ks0.B16 + VEOR ks1.B16, ks1.B16, ks1.B16 + VEOR ks2.B16, ks2.B16, ks2.B16 + VEOR ks3.B16, ks3.B16, ks3.B16 + +#ifdef const_useMultiBlock + VEOR src4.B16, src4.B16, src4.B16 + VEOR src5.B16, src5.B16, src5.B16 + VEOR src6.B16, src6.B16, src6.B16 + VEOR src7.B16, src7.B16, src7.B16 + + VEOR ks4.B16, ks4.B16, ks4.B16 + VEOR ks5.B16, ks5.B16, ks5.B16 + VEOR ks6.B16, ks6.B16, ks6.B16 + VEOR ks7.B16, ks7.B16, ks7.B16 + +#endif + + VEOR rk1.B16, rk1.B16, rk1.B16 + VEOR rk2.B16, rk2.B16, rk2.B16 + VEOR rk3.B16, rk3.B16, rk3.B16 + VEOR rk4.B16, rk4.B16, rk4.B16 + VEOR rk5.B16, rk5.B16, rk5.B16 + VEOR rk6.B16, rk6.B16, rk6.B16 + VEOR rk7.B16, rk7.B16, rk7.B16 + VEOR rk8.B16, rk8.B16, rk8.B16 + VEOR rk9.B16, rk9.B16, rk9.B16 + VEOR rk10.B16, rk10.B16, rk10.B16 + VEOR rk11.B16, rk11.B16, rk11.B16 + VEOR rk12.B16, rk12.B16, rk12.B16 + VEOR rk13.B16, rk13.B16, rk13.B16 + VEOR rk14.B16, rk14.B16, rk14.B16 + VEOR rk15.B16, rk15.B16, rk15.B16 + + RET diff --git a/fuzz_test.go b/fuzz_test.go index 28379cf..13438f9 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -83,10 +83,17 @@ func testTink(t *testing.T, keySize int) { gotCt := gotAead.Seal(nil, nonce, plaintext, aad) if !bytes.Equal(wantCt, gotCt) { + wantTag := wantCt[len(wantCt)-TagSize:] + gotTag := gotCt[len(gotCt)-TagSize:] + if !bytes.Equal(wantTag, gotTag) { + t.Fatalf("expected tag %x, got %x", wantTag, gotTag) + } + wantCt = wantCt[:len(wantCt)-TagSize] + gotCt = gotCt[:len(gotCt)-TagSize] for i, c := range gotCt { if c != wantCt[i] { - t.Fatalf("bad value at index %d of %d (%d): %#x", - i, len(wantCt), len(wantCt)-i, c) + t.Fatalf("bad value at index %d (block %d of %d): %#x", + i, i/blockSize, len(wantCt)/blockSize, c) } } t.Fatalf("expected %#x, got %#x", wantCt, gotCt) diff --git a/siv_asm.go b/siv_amd64.go similarity index 93% rename from siv_asm.go rename to siv_amd64.go index 01a1f4e..ddb1bc8 100644 --- a/siv_asm.go +++ b/siv_amd64.go @@ -1,4 +1,4 @@ -//go:build (amd64 || arm64) && gc && !purego +//go:build amd64 && gc && !purego package siv @@ -33,9 +33,11 @@ func (a *aead) seal(out, nonce, plaintext, additionalData []byte) { sum(tag, authKey[:16], nonce, plaintext, additionalData) encryptBlockAsm(nr, &enc[0], &tag[0], &tag[0]) - block := *tag - block[15] |= 0x80 - aesctr(nr, &enc[0], &block, out, plaintext) + if len(plaintext) > 0 { + block := *tag + block[15] |= 0x80 + aesctr(nr, &enc[0], &block, out, plaintext) + } } func (a *aead) open(out, nonce, ciphertext, tag, additionalData []byte) bool { @@ -51,10 +53,12 @@ func (a *aead) open(out, nonce, ciphertext, tag, additionalData []byte) bool { var enc [maxEncSize]uint32 expandKeyAsm(nr, &encKey[0], &enc[0]) - var block [TagSize]byte - copy(block[:], tag) - block[15] |= 0x80 - aesctr(nr, &enc[0], &block, out, ciphertext) + if len(ciphertext) > 0 { + var block [TagSize]byte + copy(block[:], tag) + block[15] |= 0x80 + aesctr(nr, &enc[0], &block, out, ciphertext) + } var wantTag [TagSize]byte sum(&wantTag, authKey[:16], nonce, out, additionalData) diff --git a/siv_arm64.go b/siv_arm64.go new file mode 100644 index 0000000..a16a454 --- /dev/null +++ b/siv_arm64.go @@ -0,0 +1,180 @@ +//go:build arm64 && gc && !purego + +package siv + +import ( + "encoding/binary" + + "github.com/ericlagergren/polyval" + "github.com/ericlagergren/siv/internal/subtle" +) + +const ( + // maxEncSize is the maximum number of uint32s used in the + // AES round key expansion. + maxEncSize = 32 + 28 +) + +func (a *aead) seal(out, nonce, plaintext, additionalData []byte) { + if !haveAsm { + a.sealGeneric(out, nonce, plaintext, additionalData) + return + } + + var encKey [40]byte + var authKey [24]byte + deriveKeys(&authKey, &encKey, a.key, nonce) + + nr := 6 + len(a.key)/4 + var enc [maxEncSize]uint32 + expandKeyAsm(nr, &encKey[0], &enc[0]) + + tag := (*[TagSize]byte)(out[len(out)-TagSize:]) + sum(tag, authKey[:16], nonce, plaintext, additionalData) + encryptBlockAsm(nr, &enc[0], &tag[0], &tag[0]) + + if len(plaintext) > 0 { + block := *tag + block[15] |= 0x80 + aesctr(nr, &enc[0], &block, out, plaintext) + } +} + +func (a *aead) open(out, nonce, ciphertext, tag, additionalData []byte) bool { + if !haveAsm { + return a.openGeneric(out, nonce, ciphertext, tag, additionalData) + } + + var encKey [40]byte + var authKey [24]byte + deriveKeys(&authKey, &encKey, a.key, nonce) + + nr := 6 + len(a.key)/4 + var enc [maxEncSize]uint32 + expandKeyAsm(nr, &encKey[0], &enc[0]) + + if len(ciphertext) > 0 { + var block [TagSize]byte + copy(block[:], tag) + block[15] |= 0x80 + aesctr(nr, &enc[0], &block, out, ciphertext) + } + + var wantTag [TagSize]byte + sum(&wantTag, authKey[:16], nonce, out, additionalData) + encryptBlockAsm(nr, &enc[0], &wantTag[0], &wantTag[0]) + + return subtle.ConstantTimeCompare(tag, wantTag[:]) == 1 +} + +func deriveKeys(authKey *[24]byte, encKey *[40]byte, keyGenKey, nonce []byte) { + src := make([]byte, 16) + copy(src[4:], nonce) + + nr := 6 + len(keyGenKey)/4 + var enc [maxEncSize]uint32 + expandKeyAsm(nr, &keyGenKey[0], &enc[0]) + + // message_authentication_key = + // AES(key = key_generating_key, + // block = little_endian_uint32(0) ++ nonce + // )[:8] ++ + // AES(key = key_generating_key, + // block = little_endian_uint32(1) ++ nonce + // )[:8] + binary.LittleEndian.PutUint32(src, 0) + encryptBlockAsm(nr, &enc[0], &authKey[0], &src[0]) + + binary.LittleEndian.PutUint32(src, 1) + encryptBlockAsm(nr, &enc[0], &authKey[8], &src[0]) + + // messasge_encryption_key = + // AES(key = key_generating_key, + // block = little_endian_uint32(2) ++ nonce + // )[:8] ++ + // AES(key = key_generating_key, + // block = little_endian_uint32(3) ++ nonce + // )[:8] + binary.LittleEndian.PutUint32(src, 2) + encryptBlockAsm(nr, &enc[0], &encKey[0], &src[0]) + + binary.LittleEndian.PutUint32(src, 3) + encryptBlockAsm(nr, &enc[0], &encKey[8], &src[0]) + + // if bytelen(key_generating_key) == 32 { + // message_encryption_key = + // AES(key = key_generating_key, + // block = little_endian_uint32(4) ++ nonce + // )[:8] ++ + // AES(key = key_generating_key, + // block = little_endian_uint32(5) ++ nonce + // )[:8] + // } + if len(keyGenKey) == 32 { + binary.LittleEndian.PutUint32(src, 4) + encryptBlockAsm(nr, &enc[0], &encKey[16], &src[0]) + + binary.LittleEndian.PutUint32(src, 5) + encryptBlockAsm(nr, &enc[0], &encKey[24], &src[0]) + } +} + +func sum(tag *[TagSize]byte, authKey, nonce, plaintext, additionalData []byte) { + length := make([]byte, 16) + binary.LittleEndian.PutUint64(length[0:8], uint64(len(additionalData))*8) + binary.LittleEndian.PutUint64(length[8:16], uint64(len(plaintext))*8) + + var p polyval.Polyval + if err := p.Init(authKey); err != nil { + panic(err) + } + + // Additional data + if len(additionalData) >= 16 { + n := len(additionalData) &^ (16 - 1) + p.Update(additionalData[:n]) + additionalData = additionalData[n:] + } + if len(additionalData) > 0 { + dst := make([]byte, 16) + copy(dst, additionalData) + p.Update(dst) + } + + // Plaintext + if len(plaintext) >= 16 { + n := len(plaintext) &^ (16 - 1) + p.Update(plaintext[:n]) + plaintext = plaintext[n:] + } + if len(plaintext) > 0 { + dst := make([]byte, 16) + copy(dst, plaintext) + p.Update(dst) + } + + // Length + p.Update(length) + + p.Sum(tag[:0]) + for i := range nonce { + tag[i] ^= nonce[i] + } + tag[15] &= 0x7f +} + +func aesctr(nr int, enc *uint32, block *[TagSize]byte, dst, src []byte) { + n := len(src) / blockSize + if n > 0 { + aesctrAsm(nr, enc, block, &dst[0], &src[0], n) + dst = dst[n*blockSize:] + src = src[n*blockSize:] + } + if len(src) > 0 { + var ks [blockSize]byte + ctr := binary.LittleEndian.Uint32(block[0:4]) + uint32(n) + binary.LittleEndian.PutUint32(block[0:4], ctr) + encryptBlockAsm(nr, enc, &ks[0], &block[0]) + xor(dst, src, ks[:], len(src)) + } +} diff --git a/siv_test.go b/siv_test.go index 2009420..2179475 100644 --- a/siv_test.go +++ b/siv_test.go @@ -17,6 +17,8 @@ import ( "testing/quick" rand "github.com/ericlagergren/saferand" + tink "github.com/google/tink/go/aead/subtle" + "github.com/ericlagergren/siv/internal/subtle" ) @@ -28,6 +30,31 @@ func randbuf(n int) []byte { return buf } +func hex16(src []byte) string { + const hextable = "0123456789abcdef" + + var dst strings.Builder + for i := 0; len(src) > TagSize; i++ { + if i > 0 && i%16 == 0 { + dst.WriteByte(' ') + } + v := src[0] + dst.WriteByte(hextable[v>>4]) + dst.WriteByte(hextable[v&0x0f]) + src = src[1:] + } + if dst.Len() > 0 { + dst.WriteByte(' ') + } + for len(src) > 0 { + v := src[0] + dst.WriteByte(hextable[v>>4]) + dst.WriteByte(hextable[v&0x0f]) + src = src[1:] + } + return dst.String() +} + func disableAsm(t *testing.T) { old := haveAsm haveAsm = false @@ -38,17 +65,17 @@ func disableAsm(t *testing.T) { // runTests runs both generic and assembly tests. func runTests(t *testing.T, fn func(t *testing.T)) { - t.Run("generic", func(t *testing.T) { - t.Helper() - disableAsm(t) - fn(t) - }) if haveAsm { t.Run("assembly", func(t *testing.T) { t.Helper() fn(t) }) } + t.Run("generic", func(t *testing.T) { + t.Helper() + disableAsm(t) + fn(t) + }) } // loadVectors reads test vectors from testdata/nameinto v. @@ -350,6 +377,8 @@ func testRFC(t *testing.T, i int, tc testVector) { ciphertext := aead.Seal(nil, tc.nonce, tc.plaintext, tc.aad) if !bytes.Equal(ciphertext, tc.result) { + t.Logf("W: %q\n", hex16(tc.result)) + t.Logf("G: %q\n", hex16(ciphertext)) t.Fatalf("#%d: expected %x, got %x", i, tc.result, ciphertext) } @@ -368,6 +397,75 @@ func testRFC(t *testing.T, i int, tc testVector) { } } +// TestMultiBlock tests the code paths that handle N blocks at +// a time. +func TestMultiBlock(t *testing.T) { + runTests(t, func(t *testing.T) { + t.Run("128", func(t *testing.T) { + testMultiBlock(t, 16) + }) + t.Run("256", func(t *testing.T) { + testMultiBlock(t, 32) + }) + }) +} + +func testMultiBlock(t *testing.T, keySize int) { + key := randbuf(keySize) + plaintext := randbuf((blockSize * 16) + blockSize/3) + aad := randbuf(773) + + // TODO(eric): add test vectors to testdata instead of using + // Tink. + refAead, err := tink.NewAESGCMSIV(key) + if err != nil { + t.Fatal(err) + } + nonceAndCt, err := refAead.Encrypt(plaintext, aad) + if err != nil { + t.Fatal(err) + } + nonce := nonceAndCt[:NonceSize] + wantCt := nonceAndCt[NonceSize:] + + gotAead, err := NewGCM(key) + if err != nil { + t.Fatal(err) + } + + gotCt := gotAead.Seal(nil, nonce, plaintext, aad) + if !bytes.Equal(wantCt, gotCt) { + wantTag := wantCt[len(wantCt)-TagSize:] + gotTag := gotCt[len(gotCt)-TagSize:] + if !bytes.Equal(wantTag, gotTag) { + t.Fatalf("expected tag %x, got %x", wantTag, gotTag) + } + wantCt = wantCt[:len(wantCt)-TagSize] + gotCt = gotCt[:len(gotCt)-TagSize] + t.Logf("W: %q\n", hex16(wantCt)) + t.Logf("G: %q\n", hex16(gotCt)) + for i, c := range gotCt { + if c != wantCt[i] { + t.Fatalf("bad value at index %d (block %d of %d): %#x", + i, i/blockSize, len(wantCt)/blockSize, c) + } + } + panic("unreachable") + } + + wantPt, err := refAead.Decrypt(nonceAndCt, aad) + if err != nil { + t.Fatal(err) + } + gotPt, err := gotAead.Open(nil, nonce, wantCt, aad) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(wantPt, gotPt) { + t.Fatalf("expected %#x, got %#x", wantPt, gotPt) + } +} + // TestOverlap tests Seal and Open with overlapping buffers. func TestOverlap(t *testing.T) { runTests(t, func(t *testing.T) { @@ -388,8 +486,7 @@ func testOverlap(t *testing.T, keySize int) { i, j int } const ( - // max = 7789 - max = 33 + max = 7789 ) args := []arg{ {buf: randbuf(keySize), ptr: &key}, diff --git a/stub_arm64.go b/stub_arm64.go index b118870..5aaf682 100644 --- a/stub_arm64.go +++ b/stub_arm64.go @@ -7,3 +7,14 @@ func encryptBlockAsm(nr int, xk *uint32, dst, src *byte) //go:noescape func expandKeyAsm(nr int, key *byte, enc *uint32) + +//go:noescape +func aesctrAsm(nr int, enc *uint32, iv *[blockSize]byte, dst, src *byte, nblocks int) + +// useMultiBlock causes cmd/asm to define "const_useMultiBlock" +// in "go_asm.h", which instructs aesctrAsm to compute multiple +// blocks at a time. +// +// Commenting out or deleting this constant restricts aesctrAsm +// to just one block a a time. +const useMultiBlock = true