diff --git a/.gitignore b/.gitignore index 846a894..237b4c0 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,7 @@ report.json # Images *.png + +# VSCode +*.code-workspace +.vscode/* diff --git a/go.mod b/go.mod index 1fd3bdd..393c95c 100644 --- a/go.mod +++ b/go.mod @@ -12,4 +12,5 @@ require ( github.com/ugorji/go/codec v1.2.6 github.com/xmidt-org/httpaux v0.3.0 github.com/xmidt-org/webpa-common v1.3.2 + go.uber.org/multierr v1.8.0 ) diff --git a/go.sum b/go.sum index a1d9119..8370579 100644 --- a/go.sum +++ b/go.sum @@ -25,6 +25,10 @@ github.com/xmidt-org/httpaux v0.3.0 h1:JdV4QceiE8EMA1Qf5rnJzHdkIPzQV12ddARHLU8hr github.com/xmidt-org/httpaux v0.3.0/go.mod h1:mviIlg5fHGb3lAv3l0sbiwVG/q9rqvXaudEYxVrzXdE= github.com/xmidt-org/webpa-common v1.3.2 h1:dE1Fi+XVnkt3tMGMjH7/hN/UGcaQ/ukKriXuMDyCWnM= github.com/xmidt-org/webpa-common v1.3.2/go.mod h1:oCpKzOC+9h2vYHVzAU/06tDTQuBN4RZz+rhgIXptpOI= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= +go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/header_wrp.go b/header_wrp.go index 26feeaf..eada008 100644 --- a/header_wrp.go +++ b/header_wrp.go @@ -17,10 +17,6 @@ package wrp -import ( - "errors" -) - // Constant HTTP header strings representing WRP fields const ( MsgTypeHeader = "X-Midt-Msg-Type" @@ -34,8 +30,6 @@ const ( SourceHeader = "X-Midt-Source" ) -var ErrInvalidMsgType = errors.New("Invalid Message Type") - // Map string to MessageType int /* func StringToMessageType(str string) MessageType { diff --git a/validator.go b/validator.go new file mode 100644 index 0000000..1b3c07e --- /dev/null +++ b/validator.go @@ -0,0 +1,101 @@ +/** + * Copyright (c) 2022 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package wrp + +import ( + "errors" + + "go.uber.org/multierr" +) + +var ( + ErrInvalidTypeValidator = errors.New("invalid TypeValidator") + ErrInvalidValidator = errors.New("invalid WRP message type validator") + ErrInvalidMsgType = errors.New("invalid WRP message type") +) + +// AlwaysInvalid doesn't validate anything about the message and always returns an error. +var AlwaysInvalid ValidatorFunc = func(m Message) error { return ErrInvalidMsgType } + +// AlwaysValid doesn't validate anything about the message and always returns nil. +var AlwaysValid ValidatorFunc = func(msg Message) error { return nil } + +// Validator is a WRP validator that allows access to the Validate function. +type Validator interface { + Validate(m Message) error +} + +// Validators is a WRP validator that ensures messages are valid based on +// message type and each validator in the list. +type Validators []Validator + +// Validate runs messages through each validator in the validators list. +// It returns as soon as the message is considered invalid, otherwise returns nil if valid. +func (vs Validators) Validate(m Message) error { + var err error + for _, v := range vs { + if v != nil { + err = multierr.Append(err, v.Validate(m)) + } + } + + return err +} + +// ValidatorFunc is a WRP validator that takes messages and validates them +// against functions. +type ValidatorFunc func(Message) error + +// Validate executes its own ValidatorFunc receiver and returns the result. +func (vf ValidatorFunc) Validate(m Message) error { + return vf(m) +} + +// TypeValidator is a WRP validator that validates based on message type +// or using the defaultValidator if message type is unfound. +type TypeValidator struct { + m map[MessageType]Validator + defaultValidator Validator +} + +// Validate validates messages based on message type or using the defaultValidator +// if message type is unfound. +func (m TypeValidator) Validate(msg Message) error { + vs := m.m[msg.MessageType()] + if vs == nil { + return m.defaultValidator.Validate(msg) + } + + return vs.Validate(msg) +} + +// NewTypeValidator is a TypeValidator factory. +func NewTypeValidator(m map[MessageType]Validator, defaultValidator Validator) (TypeValidator, error) { + if m == nil { + return TypeValidator{}, ErrInvalidValidator + } + + if defaultValidator == nil { + defaultValidator = AlwaysInvalid + } + + return TypeValidator{ + m: m, + defaultValidator: defaultValidator, + }, nil +} diff --git a/validator_test.go b/validator_test.go new file mode 100644 index 0000000..70c3aab --- /dev/null +++ b/validator_test.go @@ -0,0 +1,437 @@ +/** + * Copyright (c) 2022 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package wrp + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/multierr" +) + +func TestValidators(t *testing.T) { + tests := []struct { + description string + vs Validators + msg Message + expectedErr []error + }{ + // Success case + { + description: "Empty Validators success", + vs: Validators{}, + msg: Message{Type: SimpleEventMessageType}, + }, + // Failure case + { + description: "Mix Validators error", + vs: Validators{AlwaysValid, nil, AlwaysInvalid, Validators{AlwaysValid, nil, AlwaysInvalid}}, + msg: Message{Type: SimpleEventMessageType}, + expectedErr: []error{ErrInvalidMsgType, ErrInvalidMsgType}, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := tc.vs.Validate(tc.msg) + if tc.expectedErr != nil { + assert.Equal(multierr.Errors(err), tc.expectedErr) + for _, e := range tc.expectedErr { + assert.ErrorIs(err, e) + } + return + } + + assert.NoError(err) + }) + } +} + +func TestHelperValidators(t *testing.T) { + tests := []struct { + description string + test func(*testing.T) + }{ + {"AlwaysInvalid", testAlwaysInvalid}, + {"AlwaysValid", testAlwaysValid}, + } + + for _, tc := range tests { + t.Run(tc.description, tc.test) + } +} + +func TestTypeValidator(t *testing.T) { + tests := []struct { + description string + test func(*testing.T) + }{ + {"Validate", testTypeValidatorValidate}, + {"Factory", testTypeValidatorFactory}, + } + + for _, tc := range tests { + t.Run(tc.description, tc.test) + } +} + +func ExampleNewTypeValidator() { + msgv, err := NewTypeValidator( + // Validates found msg types + map[MessageType]Validator{SimpleEventMessageType: AlwaysValid}, + // Validates unfound msg types + AlwaysInvalid) + fmt.Printf("%v %T", err == nil, msgv) + // Output: true wrp.TypeValidator +} + +func ExampleTypeValidator_Validate() { + msgv, err := NewTypeValidator( + // Validates found msg types + map[MessageType]Validator{SimpleEventMessageType: AlwaysValid}, + // Validates unfound msg types + AlwaysInvalid) + if err != nil { + return + } + + foundErr := msgv.Validate(Message{Type: SimpleEventMessageType}) // Found success + unfoundErr := msgv.Validate(Message{Type: CreateMessageType}) // Unfound error + fmt.Println(foundErr == nil, unfoundErr == nil) + // Output: true false +} + +func testTypeValidatorValidate(t *testing.T) { + tests := []struct { + description string + m map[MessageType]Validator + defaultValidator Validator + msg Message + expectedErr error + }{ + // Success case + { + description: "Found success", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysValid, + }, + msg: Message{Type: SimpleEventMessageType}, + }, + { + description: "Unfound success", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysInvalid, + }, + defaultValidator: AlwaysValid, + msg: Message{Type: CreateMessageType}, + }, + { + description: "Unfound success, nil list of default Validators", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysInvalid, + }, + defaultValidator: Validators{nil}, + msg: Message{Type: CreateMessageType}, + }, + { + description: "Unfound success, empty map of default Validators", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysInvalid, + }, + defaultValidator: Validators{}, + msg: Message{Type: CreateMessageType}, + }, + // Failure case + { + description: "Found error", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysInvalid, + }, + defaultValidator: AlwaysValid, + msg: Message{Type: SimpleEventMessageType}, + expectedErr: ErrInvalidMsgType, + }, + { + description: "Found error, nil Validator", + m: map[MessageType]Validator{ + SimpleEventMessageType: nil, + }, + msg: Message{Type: SimpleEventMessageType}, + expectedErr: ErrInvalidMsgType, + }, + { + description: "Unfound error", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysValid, + }, + msg: Message{Type: CreateMessageType}, + expectedErr: ErrInvalidMsgType, + }, + { + description: "Unfound error, nil default Validators", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysInvalid, + }, + defaultValidator: nil, + msg: Message{Type: CreateMessageType}, + expectedErr: ErrInvalidMsgType, + }, + { + description: "Unfound error, empty map of Validators", + m: map[MessageType]Validator{}, + msg: Message{Type: CreateMessageType}, + expectedErr: ErrInvalidMsgType, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + msgv, err := NewTypeValidator(tc.m, tc.defaultValidator) + require.NoError(err) + require.NotNil(msgv) + assert.NotZero(msgv) + err = msgv.Validate(tc.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} + +func testTypeValidatorFactory(t *testing.T) { + tests := []struct { + description string + m map[MessageType]Validator + defaultValidator Validator + expectedErr error + }{ + // Success case + { + description: "Default Validators success", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysValid, + }, + defaultValidator: AlwaysValid, + expectedErr: nil, + }, + { + description: "Omit default Validators success", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysValid, + }, + expectedErr: nil, + }, + // Failure case + { + description: "Nil map of Validators error", + m: nil, + defaultValidator: AlwaysValid, + expectedErr: ErrInvalidValidator, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + msgv, err := NewTypeValidator(tc.m, tc.defaultValidator) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + // Zero asserts that msgv is the zero value for its type and not nil. + assert.Zero(msgv) + return + } + + assert.NoError(err) + assert.NotNil(msgv) + assert.NotZero(msgv) + }) + } +} + +func testAlwaysValid(t *testing.T) { + var ( + expectedStatus int64 = 3471 + expectedRequestDeliveryResponse int64 = 34 + expectedIncludeSpans bool = true + ) + tests := []struct { + description string + msg Message + expectedErr []error + }{ + // Success case + { + description: "Not UTF8 success", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + // Not UFT8 Destination string + Destination: "mac:\xed\xbf\xbf", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + // Failure case + { + description: "Filled message success", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + { + description: "Empty message success", + msg: Message{}, + }, + { + description: "Bad message type success", + msg: Message{ + Type: lastMessageType + 1, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := AlwaysValid.Validate(tc.msg) + assert.NoError(err) + }) + } +} + +func testAlwaysInvalid(t *testing.T) { + var ( + expectedStatus int64 = 3471 + expectedRequestDeliveryResponse int64 = 34 + expectedIncludeSpans bool = true + ) + tests := []struct { + description string + msg Message + expectedErr []error + }{ + // Failure case + { + description: "Not UTF8 error", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + // Not UFT8 Destination string + Destination: "mac:\xed\xbf\xbf", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + { + description: "Filled message error", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + { + description: "Empty message error", + msg: Message{}, + }, + { + description: "Bad message type error", + msg: Message{ + Type: lastMessageType + 1, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := AlwaysInvalid.Validate(tc.msg) + assert.ErrorIs(err, ErrInvalidMsgType) + }) + } +}