diff --git a/README.md b/README.md index 63b25ff..b759f4a 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ ## Features - **Supports keyset-based and offset-based pagination**: You can freely choose high-performance keyset pagination based on multiple indexed columns, or use offset pagination. -- **Optional cursor encryption**: Supports encrypting cursors using `AES` or `Base64` to ensure the security of pagination information. +- **Optional cursor encryption**: Supports encrypting cursors using `GCM(AES)` or `Base64` to ensure the security of pagination information. - **Flexible query strategies**: Optionally skip the `TotalCount` query to improve performance, especially in large datasets. - **Non-generic support**: Even without using Go generics, you can paginate using the `any` type for flexible use cases. @@ -33,14 +33,16 @@ resp, err := p.Paginate(context.Background(), &relay.PaginateRequest[*User]{ ### Middleware -If you need to encrypt cursors, you can use `cursor.Base64` or `cursor.AES` middlewares: +If you need to encrypt cursors, you can use `cursor.Base64` or `cursor.GCM` middlewares: ```go // Encrypt cursors with Base64 cursor.Base64(gormrelay.NewOffsetAdapter[*User](db)) -// Encrypt cursors with AES -cursor.AES(encryptionKey)(gormrelay.NewKeysetAdapter[*User](db)) +// Encrypt cursors with GCM(AES) +gcm, err := cursor.NewGCM(encryptionKey) +require.NoError(t, err) +cursor.GCM(gcm)(gormrelay.NewKeysetAdapter[*User](db)) ``` If you need to append `PrimaryOrderBys` to `PaginateRequest.OrderBys` diff --git a/cursor/aes.go b/cursor/gcm.go similarity index 70% rename from cursor/aes.go rename to cursor/gcm.go index d04707f..7924c4e 100644 --- a/cursor/aes.go +++ b/cursor/gcm.go @@ -13,40 +13,20 @@ import ( "github.com/theplant/relay" ) -func encryptAES(plainText string, key []byte) (string, error) { - block, err := aes.NewCipher(key) - if err != nil { - return "", errors.New("could not create cipher block") - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - +func encryptGCM(gcm cipher.AEAD, plainText string) (string, error) { nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return "", err + return "", errors.Wrap(err, "could not generate nonce") } cipherText := gcm.Seal(nonce, nonce, []byte(plainText), nil) return base64.RawURLEncoding.EncodeToString(cipherText), nil } -func decryptAES(cipherText string, key []byte) (string, error) { - block, err := aes.NewCipher(key) - if err != nil { - return "", errors.New("could not create cipher block") - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - +func decryptGCM(gcm cipher.AEAD, cipherText string) (string, error) { decodedCipherText, err := base64.RawURLEncoding.DecodeString(cipherText) if err != nil { - return "", err + return "", errors.Wrap(err, "could not decode cipher text") } nonceSize := gcm.NonceSize() @@ -57,17 +37,32 @@ func decryptAES(cipherText string, key []byte) (string, error) { nonce, dataCipherText := decodedCipherText[:nonceSize], decodedCipherText[nonceSize:] plainText, err := gcm.Open(nil, nonce, dataCipherText, nil) if err != nil { - return "", err + return "", errors.Wrap(err, "could not decrypt cipher text") } return string(plainText), nil } -func AES[T any](encryptionKey []byte) relay.CursorMiddleware[T] { +// NewGCM creates a new GCM cipher +// Concurrent safe: https://github.com/golang/go/issues/41689 +func NewGCM(key []byte) (cipher.AEAD, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, errors.New("could not create cipher block") + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, errors.New("could not create GCM") + } + return gcm, nil +} + +func GCM[T any](gcm cipher.AEAD) relay.CursorMiddleware[T] { return func(next relay.ApplyCursorsFunc[T]) relay.ApplyCursorsFunc[T] { return func(ctx context.Context, req *relay.ApplyCursorsRequest) (*relay.ApplyCursorsResponse[T], error) { if req.After != nil { - decodedCursor, err := decryptAES(*req.After, encryptionKey) + decodedCursor, err := decryptGCM(gcm, *req.After) if err != nil { return nil, errors.Wrap(err, "invalid after cursor") } @@ -75,7 +70,7 @@ func AES[T any](encryptionKey []byte) relay.CursorMiddleware[T] { } if req.Before != nil { - decodedCursor, err := decryptAES(*req.Before, encryptionKey) + decodedCursor, err := decryptGCM(gcm, *req.Before) if err != nil { return nil, errors.Wrap(err, "invalid before cursor") } @@ -95,7 +90,7 @@ func AES[T any](encryptionKey []byte) relay.CursorMiddleware[T] { if err != nil { return "", err } - encryptedCursor, err := encryptAES(cursor, encryptionKey) + encryptedCursor, err := encryptGCM(gcm, cursor) if err != nil { return "", err } diff --git a/cursor/aes_test.go b/cursor/gcm_test.go similarity index 57% rename from cursor/aes_test.go rename to cursor/gcm_test.go index 1407911..4218751 100644 --- a/cursor/aes_test.go +++ b/cursor/gcm_test.go @@ -6,41 +6,45 @@ import ( "io" "testing" + "github.com/pkg/errors" "github.com/stretchr/testify/require" ) -func generateAESKey(length int) ([]byte, error) { +func generateGCMKey(length int) ([]byte, error) { key := make([]byte, length) if _, err := io.ReadFull(rand.Reader, key); err != nil { - return nil, err + return nil, errors.Wrap(err, "could not generate key") } return key, nil } -func TestAES(t *testing.T) { - aesKey, err := generateAESKey(32) +func TestGCM(t *testing.T) { + gcmKey, err := generateGCMKey(32) + require.NoError(t, err) + + gcm, err := NewGCM(gcmKey) require.NoError(t, err) plainText := `{"ID":225}` { - cipherText, err := encryptAES(plainText, aesKey) + cipherText, err := encryptGCM(gcm, plainText) require.NoError(t, err) t.Logf("cipherText: %s", cipherText) - decryptedText, err := decryptAES(cipherText, aesKey) + decryptedText, err := decryptGCM(gcm, cipherText) require.NoError(t, err) require.Equal(t, plainText, decryptedText) } { - cipherText, err := encryptAES(base64.RawURLEncoding.EncodeToString([]byte(plainText)), aesKey) + cipherText, err := encryptGCM(gcm, base64.RawURLEncoding.EncodeToString([]byte(plainText))) require.NoError(t, err) t.Logf("cipherText: %s", cipherText) - decryptedText, err := decryptAES(cipherText, aesKey) + decryptedText, err := decryptGCM(gcm, cipherText) require.NoError(t, err) plainTextData, err := base64.RawURLEncoding.DecodeString(decryptedText) diff --git a/gormrelay/relay_test.go b/gormrelay/relay_test.go index 8df675a..11f1e80 100644 --- a/gormrelay/relay_test.go +++ b/gormrelay/relay_test.go @@ -245,10 +245,10 @@ func TestTotalCountZero(t *testing.T) { t.Run("offset", func(t *testing.T) { testCase(t, NewOffsetAdapter) }) } -func generateAESKey(length int) ([]byte, error) { +func generateGCMKey(length int) ([]byte, error) { key := make([]byte, length) if _, err := io.ReadFull(rand.Reader, key); err != nil { - return nil, err + return nil, errors.Wrap(err, "could not generate key") } return key, nil } @@ -320,19 +320,22 @@ func TestMiddleware(t *testing.T) { }) }) - t.Run("AES", func(t *testing.T) { - encryptionKey, err := generateAESKey(32) + t.Run("GCM", func(t *testing.T) { + encryptionKey, err := generateGCMKey(32) + require.NoError(t, err) + + gcm, err := cursor.NewGCM(encryptionKey) require.NoError(t, err) t.Run("keyset", func(t *testing.T) { testCase(t, func(db *gorm.DB) relay.ApplyCursorsFunc[*User] { - return cursor.AES[*User](encryptionKey)(NewKeysetAdapter[*User](db)) + return cursor.GCM[*User](gcm)(NewKeysetAdapter[*User](db)) }) }) t.Run("offset", func(t *testing.T) { testCase(t, func(db *gorm.DB) relay.ApplyCursorsFunc[*User] { - return cursor.AES[*User](encryptionKey)(NewOffsetAdapter[*User](db)) + return cursor.GCM[*User](gcm)(NewOffsetAdapter[*User](db)) }) }) }) @@ -376,10 +379,13 @@ func TestMiddleware(t *testing.T) { func TestAppendCursorMiddleware(t *testing.T) { resetDB(t) - encryptionKey, err := generateAESKey(32) + encryptionKey, err := generateGCMKey(32) + require.NoError(t, err) + + gcm, err := cursor.NewGCM(encryptionKey) require.NoError(t, err) - aesMiddleware := cursor.AES[*User](encryptionKey) + gcmMiddleware := cursor.GCM[*User](gcm) testCase := func(t *testing.T, f func(db *gorm.DB) relay.ApplyCursorsFunc[*User]) { p := relay.New( @@ -387,8 +393,8 @@ func TestAppendCursorMiddleware(t *testing.T) { 10, 10, f(db), ) - p = relay.AppendCursorMiddleware(aesMiddleware)(p) // test add single middleware - p = relay.AppendCursorMiddleware(cursor.Base64[*User], aesMiddleware)(p) // test add multiple middlewares + p = relay.AppendCursorMiddleware(gcmMiddleware)(p) // test add single middleware + p = relay.AppendCursorMiddleware(cursor.Base64[*User], gcmMiddleware)(p) // test add multiple middlewares p = relay.PrimaryOrderBy[*User](relay.OrderBy{Field: "ID", Desc: false})(p) // test a pagination middleware resp, err := p.Paginate(context.Background(), &relay.PaginateRequest[*User]{