diff --git a/.gitignore b/.gitignore index 846a894..237b4c0 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,7 @@ report.json # Images *.png + +# VSCode +*.code-workspace +.vscode/* diff --git a/go.mod b/go.mod index 1fd3bdd..5cb024a 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,11 @@ require ( github.com/go-kit/kit v0.8.0 github.com/go-logfmt/logfmt v0.3.0 // indirect github.com/go-stack/stack v1.8.0 // indirect + github.com/google/uuid v1.3.0 github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 // indirect github.com/stretchr/testify v1.7.0 github.com/ugorji/go/codec v1.2.6 github.com/xmidt-org/httpaux v0.3.0 github.com/xmidt-org/webpa-common v1.3.2 + go.uber.org/multierr v1.8.0 ) diff --git a/go.sum b/go.sum index a1d9119..6684580 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/go-logfmt/logfmt v0.3.0 h1:8HUsc87TaSWLKwrnumgC8/YconD2fJQsRJAsWaPg2i github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -25,6 +27,10 @@ github.com/xmidt-org/httpaux v0.3.0 h1:JdV4QceiE8EMA1Qf5rnJzHdkIPzQV12ddARHLU8hr github.com/xmidt-org/httpaux v0.3.0/go.mod h1:mviIlg5fHGb3lAv3l0sbiwVG/q9rqvXaudEYxVrzXdE= github.com/xmidt-org/webpa-common v1.3.2 h1:dE1Fi+XVnkt3tMGMjH7/hN/UGcaQ/ukKriXuMDyCWnM= github.com/xmidt-org/webpa-common v1.3.2/go.mod h1:oCpKzOC+9h2vYHVzAU/06tDTQuBN4RZz+rhgIXptpOI= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= +go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/header_wrp.go b/header_wrp.go index 26feeaf..eada008 100644 --- a/header_wrp.go +++ b/header_wrp.go @@ -17,10 +17,6 @@ package wrp -import ( - "errors" -) - // Constant HTTP header strings representing WRP fields const ( MsgTypeHeader = "X-Midt-Msg-Type" @@ -34,8 +30,6 @@ const ( SourceHeader = "X-Midt-Source" ) -var ErrInvalidMsgType = errors.New("Invalid Message Type") - // Map string to MessageType int /* func StringToMessageType(str string) MessageType { diff --git a/spec_validator.go b/spec_validator.go new file mode 100644 index 0000000..7cb8de9 --- /dev/null +++ b/spec_validator.go @@ -0,0 +1,147 @@ +/** + * Copyright (c) 2022 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package wrp + +import ( + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "unicode" + + "github.com/google/uuid" +) + +const ( + serialPrefix = "serial" + uuidPrefix = "uuid" + eventPrefix = "event" + dnsPrefix = "dns" +) + +var ( + ErrorInvalidMessageEncoding = errors.New("invalid message encoding") + ErrorInvalidMessageType = errors.New("invalid message type") + ErrorInvalidSource = errors.New("invalid Source name") + ErrorInvalidDestination = errors.New("invalid Destination name") +) + +// locatorPattern is the precompiled regular expression that all source and dest locators must match. +// Matching is partial, as everything after the authority (ID) is ignored. https://xmidt.io/docs/wrp/basics/#locators +var LocatorPattern = regexp.MustCompile( + `^(?P(?i)` + macPrefix + `|` + uuidPrefix + `|` + eventPrefix + `|` + dnsPrefix + `|` + serialPrefix + `):(?P[^/]+)?`, +) + +// SpecValidators is a WRP validator that ensures messages are valid based on +// each spec validator in the list. Only validates the opinionated portions of the spec. +func SpecValidators() Validators { + return Validators{UTF8Validator, MessageTypeValidator, SourceValidator, DestinationValidator} +} + +// UTF8Validator is a WRP validator that takes messages and validates that it contains UTF-8 strings. +var UTF8Validator ValidatorFunc = func(m Message) error { + if err := UTF8(m); err != nil { + return fmt.Errorf("%w: %v", ErrorInvalidMessageEncoding, err) + } + + return nil +} + +// MessageTypeValidator is a WRP validator that takes messages and validates their Type. +var MessageTypeValidator ValidatorFunc = func(m Message) error { + t := m.MessageType() + if t < Invalid0MessageType || t > lastMessageType { + return ErrorInvalidMessageType + } + + switch t { + case Invalid0MessageType, Invalid1MessageType, lastMessageType: + return ErrorInvalidMessageType + } + + return nil +} + +// SourceValidator is a WRP validator that takes messages and validates their Source. +// Only mac and uuid sources are validated. Serial, event and dns sources are +// not validated. +var SourceValidator ValidatorFunc = func(m Message) error { + if err := validateLocator(m.Source); err != nil { + return fmt.Errorf("%w: %v", ErrorInvalidSource, err) + } + + return nil +} + +// DestinationValidator is a WRP validator that takes messages and validates their Destination. +// Only mac and uuid destinations are validated. Serial, event and dns destinations are +// not validated. +var DestinationValidator ValidatorFunc = func(m Message) error { + if err := validateLocator(m.Destination); err != nil { + return fmt.Errorf("%w: %v", ErrorInvalidDestination, err) + } + + return nil +} + +// validateLocator validates a given locator's scheme and authority (ID). +// Only mac and uuid schemes' IDs are validated. IDs from serial, event and dns schemes are +// not validated. +func validateLocator(l string) error { + match := LocatorPattern.FindStringSubmatch(l) + if match == nil { + return fmt.Errorf("spec scheme not found") + } + + idPart := match[2] + switch strings.ToLower(match[1]) { + case macPrefix: + var invalidCharacter rune = -1 + idPart = strings.Map( + func(r rune) rune { + switch { + case strings.ContainsRune(hexDigits, r): + return unicode.ToLower(r) + case strings.ContainsRune(macDelimiters, r): + return -1 + default: + invalidCharacter = r + return -1 + } + }, + idPart, + ) + + if invalidCharacter != -1 { + return fmt.Errorf("invalid character %v", strconv.QuoteRune(invalidCharacter)) + } else if len(idPart) != macLength { + return errors.New("invalid mac length") + } + case uuidPrefix: + if _, err := uuid.Parse(idPart); err != nil { + return err + } + case serialPrefix, eventPrefix, dnsPrefix: + if len(idPart) == 0 { + return fmt.Errorf("invalid %v empty authority (ID)", serialPrefix) + } + } + + return nil +} diff --git a/spec_validator_test.go b/spec_validator_test.go new file mode 100644 index 0000000..5f11970 --- /dev/null +++ b/spec_validator_test.go @@ -0,0 +1,617 @@ +/** + * Copyright (c) 2022 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package wrp + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSpecHelperValidators(t *testing.T) { + tests := []struct { + description string + test func(*testing.T) + }{ + {"UTF8Validator", testUTF8Validator}, + {"MessageTypeValidator", testMessageTypeValidator}, + {"SourceValidator", testSourceValidator}, + {"DestinationValidator", testDestinationValidator}, + {"validateLocator", testValidateLocator}, + } + + for _, tc := range tests { + t.Run(tc.description, tc.test) + } +} + +func TestSpecValidators(t *testing.T) { + var ( + expectedStatus int64 = 3471 + expectedRequestDeliveryResponse int64 = 34 + expectedIncludeSpans bool = true + ) + + tests := []struct { + description string + msg Message + expectedErr []error + }{ + // Success case + { + description: "Valid spec success", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + // Failure case + { + description: "Invaild spec error", + msg: Message{ + Type: Invalid0MessageType, + // Missing scheme + Source: "external.com", + // Invalid Mac + Destination: "MAC:+++BB-44-55", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + // Not UFT8 Payload + Payload: []byte{1, 2, 3, 4, 0xff /* \xed\xbf\xbf is invalid */, 0xce}, + ServiceName: "ServiceName", + // Not UFT8 URL string + URL: "someURL\xed\xbf\xbf.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + expectedErr: []error{ErrorInvalidMessageType, ErrorInvalidSource, ErrorInvalidDestination, ErrorInvalidMessageEncoding}, + }, + { + description: "Invaild spec error, empty message", + msg: Message{}, + expectedErr: []error{ErrorInvalidMessageType, ErrorInvalidSource, ErrorInvalidDestination}, + }, + { + description: "Invaild spec error, nonexistent MessageType", + msg: Message{ + Type: lastMessageType + 1, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + }, + expectedErr: []error{ErrorInvalidMessageType}, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := SpecValidators().Validate(tc.msg) + if tc.expectedErr != nil { + for _, e := range tc.expectedErr { + assert.ErrorIs(err, e) + } + return + } + + assert.NoError(err) + }) + } +} + +func ExampleTypeValidator_Validate_specValidators() { + msgv, err := NewTypeValidator( + // Validates found msg types + map[MessageType]Validator{ + // Validates opinionated portions of the spec + SimpleEventMessageType: SpecValidators(), + // Only validates Source and nothing else + SimpleRequestResponseMessageType: SourceValidator, + }, + // Validates unfound msg types + AlwaysInvalid) + if err != nil { + return + } + + var ( + expectedStatus int64 = 3471 + expectedRequestDeliveryResponse int64 = 34 + expectedIncludeSpans bool = true + ) + foundErrFailure := msgv.Validate(Message{ + Type: SimpleEventMessageType, + // Missing scheme + Source: "external.com", + // Invalid Mac + Destination: "MAC:+++BB-44-55", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + // Not UFT8 Payload + Payload: []byte{1, 2, 3, 4, 0xff /* \xed\xbf\xbf is invalid */, 0xce}, + ServiceName: "ServiceName", + // Not UFT8 URL string + URL: "someURL\xed\xbf\xbf.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }) // Found error + foundErrSuccess1 := msgv.Validate(Message{ + Type: SimpleEventMessageType, + Source: "MAC:11:22:33:44:55:66", + Destination: "MAC:11:22:33:44:55:61", + }) // Found success + foundErrSuccess2 := msgv.Validate(Message{ + Type: SimpleRequestResponseMessageType, + Source: "MAC:11:22:33:44:55:66", + Destination: "invalid:a-BB-44-55", + }) // Found success + unfoundErrFailure := msgv.Validate(Message{Type: CreateMessageType}) // Unfound error + fmt.Println(foundErrFailure == nil, foundErrSuccess1 == nil, foundErrSuccess2 == nil, unfoundErrFailure == nil) + // Output: false true true false +} + +func testUTF8Validator(t *testing.T) { + var ( + expectedStatus int64 = 3471 + expectedRequestDeliveryResponse int64 = 34 + expectedIncludeSpans bool = true + ) + + tests := []struct { + description string + msg Message + expectedErr error + }{ + // Success case + { + description: "UTF8 success", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + { + description: "Not UTF8 error", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + // Not UFT8 Destination string + Destination: "MAC:\xed\xbf\xbf", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + expectedErr: ErrorInvalidMessageEncoding, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := UTF8Validator(tc.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} + +func testMessageTypeValidator(t *testing.T) { + tests := []struct { + description string + msg Message + expectedErr error + }{ + // Success case + { + description: "AuthorizationMessageType success", + msg: Message{Type: AuthorizationMessageType}, + }, + { + description: "SimpleRequestResponseMessageType success", + msg: Message{Type: SimpleRequestResponseMessageType}, + }, + { + description: "SimpleEventMessageType success", + msg: Message{Type: SimpleEventMessageType}, + }, + { + description: "CreateMessageType success", + msg: Message{Type: CreateMessageType}, + }, + { + description: "RetrieveMessageType success", + msg: Message{Type: RetrieveMessageType}, + }, + { + description: "UpdateMessageType success", + msg: Message{Type: UpdateMessageType}, + }, + { + description: "DeleteMessageType success", + msg: Message{Type: DeleteMessageType}, + }, + { + description: "ServiceRegistrationMessageType success", + msg: Message{Type: ServiceRegistrationMessageType}, + }, + { + description: "ServiceAliveMessageType success", + msg: Message{Type: ServiceAliveMessageType}, + }, + { + description: "UnknownMessageType success", + msg: Message{Type: UnknownMessageType}, + }, + // Failure case + { + description: "Invalid0MessageType error", + msg: Message{Type: Invalid0MessageType}, + expectedErr: ErrorInvalidMessageType, + }, + { + description: "Invalid0MessageType error", + msg: Message{Type: Invalid0MessageType}, + expectedErr: ErrorInvalidMessageType, + }, + { + description: "Invalid1MessageType error", + msg: Message{Type: Invalid1MessageType}, + expectedErr: ErrorInvalidMessageType, + }, + { + description: "lastMessageType error", + msg: Message{Type: lastMessageType}, + expectedErr: ErrorInvalidMessageType, + }, + { + description: "Nonexistent negative MessageType error", + msg: Message{Type: -10}, + expectedErr: ErrorInvalidMessageType, + }, + { + description: "Nonexistent positive MessageType error", + msg: Message{Type: lastMessageType + 1}, + expectedErr: ErrorInvalidMessageType, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := MessageTypeValidator(tc.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} + +func testSourceValidator(t *testing.T) { + // SourceValidator is mainly a wrapper for validateLocator. + // This test mainly ensures that SourceValidator returns nil for non errors + // and wraps errors with ErrorInvalidSource. + // testValidateLocator covers the actual spectrum of test cases. + + tests := []struct { + description string + msg Message + expectedErr error + }{ + // Success case + { + description: "Source success", + msg: Message{Source: "MAC:11:22:33:44:55:66"}, + }, + // Failures + { + description: "Source error", + msg: Message{Source: "invalid:a-BB-44-55"}, + expectedErr: ErrorInvalidSource, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := SourceValidator(tc.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} + +func testDestinationValidator(t *testing.T) { + // DestinationValidator is mainly a wrapper for validateLocator. + // This test mainly ensures that DestinationValidator returns nil for non errors + // and wraps errors with ErrorInvalidDestination. + // testValidateLocator covers the actual spectrum of test cases. + + tests := []struct { + description string + msg Message + expectedErr error + }{ + // Success case + { + description: "Destination success", + msg: Message{Destination: "MAC:11:22:33:44:55:66"}, + }, + // Failures + { + description: "Destination error", + msg: Message{Destination: "invalid:a-BB-44-55"}, + expectedErr: ErrorInvalidDestination, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := DestinationValidator(tc.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} + +func testValidateLocator(t *testing.T) { + tests := []struct { + description string + value string + shouldErr bool + }{ + // mac success case + { + description: "Mac ID success, ':' delimiter", + value: "MAC:11:22:33:44:55:66", + shouldErr: false, + }, + { + description: "Mac ID success, no delimiter", + value: "MAC:11aaBB445566", + shouldErr: false, + }, + { + description: "Mac ID success, '-' delimiter", + value: "mac:11-aa-BB-44-55-66", + shouldErr: false, + }, + { + description: "Mac ID success, ',' delimiter", + value: "mac:11,aa,BB,44,55,66", + shouldErr: false, + }, + { + description: "Mac service success", + value: "mac:11,aa,BB,44,55,66/parodus/tag/test0", + shouldErr: false, + }, + // Mac failure case + { + description: "Mac ID error, invalid mac ID character", + value: "MAC:invalid45566", + shouldErr: true, + }, + { + description: "Mac ID error, invalid mac ID length", + value: "mac:11-aa-BB-44-55", + shouldErr: true, + }, + { + description: "Mac ID error, no ID", + value: "mac:", + shouldErr: true, + }, + // Serial success case + { + description: "Serial ID success", + value: "serial:anything Goes!", + shouldErr: false, + }, + // Serial failure case + { + description: "Invalid serial ID error, no ID", + value: "serial:", + shouldErr: true, + }, + // UUID success case + // The variant specified in RFC4122 + { + description: "UUID RFC4122 variant ID success", + value: "uuid:f47ac10b-58cc-0372-8567-0e02b2c3d479", + shouldErr: false, + }, + { + description: "UUID RFC4122 variant ID success, with Microsoft encoding", + value: "uuid:{f47ac10b-58cc-0372-8567-0e02b2c3d479}", + shouldErr: false, + }, + // Reserved, NCS backward compatibility. + { + description: "UUID Reserved variant ID success, with URN lower case ", + value: "UUID:urn:uuid:f47ac10b-58cc-4372-0567-0e02b2c3d479", + shouldErr: false, + }, + { + description: "UUID Reserved variant ID success, with URN upper case", + value: "UUID:URN:UUID:f47ac10b-58cc-4372-0567-0e02b2c3d479", + shouldErr: false, + }, + { + description: "UUID Reserved variant ID success, without URN", + value: "UUID:f47ac10b-58cc-4372-0567-0e02b2c3d479", + shouldErr: false, + }, + // Reserved, Microsoft Corporation backward compatibility. + { + description: "UUID Microsoft variant ID success", + value: "uuid:f47ac10b-58cc-4372-c567-0e02b2c3d479", + shouldErr: false, + }, + // Reserved for future definition. + { + description: "UUID Future variant ID success", + value: "uuid:f47ac10b-58cc-4372-e567-0e02b2c3d479", + shouldErr: false, + }, + // UUID failure case + { + description: "Invalid UUID ID error", + value: "uuid:invalid45566", + shouldErr: true, + }, + { + description: "Invalid UUID ID error, with URN", + value: "uuid:URN:UUID:invalid45566", + shouldErr: true, + }, + { + description: "Invalid UUID ID error, with Microsoft encoding", + value: "uuid:{invalid45566}", + shouldErr: true, + }, + { + description: "Invalid UUID ID error, no ID", + value: "uuid:", + shouldErr: true, + }, + // Event success case + { + description: "Event ID success", + value: "event:anything Goes!", + shouldErr: false, + }, + // Event failure case + { + description: "Invalid event ID error, no ID", + value: "event:", + shouldErr: true, + }, + // DNS success case + { + description: "DNS ID success", + value: "dns:anything Goes!", + shouldErr: false, + }, + // DNS failure case + { + description: "Invalid DNS ID error, no ID", + value: "dns:", + shouldErr: true, + }, + // Scheme failure case + { + description: "Invalid scheme error", + value: "invalid:a-BB-44-55", + shouldErr: true, + }, + { + description: "Invalid scheme error, empty string", + value: "", + shouldErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := validateLocator(tc.value) + if tc.shouldErr { + assert.Error(err) + return + } + + assert.NoError(err) + }) + } +} diff --git a/utf8.go b/utf8.go index 5aa0417..d1e26a2 100644 --- a/utf8.go +++ b/utf8.go @@ -26,7 +26,7 @@ import ( var ( ErrNotUTF8 = errors.New("field contains non-utf-8 characters") - ErrUnexpectedKind = errors.New("A struct or non-nil pointer to struct is required") + ErrUnexpectedKind = errors.New("a struct or non-nil pointer to struct is required") ) // UTF8 takes any struct verifies that it contains UTF-8 strings. @@ -55,7 +55,6 @@ func UTF8(v interface{}) error { if !utf8.ValidString(s) { return fmt.Errorf("%w: '%s:%v'", ErrNotUTF8, ft.Name, s) } - fmt.Println(s) } } diff --git a/validator.go b/validator.go new file mode 100644 index 0000000..f3ed27e --- /dev/null +++ b/validator.go @@ -0,0 +1,101 @@ +/** + * Copyright (c) 2022 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package wrp + +import ( + "errors" + + "go.uber.org/multierr" +) + +var ( + ErrInvalidTypeValidator = errors.New("invalid TypeValidator") + ErrInvalidValidator = errors.New("invalid WRP message type validator") + ErrInvalidMsgType = errors.New("invalid WRP message type") +) + +// AlwaysInvalid doesn't validate anything about the message and always returns an error. +var AlwaysInvalid ValidatorFunc = func(m Message) error { return ErrInvalidMsgType } + +// AlwaysValid doesn't validate anything about the message and always returns nil. +var AlwaysValid ValidatorFunc = func(msg Message) error { return nil } + +// Validator is a WRP validator that allows access to the Validate function. +type Validator interface { + Validate(m Message) error +} + +// Validators is a WRP validator that ensures messages are valid based on +// message type and each validator in the list. +type Validators []Validator + +// Validate runs messages through each validator in the validators list. +// It returns as soon as the message is considered invalid, otherwise returns nil if valid. +func (vs Validators) Validate(m Message) error { + var err error + for _, v := range vs { + if v != nil { + err = multierr.Append(err, v.Validate(m)) + } + } + + return err +} + +// ValidatorFunc is a WRP validator that takes messages and validates them +// against functions. +type ValidatorFunc func(Message) error + +// Validate executes its own ValidatorFunc receiver and returns the result. +func (vf ValidatorFunc) Validate(m Message) error { + return vf(m) +} + +// TypeValidator is a WRP validator that validates based on message type +// or using the defaultValidators if message type is unfound. +type TypeValidator struct { + m map[MessageType]Validator + defaultValidators Validator +} + +// Validate validates messages based on message type or using the defaultValidators +// if message type is unfound. +func (m TypeValidator) Validate(msg Message) error { + vs := m.m[msg.MessageType()] + if vs == nil { + return m.defaultValidators.Validate(msg) + } + + return vs.Validate(msg) +} + +// NewTypeValidator is a TypeValidator factory. +func NewTypeValidator(m map[MessageType]Validator, defaultValidators ...Validator) (TypeValidator, error) { + if m == nil { + return TypeValidator{}, ErrInvalidValidator + } + + if defaultValidators == nil { + defaultValidators = Validators{AlwaysInvalid} + } + + return TypeValidator{ + m: m, + defaultValidators: Validators(defaultValidators), + }, nil +} diff --git a/validator_test.go b/validator_test.go new file mode 100644 index 0000000..8f9278f --- /dev/null +++ b/validator_test.go @@ -0,0 +1,437 @@ +/** + * Copyright (c) 2022 Comcast Cable Communications Management, LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package wrp + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/multierr" +) + +func TestValidators(t *testing.T) { + tests := []struct { + description string + vs Validators + msg Message + expectedErr []error + }{ + // Success case + { + description: "Empty Validators success", + vs: Validators{}, + msg: Message{Type: SimpleEventMessageType}, + }, + // Failure case + { + description: "Mix Validators error", + vs: Validators{AlwaysValid, nil, AlwaysInvalid, Validators{AlwaysValid, nil, AlwaysInvalid}}, + msg: Message{Type: SimpleEventMessageType}, + expectedErr: []error{ErrInvalidMsgType, ErrInvalidMsgType}, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := tc.vs.Validate(tc.msg) + if tc.expectedErr != nil { + assert.Equal(multierr.Errors(err), tc.expectedErr) + for _, e := range tc.expectedErr { + assert.ErrorIs(err, e) + } + return + } + + assert.NoError(err) + }) + } +} + +func TestHelperValidators(t *testing.T) { + tests := []struct { + description string + test func(*testing.T) + }{ + {"AlwaysInvalid", testAlwaysInvalid}, + {"AlwaysValid", testAlwaysValid}, + } + + for _, tc := range tests { + t.Run(tc.description, tc.test) + } +} + +func TestTypeValidator(t *testing.T) { + tests := []struct { + description string + test func(*testing.T) + }{ + {"Validate", testTypeValidatorValidate}, + {"Factory", testTypeValidatorFactory}, + } + + for _, tc := range tests { + t.Run(tc.description, tc.test) + } +} + +func ExampleNewTypeValidator() { + msgv, err := NewTypeValidator( + // Validates found msg types + map[MessageType]Validator{SimpleEventMessageType: AlwaysValid}, + // Validates unfound msg types + AlwaysInvalid) + fmt.Printf("%v %T", err == nil, msgv) + // Output: true wrp.TypeValidator +} + +func ExampleTypeValidator_Validate() { + msgv, err := NewTypeValidator( + // Validates found msg types + map[MessageType]Validator{SimpleEventMessageType: AlwaysValid}, + // Validates unfound msg types + AlwaysInvalid) + if err != nil { + return + } + + foundErr := msgv.Validate(Message{Type: SimpleEventMessageType}) // Found success + unfoundErr := msgv.Validate(Message{Type: CreateMessageType}) // Unfound error + fmt.Println(foundErr == nil, unfoundErr == nil) + // Output: true false +} + +func testTypeValidatorValidate(t *testing.T) { + tests := []struct { + description string + m map[MessageType]Validator + defaultValidators []Validator + msg Message + expectedErr error + }{ + // Success case + { + description: "Found success", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysValid, + }, + msg: Message{Type: SimpleEventMessageType}, + }, + { + description: "Unfound success", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysInvalid, + }, + defaultValidators: []Validator{AlwaysValid}, + msg: Message{Type: CreateMessageType}, + }, + { + description: "Unfound success, nil list of default Validators", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysInvalid, + }, + defaultValidators: []Validator{nil}, + msg: Message{Type: CreateMessageType}, + }, + { + description: "Unfound success, empty map of default Validators", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysInvalid, + }, + defaultValidators: []Validator{}, + msg: Message{Type: CreateMessageType}, + }, + // Failure case + { + description: "Found error", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysInvalid, + }, + defaultValidators: []Validator{AlwaysValid}, + msg: Message{Type: SimpleEventMessageType}, + expectedErr: ErrInvalidMsgType, + }, + { + description: "Found error, nil Validator", + m: map[MessageType]Validator{ + SimpleEventMessageType: nil, + }, + msg: Message{Type: SimpleEventMessageType}, + expectedErr: ErrInvalidMsgType, + }, + { + description: "Unfound error", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysValid, + }, + msg: Message{Type: CreateMessageType}, + expectedErr: ErrInvalidMsgType, + }, + { + description: "Unfound error, nil default Validators", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysInvalid, + }, + defaultValidators: nil, + msg: Message{Type: CreateMessageType}, + expectedErr: ErrInvalidMsgType, + }, + { + description: "Unfound error, empty map of Validators", + m: map[MessageType]Validator{}, + msg: Message{Type: CreateMessageType}, + expectedErr: ErrInvalidMsgType, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + msgv, err := NewTypeValidator(tc.m, tc.defaultValidators...) + require.NoError(err) + require.NotNil(msgv) + assert.NotZero(msgv) + err = msgv.Validate(tc.msg) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + return + } + + assert.NoError(err) + }) + } +} + +func testTypeValidatorFactory(t *testing.T) { + tests := []struct { + description string + m map[MessageType]Validator + defaultValidators []Validator + expectedErr error + }{ + // Success case + { + description: "Default Validators success", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysValid, + }, + defaultValidators: []Validator{AlwaysValid}, + expectedErr: nil, + }, + { + description: "Omit default Validators success", + m: map[MessageType]Validator{ + SimpleEventMessageType: AlwaysValid, + }, + expectedErr: nil, + }, + // Failure case + { + description: "Nil map of Validators error", + m: nil, + defaultValidators: []Validator{AlwaysValid}, + expectedErr: ErrInvalidValidator, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + msgv, err := NewTypeValidator(tc.m, tc.defaultValidators...) + if tc.expectedErr != nil { + assert.ErrorIs(err, tc.expectedErr) + // Zero asserts that msgv is the zero value for its type and not nil. + assert.Zero(msgv) + return + } + + assert.NoError(err) + assert.NotNil(msgv) + assert.NotZero(msgv) + }) + } +} + +func testAlwaysValid(t *testing.T) { + var ( + expectedStatus int64 = 3471 + expectedRequestDeliveryResponse int64 = 34 + expectedIncludeSpans bool = true + ) + tests := []struct { + description string + msg Message + expectedErr []error + }{ + // Success case + { + description: "Not UTF8 success", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + // Not UFT8 Destination string + Destination: "mac:\xed\xbf\xbf", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + // Failure case + { + description: "Filled message success", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + { + description: "Empty message success", + msg: Message{}, + }, + { + description: "Bad message type success", + msg: Message{ + Type: lastMessageType + 1, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := AlwaysValid.Validate(tc.msg) + assert.NoError(err) + }) + } +} + +func testAlwaysInvalid(t *testing.T) { + var ( + expectedStatus int64 = 3471 + expectedRequestDeliveryResponse int64 = 34 + expectedIncludeSpans bool = true + ) + tests := []struct { + description string + msg Message + expectedErr []error + }{ + // Failure case + { + description: "Not UTF8 error", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + // Not UFT8 Destination string + Destination: "mac:\xed\xbf\xbf", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + { + description: "Filled message error", + msg: Message{ + Type: SimpleRequestResponseMessageType, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + TransactionUUID: "DEADBEEF", + ContentType: "ContentType", + Accept: "Accept", + Status: &expectedStatus, + RequestDeliveryResponse: &expectedRequestDeliveryResponse, + Headers: []string{"Header1", "Header2"}, + Metadata: map[string]string{"name": "value"}, + Spans: [][]string{{"1", "2"}, {"3"}}, + IncludeSpans: &expectedIncludeSpans, + Path: "/some/where/over/the/rainbow", + Payload: []byte{1, 2, 3, 4, 0xff, 0xce}, + ServiceName: "ServiceName", + URL: "someURL.com", + PartnerIDs: []string{"foo"}, + SessionID: "sessionID123", + }, + }, + { + description: "Empty message error", + msg: Message{}, + }, + { + description: "Bad message type error", + msg: Message{ + Type: lastMessageType + 1, + Source: "dns:external.com", + Destination: "MAC:11:22:33:44:55:66", + }, + }, + } + + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + assert := assert.New(t) + err := AlwaysInvalid.Validate(tc.msg) + assert.ErrorIs(err, ErrInvalidMsgType) + }) + } +}