diff --git a/.github/workflows/challenge-bypass-tests.yaml b/.github/workflows/challenge-bypass-tests.yaml index da2f2ebe..2c9e0ee8 100644 --- a/.github/workflows/challenge-bypass-tests.yaml +++ b/.github/workflows/challenge-bypass-tests.yaml @@ -10,5 +10,7 @@ jobs: steps: - name: checkout repo uses: actions/checkout@v3 + - name: run lint + run: make lint - name: run tests run: make docker-test diff --git a/.github/workflows/golangci-lint.yaml b/.github/workflows/golangci-lint.yaml deleted file mode 100644 index db4e211e..00000000 --- a/.github/workflows/golangci-lint.yaml +++ /dev/null @@ -1,26 +0,0 @@ -name: golangci-lint -on: - push: - tags: - - v* - branches: - - master - - main - pull_request: -permissions: - contents: read - -jobs: - golangci: - name: lint - runs-on: ubuntu-latest - steps: - - uses: actions/setup-go@v3 - with: - go-version: 1.18 - - uses: actions/checkout@v3 - - name: golangci-lint - uses: golangci/golangci-lint-action@v3 - with: - version: v1.46 - args: -v diff --git a/.golangci.yaml b/.golangci.yaml index 5f2b1a37..1e5d90a3 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -3,14 +3,14 @@ run: timeout: 3m linters-settings: - cyclop: + #cyclop: # The maximal code complexity to report. # Default: 10 - max-complexity: 10 + # max-complexity: 10 # The maximal average package complexity. # If it's higher than 0.0 (float) the check is enabled # Default: 0.0 - package-average: 10.0 + #package-average: 10.0 errcheck: # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. @@ -41,23 +41,23 @@ linters: disable-all: true enable: ## enabled by default - - deadcode # Finds unused code + #- deadcode # Finds unused code - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - gosimple # Linter for Go source code that specializes in simplifying a code - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - ineffassign # Detects when assignments to existing variables are not used - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - structcheck # Finds unused struct fields + #- structcheck # Finds unused struct fields - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unused # Checks Go code for unused constants, variables, functions and types - - varcheck # Finds unused global variables and constants + #- varcheck # Finds unused global variables and constants # ## disabled by default - contextcheck # check the function whether use a non-inherited context - - cyclop # checks function and package cyclomatic complexity + #- cyclop # checks function and package cyclomatic complexity - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - gocritic # Provides diagnostics that check for bugs, performance and style issues. - - gocyclo # Computes and checks the cyclomatic complexity of functions - - nestif # Reports deeply nested if statements + #- gocyclo # Computes and checks the cyclomatic complexity of functions + #- nestif # Reports deeply nested if statements - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed - stylecheck # Stylecheck is a replacement for golint diff --git a/Makefile b/Makefile index 0c95b45d..d7206218 100644 --- a/Makefile +++ b/Makefile @@ -22,4 +22,4 @@ generate-avro: sed -i 's/"public_key/"issuer_public_key/g' ./avro/generated/signing_result*.go lint: - docker run --rm -v "$$(pwd):/app" --workdir /app golangci/golangci-lint:v1.46.2 go get ./... && golangci-lint run -v ./... + docker run --rm -v "$$(pwd):/app" --workdir /app golangci/golangci-lint:v1.49.0 golangci-lint run -v ./... diff --git a/avro/generated/redeem_result.go b/avro/generated/redeem_result.go index b8d88ca0..c9bb3688 100644 --- a/avro/generated/redeem_result.go +++ b/avro/generated/redeem_result.go @@ -31,7 +31,7 @@ type RedeemResult struct { Associated_data Bytes `json:"associated_data"` } -const RedeemResultAvroCRC64Fingerprint = "֯*\xbf+\xa0\x84\xe0" +const RedeemResultAvroCRC64Fingerprint = "\x11T\xa5\xba@д;" func NewRedeemResult() RedeemResult { r := RedeemResult{} @@ -87,7 +87,7 @@ func (r RedeemResult) Serialize(w io.Writer) error { } func (r RedeemResult) Schema() string { - return "{\"fields\":[{\"name\":\"issuer_name\",\"type\":\"string\"},{\"name\":\"issuer_cohort\",\"type\":\"int\"},{\"name\":\"status\",\"type\":{\"name\":\"RedeemResultStatus\",\"symbols\":[\"ok\",\"duplicate_redemption\",\"unverified\",\"error\"],\"type\":\"enum\"}},{\"doc\":\"contains METADATA\",\"name\":\"associated_data\",\"type\":\"bytes\"}],\"name\":\"brave.cbp.RedeemResult\",\"type\":\"record\"}" + return "{\"fields\":[{\"name\":\"issuer_name\",\"type\":\"string\"},{\"name\":\"issuer_cohort\",\"type\":\"int\"},{\"name\":\"status\",\"type\":{\"name\":\"RedeemResultStatus\",\"symbols\":[\"ok\",\"duplicate_redemption\",\"unverified\",\"error\",\"idempotent_redemption\"],\"type\":\"enum\"}},{\"doc\":\"contains METADATA\",\"name\":\"associated_data\",\"type\":\"bytes\"}],\"name\":\"brave.cbp.RedeemResult\",\"type\":\"record\"}" } func (r RedeemResult) SchemaName() string { diff --git a/avro/generated/redeem_result_set.go b/avro/generated/redeem_result_set.go index 68b2897f..53b5e42b 100644 --- a/avro/generated/redeem_result_set.go +++ b/avro/generated/redeem_result_set.go @@ -28,7 +28,7 @@ type RedeemResultSet struct { Data []RedeemResult `json:"data"` } -const RedeemResultSetAvroCRC64Fingerprint = "\xa5a\x92\xe9\xfb@i\"" +const RedeemResultSetAvroCRC64Fingerprint = "\x04\xe6\xb5@7\xfb\xc28" func NewRedeemResultSet() RedeemResultSet { r := RedeemResultSet{} @@ -78,7 +78,7 @@ func (r RedeemResultSet) Serialize(w io.Writer) error { } func (r RedeemResultSet) Schema() string { - return "{\"doc\":\"Top level request containing the data to be processed, as well as any top level metadata for this message.\",\"fields\":[{\"name\":\"request_id\",\"type\":\"string\"},{\"name\":\"data\",\"type\":{\"items\":{\"fields\":[{\"name\":\"issuer_name\",\"type\":\"string\"},{\"name\":\"issuer_cohort\",\"type\":\"int\"},{\"name\":\"status\",\"type\":{\"name\":\"RedeemResultStatus\",\"symbols\":[\"ok\",\"duplicate_redemption\",\"unverified\",\"error\"],\"type\":\"enum\"}},{\"doc\":\"contains METADATA\",\"name\":\"associated_data\",\"type\":\"bytes\"}],\"name\":\"RedeemResult\",\"namespace\":\"brave.cbp\",\"type\":\"record\"},\"type\":\"array\"}}],\"name\":\"brave.cbp.RedeemResultSet\",\"type\":\"record\"}" + return "{\"doc\":\"Top level request containing the data to be processed, as well as any top level metadata for this message.\",\"fields\":[{\"name\":\"request_id\",\"type\":\"string\"},{\"name\":\"data\",\"type\":{\"items\":{\"fields\":[{\"name\":\"issuer_name\",\"type\":\"string\"},{\"name\":\"issuer_cohort\",\"type\":\"int\"},{\"name\":\"status\",\"type\":{\"name\":\"RedeemResultStatus\",\"symbols\":[\"ok\",\"duplicate_redemption\",\"unverified\",\"error\",\"idempotent_redemption\"],\"type\":\"enum\"}},{\"doc\":\"contains METADATA\",\"name\":\"associated_data\",\"type\":\"bytes\"}],\"name\":\"RedeemResult\",\"namespace\":\"brave.cbp\",\"type\":\"record\"},\"type\":\"array\"}}],\"name\":\"brave.cbp.RedeemResultSet\",\"type\":\"record\"}" } func (r RedeemResultSet) SchemaName() string { diff --git a/avro/generated/redeem_result_status.go b/avro/generated/redeem_result_status.go index 337cb52f..d9362e48 100644 --- a/avro/generated/redeem_result_status.go +++ b/avro/generated/redeem_result_status.go @@ -23,10 +23,11 @@ var _ = fmt.Printf type RedeemResultStatus int32 const ( - RedeemResultStatusOk RedeemResultStatus = 0 - RedeemResultStatusDuplicate_redemption RedeemResultStatus = 1 - RedeemResultStatusUnverified RedeemResultStatus = 2 - RedeemResultStatusError RedeemResultStatus = 3 + RedeemResultStatusOk RedeemResultStatus = 0 + RedeemResultStatusDuplicate_redemption RedeemResultStatus = 1 + RedeemResultStatusUnverified RedeemResultStatus = 2 + RedeemResultStatusError RedeemResultStatus = 3 + RedeemResultStatusIdempotent_redemption RedeemResultStatus = 4 ) func (e RedeemResultStatus) String() string { @@ -39,6 +40,8 @@ func (e RedeemResultStatus) String() string { return "unverified" case RedeemResultStatusError: return "error" + case RedeemResultStatusIdempotent_redemption: + return "idempotent_redemption" } return "unknown" } @@ -57,6 +60,8 @@ func NewRedeemResultStatusValue(raw string) (r RedeemResultStatus, err error) { return RedeemResultStatusUnverified, nil case "error": return RedeemResultStatusError, nil + case "idempotent_redemption": + return RedeemResultStatusIdempotent_redemption, nil } return -1, fmt.Errorf("invalid value for RedeemResultStatus: '%s'", raw) diff --git a/avro/schemas/redeem_result.avsc b/avro/schemas/redeem_result.avsc index 0c26c873..f22406f1 100644 --- a/avro/schemas/redeem_result.avsc +++ b/avro/schemas/redeem_result.avsc @@ -21,7 +21,7 @@ "type": { "name": "RedeemResultStatus", "type": "enum", - "symbols": ["ok", "duplicate_redemption", "unverified", "error"] + "symbols": ["ok", "duplicate_redemption", "unverified", "error", "idempotent_redemption"] } }, {"name": "associated_data", "type": "bytes", "doc": "contains METADATA"} diff --git a/btd/issuer.go b/btd/issuer.go index 1a93a141..51b1d225 100644 --- a/btd/issuer.go +++ b/btd/issuer.go @@ -9,8 +9,10 @@ import ( ) var ( - ErrInvalidMAC = errors.New("binding MAC didn't match derived MAC") - ErrInvalidBatchProof = errors.New("New batch proof for signed tokens is invalid") + // ErrInvalidMAC - the mac was invalid + ErrInvalidMAC = errors.New("binding MAC didn't match derived MAC") + // ErrInvalidBatchProof - the batch proof was invalid + ErrInvalidBatchProof = errors.New("new batch proof for signed tokens is invalid") latencyBuckets = []float64{.25, .5, 1, 2.5, 5, 10} @@ -74,7 +76,7 @@ func init() { func ApproveTokens(blindedTokens []*crypto.BlindedToken, key *crypto.SigningKey) ([]*crypto.SignedToken, *crypto.BatchDLEQProof, error) { var err error if len(blindedTokens) < 1 { - err = errors.New("Provided blindedTokens array was empty.") + err = errors.New("provided blindedTokens array was empty") return []*crypto.SignedToken{}, nil, err } diff --git a/go.mod b/go.mod index 98915694..4caf8099 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/actgardner/gogen-avro/v10 v10.2.1 github.com/aws/aws-sdk-go v1.44.136 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.17.4 - github.com/brave-intl/bat-go/libs v0.0.0-20220823005459-d3a4d8ccf976 + github.com/brave-intl/bat-go/libs v0.0.0-20220913154833-730f36b772de github.com/brave-intl/challenge-bypass-ristretto-ffi v0.0.0-20190717223301-f88d942ddfaf github.com/getsentry/raven-go v0.2.0 github.com/go-chi/chi v4.1.2+incompatible @@ -32,7 +32,6 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect - github.com/containerd/containerd v1.6.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/getsentry/sentry-go v0.13.0 // indirect github.com/go-chi/chi/v5 v5.0.7 // indirect @@ -48,7 +47,6 @@ require ( github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect - github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect github.com/pierrec/lz4/v4 v4.1.15 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index 90f8cb4f..81bddd4f 100644 --- a/go.sum +++ b/go.sum @@ -184,8 +184,8 @@ github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnweb github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= -github.com/brave-intl/bat-go/libs v0.0.0-20220823005459-d3a4d8ccf976 h1:kz83/D17IsaIVrDSqYvNgyLgMRvgSExcHvFOsRAOPEM= -github.com/brave-intl/bat-go/libs v0.0.0-20220823005459-d3a4d8ccf976/go.mod h1:bIOgpByIK7sC11XzdMZlM1Ri17g0eYqLFs5sd/D1wF8= +github.com/brave-intl/bat-go/libs v0.0.0-20220913154833-730f36b772de h1:A7l6jiuZW6ED7SuDK331LhkCqQNUYNv0RclciTwvIZU= +github.com/brave-intl/bat-go/libs v0.0.0-20220913154833-730f36b772de/go.mod h1:Hdx1PUXLp4TevCH6X7hzfCBcjaQnuechLVUWqD2I3aQ= github.com/brave-intl/challenge-bypass-ristretto-ffi v0.0.0-20190717223301-f88d942ddfaf h1:ZAsT/fM7Kxipf3wtoY7xa2bpFmAxzYPhVJ3hUcSdTRI= github.com/brave-intl/challenge-bypass-ristretto-ffi v0.0.0-20190717223301-f88d942ddfaf/go.mod h1:I9sAUIQc7AvvUU0Ustl5WMTdqmlNjXsX6dRLnDNxXiE= github.com/bshuster-repo/logrus-logstash-hook v0.4.1/go.mod h1:zsTqEiSzDgAa/8GZR7E1qaXrhYNDKBYy5/dWPTIflbk= @@ -274,7 +274,6 @@ github.com/containerd/containerd v1.5.7/go.mod h1:gyvv6+ugqY25TiXxcZC3L5yOeYgEw0 github.com/containerd/containerd v1.5.8/go.mod h1:YdFSv5bTFLpG2HIYmfqDpSYYTDX+mc5qtSuYx1YUb/s= github.com/containerd/containerd v1.6.1/go.mod h1:1nJz5xCZPusx6jJU8Frfct988y0NpumIq9ODB0kLtoE= github.com/containerd/containerd v1.6.6 h1:xJNPhbrmz8xAMDNoVjHy9YHtWwEQNS+CDkcIRh7t8Y0= -github.com/containerd/containerd v1.6.6/go.mod h1:ZoP1geJldzCVY3Tonoz7b1IXk8rIX0Nltt5QE4OMNk0= github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/containerd/continuity v0.0.0-20190815185530-f2a389ac0a02/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/containerd/continuity v0.0.0-20191127005431-f65d91d395eb/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= @@ -935,7 +934,6 @@ github.com/opencontainers/image-spec v1.0.1/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zM github.com/opencontainers/image-spec v1.0.2-0.20211117181255-693428a734f5/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= github.com/opencontainers/image-spec v1.0.2/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 h1:rc3tiVYb5z54aKaDfakKn0dDjIyPpTtszkjuMzyt7ec= -github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= github.com/opencontainers/runc v0.0.0-20190115041553-12f6a991201f/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= github.com/opencontainers/runc v1.0.0-rc8.0.20190926000215-3e425f80a8c9/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= diff --git a/kafka/avro_test.go b/kafka/avro_test.go index 93bb4836..f4a42283 100644 --- a/kafka/avro_test.go +++ b/kafka/avro_test.go @@ -2,11 +2,12 @@ package kafka import ( "bytes" + "testing" + "time" + avroSchema "github.com/brave-intl/challenge-bypass-server/avro/generated" "github.com/brave-intl/challenge-bypass-server/utils/test" "github.com/stretchr/testify/assert" - "testing" - "time" ) // Tests v2 adds new fields validTo, validFrom and BlindedTokens. @@ -37,8 +38,8 @@ func TestSchemaCompatability_SigningResult_V2ToV1(t *testing.T) { assert.Equal(t, v2.Status.String(), v1.Status.String()) } -//// Tests v2 consumers reading v1 messages. -//func TestSchemaCompatability_SigningResult_V1ToV2(t *testing.T) { +// Tests v2 consumers reading v1 messages. +// func TestSchemaCompatability_SigningResult_V1ToV2(t *testing.T) { // v1 := &avroSchema.SigningResultV1{ // Signed_tokens: []string{test.RandomString()}, // Issuer_public_key: test.RandomString(), @@ -50,7 +51,6 @@ func TestSchemaCompatability_SigningResult_V2ToV1(t *testing.T) { // var buf bytes.Buffer // err := v1.Serialize(&buf) // assert.NoError(t, err) -// // v2, err := avroSchema.DeserializeSigningResultV2(&buf) // assert.NoError(t, err) // @@ -61,4 +61,4 @@ func TestSchemaCompatability_SigningResult_V2ToV1(t *testing.T) { // //assert.Nil(t, v2.Valid_to) // //assert.Nil(t, v2.Valid_from) // assert.Empty(t, v2.Blinded_tokens) -//} +// } diff --git a/kafka/main.go b/kafka/main.go index c1d9e5ee..ab3fb2cc 100644 --- a/kafka/main.go +++ b/kafka/main.go @@ -1,9 +1,11 @@ +// Package kafka manages kafka interaction package kafka import ( "context" + "errors" + "io" "os" - "strconv" "strings" "time" @@ -17,9 +19,21 @@ import ( var brokers []string -// Processor is an interface that represents functions which can be used to process kafka -// messages in our pipeline. -type Processor func([]byte, *kafka.Writer, *server.Server, *zerolog.Logger) error +// Processor is a function that is used to process Kafka messages +type Processor func( + kafka.Message, + *kafka.Writer, + *server.Server, + *zerolog.Logger, +) error + +// ProcessingResult contains a message and the topic to which the message should be +// emitted +type ProcessingResult struct { + ResultProducer *kafka.Writer + Message []byte + RequestID string +} // TopicMapping represents a kafka topic, how to process it, and where to emit the result. type TopicMapping struct { @@ -29,6 +43,12 @@ type TopicMapping struct { Group string } +// MessageContext is used for channel coordination when processing batches of messages +type MessageContext struct { + errorResult chan error + msg kafka.Message +} + // StartConsumers reads configuration variables and starts the associated kafka consumers func StartConsumers(providedServer *server.Server, logger *zerolog.Logger) error { adsRequestRedeemV1Topic := os.Getenv("REDEEM_CONSUMER_TOPIC") @@ -67,94 +87,154 @@ func StartConsumers(providedServer *server.Server, logger *zerolog.Logger) error topics = append(topics, topicMapping.Topic) } - consumerCount, err := strconv.Atoi(os.Getenv("KAFKA_CONSUMERS_PER_NODE")) + reader := newConsumer(topics, adsConsumerGroupV1, logger) + + batchPipeline := make(chan *MessageContext, 100) + ctx := context.Background() + go processMessagesIntoBatchPipeline(ctx, topicMappings, providedServer, reader, batchPipeline, logger) + for { + err := readAndCommitBatchPipelineResults(ctx, reader, batchPipeline, logger) + if err != nil { + // If readAndCommitBatchPipelineResults returns an error. + close(batchPipeline) + return err + } + } +} + +// readAndCommitBatchPipelineResults does a blocking read of the batchPipeline channel and +// then does a blocking read of the errorResult in the MessageContext in the batchPipeline. +// When an error appears it means that the channel was closed or a temporary error was +// encountered. In the case of a temporary error, the application returns an error without +// committing so that the next reader gets the same message to try again. +func readAndCommitBatchPipelineResults( + ctx context.Context, + reader *kafka.Reader, + batchPipeline chan *MessageContext, + logger *zerolog.Logger, +) error { + msgCtx, ok := <-batchPipeline + if !ok { + logger.Error().Msg("batchPipeline channel closed") + return errors.New("batch item error") + } + err := <-msgCtx.errorResult if err != nil { - logger.Error().Err(err).Msg("Failed to convert KAFKA_CONSUMERS_PER_NODE variable to a usable integer. Defaulting to 1.") - consumerCount = 1 + logger.Error().Msg("temporary failure encountered") + return errors.New("temporary failure encountered") + } + logger.Info().Msgf("Committing offset %d", msgCtx.msg.Offset) + if err := reader.CommitMessages(ctx, msgCtx.msg); err != nil { + logger.Error().Err(err).Msg("failed to commit") + return errors.New("failed to commit") } + return nil +} - logger.Trace().Msgf("Spawning %d consumer goroutines", consumerCount) - - for i := 1; i <= consumerCount; i++ { - go func(topicMappings []TopicMapping) { - consumer := newConsumer(topics, adsConsumerGroupV1, logger) - var ( - failureCount = 0 - failureLimit = 10 - ) - logger.Trace().Msg("Beginning message processing") - for { - // `FetchMessage` blocks until the next event. Do not block main. - ctx := context.Background() - logger.Trace().Msgf("Fetching messages from Kafka") - msg, err := consumer.FetchMessage(ctx) - if err != nil { - logger.Error().Err(err).Msg("") - if failureCount > failureLimit { - break - } - failureCount++ - continue - } - logger.Debug().Msgf("Processing message for topic %s at offset %d", msg.Topic, msg.Offset) - logger.Debug().Msgf("Reader Stats: %#v", consumer.Stats()) - logger.Debug().Msgf("topicMappings: %+v", topicMappings) - for _, topicMapping := range topicMappings { - logger.Debug().Msgf("topic: %+v, topicMapping: %+v", msg.Topic, topicMapping.Topic) - if msg.Topic == topicMapping.Topic { - go func( - msg kafka.Message, - topicMapping TopicMapping, - providedServer *server.Server, - logger *zerolog.Logger, - ) { - err := topicMapping.Processor( - msg.Value, - topicMapping.ResultProducer, - providedServer, - logger, - ) - if err != nil { - logger.Error().Err(err).Msg("Processing failed.") - } - }(msg, topicMapping, providedServer, logger) - - if err := consumer.CommitMessages(ctx, msg); err != nil { - logger.Error().Msgf("Failed to commit: %s", err) - } - } - } - } +// processMessagesIntoBatchPipeline fetches messages from Kafka indefinitely, pushes a +// MessageContext into the batchPipeline to maintain message order, and then spawns a +// goroutine that will process the message and push to errorResult of the MessageContext +// when the processing completes. In case of an error, we panic from this function, +// triggering the deferral which closes the batchPipeline channel. This will result in +// readAndCommitBatchPipelineResults returning an error and the processing loop being recreated. +func processMessagesIntoBatchPipeline( + ctx context.Context, + topicMappings []TopicMapping, + providedServer *server.Server, + reader *kafka.Reader, + batchPipeline chan *MessageContext, + logger *zerolog.Logger, +) { + // During normal operation processMessagesIntoBatchPipeline will never complete and + // this deferral should not run. It's only called if we encounter some unrecoverable + // error. + defer func() { + close(batchPipeline) + }() - // The below block will close the producer connection when the error threshold is reached. - // @TODO: Test to determine if this Close() impacts the other goroutines that were passed - // the same topicMappings before re-enabling this block. - //for _, topicMapping := range topicMappings { - // logger.Trace().Msg(fmt.Sprintf("Closing producer connection %v", topicMapping)) - // if err := topicMapping.ResultProducer.Close(); err != nil { - // logger.Error().Msg(fmt.Sprintf("Failed to close writer: %e", err)) - // } - //} - }(topicMappings) + for { + msg, err := reader.FetchMessage(ctx) + if err != nil { + // Indicates batch has no more messages. End the loop for + // this batch and fetch another. + if err == io.EOF { + logger.Info().Msg("Batch complete") + } else if errors.Is(err, context.DeadlineExceeded) { + logger.Error().Err(err).Msg("batch item error") + panic("failed to fetch kafka messages and closed channel") + } + // There are other possible errors, but the underlying consumer + // group handler handle retryable failures well. If further + // investigation is needed you can review the handler here: + // https://github.com/segmentio/kafka-go/blob/main/consumergroup.go#L729 + continue + } + msgCtx := &MessageContext{ + errorResult: make(chan error), + msg: msg, + } + // If batchPipeline has been closed by an error in readAndCommitBatchPipelineResults, + // this write will panic, which is desired behavior, as the rest of the context + // will also have died and will be restarted from kafka/main.go + batchPipeline <- msgCtx + logger.Debug().Msgf("Processing message for topic %s at offset %d", msg.Topic, msg.Offset) + logger.Debug().Msgf("Reader Stats: %#v", reader.Stats()) + logger.Debug().Msgf("topicMappings: %+v", topicMappings) + // Check if any of the existing topicMappings match the fetched message + matchFound := false + for _, topicMapping := range topicMappings { + logger.Debug().Msgf("topic: %+v, topicMapping: %+v", msg.Topic, topicMapping.Topic) + if msg.Topic == topicMapping.Topic { + matchFound = true + go processMessageIntoErrorResultChannel( + msg, + topicMapping, + providedServer, + msgCtx.errorResult, + logger, + ) + } + } + if !matchFound { + logger.Error().Msgf("Topic received whose topic is not configured: %s", msg.Topic) + } } +} - return nil +// processMessageIntoErrorResultChannel executes the processor defined by a topicMapping +// on a provided message. It then puts the result into the errChan. This result will be +// nil in cases of success or permanent failures and will be some error in the case that +// a temporary error is encountered. +func processMessageIntoErrorResultChannel( + msg kafka.Message, + topicMapping TopicMapping, + providedServer *server.Server, + errChan chan error, + logger *zerolog.Logger, +) { + errChan <- topicMapping.Processor( + msg, + topicMapping.ResultProducer, + providedServer, + logger, + ) } -// newConsumer returns a Kafka reader configured for the given topic and group. +// NewConsumer returns a Kafka reader configured for the given topic and group. func newConsumer(topics []string, groupID string, logger *zerolog.Logger) *kafka.Reader { brokers = strings.Split(os.Getenv("KAFKA_BROKERS"), ",") logger.Info().Msgf("Subscribing to kafka topic %s on behalf of group %s using brokers %s", topics, groupID, brokers) kafkaLogger := logrus.New() kafkaLogger.SetLevel(logrus.WarnLevel) + dialer := getDialer(logger) reader := kafka.NewReader(kafka.ReaderConfig{ Brokers: brokers, - Dialer: getDialer(logger), + Dialer: dialer, GroupTopics: topics, GroupID: groupID, StartOffset: kafka.FirstOffset, Logger: kafkaLogger, - MaxWait: time.Second * 20, // default 10s + MaxWait: time.Second * 20, // default 20s CommitInterval: time.Second, // flush commits to Kafka every second MinBytes: 1e3, // 1KB MaxBytes: 10e6, // 10MB @@ -170,7 +250,7 @@ func Emit(producer *kafka.Writer, message []byte, logger *zerolog.Logger) error messageKey := uuid.New() marshaledMessageKey, err := messageKey.MarshalBinary() if err != nil { - logger.Error().Msgf("Failed to marshal UUID into binary. Using default key value. %e", err) + logger.Error().Msgf("failed to marshal UUID into binary. Using default key value: %e", err) marshaledMessageKey = []byte("default") } @@ -182,7 +262,7 @@ func Emit(producer *kafka.Writer, message []byte, logger *zerolog.Logger) error }, ) if err != nil { - logger.Error().Msgf("Failed to write messages: %e", err) + logger.Error().Msgf("failed to write messages: %e", err) return err } @@ -190,14 +270,22 @@ func Emit(producer *kafka.Writer, message []byte, logger *zerolog.Logger) error return nil } +// getDialer returns a reference to a Kafka dialer. The dialer is TLS enabled in non-local +// environments. func getDialer(logger *zerolog.Logger) *kafka.Dialer { var dialer *kafka.Dialer - brokers = strings.Split(os.Getenv("KAFKA_BROKERS"), ",") if os.Getenv("ENV") != "local" { + logger.Info().Msg("Generating TLSDialer") tlsDialer, _, err := batgo_kafka.TLSDialer() dialer = tlsDialer if err != nil { - logger.Error().Msgf("Failed to initialize TLS dialer: %e", err) + logger.Error().Msgf("failed to initialize TLS dialer: %e", err) + } + } else { + logger.Info().Msg("Generating Dialer") + dialer = &kafka.Dialer{ + Timeout: 10 * time.Second, + DualStack: true, } } return dialer diff --git a/kafka/signed_blinded_token_issuer_handler.go b/kafka/signed_blinded_token_issuer_handler.go index 426bc5a9..22d935d1 100644 --- a/kafka/signed_blinded_token_issuer_handler.go +++ b/kafka/signed_blinded_token_issuer_handler.go @@ -11,20 +11,35 @@ import ( avroSchema "github.com/brave-intl/challenge-bypass-server/avro/generated" "github.com/brave-intl/challenge-bypass-server/btd" cbpServer "github.com/brave-intl/challenge-bypass-server/server" + "github.com/brave-intl/challenge-bypass-server/utils" "github.com/rs/zerolog" "github.com/segmentio/kafka-go" ) -// SignedBlindedTokenIssuerHandler emits signed, blinded tokens based on provided blinded tokens. -// @TODO: It would be better for the Server implementation and the Kafka implementation of -// this behavior to share utility functions rather than passing an instance of the server -// as an argument here. That will require a bit of refactoring. -func SignedBlindedTokenIssuerHandler(data []byte, producer *kafka.Writer, server *cbpServer.Server, log *zerolog.Logger) error { +/* +SignedBlindedTokenIssuerHandler emits signed, blinded tokens based on provided blinded tokens. + In cases where there are unrecoverable errors that prevent progress we will return nil. + These permanent failure cases are different from cases where we encounter temporary + errors inside the request data. For permanent failures inside the data processing loop we + simply add the error to the results. However, temporary errors inside the loop should break + the loop and return non-nil just like the errors outside the data processing loop. This is + because future attempts to process permanent failure cases will not succeed. + @TODO: It would be better for the Server implementation and the Kafka implementation of + this behavior to share utility functions rather than passing an instance of the server + as an argument here. That will require a bit of refactoring. +*/ +func SignedBlindedTokenIssuerHandler( + msg kafka.Message, + producer *kafka.Writer, + server *cbpServer.Server, + log *zerolog.Logger, +) error { const ( issuerOk = 0 issuerInvalid = 1 issuerError = 2 ) + data := msg.Value log.Info().Msg("starting blinded token processor") @@ -38,8 +53,23 @@ func SignedBlindedTokenIssuerHandler(data []byte, producer *kafka.Writer, server blindedTokenRequestSet, err := avroSchema.DeserializeSigningRequestSet(bytes.NewReader(data)) if err != nil { - log.Error().Err(err).Msg("failed to deserialize avro request message") - return fmt.Errorf("request %s: failed avro deserialization: %w", blindedTokenRequestSet.Request_id, err) + message := fmt.Sprintf( + "request %s: failed avro deserialization", + blindedTokenRequestSet.Request_id, + ) + handlePermanentIssuanceError( + message, + nil, + nil, + nil, + nil, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + log, + ) + return nil } logger := log.With().Str("request_id", blindedTokenRequestSet.Request_id).Logger() @@ -50,9 +80,23 @@ func SignedBlindedTokenIssuerHandler(data []byte, producer *kafka.Writer, server if len(blindedTokenRequestSet.Data) > 1 { // NOTE: When we start supporting multiple requests we will need to review // errors and return values as well. - return fmt.Errorf(`request %s: data array unexpectedly contained more than a single message. this array is - intended to make future extension easier, but no more than a single value is currently expected`, - blindedTokenRequestSet.Request_id) + message := fmt.Sprintf( + "request %s: data array unexpectedly contained more than a single message. This array is intended to make future extension easier, but no more than a single value is currently expected", + blindedTokenRequestSet.Request_id, + ) + handlePermanentIssuanceError( + message, + nil, + nil, + nil, + nil, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + &logger, + ) + return nil } OUTER: @@ -66,7 +110,7 @@ OUTER: Status: issuerError, Associated_data: request.Associated_data, }) - break OUTER + continue OUTER } // check to see if issuer cohort will overflow @@ -79,20 +123,26 @@ OUTER: Status: issuerError, Associated_data: request.Associated_data, }) - break OUTER + continue OUTER } logger.Info().Msgf("getting latest issuer: %+v - %+v", request.Issuer_type, request.Issuer_cohort) - issuer, appErr := server.GetLatestIssuer(request.Issuer_type, int16(request.Issuer_cohort)) + issuer, appErr := server.GetLatestIssuerKafka(request.Issuer_type, int16(request.Issuer_cohort)) if appErr != nil { logger.Error().Err(appErr).Msg("error retrieving issuer") + var processingError *utils.ProcessingError + if errors.As(err, &processingError) { + if processingError.Temporary { + return err + } + } blindedTokenResults = append(blindedTokenResults, avroSchema.SigningResultV2{ Signed_tokens: nil, Issuer_public_key: "", Status: issuerInvalid, Associated_data: request.Associated_data, }) - break OUTER + continue OUTER } logger.Info().Msgf("checking if issuer is version 3: %+v", issuer) @@ -106,7 +156,7 @@ OUTER: Status: issuerError, Associated_data: request.Associated_data, }) - break OUTER + continue OUTER } } @@ -192,39 +242,88 @@ OUTER: marshaledDLEQProof, err := DLEQProof.MarshalText() if err != nil { - return fmt.Errorf("request %s: could not marshal dleq proof: %w", blindedTokenRequestSet.Request_id, err) + message := fmt.Sprintf("request %s: could not marshal dleq proof: %s", blindedTokenRequestSet.Request_id, err) + handlePermanentIssuanceError( + message, + nil, + nil, + nil, + nil, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + &logger, + ) + return nil } - var marshalledBlindedTokens []string + var marshaledBlindedTokens []string for _, token := range blindedTokensSlice { marshaledToken, err := token.MarshalText() if err != nil { - return fmt.Errorf("request %s: could not marshal blinded token slice to bytes: %w", - blindedTokenRequestSet.Request_id, err) + message := fmt.Sprintf("request %s: could not marshal blinded token slice to bytes: %s", blindedTokenRequestSet.Request_id, err) + handlePermanentIssuanceError( + message, + marshaledBlindedTokens, + nil, + nil, + nil, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + &logger, + ) + return nil } - marshalledBlindedTokens = append(marshalledBlindedTokens, string(marshaledToken[:])) + marshaledBlindedTokens = append(marshaledBlindedTokens, string(marshaledToken)) } var marshaledSignedTokens []string for _, token := range signedTokens { marshaledToken, err := token.MarshalText() if err != nil { - return fmt.Errorf("request %s: could not marshal new tokens to bytes: %w", - blindedTokenRequestSet.Request_id, err) + message := fmt.Sprintf("request %s: could not marshal new tokens to bytes: %s", blindedTokenRequestSet.Request_id, err) + handlePermanentIssuanceError( + message, + marshaledBlindedTokens, + marshaledSignedTokens, + nil, + nil, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + &logger, + ) + return nil } - marshaledSignedTokens = append(marshaledSignedTokens, string(marshaledToken[:])) + marshaledSignedTokens = append(marshaledSignedTokens, string(marshaledToken)) } logger.Info().Msg("getting public key") publicKey := signingKey.PublicKey() marshaledPublicKey, err := publicKey.MarshalText() if err != nil { - return fmt.Errorf("request %s: could not marshal signing key: %w", - blindedTokenRequestSet.Request_id, err) + message := fmt.Sprintf("request %s: could not marshal signing key: %s", blindedTokenRequestSet.Request_id, err) + handlePermanentIssuanceError( + message, + marshaledBlindedTokens, + marshaledSignedTokens, + marshaledDLEQProof, + nil, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + &logger, + ) + return nil } blindedTokenResults = append(blindedTokenResults, avroSchema.SigningResultV2{ - Blinded_tokens: marshalledBlindedTokens, + Blinded_tokens: marshaledBlindedTokens, Signed_tokens: marshaledSignedTokens, Proof: string(marshaledDLEQProof), Issuer_public_key: string(marshaledPublicKey), @@ -234,7 +333,7 @@ OUTER: Associated_data: request.Associated_data, }) logger.Info(). - Str("blinded_tokens", fmt.Sprintf("%+v", marshalledBlindedTokens)). + Str("blinded_tokens", fmt.Sprintf("%+v", marshaledBlindedTokens)). Str("signed_tokens", fmt.Sprintf("%+v", marshaledSignedTokens)). Str("proof", string(marshaledDLEQProof)). Str("public_key", string(marshaledPublicKey)). @@ -262,42 +361,93 @@ OUTER: Status: issuerError, Associated_data: request.Associated_data, }) - break OUTER + continue OUTER } marshaledDLEQProof, err := DLEQProof.MarshalText() if err != nil { - return fmt.Errorf("request %s: could not marshal dleq proof: %w", + message := fmt.Sprintf("request %s: could not marshal dleq proof: %s", blindedTokenRequestSet.Request_id, err) + handlePermanentIssuanceError( + message, + nil, + nil, + marshaledDLEQProof, + nil, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + &logger, + ) + return nil } - var marshalledBlindedTokens []string + var marshaledBlindedTokens []string for _, token := range blindedTokens { marshaledToken, err := token.MarshalText() if err != nil { - return fmt.Errorf("request %s: could not marshal blinded token slice to bytes: %w", - blindedTokenRequestSet.Request_id, err) + message := fmt.Sprintf("request %s: could not marshal blinded token slice to bytes: %s", blindedTokenRequestSet.Request_id, err) + handlePermanentIssuanceError( + message, + marshaledBlindedTokens, + nil, + marshaledDLEQProof, + nil, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + &logger, + ) + return nil } - marshalledBlindedTokens = append(marshalledBlindedTokens, string(marshaledToken[:])) + marshaledBlindedTokens = append(marshaledBlindedTokens, string(marshaledToken)) } var marshaledSignedTokens []string for _, token := range signedTokens { marshaledToken, err := token.MarshalText() if err != nil { - return fmt.Errorf("error could not marshal new tokens to bytes: %w", err) + message := fmt.Sprintf("error could not marshal new tokens to bytes: %s", err) + handlePermanentIssuanceError( + message, + marshaledBlindedTokens, + marshaledSignedTokens, + marshaledDLEQProof, + nil, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + &logger, + ) + return nil } - marshaledSignedTokens = append(marshaledSignedTokens, string(marshaledToken[:])) + marshaledSignedTokens = append(marshaledSignedTokens, string(marshaledToken)) } publicKey := signingKey.PublicKey() marshaledPublicKey, err := publicKey.MarshalText() if err != nil { - return fmt.Errorf("error could not marshal signing key: %w", err) + message := fmt.Sprintf("error could not marshal signing key: %s", err) + handlePermanentIssuanceError( + message, + marshaledBlindedTokens, + marshaledSignedTokens, + marshaledDLEQProof, + marshaledPublicKey, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + &logger, + ) + return nil } blindedTokenResults = append(blindedTokenResults, avroSchema.SigningResultV2{ - Blinded_tokens: marshalledBlindedTokens, + Blinded_tokens: marshaledBlindedTokens, Signed_tokens: marshaledSignedTokens, Proof: string(marshaledDLEQProof), Issuer_public_key: string(marshaledPublicKey), @@ -316,19 +466,116 @@ OUTER: var resultSetBuffer bytes.Buffer err = resultSet.Serialize(&resultSetBuffer) if err != nil { - return fmt.Errorf("request %s: failed to serialize result set: %s: %w", - blindedTokenRequestSet.Request_id, resultSetBuffer.String(), err) + message := fmt.Sprintf( + "request %s: failed to serialize ResultSet: %+v", + blindedTokenRequestSet.Request_id, + resultSet, + ) + handlePermanentIssuanceError( + message, + nil, + nil, + nil, + nil, + issuerError, + blindedTokenRequestSet.Request_id, + msg, + producer, + &logger, + ) + return nil } logger.Info().Msg("ending blinded token request processor loop") logger.Info().Msgf("about to emit: %+v", resultSet) err = Emit(producer, resultSetBuffer.Bytes(), log) if err != nil { - logger.Error().Msgf("failed to emit: %+v", resultSet) - return fmt.Errorf("request %s: failed to emit results to topic %s: %w", - blindedTokenRequestSet.Request_id, producer.Topic, err) + message := fmt.Sprintf( + "request %s: failed to emit to topic %s with result: %v", + resultSet.Request_id, + producer.Topic, + resultSet, + ) + log.Error().Err(err).Msgf(message) + return err } logger.Info().Msgf("emitted: %+v", resultSet) return nil } + +// avroIssuerErrorResultFromError returns a ProcessingResult that is constructed from the +// provided values. +func avroIssuerErrorResultFromError( + message string, + marshaledBlindedTokens []string, + marshaledSignedTokens []string, + marshaledDLEQProof []byte, + marshaledPublicKey []byte, + issuerResultStatus int32, + requestID string, + msg kafka.Message, + producer *kafka.Writer, + logger *zerolog.Logger, +) *ProcessingResult { + signingResult := avroSchema.SigningResultV2{ + Blinded_tokens: marshaledBlindedTokens, + Signed_tokens: marshaledSignedTokens, + Proof: string(marshaledDLEQProof), + Issuer_public_key: string(marshaledPublicKey), + Status: avroSchema.SigningResultV2Status(issuerResultStatus), + Associated_data: []byte(message), + } + resultSet := avroSchema.SigningResultV2Set{ + Request_id: "", + Data: []avroSchema.SigningResultV2{signingResult}, + } + var resultSetBuffer bytes.Buffer + err := resultSet.Serialize(&resultSetBuffer) + if err != nil { + message := fmt.Sprintf("request %s: failed to serialize result set", requestID) + return &ProcessingResult{ + Message: []byte(message), + ResultProducer: producer, + RequestID: requestID, + } + } + + return &ProcessingResult{ + Message: []byte(message), + ResultProducer: producer, + RequestID: requestID, + } +} + +// handlePermanentIssuanceError is a convenience function to both generate a result from +// an errorand emit it. +func handlePermanentIssuanceError( + message string, + marshaledBlindedTokens []string, + marshaledSignedTokens []string, + marshaledDLEQProof []byte, + marshaledPublicKey []byte, + issuerResultStatus int32, + requestID string, + msg kafka.Message, + producer *kafka.Writer, + logger *zerolog.Logger, +) { + processingResult := avroIssuerErrorResultFromError( + message, + marshaledBlindedTokens, + marshaledSignedTokens, + marshaledDLEQProof, + marshaledPublicKey, + issuerResultStatus, + requestID, + msg, + producer, + logger, + ) + + if err := Emit(producer, processingResult.Message, logger); err != nil { + logger.Error().Err(err).Msg("failed to emit") + } +} diff --git a/kafka/signed_token_redeem_handler.go b/kafka/signed_token_redeem_handler.go index 92d3fd62..27da8af1 100644 --- a/kafka/signed_token_redeem_handler.go +++ b/kafka/signed_token_redeem_handler.go @@ -2,6 +2,7 @@ package kafka import ( "bytes" + "errors" "fmt" "strings" "time" @@ -10,50 +11,82 @@ import ( avroSchema "github.com/brave-intl/challenge-bypass-server/avro/generated" "github.com/brave-intl/challenge-bypass-server/btd" cbpServer "github.com/brave-intl/challenge-bypass-server/server" + "github.com/brave-intl/challenge-bypass-server/utils" "github.com/rs/zerolog" kafka "github.com/segmentio/kafka-go" ) -// SignedTokenRedeemHandler emits payment tokens that correspond to the signed confirmation -// tokens provided. +/* +SignedTokenRedeemHandler emits payment tokens that correspond to the signed confirmation + tokens provided. If it encounters a permanent error, it emits a permanent result for that + item. If the error is temporary, an error is returned to indicate that progress cannot be + made. +*/ func SignedTokenRedeemHandler( - data []byte, + msg kafka.Message, producer *kafka.Writer, server *cbpServer.Server, - logger *zerolog.Logger, + log *zerolog.Logger, ) error { - const ( - redeemOk = 0 - redeemDuplicateRedemption = 1 - redeemUnverified = 2 - redeemError = 3 - ) + data := msg.Value + // Deserialize request into usable struct tokenRedeemRequestSet, err := avroSchema.DeserializeRedeemRequestSet(bytes.NewReader(data)) if err != nil { - return fmt.Errorf("request %s: failed avro deserialization: %w", tokenRedeemRequestSet.Request_id, err) + message := fmt.Sprintf("request %s: failed avro deserialization", tokenRedeemRequestSet.Request_id) + handlePermanentRedemptionError( + message, + msg, + producer, + tokenRedeemRequestSet.Request_id, + int32(avroSchema.RedeemResultStatusError), + log, + ) + return nil } - defer func() { - if recover() != nil { - logger.Error(). - Err(fmt.Errorf("request %s: redeem attempt panicked", tokenRedeemRequestSet.Request_id)). - Msg("signed token redeem handler") - } - }() + + logger := log.With().Str("request_id", tokenRedeemRequestSet.Request_id).Logger() + var redeemedTokenResults []avroSchema.RedeemResult + // For the time being, we are only accepting one message at a time in this data set. + // Therefore, we will error if more than a single message is present in the message. if len(tokenRedeemRequestSet.Data) > 1 { // NOTE: When we start supporting multiple requests we will need to review // errors and return values as well. - return fmt.Errorf("request %s: data array unexpectedly contained more than a single message. this array is intended to make future extension easier, but no more than a single value is currently expected", tokenRedeemRequestSet.Request_id) + message := fmt.Sprintf("request %s: data array unexpectedly contained more than a single message. This array is intended to make future extension easier, but no more than a single value is currently expected", tokenRedeemRequestSet.Request_id) + handlePermanentRedemptionError( + message, + msg, + producer, + tokenRedeemRequestSet.Request_id, + int32(avroSchema.RedeemResultStatusError), + log, + ) + return nil } issuers, err := server.FetchAllIssuers() if err != nil { - return fmt.Errorf("request %s: failed to fetch all issuers: %w", tokenRedeemRequestSet.Request_id, err) + if processingError, ok := err.(*utils.ProcessingError); ok && processingError.Temporary { + return processingError + } + message := fmt.Sprintf("request %s: failed to fetch all issuers", tokenRedeemRequestSet.Request_id) + handlePermanentRedemptionError( + message, + msg, + producer, + tokenRedeemRequestSet.Request_id, + int32(avroSchema.RedeemResultStatusError), + log, + ) + return nil } + + // Iterate over requests (only one at this point but the schema can support more + // in the future if needed) for _, request := range tokenRedeemRequestSet.Data { var ( - verified = false - verifiedIssuer = &cbpServer.Issuer{} - verifiedCohort int32 = 0 + verified = false + verifiedIssuer = &cbpServer.Issuer{} + verifiedCohort int32 ) if request.Public_key == "" { logger.Error(). @@ -62,12 +95,13 @@ func SignedTokenRedeemHandler( redeemedTokenResults = append(redeemedTokenResults, avroSchema.RedeemResult{ Issuer_name: "", Issuer_cohort: 0, - Status: redeemError, + Status: avroSchema.RedeemResultStatusError, Associated_data: request.Associated_data, }) continue } + // preimage, signature, and binding are all required to proceed if request.Token_preimage == "" || request.Signature == "" || request.Binding == "" { logger.Error(). Err(fmt.Errorf("request %s: empty request", tokenRedeemRequestSet.Request_id)). @@ -75,7 +109,7 @@ func SignedTokenRedeemHandler( redeemedTokenResults = append(redeemedTokenResults, avroSchema.RedeemResult{ Issuer_name: "", Issuer_cohort: 0, - Status: redeemError, + Status: avroSchema.RedeemResultStatusError, Associated_data: request.Associated_data, }) continue @@ -83,15 +117,33 @@ func SignedTokenRedeemHandler( tokenPreimage := crypto.TokenPreimage{} err = tokenPreimage.UnmarshalText([]byte(request.Token_preimage)) + // Unmarshaling failure is a data issue and is probably permanent. if err != nil { - return fmt.Errorf("request %s: could not unmarshal text into preimage: %w", - tokenRedeemRequestSet.Request_id, err) + message := fmt.Sprintf("request %s: could not unmarshal text into preimage", tokenRedeemRequestSet.Request_id) + handlePermanentRedemptionError( + message, + msg, + producer, + tokenRedeemRequestSet.Request_id, + int32(avroSchema.RedeemResultStatusError), + log, + ) + return nil } verificationSignature := crypto.VerificationSignature{} err = verificationSignature.UnmarshalText([]byte(request.Signature)) + // Unmarshaling failure is a data issue and is probably permanent. if err != nil { - return fmt.Errorf("request %s: could not unmarshal text into verification signature: %w", - tokenRedeemRequestSet.Request_id, err) + message := fmt.Sprintf("request %s: could not unmarshal text into verification signature", tokenRedeemRequestSet.Request_id) + handlePermanentRedemptionError( + message, + msg, + producer, + tokenRedeemRequestSet.Request_id, + int32(avroSchema.RedeemResultStatusError), + log, + ) + return nil } for _, issuer := range *issuers { if !issuer.ExpiresAt.IsZero() && issuer.ExpiresAt.Before(time.Now()) { @@ -117,9 +169,18 @@ func SignedTokenRedeemHandler( // Only attempt token verification with the issuer that was provided. issuerPublicKey := signingKey.PublicKey() marshaledPublicKey, err := issuerPublicKey.MarshalText() + // Unmarshaling failure is a data issue and is probably permanent. if err != nil { - return fmt.Errorf("request %s: could not unmarshal issuer public key into text: %w", - tokenRedeemRequestSet.Request_id, err) + message := fmt.Sprintf("request %s: could not unmarshal issuer public key into text", tokenRedeemRequestSet.Request_id) + handlePermanentRedemptionError( + message, + msg, + producer, + tokenRedeemRequestSet.Request_id, + int32(avroSchema.RedeemResultStatusError), + log, + ) + return nil } logger.Trace(). @@ -151,35 +212,109 @@ func SignedTokenRedeemHandler( redeemedTokenResults = append(redeemedTokenResults, avroSchema.RedeemResult{ Issuer_name: "", Issuer_cohort: 0, - Status: redeemUnverified, + Status: avroSchema.RedeemResultStatusUnverified, Associated_data: request.Associated_data, }) continue } else { - logger.Trace().Msgf("request %s: validated", tokenRedeemRequestSet.Request_id) + logger.Info().Msg(fmt.Sprintf("request %s: validated", tokenRedeemRequestSet.Request_id)) } - if err := server.RedeemToken(verifiedIssuer, &tokenPreimage, request.Binding); err != nil { - logger.Error().Err(fmt.Errorf("request %s: token redemption failed: %w", - tokenRedeemRequestSet.Request_id, err)). - Msg("signed token redeem handler") + redemption, equivalence, err := server.CheckRedeemedTokenEquivalence(verifiedIssuer, &tokenPreimage, string(request.Binding), msg.Offset) + if err != nil { + var processingError *utils.ProcessingError + if errors.As(err, &processingError) { + if processingError.Temporary { + return err + } + } + message := fmt.Sprintf("request %s: failed to check redemption equivalence", tokenRedeemRequestSet.Request_id) + handlePermanentRedemptionError( + message, + msg, + producer, + tokenRedeemRequestSet.Request_id, + int32(avroSchema.RedeemResultStatusError), + log, + ) + return nil + } + + // Continue if there is a duplicate + switch equivalence { + case cbpServer.IDEquivalence: + redeemedTokenResults = append(redeemedTokenResults, avroSchema.RedeemResult{ + Issuer_name: verifiedIssuer.IssuerType, + Issuer_cohort: int32(verifiedIssuer.IssuerCohort), + Status: avroSchema.RedeemResultStatusDuplicate_redemption, + Associated_data: request.Associated_data, + }) + continue + case cbpServer.BindingEquivalence: + redeemedTokenResults = append(redeemedTokenResults, avroSchema.RedeemResult{ + Issuer_name: verifiedIssuer.IssuerType, + Issuer_cohort: int32(verifiedIssuer.IssuerCohort), + Status: avroSchema.RedeemResultStatusIdempotent_redemption, + Associated_data: request.Associated_data, + }) + continue + } + + // If no equivalent record was found in the database, persist. + if err := server.PersistRedemption(*redemption); err != nil { + logger.Error().Err(err).Msgf("request %s: token redemption failed: %e", tokenRedeemRequestSet.Request_id, err) + // In the unlikely event that there is a race condition that results + // in a duplicate error upon save that was not detected previously + // we will check equivalence upon receipt of a duplicate error. if strings.Contains(err.Error(), "Duplicate") { + _, equivalence, err := server.CheckRedeemedTokenEquivalence(verifiedIssuer, &tokenPreimage, string(request.Binding), msg.Offset) + if err != nil { + message := fmt.Sprintf("request %s: failed to check redemption equivalence", tokenRedeemRequestSet.Request_id) + var processingError *utils.ProcessingError + if errors.As(err, &processingError) { + if processingError.Temporary { + return err + } + } + handlePermanentRedemptionError( + message, + msg, + producer, + tokenRedeemRequestSet.Request_id, + int32(avroSchema.RedeemResultStatusError), + log, + ) + return nil + } logger.Error().Err(fmt.Errorf("request %s: duplicate redemption: %w", tokenRedeemRequestSet.Request_id, err)). Msg("signed token redeem handler") - redeemedTokenResults = append(redeemedTokenResults, avroSchema.RedeemResult{ - Issuer_name: "", - Issuer_cohort: 0, - Status: redeemDuplicateRedemption, - Associated_data: request.Associated_data, - }) + // Continue if there is a duplicate + switch equivalence { + case cbpServer.IDEquivalence: + redeemedTokenResults = append(redeemedTokenResults, avroSchema.RedeemResult{ + Issuer_name: verifiedIssuer.IssuerType, + Issuer_cohort: int32(verifiedIssuer.IssuerCohort), + Status: avroSchema.RedeemResultStatusDuplicate_redemption, + Associated_data: request.Associated_data, + }) + continue + case cbpServer.BindingEquivalence: + redeemedTokenResults = append(redeemedTokenResults, avroSchema.RedeemResult{ + Issuer_name: verifiedIssuer.IssuerType, + Issuer_cohort: int32(verifiedIssuer.IssuerCohort), + Status: avroSchema.RedeemResultStatusIdempotent_redemption, + Associated_data: request.Associated_data, + }) + continue + } } logger.Error().Err(fmt.Errorf("request %s: could not mark token redemption", tokenRedeemRequestSet.Request_id)). Msg("signed token redeem handler") redeemedTokenResults = append(redeemedTokenResults, avroSchema.RedeemResult{ - Issuer_name: "", - Issuer_cohort: 0, - Status: redeemError, + Issuer_name: verifiedIssuer.IssuerType, + Issuer_cohort: int32(verifiedIssuer.IssuerCohort), + Status: avroSchema.RedeemResultStatusError, Associated_data: request.Associated_data, }) continue @@ -189,7 +324,7 @@ func SignedTokenRedeemHandler( redeemedTokenResults = append(redeemedTokenResults, avroSchema.RedeemResult{ Issuer_name: issuerName, Issuer_cohort: verifiedCohort, - Status: redeemOk, + Status: avroSchema.RedeemResultStatusOk, Associated_data: request.Associated_data, }) } @@ -200,14 +335,88 @@ func SignedTokenRedeemHandler( var resultSetBuffer bytes.Buffer err = resultSet.Serialize(&resultSetBuffer) if err != nil { - return fmt.Errorf("request %s: failed to serialize result set: %w", - tokenRedeemRequestSet.Request_id, err) + message := fmt.Sprintf("request %s: failed to serialize result set", tokenRedeemRequestSet.Request_id) + handlePermanentRedemptionError( + message, + msg, + producer, + tokenRedeemRequestSet.Request_id, + int32(avroSchema.RedeemResultStatusError), + log, + ) + return nil } - err = Emit(producer, resultSetBuffer.Bytes(), logger) + err = Emit(producer, resultSetBuffer.Bytes(), log) if err != nil { - return fmt.Errorf("request %s: failed to emit results to topic %s: %w", - tokenRedeemRequestSet.Request_id, producer.Topic, err) + message := fmt.Sprintf( + "request %s: failed to emit results to topic %s", + resultSet.Request_id, + producer.Topic, + ) + log.Error().Err(err).Msgf(message) + return err } + return nil } + +// avroRedeemErrorResultFromError returns a ProcessingResult that is constructed from the +// provided values. +func avroRedeemErrorResultFromError( + message string, + msg kafka.Message, + producer *kafka.Writer, + requestID string, + redeemResultStatus int32, + logger *zerolog.Logger, +) *ProcessingResult { + redeemResult := avroSchema.RedeemResult{ + Issuer_name: "", + Issuer_cohort: 0, + Status: avroSchema.RedeemResultStatus(redeemResultStatus), + Associated_data: []byte(message), + } + resultSet := avroSchema.RedeemResultSet{ + Request_id: "", + Data: []avroSchema.RedeemResult{redeemResult}, + } + var resultSetBuffer bytes.Buffer + err := resultSet.Serialize(&resultSetBuffer) + if err != nil { + message := fmt.Sprintf("request %s: failed to serialize result set", requestID) + return &ProcessingResult{ + Message: []byte(message), + ResultProducer: producer, + RequestID: requestID, + } + } + return &ProcessingResult{ + Message: []byte(message), + ResultProducer: producer, + RequestID: requestID, + } +} + +// handleRedemptionError is a convenience function that executes a call pattern shared +// when handling all errors in the redeem flow +func handlePermanentRedemptionError( + message string, + msg kafka.Message, + producer *kafka.Writer, + requestID string, + redeemResultStatus int32, + logger *zerolog.Logger, +) { + processingResult := avroRedeemErrorResultFromError( + message, + msg, + producer, + requestID, + int32(avroSchema.RedeemResultStatusError), + logger, + ) + if err := Emit(producer, processingResult.Message, logger); err != nil { + logger.Error().Err(err).Msg("failed to emit") + } +} diff --git a/main.go b/main.go index b609a2de..5843d4d1 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( _ "net/http/pprof" "os" "strconv" + "time" "github.com/brave-intl/challenge-bypass-server/kafka" "github.com/brave-intl/challenge-bypass-server/server" @@ -23,14 +24,17 @@ func main() { serverCtx, logger := server.SetupLogger(context.Background()) zeroLogger := zerolog.New(os.Stderr).With().Timestamp().Caller().Logger() if os.Getenv("ENV") != "production" { - zerolog.SetGlobalLevel(zerolog.TraceLevel) + zerolog.SetGlobalLevel(zerolog.WarnLevel) + if os.Getenv("ENV") == "local" { + zerolog.SetGlobalLevel(zerolog.TraceLevel) + } } zerolog.SetGlobalLevel(zerolog.TraceLevel) srv := *server.DefaultServer flag.StringVar(&configFile, "config", "", "local config file for development (overrides cli options)") - flag.StringVar(&srv.DbConfigPath, "db_config", "", "path to the json file with database configuration") + flag.StringVar(&srv.DBConfigPath, "db_config", "", "path to the json file with database configuration") flag.IntVar(&srv.ListenPort, "p", 2416, "port to listen on") flag.Parse() @@ -48,7 +52,7 @@ func main() { } } - err = srv.InitDbConfig() + err = srv.InitDBConfig() if err != nil { logger.Panic(err) } @@ -56,7 +60,7 @@ func main() { zeroLogger.Trace().Msg("Initializing persistence and cron jobs") // Initialize databases and cron tasks before the Kafka processors and server start - srv.InitDb() + srv.InitDB() srv.InitDynamo() // Run the cron job unless it's explicitly disabled. if os.Getenv("CRON_ENABLED") != "false" { @@ -82,15 +86,7 @@ func main() { if os.Getenv("KAFKA_ENABLED") != "false" { zeroLogger.Trace().Msg("Spawning Kafka goroutine") - go func() { - zeroLogger.Trace().Msg("Initializing Kafka consumers") - err = kafka.StartConsumers(&srv, &zeroLogger) - - if err != nil { - zeroLogger.Error().Err(err).Msg("Failed to initialize Kafka consumers") - return - } - }() + go startKafka(srv, zeroLogger) } zeroLogger.Trace().Msg("Initializing API server") @@ -104,3 +100,15 @@ func main() { return } } + +func startKafka(srv server.Server, zeroLogger zerolog.Logger) { + zeroLogger.Trace().Msg("Initializing Kafka consumers") + err := kafka.StartConsumers(&srv, &zeroLogger) + + if err != nil { + zeroLogger.Error().Err(err).Msg("Failed to initialize Kafka consumers") + // If err is something then start consumer again + time.Sleep(10 * time.Second) + startKafka(srv, zeroLogger) + } +} diff --git a/server/db.go b/server/db.go index f4bb06b8..1ffe1fe8 100644 --- a/server/db.go +++ b/server/db.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + "github.com/brave-intl/challenge-bypass-server/utils" "github.com/brave-intl/challenge-bypass-server/utils/metrics" "github.com/brave-intl/challenge-bypass-server/utils/ptr" @@ -31,8 +32,8 @@ type CachingConfig struct { ExpirationSec int `json:"expirationSec"` } -// DbConfig defines app configurations -type DbConfig struct { +// DBConfig defines app configurations +type DBConfig struct { ConnectionURI string `json:"connectionURI"` CachingConfig CachingConfig `json:"caching"` MaxConnection int `json:"maxConnection"` @@ -117,6 +118,7 @@ type RedemptionV2 struct { Timestamp time.Time `json:"timestamp"` Payload string `json:"payload"` TTL int64 `json:"TTL"` + Offset int64 `json:"offset"` } // CacheInterface cache functions @@ -127,19 +129,19 @@ type CacheInterface interface { } var ( - errIssuerNotFound = errors.New("Issuer with the given name does not exist") - errIssuerCohortNotFound = errors.New("Issuer with the given name and cohort does not exist") - errDuplicateRedemption = errors.New("Duplicate Redemption") - errRedemptionNotFound = errors.New("Redemption with the given id does not exist") + errIssuerNotFound = errors.New("issuer with the given name does not exist") + errIssuerCohortNotFound = errors.New("issuer with the given name and cohort does not exist") + errDuplicateRedemption = errors.New("duplicate Redemption") + errRedemptionNotFound = errors.New("redemption with the given id does not exist") ) -// LoadDbConfig loads config into server variable -func (c *Server) LoadDbConfig(config DbConfig) { +// LoadDBConfig loads config into server variable +func (c *Server) LoadDBConfig(config DBConfig) { c.dbConfig = config } -// InitDb initialzes the database connection based on a server's configuration -func (c *Server) InitDb() { +// InitDB initialzes the database connection based on a server's configuration +func (c *Server) InitDB() { cfg := c.dbConfig db, err := sqlx.Open("postgres", cfg.ConnectionURI) @@ -255,16 +257,10 @@ func incrementCounter(c prometheus.Counter) { func (c *Server) fetchIssuer(issuerID string) (*Issuer, error) { defer incrementCounter(fetchIssuerCounter) - tx := c.db.MustBegin() - var err error = nil - - defer func() { - if err != nil { - err = tx.Rollback() - return - } - err = tx.Commit() - }() + var ( + err error + temporary = false + ) if c.caches != nil { if cached, found := c.caches["issuer"].Get(issuerID); found { @@ -273,26 +269,26 @@ func (c *Server) fetchIssuer(issuerID string) (*Issuer, error) { } fetchedIssuer := issuer{} - err = tx.Get(&fetchedIssuer, ` + err = c.db.Select(&fetchedIssuer, ` SELECT * FROM v3_issuers WHERE issuer_id=$1 `, issuerID) if err != nil { - return nil, errIssuerNotFound + if !isPostgresNotFoundError(err) { + temporary = true + } + return nil, utils.ProcessingErrorFromError(errIssuerNotFound, temporary) } - convertedIssuer, err := c.convertDBIssuer(fetchedIssuer) - if err != nil { - return nil, err - } + convertedIssuer := c.convertDBIssuer(fetchedIssuer) // get the signing keys if convertedIssuer.Keys == nil { convertedIssuer.Keys = []IssuerKeys{} } var fetchIssuerKeys = []issuerKeys{} - err = tx.Select( + err = c.db.Select( &fetchIssuerKeys, `SELECT * FROM v3_issuer_keys where issuer_id=$1 and @@ -304,15 +300,18 @@ func (c *Server) fetchIssuer(issuerID string) (*Issuer, error) { convertedIssuer.ID, ) if err != nil { - c.Logger.Error("Failed to extract issuer keys from DB") - return nil, err + if !isPostgresNotFoundError(err) { + c.Logger.Error("Postgres encountered temporary error") + temporary = true + } + return nil, utils.ProcessingErrorFromError(err, temporary) } for _, v := range fetchIssuerKeys { k, err := c.convertDBIssuerKeys(v) if err != nil { c.Logger.Error("Failed to convert issuer keys from DB") - return nil, err + return nil, utils.ProcessingErrorFromError(err, temporary) } convertedIssuer.Keys = append(convertedIssuer.Keys, *k) } @@ -333,45 +332,40 @@ func (c *Server) fetchIssuersByCohort(issuerType string, issuerCohort int16) (*[ } } - tx := c.db.MustBegin() - var err error = nil - - defer func() { - if err != nil { - err = tx.Rollback() - return - } - err = tx.Commit() - }() + var ( + err error + temporary = false + ) fetchedIssuers := []issuer{} - err = tx.Select( + err = c.db.Select( &fetchedIssuers, `SELECT i.* FROM v3_issuers i join v3_issuer_keys k on (i.issuer_id=k.issuer_id) WHERE i.issuer_type=$1 AND k.cohort=$2 ORDER BY i.expires_at DESC NULLS FIRST, i.created_at DESC`, issuerType, issuerCohort) if err != nil { - return nil, err + c.Logger.Error("Failed to extract issuers from DB") + if isPostgresNotFoundError(err) { + temporary = true + } + return nil, utils.ProcessingErrorFromError(err, temporary) } if len(fetchedIssuers) < 1 { - return nil, errIssuerCohortNotFound + return nil, utils.ProcessingErrorFromError(errIssuerCohortNotFound, temporary) } issuers := []Issuer{} for _, fetchedIssuer := range fetchedIssuers { - convertedIssuer, err := c.convertDBIssuer(fetchedIssuer) - if err != nil { - return nil, err - } + convertedIssuer := c.convertDBIssuer(fetchedIssuer) // get the keys for the Issuer if convertedIssuer.Keys == nil { convertedIssuer.Keys = []IssuerKeys{} } var fetchIssuerKeys = []issuerKeys{} - err = tx.Select( + err = c.db.Select( &fetchIssuerKeys, `SELECT * FROM v3_issuer_keys where issuer_id=$1 and @@ -383,15 +377,18 @@ func (c *Server) fetchIssuersByCohort(issuerType string, issuerCohort int16) (*[ convertedIssuer.ID, ) if err != nil { - c.Logger.Error("Failed to extract issuer keys from DB") - return nil, err + if !isPostgresNotFoundError(err) { + c.Logger.Error("Postgres encountered temporary error") + temporary = true + } + return nil, utils.ProcessingErrorFromError(err, temporary) } for _, v := range fetchIssuerKeys { k, err := c.convertDBIssuerKeys(v) if err != nil { c.Logger.Error("Failed to convert issuer keys from DB") - return nil, err + return nil, utils.ProcessingErrorFromError(err, temporary) } convertedIssuer.Keys = append(convertedIssuer.Keys, *k) } @@ -424,10 +421,7 @@ func (c *Server) fetchIssuerByType(ctx context.Context, issuerType string) (*Iss return nil, err } - convertedIssuer, err := c.convertDBIssuer(issuerV3) - if err != nil { - return nil, err - } + convertedIssuer := c.convertDBIssuer(issuerV3) if convertedIssuer.Keys == nil { convertedIssuer.Keys = []IssuerKeys{} @@ -467,45 +461,40 @@ func (c *Server) fetchIssuers(issuerType string) (*[]Issuer, error) { } } - tx := c.db.MustBegin() - var err error = nil - - defer func() { - if err != nil { - err = tx.Rollback() - return - } - err = tx.Commit() - }() + var ( + err error + temporary = false + ) fetchedIssuers := []issuer{} - err = tx.Select( + err = c.db.Select( &fetchedIssuers, `SELECT * FROM v3_issuers WHERE issuer_type=$1 ORDER BY expires_at DESC NULLS LAST, created_at DESC`, issuerType) if err != nil { - return nil, err + c.Logger.Error("Failed to extract issuers from DB") + if !isPostgresNotFoundError(err) { + temporary = true + } + return nil, utils.ProcessingErrorFromError(err, temporary) } if len(fetchedIssuers) < 1 { - return nil, errIssuerNotFound + return nil, utils.ProcessingErrorFromError(errIssuerNotFound, temporary) } issuers := []Issuer{} for _, fetchedIssuer := range fetchedIssuers { - convertedIssuer, err := c.convertDBIssuer(fetchedIssuer) - if err != nil { - return nil, err - } + convertedIssuer := c.convertDBIssuer(fetchedIssuer) // get the keys for the Issuer if convertedIssuer.Keys == nil { convertedIssuer.Keys = []IssuerKeys{} } var fetchIssuerKeys = []issuerKeys{} - err = tx.Select( + err = c.db.Select( &fetchIssuerKeys, `SELECT * FROM v3_issuer_keys where issuer_id=$1 and @@ -517,15 +506,18 @@ func (c *Server) fetchIssuers(issuerType string) (*[]Issuer, error) { convertedIssuer.ID, ) if err != nil { - c.Logger.Error("Failed to extract issuer keys from DB") - return nil, err + if !isPostgresNotFoundError(err) { + c.Logger.Error("Failed to extract issuer keys from DB") + temporary = true + } + return nil, utils.ProcessingErrorFromError(err, temporary) } for _, v := range fetchIssuerKeys { k, err := c.convertDBIssuerKeys(v) if err != nil { c.Logger.Error("Failed to convert issuer keys from DB") - return nil, err + return nil, utils.ProcessingErrorFromError(err, temporary) } convertedIssuer.Keys = append(convertedIssuer.Keys, *k) } @@ -549,42 +541,37 @@ func (c *Server) FetchAllIssuers() (*[]Issuer, error) { } } - tx := c.db.MustBegin() - var err error = nil - - defer func() { - if err != nil { - err = tx.Rollback() - return - } - err = tx.Commit() - }() + var ( + err error + temporary = false + ) fetchedIssuers := []issuer{} - err = tx.Select( + err = c.db.Select( &fetchedIssuers, `SELECT * FROM v3_issuers ORDER BY expires_at DESC NULLS LAST, created_at DESC`) if err != nil { c.Logger.Error("Failed to extract issuers from DB") - return nil, err + if !isPostgresNotFoundError(err) { + temporary = true + } else { + panic("Postgres encountered temporary error") + } + return nil, utils.ProcessingErrorFromError(err, temporary) } issuers := []Issuer{} for _, fetchedIssuer := range fetchedIssuers { - convertedIssuer, err := c.convertDBIssuer(fetchedIssuer) - if err != nil { - c.Logger.Error("Error converting extracted Issuer") - return nil, err - } + convertedIssuer := c.convertDBIssuer(fetchedIssuer) if convertedIssuer.Keys == nil { convertedIssuer.Keys = []IssuerKeys{} } var fetchIssuerKeys = []issuerKeys{} - err = tx.Select( + err = c.db.Select( &fetchIssuerKeys, `SELECT * FROM v3_issuer_keys where issuer_id=$1 and @@ -596,15 +583,18 @@ func (c *Server) FetchAllIssuers() (*[]Issuer, error) { convertedIssuer.ID, ) if err != nil { - c.Logger.Error("Failed to extract issuer keys from DB") - return nil, err + if !isPostgresNotFoundError(err) { + c.Logger.Error("Postgres encountered temporary error") + temporary = true + } + return nil, utils.ProcessingErrorFromError(err, temporary) } for _, v := range fetchIssuerKeys { k, err := c.convertDBIssuerKeys(v) if err != nil { c.Logger.Error("Failed to convert issuer keys from DB") - return nil, err + return nil, utils.ProcessingErrorFromError(err, temporary) } convertedIssuer.Keys = append(convertedIssuer.Keys, *k) } @@ -625,7 +615,7 @@ func (c *Server) rotateIssuers() error { tx := c.db.MustBegin() - var err error = nil + var err error defer func() { if err != nil { @@ -651,14 +641,9 @@ func (c *Server) rotateIssuers() error { for _, v := range fetchedIssuers { // converted - issuer, err := c.convertDBIssuer(v) - if err != nil { - tx.Rollback() - return fmt.Errorf("failed to convert rows on v3 issuer creation: %w", err) - } + issuer := c.convertDBIssuer(v) // populate keys in db if err := txPopulateIssuerKeys(c.Logger, tx, *issuer); err != nil { - tx.Rollback() return fmt.Errorf("failed to populate v3 issuer keys: %w", err) } @@ -675,10 +660,9 @@ func (c *Server) rotateIssuers() error { // rotateIssuers is the function that rotates func (c *Server) rotateIssuersV3() error { - tx := c.db.MustBegin() - var err error = nil + var err error defer func() { if err != nil { @@ -716,9 +700,8 @@ func (c *Server) rotateIssuersV3() error { // for each issuer fetched for _, issuer := range fetchedIssuers { - issuerDTO, err := parseIssuer(issuer) + issuerDTO := parseIssuer(issuer) if err != nil { - tx.Rollback() return fmt.Errorf("error failed to parse db issuer to dto: %w", err) } // get this issuer's keys populated @@ -753,7 +736,6 @@ func (c *Server) rotateIssuersV3() error { // populate the buffer of keys for the v3 issuer if err := txPopulateIssuerKeys(c.Logger, tx, issuerDTO); err != nil { - tx.Rollback() return fmt.Errorf("failed to close rows on v3 issuer creation: %w", err) } // denote that the v3 issuer was rotated at this time @@ -784,7 +766,7 @@ func (c *Server) deleteIssuerKeys(duration string) (int64, error) { } // createIssuer - creation of a v3 issuer -func (c *Server) createV3Issuer(issuer Issuer) error { +func (c *Server) createV3Issuer(issuer Issuer) (err error) { defer incrementCounter(createIssuerCounter) if issuer.MaxTokens == 0 { issuer.MaxTokens = 40 @@ -796,6 +778,13 @@ func (c *Server) createV3Issuer(issuer Issuer) error { } tx := c.db.MustBegin() + defer func() { + if err != nil { + err = tx.Rollback() + return + } + err = tx.Commit() + }() queryTimer := prometheus.NewTimer(createTimeLimitedIssuerDBDuration) row := tx.QueryRowx( @@ -826,16 +815,14 @@ func (c *Server) createV3Issuer(issuer Issuer) error { ) // get the newly inserted issuer identifier if err := row.Scan(&issuer.ID); err != nil { - tx.Rollback() return fmt.Errorf("failed to get v3 issuer id: %w", err) } if err := txPopulateIssuerKeys(c.Logger, tx, issuer); err != nil { - tx.Rollback() return fmt.Errorf("failed to close rows on v3 issuer creation: %w", err) } queryTimer.ObserveDuration() - return tx.Commit() + return nil } // on the transaction, populate v3 issuer keys for the v3 issuer @@ -893,7 +880,6 @@ func txPopulateIssuerKeys(logger *logrus.Logger, tx *sqlx.Tx, issuer Issuer) err // start/end, increment every iteration end, err = duration.From(*start) if err != nil { - tx.Rollback() return fmt.Errorf("unable to calculate end time: %w", err) } } @@ -901,21 +887,18 @@ func txPopulateIssuerKeys(logger *logrus.Logger, tx *sqlx.Tx, issuer Issuer) err signingKey, err := crypto.RandomSigningKey() if err != nil { logger.Error("Error generating key") - tx.Rollback() return err } signingKeyTxt, err := signingKey.MarshalText() if err != nil { logger.Error("Error marshalling signing key") - tx.Rollback() return err } pubKeyTxt, err := signingKey.PublicKey().MarshalText() if err != nil { logger.Error("Error marshalling public key") - tx.Rollback() return err } logger.Infof("iteration key pubkey: %s", string(pubKeyTxt)) @@ -949,10 +932,8 @@ func txPopulateIssuerKeys(logger *logrus.Logger, tx *sqlx.Tx, issuer Issuer) err position += 6 // increment start - if start != nil && end != nil { - tmp := *end - start = &tmp - } + tmp := *end + start = &tmp } var values []interface{} @@ -981,10 +962,10 @@ func txPopulateIssuerKeys(logger *logrus.Logger, tx *sqlx.Tx, issuer Issuer) err VALUES %s`, valueFmtStr), values...) if err != nil { logger.Error("Could not insert the new issuer keys into the DB") - tx.Rollback() return err } - return rows.Close() + defer rows.Close() + return nil } func (c *Server) createIssuerV2(issuerType string, issuerCohort int16, maxTokens int, expiresAt *time.Time) error { @@ -1025,14 +1006,14 @@ type Queryable interface { } // RedeemToken redeems a token given an issuer and and preimage -func (c *Server) RedeemToken(issuerForRedemption *Issuer, preimage *crypto.TokenPreimage, payload string) error { +func (c *Server) RedeemToken(issuerForRedemption *Issuer, preimage *crypto.TokenPreimage, payload string, offset int64) error { defer incrementCounter(redeemTokenCounter) if issuerForRedemption.Version == 1 { return redeemTokenWithDB(c.db, issuerForRedemption.IssuerType, preimage, payload) } else if issuerForRedemption.Version == 2 || issuerForRedemption.Version == 3 { - return c.redeemTokenWithDynamo(issuerForRedemption, preimage, payload) + return c.redeemTokenWithDynamo(issuerForRedemption, preimage, payload, offset) } - return errors.New("Wrong Issuer Version") + return errors.New("wrong issuer version") } func redeemTokenWithDB(db Queryable, stringIssuer string, preimage *crypto.TokenPreimage, payload string) error { @@ -1044,37 +1025,29 @@ func redeemTokenWithDB(db Queryable, stringIssuer string, preimage *crypto.Token queryTimer := prometheus.NewTimer(createRedemptionDBDuration) rows, err := db.Query( `INSERT INTO redemptions(id, issuer_type, ts, payload) VALUES ($1, $2, NOW(), $3)`, preimageTxt, stringIssuer, payload) - defer func() error { - if rows != nil { - err := rows.Close() - if err != nil { - return err - } - } - return nil - }() if err != nil { if err, ok := err.(*pq.Error); ok && err.Code == "23505" { // unique constraint violation return errDuplicateRedemption } return err } + defer rows.Close() queryTimer.ObserveDuration() return nil } -func (c *Server) fetchRedemption(issuerType, ID string) (*Redemption, error) { +func (c *Server) fetchRedemption(issuerType, id string) (*Redemption, error) { defer incrementCounter(fetchRedemptionCounter) if c.caches != nil { - if cached, found := c.caches["redemptions"].Get(fmt.Sprintf("%s:%s", issuerType, ID)); found { + if cached, found := c.caches["redemptions"].Get(fmt.Sprintf("%s:%s", issuerType, id)); found { return cached.(*Redemption), nil } } queryTimer := prometheus.NewTimer(fetchRedemptionDBDuration) rows, err := c.db.Query( - `SELECT id, issuer_id, ts, payload FROM redemptions WHERE id = $1 AND issuer_type = $2`, ID, issuerType) + `SELECT id, issuer_id, ts, payload FROM redemptions WHERE id = $1 AND issuer_type = $2`, id, issuerType) queryTimer.ObserveDuration() if err != nil { @@ -1091,7 +1064,7 @@ func (c *Server) fetchRedemption(issuerType, ID string) (*Redemption, error) { } if c.caches != nil { - c.caches["redemptions"].SetDefault(fmt.Sprintf("%s:%s", issuerType, ID), redemption) + c.caches["redemptions"].SetDefault(fmt.Sprintf("%s:%s", issuerType, id), redemption) } return redemption, nil @@ -1123,24 +1096,25 @@ func (c *Server) convertDBIssuerKeys(issuerKeyToConvert issuerKeys) (*IssuerKeys return &parsedIssuerKeys, nil } -func (c *Server) convertDBIssuer(issuerToConvert issuer) (*Issuer, error) { +// convertDBIssuer takes an issuer from the database and returns a reference to that issuer +// Represented as an Issuer struct. It will return out of the cache if possible. If there +// is no cache record, the database issuer will be parsed into an Issuer, the cache will be +// updated, and then the Issuer reference will be returned. +func (c *Server) convertDBIssuer(issuerToConvert issuer) *Issuer { stringifiedID := string(issuerToConvert.ID.String()) if c.caches != nil { if cached, found := c.caches["convertedissuers"].Get(stringifiedID); found { - return cached.(*Issuer), nil + return cached.(*Issuer) } } - parsedIssuer, err := parseIssuer(issuerToConvert) - if err != nil { - return nil, err - } + parsedIssuer := parseIssuer(issuerToConvert) if c.caches != nil { - c.caches["issuer"].SetDefault(stringifiedID, parseIssuer) + c.caches["issuer"].SetDefault(stringifiedID, parsedIssuer) } - return &parsedIssuer, nil + return &parsedIssuer } func parseIssuerKeys(issuerKeysToParse issuerKeys) (IssuerKeys, error) { @@ -1162,7 +1136,8 @@ func parseIssuerKeys(issuerKeysToParse issuerKeys) (IssuerKeys, error) { return parsedIssuerKey, nil } -func parseIssuer(issuerToParse issuer) (Issuer, error) { +// parseIssuer converts a database issuer into an Issuer struct with no additional side-effects +func parseIssuer(issuerToParse issuer) Issuer { parsedIssuer := Issuer{ ID: issuerToParse.ID, Version: issuerToParse.Version, @@ -1184,5 +1159,19 @@ func parseIssuer(issuerToParse issuer) (Issuer, error) { parsedIssuer.RotatedAt = issuerToParse.RotatedAt.Time } - return parsedIssuer, nil + return parsedIssuer +} + +// isPostgresNotFoundError uses the error map found at the below URL to determine if an +// error is a Postgres no_data_found error. +// https://github.com/lib/pq/blob/d5affd5073b06f745459768de35356df2e5fd91d/error.go#L348 +func isPostgresNotFoundError(err error) bool { + pqError, ok := err.(*pq.Error) + if !ok { + return false + } + if pqError.Code.Class().Name() != "no_data_found" { + return true + } + return false } diff --git a/server/dynamo.go b/server/dynamo.go index e5eac036..95f72538 100644 --- a/server/dynamo.go +++ b/server/dynamo.go @@ -1,18 +1,37 @@ package server import ( + "errors" "os" "time" + awsDynamoTypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" // nolint "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" crypto "github.com/brave-intl/challenge-bypass-ristretto-ffi" + "github.com/brave-intl/challenge-bypass-server/utils" "github.com/google/uuid" ) +// Equivalence represents the type of equality discovered when checking DynamoDB data +type Equivalence int64 + +const ( + // UnknownEquivalence means equivalence could not be determined + UnknownEquivalence Equivalence = iota + // NoEquivalence means means there was no matching record of any kind in Dynamo + NoEquivalence + // IDEquivalence means a record with the same ID as the subject was found, but one + // or more of its other fields did not match the subject + IDEquivalence + // BindingEquivalence means a record that matched all of the fields of the + // subject was found + BindingEquivalence +) + // InitDynamo initialzes the dynamo database connection func (c *Server) InitDynamo() { sess := session.Must(session.NewSessionWithOptions(session.Options{ @@ -67,7 +86,7 @@ func (c *Server) fetchRedemptionV2(id uuid.UUID) (*RedemptionV2, error) { return &redemption, nil } -func (c *Server) redeemTokenWithDynamo(issuer *Issuer, preimage *crypto.TokenPreimage, payload string) error { +func (c *Server) redeemTokenWithDynamo(issuer *Issuer, preimage *crypto.TokenPreimage, payload string, offset int64) error { preimageTxt, err := preimage.MarshalText() if err != nil { c.Logger.Error("Error Marshalling preimage") @@ -83,8 +102,35 @@ func (c *Server) redeemTokenWithDynamo(issuer *Issuer, preimage *crypto.TokenPre Payload: payload, Timestamp: time.Now(), TTL: issuer.ExpiresAt.Unix(), + Offset: offset, + } + + av, err := dynamodbattribute.MarshalMap(redemption) + if err != nil { + c.Logger.Error("Error marshalling redemption") + return err + } + + input := &dynamodb.PutItemInput{ + Item: av, + ConditionExpression: aws.String("attribute_not_exists(id)"), + TableName: aws.String("redemptions"), + } + + _, err = c.dynamo.PutItem(input) + if err != nil { + if err, ok := err.(awserr.Error); ok && err.Code() == "ConditionalCheckFailedException" { // unique constraint violation + c.Logger.Error("Duplicate redemption") + return errDuplicateRedemption + } + c.Logger.Error("Error creating item") + return err } + return nil +} +// PersistRedemption saves the redemption in the database +func (c *Server) PersistRedemption(redemption RedemptionV2) error { av, err := dynamodbattribute.MarshalMap(redemption) if err != nil { c.Logger.Error("Error marshalling redemption") @@ -108,3 +154,52 @@ func (c *Server) redeemTokenWithDynamo(issuer *Issuer, preimage *crypto.TokenPre } return nil } + +// CheckRedeemedTokenEquivalence returns whether just the ID of a given RedemptionV2 token +// matches an existing persisted record, the whole value matches, or neither match and +// this is a new token to be redeemed. +func (c *Server) CheckRedeemedTokenEquivalence(issuer *Issuer, preimage *crypto.TokenPreimage, payload string, offset int64) (*RedemptionV2, Equivalence, error) { + var temporary = false + preimageTxt, err := preimage.MarshalText() + if err != nil { + c.Logger.Error("Error Marshalling preimage") + return nil, UnknownEquivalence, utils.ProcessingErrorFromError(err, temporary) + } + + id := uuid.NewSHA1(*issuer.ID, preimageTxt) + + redemption := RedemptionV2{ + IssuerID: issuer.ID.String(), + ID: id.String(), + PreImage: string(preimageTxt), + Payload: payload, + Timestamp: time.Now(), + TTL: issuer.ExpiresAt.Unix(), + } + + existingRedemption, err := c.fetchRedemptionV2(*issuer.ID) + + // If err is nil that means that the record does exist in the database and we need + // to determine whether the body is equivalent to what was provided or just the + // id. + if err == nil { + if redemption.Payload == existingRedemption.Payload { + return &redemption, BindingEquivalence, nil + } + return &redemption, IDEquivalence, nil + } + + var ( + ptee *awsDynamoTypes.ProvisionedThroughputExceededException + rle *awsDynamoTypes.RequestLimitExceeded + ise *awsDynamoTypes.InternalServerError + ) + + // is this a temporary error? + if errors.As(err, &ptee) || + errors.As(err, &rle) || + errors.As(err, &ise) { + temporary = true + } + return &redemption, NoEquivalence, utils.ProcessingErrorFromError(err, temporary) +} diff --git a/server/issuers.go b/server/issuers.go index 71968c41..a2583d5c 100644 --- a/server/issuers.go +++ b/server/issuers.go @@ -46,6 +46,7 @@ type issuerFetchRequestV2 struct { Cohort int16 `json:"cohort"` } +// GetLatestIssuer - get the latest issuer by type/cohort func (c *Server) GetLatestIssuer(issuerType string, issuerCohort int16) (*Issuer, *handlers.AppError) { issuer, err := c.fetchIssuersByCohort(issuerType, issuerCohort) if err != nil { @@ -67,6 +68,17 @@ func (c *Server) GetLatestIssuer(issuerType string, issuerCohort int16) (*Issuer return &(*issuer)[0], nil } +// GetLatestIssuerKafka - get the issuer and any processing error +func (c *Server) GetLatestIssuerKafka(issuerType string, issuerCohort int16) (*Issuer, error) { + issuer, err := c.fetchIssuersByCohort(issuerType, issuerCohort) + if err != nil { + return nil, err + } + + return &(*issuer)[0], nil +} + +// GetIssuers - get all issuers by issuer type func (c *Server) GetIssuers(issuerType string) (*[]Issuer, error) { issuers, err := c.getIssuers(issuerType) if err != nil { @@ -225,7 +237,6 @@ func (c *Server) issuerV3CreateHandler(w http.ResponseWriter, r *http.Request) * ValidFrom: req.ValidFrom, Duration: &req.Duration, }); err != nil { - var pqErr *pq.Error if errors.As(err, &pqErr) { if pqErr.Code == "23505" { // unique violation diff --git a/server/server.go b/server/server.go index 21c63a97..8504c87f 100644 --- a/server/server.go +++ b/server/server.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" "net/http" "os" "strconv" @@ -23,11 +22,15 @@ import ( ) var ( + // Version - the version? Version = "dev" maxRequestSize = int64(1024 * 1024) // 1MiB - ErrNoSecretKey = errors.New("server config does not contain a key") - ErrRequestTooLarge = errors.New("request too large to process") + // ErrNoSecretKey - configuration error, no secret key + ErrNoSecretKey = errors.New("server config does not contain a key") + // ErrRequestTooLarge - processing error, request is too big + ErrRequestTooLarge = errors.New("request too large to process") + // ErrUnrecognizedRequest - processing error, request unrecognized ErrUnrecognizedRequest = errors.New("received unrecognized request type") ) @@ -45,13 +48,14 @@ func init() { prometheus.MustRegister(fetchRedemptionDBDuration) } +// Server - base server type type Server struct { ListenPort int `json:"listen_port,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` - DbConfigPath string `json:"db_config_path"` + DBConfigPath string `json:"db_config_path"` Logger *logrus.Logger `json:",omitempty"` dynamo *dynamodb.DynamoDB - dbConfig DbConfig + dbConfig DBConfig db *sqlx.DB caches map[string]CacheInterface @@ -65,7 +69,7 @@ var DefaultServer = &Server{ // LoadConfigFile loads a file into conf and returns func LoadConfigFile(filePath string) (Server, error) { conf := *DefaultServer - data, err := ioutil.ReadFile(filePath) + data, err := os.ReadFile(filePath) if err != nil { return conf, err } @@ -76,9 +80,9 @@ func LoadConfigFile(filePath string) (Server, error) { return conf, nil } -// InitDbConfig reads os environment and update conf -func (c *Server) InitDbConfig() error { - conf := DbConfig{ +// InitDBConfig reads os environment and update conf +func (c *Server) InitDBConfig() error { + conf := DBConfig{ DefaultDaysBeforeExpiry: 7, DefaultIssuerValidDays: 30, MaxConnection: 100, @@ -111,7 +115,7 @@ func (c *Server) InitDbConfig() error { } } - c.LoadDbConfig(conf) + c.LoadDBConfig(conf) return nil } diff --git a/server/server_test.go b/server/server_test.go index c50845d6..07f9c8de 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "os" @@ -46,12 +45,12 @@ func (suite *ServerTestSuite) SetupSuite() { suite.srv = &Server{} - err = suite.srv.InitDbConfig() + err = suite.srv.InitDBConfig() suite.Require().NoError(err, "Failed to setup db conn") suite.handler = chi.ServerBaseContext(suite.srv.setupRouter(SetupLogger(context.Background()))) - suite.srv.InitDb() + suite.srv.InitDB() suite.srv.InitDynamo() err = test.SetupDynamodbTables(suite.srv.dynamo) @@ -77,7 +76,7 @@ func (suite *ServerTestSuite) TestPing() { suite.Assert().Equal(http.StatusOK, resp.StatusCode) expected := "." - actual, err := ioutil.ReadAll(resp.Body) + actual, err := io.ReadAll(resp.Body) suite.Assert().NoError(err, "Reading response body should succeed") suite.Assert().Equal(expected, string(actual), "Message should match") } @@ -135,7 +134,7 @@ func (suite *ServerTestSuite) TestIssueRedeemV2() { suite.Assert().NoError(err, "HTTP Request should complete") suite.Assert().Equal(http.StatusOK, resp.StatusCode, "Attempted redemption request should succeed") - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) suite.Require().NoError(err, "Redemption response body read must succeed") var issuerResp blindedTokenRedeemResponse @@ -157,14 +156,15 @@ func (suite *ServerTestSuite) TestIssueRedeemV2() { suite.Assert().NoError(err, "HTTP Request should complete") suite.Assert().Equal(http.StatusOK, resp.StatusCode, "Attempted redemption request should succeed") - body, err = ioutil.ReadAll(resp.Body) + body, err = io.ReadAll(resp.Body) suite.Require().NoError(err, "Redemption response body read must succeed") err = json.Unmarshal(body, &issuerResp) suite.Require().NoError(err, "Redemption response body unmarshal must succeed") suite.Assert().NotEqual(issuerResp.Cohort, 1-issuerCohort, "Redemption of a token should return the same cohort with which it was signed") - _, err = suite.srv.db.Query(`UPDATE v3_issuers SET expires_at=$1 WHERE issuer_id=$2`, time.Now().AddDate(0, 0, -1), issuer.ID) + r, err := suite.srv.db.Query(`UPDATE v3_issuers SET expires_at=$1 WHERE issuer_id=$2`, time.Now().AddDate(0, 0, -1), issuer.ID) suite.Require().NoError(err, "failed to expire issuer") + defer r.Close() // keys are what rotate now, not the issuer itself issuer, _ = suite.srv.GetLatestIssuer(issuerType, issuerCohort) @@ -176,8 +176,9 @@ func (suite *ServerTestSuite) TestIssueRedeemV2() { var signingKey = issuer.Keys[len(issuer.Keys)-1].SigningKey publicKey = signingKey.PublicKey() - _, err = suite.srv.db.Query(`UPDATE v3_issuers SET expires_at=$1 WHERE issuer_id=$2`, time.Now().AddDate(0, 0, +1), issuer.ID) + r, err = suite.srv.db.Query(`UPDATE v3_issuers SET expires_at=$1 WHERE issuer_id=$2`, time.Now().AddDate(0, 0, +1), issuer.ID) suite.Require().NoError(err, "failed to unexpire issuer") + defer r.Close() unblindedToken = suite.createToken(server.URL, issuerType, publicKey) preimageText, sigText = suite.prepareRedemption(unblindedToken, msg) @@ -204,7 +205,7 @@ func (suite *ServerTestSuite) TestNewIssueRedeemV2() { suite.Assert().NoError(err, "HTTP Request should complete") suite.Assert().Equal(http.StatusOK, resp.StatusCode, "Attempted redemption request should succeed") - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) suite.Require().NoError(err, "Redemption response body read must succeed") var issuerResp blindedTokenRedeemResponse @@ -225,15 +226,16 @@ func (suite *ServerTestSuite) TestNewIssueRedeemV2() { suite.Assert().NoError(err, "HTTP Request should complete") suite.Assert().Equal(http.StatusOK, resp.StatusCode, "Attempted redemption request should succeed") - body, err = ioutil.ReadAll(resp.Body) + body, err = io.ReadAll(resp.Body) suite.Require().NoError(err, "Redemption response body read must succeed") err = json.Unmarshal(body, &issuerResp) suite.Require().NoError(err, "Redemption response body unmarshal must succeed") suite.Assert().NotEqual(issuerResp.Cohort, 1-issuerCohort, "Redemption of a token should return the same cohort with which it was signed") - _, err = suite.srv.db.Query(`UPDATE v3_issuers SET expires_at=$1 WHERE issuer_id=$2`, time.Now().AddDate(0, 0, -1), issuer.ID) + r, err := suite.srv.db.Query(`UPDATE v3_issuers SET expires_at=$1 WHERE issuer_id=$2`, time.Now().AddDate(0, 0, -1), issuer.ID) suite.Require().NoError(err, "failed to expire issuer") + defer r.Close() resp, err = suite.attemptRedeem(server.URL, preimageText2, sigText2, issuerType, msg) suite.Assert().NoError(err, "HTTP Request should complete") @@ -267,17 +269,16 @@ func (suite *ServerTestSuite) TestRotateTimeAwareIssuer() { // wait a few intervals after creation and check number of signing keys left time.Sleep(2 * time.Second) myIssuer, err := suite.srv.GetLatestIssuer(issuer.IssuerType, issuer.IssuerCohort) + fmt.Println(err) suite.Require().Equal(len(myIssuer.Keys), 1) // should be one left // rotate issuers should pick up that there are some new intervals to make up buffer and populate err = suite.srv.rotateIssuersV3() suite.Require().NoError(err) - myIssuer, err = suite.srv.GetLatestIssuer(issuer.IssuerType, issuer.IssuerCohort) + myIssuer, _ = suite.srv.GetLatestIssuer(issuer.IssuerType, issuer.IssuerCohort) suite.Require().Equal(len(myIssuer.Keys), 3) // should be 3 now - time.Sleep(1) - // rotate issuers should pick up that there are some new intervals to make up buffer and populate err = suite.srv.rotateIssuersV3() suite.Require().NoError(err) @@ -285,7 +286,7 @@ func (suite *ServerTestSuite) TestRotateTimeAwareIssuer() { // wait a few intervals after creation and check number of signing keys left time.Sleep(2 * time.Second) - myIssuer, err = suite.srv.GetLatestIssuer(issuer.IssuerType, issuer.IssuerCohort) + myIssuer, _ = suite.srv.GetLatestIssuer(issuer.IssuerType, issuer.IssuerCohort) suite.Require().Equal(len(myIssuer.Keys), 1) // should be one left } @@ -310,6 +311,7 @@ func (suite *ServerTestSuite) TestCreateIssuerV3() { createIssuerURL := fmt.Sprintf("%s/v3/issuer/", server.URL) resp, err := suite.request("POST", createIssuerURL, bytes.NewBuffer(payload)) + suite.Require().NoError(err) suite.Assert().Equal(http.StatusCreated, resp.StatusCode) @@ -395,13 +397,13 @@ func (suite *ServerTestSuite) TestRunRotate() { suite.Require().NoError(err) } -func (suite *ServerTestSuite) request(method string, URL string, payload io.Reader) (*http.Response, error) { +func (suite *ServerTestSuite) request(method string, url string, payload io.Reader) (*http.Response, error) { var req *http.Request var err error if payload != nil { - req, err = http.NewRequest(method, URL, payload) + req, err = http.NewRequest(method, url, payload) } else { - req, err = http.NewRequest(method, URL, nil) + req, err = http.NewRequest(method, url, nil) } if err != nil { return nil, err @@ -426,7 +428,7 @@ func (suite *ServerTestSuite) createIssuer(serverURL string, issuerType string, suite.Require().NoError(err, "Issuer fetch must succeed") suite.Assert().Equal(http.StatusOK, resp.StatusCode) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) suite.Require().NoError(err, "Issuer fetch body read must succeed") var issuerResp issuerResponse @@ -446,7 +448,7 @@ func (suite *ServerTestSuite) getAllIssuers(serverURL string) []issuerResponse { suite.Require().NoError(err, "Getting alll Issuers must succeed") suite.Assert().Equal(http.StatusOK, resp.StatusCode) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) suite.Require().NoError(err, "Issuer fetch body read must succeed") var issuerResp []issuerResponse @@ -475,7 +477,7 @@ func (suite *ServerTestSuite) createIssuerWithExpiration(serverURL string, issue suite.Require().NoError(err, "Issuer fetch must succeed") suite.Assert().Equal(http.StatusOK, resp.StatusCode) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) suite.Require().NoError(err, "Issuer fetch body read must succeed") var issuerResp issuerResponse @@ -517,7 +519,7 @@ func (suite *ServerTestSuite) createTokens(serverURL string, issuerType string, suite.Require().NoError(err, "Token signing must succeed") suite.Assert().Equal(http.StatusOK, resp.StatusCode) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) suite.Require().NoError(err, "Token signing body read must succeed") var decodedResp blindedTokenIssueResponse @@ -582,7 +584,7 @@ func (suite *ServerTestSuite) createCohortTokens(serverURL string, issuerType st suite.Require().NoError(err, "Token signing must succeed") suite.Assert().Equal(http.StatusOK, resp.StatusCode) - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) suite.Require().NoError(err, "Token signing body read must succeed") var decodedResp blindedTokenIssueResponse @@ -617,10 +619,7 @@ func (suite *ServerTestSuite) TestRedeemV3() { err := suite.srv.createV3Issuer(issuer) suite.Require().NoError(err) - //err = suite.srv.rotateIssuersV3() - //suite.Require().NoError(err) - - issuerKey, err := suite.srv.GetLatestIssuer(issuer.IssuerType, issuer.IssuerCohort) + issuerKey, _ := suite.srv.GetLatestIssuer(issuer.IssuerType, issuer.IssuerCohort) tokens := make([]*crypto.Token, buffer) blindedTokensSlice := make([]*crypto.BlindedToken, buffer) @@ -671,7 +670,6 @@ func (suite *ServerTestSuite) TestRedeemV3() { for i := 0; i < buffer; i++ { var unblindedToken *crypto.UnblindedToken for _, v := range redemptions { - if v.validFrom.Before(time.Now()) && v.validTo.After(time.Now()) { unblindedToken = v.unblindedTokens[0] } diff --git a/server/tokens.go b/server/tokens.go index 0a130164..73e4c606 100644 --- a/server/tokens.go +++ b/server/tokens.go @@ -27,6 +27,7 @@ type blindedTokenIssueRequest struct { BlindedTokens []*crypto.BlindedToken `json:"blinded_tokens"` } +// BlindedTokenIssueRequestV2 - version 2 blinded token issue request type BlindedTokenIssueRequestV2 struct { BlindedTokens []*crypto.BlindedToken `json:"blinded_tokens"` IssuerCohort int16 `json:"cohort"` @@ -48,23 +49,24 @@ type blindedTokenRedeemResponse struct { Cohort int16 `json:"cohort"` } +// BlindedTokenRedemptionInfo - this is the redemption information type BlindedTokenRedemptionInfo struct { TokenPreimage *crypto.TokenPreimage `json:"t"` Signature *crypto.VerificationSignature `json:"signature"` Issuer string `json:"issuer"` } +// BlindedTokenBulkRedeemRequest - this is the redemption in bulk form type BlindedTokenBulkRedeemRequest struct { Payload string `json:"payload"` Tokens []BlindedTokenRedemptionInfo `json:"tokens"` } +// BlindedTokenIssuerHandlerV2 - handler for token issuer v2 func (c *Server) BlindedTokenIssuerHandlerV2(w http.ResponseWriter, r *http.Request) *handlers.AppError { var response blindedTokenIssueResponse if issuerType := chi.URLParam(r, "type"); issuerType != "" { - var request BlindedTokenIssueRequestV2 - if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxRequestSize)).Decode(&request); err != nil { c.Logger.WithError(err) return handlers.WrapError(err, "Could not parse the request body", 400) @@ -172,7 +174,6 @@ func (c *Server) blindedTokenIssuerHandler(w http.ResponseWriter, r *http.Reques func (c *Server) blindedTokenRedeemHandlerV3(w http.ResponseWriter, r *http.Request) *handlers.AppError { var response blindedTokenRedeemResponse if issuerType := chi.URLParam(r, "type"); issuerType != "" { - issuer, err := c.fetchIssuerByType(r.Context(), issuerType) if err != nil { switch { @@ -259,7 +260,7 @@ func (c *Server) blindedTokenRedeemHandlerV3(w http.ResponseWriter, r *http.Requ } } - if err := c.RedeemToken(issuer, request.TokenPreimage, request.Payload); err != nil { + if err := c.RedeemToken(issuer, request.TokenPreimage, request.Payload, 0); err != nil { c.Logger.Error("error redeeming token") if errors.Is(err, errDuplicateRedemption) { return &handlers.AppError{ @@ -272,7 +273,6 @@ func (c *Server) blindedTokenRedeemHandlerV3(w http.ResponseWriter, r *http.Requ Message: "Could not mark token redemption", Code: http.StatusInternalServerError, } - } response = blindedTokenRedeemResponse{issuer.IssuerCohort} } @@ -342,7 +342,7 @@ func (c *Server) blindedTokenRedeemHandler(w http.ResponseWriter, r *http.Reques } } - if err := c.RedeemToken(verifiedIssuer, request.TokenPreimage, request.Payload); err != nil { + if err := c.RedeemToken(verifiedIssuer, request.TokenPreimage, request.Payload, 0); err != nil { if errors.Is(err, errDuplicateRedemption) { return &handlers.AppError{ Message: err.Error(), @@ -421,15 +421,13 @@ func (c *Server) blindedTokenBulkRedeemHandler(w http.ResponseWriter, r *http.Re Message: err.Error(), Code: http.StatusConflict, } - } else { - return &handlers.AppError{ - Cause: err, - Message: "Could not mark token redemption", - Code: http.StatusInternalServerError, - } + } + return &handlers.AppError{ + Cause: err, + Message: "Could not mark token redemption", + Code: http.StatusInternalServerError, } } - } err = tx.Commit() if err != nil { diff --git a/utils/errors.go b/utils/errors.go index dee807e3..8126d9c7 100644 --- a/utils/errors.go +++ b/utils/errors.go @@ -1,30 +1,21 @@ package utils import ( - "errors" "fmt" - "time" - - awsDynamoTypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" - "github.com/rs/zerolog" - "github.com/segmentio/kafka-go" ) -// ProcessingError is an error used for Kafka processing that communicates retry data for -// failures. +// ProcessingError is an error used to communicate whether an error is temporary. type ProcessingError struct { OriginalError error FailureMessage string Temporary bool - Backoff time.Duration - KafkaMessage kafka.Message } // Error makes ProcessingError an error func (e ProcessingError) Error() string { msg := fmt.Sprintf("error: %s", e.FailureMessage) - if e.OriginalError != nil { - msg = fmt.Sprintf("%s: %s", msg, e.OriginalError) + if e.Cause() != nil { + msg = fmt.Sprintf("%s: %s", msg, e.Cause()) } return msg } @@ -34,43 +25,11 @@ func (e ProcessingError) Cause() error { return e.OriginalError } -// ProcessingErrorFromErrorWithMessage converts an error into a ProcessingError -func ProcessingErrorFromErrorWithMessage( - err error, - message string, - kafkaMessage kafka.Message, - logger *zerolog.Logger, -) *ProcessingError { - temporary, backoff := ErrorIsTemporary(err, logger) +// ProcessingErrorFromError - given an error turn it into a processing error +func ProcessingErrorFromError(cause error, isTemporary bool) error { return &ProcessingError{ - OriginalError: err, - FailureMessage: message, - Temporary: temporary, - Backoff: backoff, - KafkaMessage: kafkaMessage, - } -} - -// ErrorIsTemporary takes an error and determines -func ErrorIsTemporary(err error, logger *zerolog.Logger) (bool, time.Duration) { - var ( - dynamoProvisionedThroughput *awsDynamoTypes.ProvisionedThroughputExceededException - dynamoRequestLimitExceeded *awsDynamoTypes.RequestLimitExceeded - dynamoInternalServerError *awsDynamoTypes.InternalServerError - ) - - if errors.As(err, &dynamoProvisionedThroughput) { - logger.Error().Err(err).Msg("Temporary message processing failure") - return true, 1 * time.Minute - } - if errors.As(err, &dynamoRequestLimitExceeded) { - logger.Error().Err(err).Msg("Temporary message processing failure") - return true, 1 * time.Minute + OriginalError: cause, + FailureMessage: cause.Error(), + Temporary: isTemporary, } - if errors.As(err, &dynamoInternalServerError) { - logger.Error().Err(err).Msg("Temporary message processing failure") - return true, 1 * time.Minute - } - - return false, 1 * time.Millisecond } diff --git a/utils/ptr/ptr.go b/utils/ptr/ptr.go index a9fd5007..0714bc46 100644 --- a/utils/ptr/ptr.go +++ b/utils/ptr/ptr.go @@ -20,6 +20,7 @@ func StringOr(s *string, or string) string { return *s } +// FromTime - return the pointer from a time? func FromTime(t time.Time) *time.Time { return &t }