From 8e695d5050fccf9d353bcfaf406a1b2866550e98 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Fri, 29 Dec 2023 16:14:09 -0500 Subject: [PATCH] Support validate (and import) taking in a schema directly Fixes #287 --- internal/cmd/import.go | 2 +- internal/cmd/validate.go | 13 +++++--- internal/decode/decoder.go | 54 +++++++++++++++++++++++++++------ internal/decode/decoder_test.go | 47 ++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 14 deletions(-) diff --git a/internal/cmd/import.go b/internal/cmd/import.go index d879384e..9a50aa37 100644 --- a/internal/cmd/import.go +++ b/internal/cmd/import.go @@ -77,7 +77,7 @@ func importCmdFunc(cmd *cobra.Command, args []string) error { return err } var p decode.SchemaRelationships - if _, err := decoder(&p); err != nil { + if _, _, err := decoder(&p); err != nil { return err } diff --git a/internal/cmd/validate.go b/internal/cmd/validate.go index 4be83a9b..1cb33f27 100644 --- a/internal/cmd/validate.go +++ b/internal/cmd/validate.go @@ -40,8 +40,8 @@ func registerValidateCmd(rootCmd *cobra.Command) { } var validateCmd = &cobra.Command{ - Use: "validate ", - Short: "validate the given validation file", + Use: "validate ", + Short: "validate the given validation or schema file", Example: ` From a local file (with prefix): zed validate file:///Users/zed/Downloads/authzed-x7izWU8_2Gw3.yaml @@ -78,7 +78,7 @@ func validateCmdFunc(cmd *cobra.Command, args []string) error { // Decode the validation document. var parsed validationfile.ValidationFile - validateContents, err := decoder(&parsed) + validateContents, isOnlySchema, err := decoder(&parsed) if err != nil { var errWithSource spiceerrors.ErrorWithSource if errors.As(err, &errWithSource) { @@ -102,7 +102,12 @@ func validateCmdFunc(cmd *cobra.Command, args []string) error { return err } if devErrs != nil { - outputDeveloperErrorsWithLineOffset(validateContents, devErrs.InputErrors, 1 /* for the 'schema:' */) + schemaOffset := 1 /* for the 'schema:' */ + if isOnlySchema { + schemaOffset = 0 + } + + outputDeveloperErrorsWithLineOffset(validateContents, devErrs.InputErrors, schemaOffset) } // Run assertions. diff --git a/internal/decode/decoder.go b/internal/decode/decoder.go index 4c01fdf1..241615b9 100644 --- a/internal/decode/decoder.go +++ b/internal/decode/decoder.go @@ -10,6 +10,11 @@ import ( "regexp" "strings" + "github.com/authzed/spicedb/pkg/schemadsl/compiler" + "github.com/authzed/spicedb/pkg/schemadsl/input" + "github.com/authzed/spicedb/pkg/spiceerrors" + "github.com/authzed/spicedb/pkg/validationfile" + "github.com/authzed/spicedb/pkg/validationfile/blocks" "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) @@ -24,7 +29,7 @@ type SchemaRelationships struct { } // Func will decode into the supplied object. -type Func func(out interface{}) ([]byte, error) +type Func func(out interface{}) ([]byte, bool, error) // DecoderForURL returns the appropriate decoder for a given URL. // Some URLs have special handling to dereference to the actual file. @@ -43,16 +48,17 @@ func DecoderForURL(u *url.URL) (d Func, err error) { } func fileDecoder(u *url.URL) Func { - return func(out interface{}) ([]byte, error) { + return func(out interface{}) ([]byte, bool, error) { file, err := os.Open(u.Path) if err != nil { - return nil, err + return nil, false, err } data, err := io.ReadAll(file) if err != nil { - return nil, err + return nil, false, err } - return data, yaml.Unmarshal(data, out) + isOnlySchema, err := unmarshalAsYAMLOrSchema(data, out) + return data, isOnlySchema, err } } @@ -82,18 +88,48 @@ func rewriteURL(u *url.URL) { } func directHTTPDecoder(u *url.URL) Func { - return func(out interface{}) ([]byte, error) { + return func(out interface{}) ([]byte, bool, error) { log.Debug().Stringer("url", u).Send() r, err := http.Get(u.String()) if err != nil { - return nil, err + return nil, false, err } defer r.Body.Close() data, err := io.ReadAll(r.Body) if err != nil { - return nil, err + return nil, false, err } - return data, yaml.Unmarshal(data, out) + isOnlySchema, err := unmarshalAsYAMLOrSchema(data, out) + return data, isOnlySchema, err } } + +func unmarshalAsYAMLOrSchema(data []byte, out interface{}) (bool, error) { + // Check for indications of a schema-only file. + if !strings.Contains(string(data), "schema:") { + compiled, serr := compiler.Compile(compiler.InputSchema{ + Source: input.Source("schema"), SchemaString: string(data), + }) + if serr != nil { + return false, serr + } + + // If that succeeds, return the compiled schema. + vfile := *out.(*validationfile.ValidationFile) + vfile.Schema = blocks.ParsedSchema{ + CompiledSchema: compiled, + Schema: string(data), + SourcePosition: spiceerrors.SourcePosition{LineNumber: 1, ColumnPosition: 1}, + } + *out.(*validationfile.ValidationFile) = vfile + return true, nil + } + + // Try to unparse as YAML for the validation file format. + if err := yaml.Unmarshal(data, out); err != nil { + return false, err + } + + return false, nil +} diff --git a/internal/decode/decoder_test.go b/internal/decode/decoder_test.go index 06c5fcad..9d91e9ee 100644 --- a/internal/decode/decoder_test.go +++ b/internal/decode/decoder_test.go @@ -4,6 +4,7 @@ import ( "net/url" "testing" + "github.com/authzed/spicedb/pkg/validationfile" "github.com/stretchr/testify/require" ) @@ -126,3 +127,49 @@ func TestRewriteURL(t *testing.T) { }) } } + +func TestUnmarshalAsYAMLOrSchema(t *testing.T) { + tests := []struct { + name string + in []byte + isOnlySchema bool + outSchema string + wantErr bool + }{ + { + name: "valid yaml", + in: []byte(` +schema: + definition user {} +`), + outSchema: `definition user {}`, + isOnlySchema: false, + wantErr: false, + }, + { + name: "valid schema", + in: []byte(`definition user {}`), + isOnlySchema: true, + outSchema: `definition user {}`, + wantErr: false, + }, + { + name: "invalid yaml", + in: []byte(`invalid yaml`), + isOnlySchema: false, + outSchema: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + block := validationfile.ValidationFile{} + isOnlySchema, err := unmarshalAsYAMLOrSchema(tt.in, &block) + require.Equal(t, tt.wantErr, err != nil) + require.Equal(t, tt.isOnlySchema, isOnlySchema) + if !tt.wantErr { + require.Equal(t, tt.outSchema, block.Schema.Schema) + } + }) + } +}