diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..02cd81b71c --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,33 @@ +Release type: minor + +This release adds a new method `get_fields` on the `Schema` class. +You can use `get_fields` to hide certain field based on some conditions, +for example: + +```python +@strawberry.type +class User: + name: str + email: str = strawberry.field(metadata={"tags": ["internal"]}) + + +@strawberry.type +class Query: + user: User + + +def public_field_filter(field: StrawberryField) -> bool: + return "internal" not in field.metadata.get("tags", []) + + +class PublicSchema(strawberry.Schema): + def get_fields( + self, type_definition: StrawberryObjectDefinition + ) -> List[StrawberryField]: + return list(filter(public_field_filter, type_definition.fields)) + + +schema = PublicSchema(query=Query) +``` + +The schema here would only have the `name` field on the `User` type. diff --git a/docs/types/schema.md b/docs/types/schema.md index 9c919aa7f6..f49969cd4e 100644 --- a/docs/types/schema.md +++ b/docs/types/schema.md @@ -233,3 +233,43 @@ class StrawberryLogger: cls.logger.error(error, exc_info=error.original_error, **logger_kwargs) ``` + +## Filtering/customising fields + +You can customise the fields that are exposed on a schema by subclassing the +`Schema` class and overriding the `get_fields` method, for example you can use +this to create different GraphQL APIs, such as a public and an internal API. +Here's an example of this: + +```python +@strawberry.type +class User: + name: str + email: str = strawberry.field(metadata={"tags": ["internal"]}) + + +@strawberry.type +class Query: + user: User + + +def public_field_filter(field: StrawberryField) -> bool: + return "internal" not in field.metadata.get("tags", []) + + +class PublicSchema(strawberry.Schema): + def get_fields( + self, type_definition: StrawberryObjectDefinition + ) -> List[StrawberryField]: + return list(filter(public_field_filter, type_definition.fields)) + + +schema = PublicSchema(query=Query) +``` + + + +The `get_fields` method is only called once when creating the schema, this is +not intended to be used to dynamically customise the schema. + + diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index dcb961baad..f613646716 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -100,7 +100,9 @@ def __init__( # TODO: check that the overrides are valid scalar_registry.update(cast(SCALAR_OVERRIDES_DICT_TYPE, scalar_overrides)) - self.schema_converter = GraphQLCoreConverter(self.config, scalar_registry) + self.schema_converter = GraphQLCoreConverter( + self.config, scalar_registry, self.get_fields + ) self.directives = directives self.schema_directives = list(schema_directives) @@ -231,6 +233,11 @@ def get_directive_by_name(self, graphql_name: str) -> Optional[StrawberryDirecti None, ) + def get_fields( + self, type_definition: StrawberryObjectDefinition + ) -> List[StrawberryField]: + return type_definition.fields + async def execute( self, query: Optional[str], diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 6040702158..f89c7b1d99 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -109,6 +109,7 @@ def _get_thunk_mapping( type_definition: StrawberryObjectDefinition, name_converter: Callable[[StrawberryField], str], field_converter: FieldConverterProtocol[FieldType], + get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]], ) -> Dict[str, FieldType]: """Create a GraphQL core `ThunkMapping` mapping of field names to field types. @@ -123,7 +124,9 @@ def _get_thunk_mapping( """ thunk_mapping: Dict[str, FieldType] = {} - for field in type_definition.fields: + fields = get_fields(type_definition) + + for field in fields: field_type = field.type if field_type is UNRESOLVED: @@ -178,10 +181,12 @@ def __init__( self, config: StrawberryConfig, scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]], + get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]], ): self.type_map: Dict[str, ConcreteType] = {} self.config = config self.scalar_registry = scalar_registry + self.get_fields = get_fields def from_argument(self, argument: StrawberryArgument) -> GraphQLArgument: argument_type = cast( @@ -374,6 +379,7 @@ def get_graphql_fields( type_definition=type_definition, name_converter=self.config.name_converter.from_field, field_converter=self.from_field, + get_fields=self.get_fields, ) def get_graphql_input_fields( @@ -383,6 +389,7 @@ def get_graphql_input_fields( type_definition=type_definition, name_converter=self.config.name_converter.from_field, field_converter=self.from_input_field, + get_fields=self.get_fields, ) def from_input_object(self, object_type: type) -> GraphQLInputObjectType: diff --git a/tests/pyright/test_federation.py b/tests/pyright/test_federation.py index 2e87edaa24..83040ecfce 100644 --- a/tests/pyright/test_federation.py +++ b/tests/pyright/test_federation.py @@ -29,13 +29,13 @@ def test_federation_type(): results = run_pyright(CODE) assert results == [ - Result(type="error", message='No parameter named "n"', line=16, column=6), Result( type="error", message='Argument missing for parameter "name"', line=16, column=1, ), + Result(type="error", message='No parameter named "n"', line=16, column=6), Result( type="information", message='Type of "User" is "type[User]"', @@ -75,15 +75,15 @@ def test_federation_interface(): assert results == [ Result( type="error", - message='No parameter named "n"', + message='Argument missing for parameter "name"', line=12, - column=6, + column=1, ), Result( type="error", - message='Argument missing for parameter "name"', + message='No parameter named "n"', line=12, - column=1, + column=6, ), Result( type="information", @@ -122,15 +122,15 @@ def test_federation_input(): assert results == [ Result( type="error", - message='No parameter named "n"', + message='Argument missing for parameter "name"', line=10, - column=6, + column=1, ), Result( type="error", - message='Argument missing for parameter "name"', + message='No parameter named "n"', line=10, - column=1, + column=6, ), Result( type="information", diff --git a/tests/pyright/test_federation_fields.py b/tests/pyright/test_federation_fields.py index 72c61d991e..6981fdc849 100644 --- a/tests/pyright/test_federation_fields.py +++ b/tests/pyright/test_federation_fields.py @@ -42,6 +42,12 @@ def test_pyright(): results = run_pyright(CODE) assert results == [ + Result( + type="error", + message='Argument missing for parameter "name"', + line=24, + column=1, + ), Result( type="error", message='No parameter named "n"', @@ -51,7 +57,7 @@ def test_pyright(): Result( type="error", message='Argument missing for parameter "name"', - line=24, + line=27, column=1, ), Result( @@ -60,12 +66,6 @@ def test_pyright(): line=27, column=11, ), - Result( - type="error", - message='Argument missing for parameter "name"', - line=27, - column=1, - ), Result( type="information", message='Type of "User" is "type[User]"', diff --git a/tests/pyright/test_federation_params.py b/tests/pyright/test_federation_params.py index e455135e09..69ae95d839 100644 --- a/tests/pyright/test_federation_params.py +++ b/tests/pyright/test_federation_params.py @@ -23,14 +23,14 @@ def test_pyright(): assert results == [ Result( type="error", - message='No parameter named "n"', + message='Argument missing for parameter "name"', line=11, - column=11, + column=1, ), Result( type="error", - message='Argument missing for parameter "name"', + message='No parameter named "n"', line=11, - column=1, + column=11, ), ] diff --git a/tests/pyright/test_fields.py b/tests/pyright/test_fields.py index 00296a72ae..445a048a47 100644 --- a/tests/pyright/test_fields.py +++ b/tests/pyright/test_fields.py @@ -26,15 +26,15 @@ def test_pyright(): assert results == [ Result( type="error", - message='No parameter named "n"', + message='Argument missing for parameter "name"', line=11, - column=6, + column=1, ), Result( type="error", - message='Argument missing for parameter "name"', + message='No parameter named "n"', line=11, - column=1, + column=6, ), Result( type="information", diff --git a/tests/pyright/test_fields_input.py b/tests/pyright/test_fields_input.py index 6effeb0797..360ad6389b 100644 --- a/tests/pyright/test_fields_input.py +++ b/tests/pyright/test_fields_input.py @@ -25,15 +25,15 @@ def test_pyright(): assert results == [ Result( type="error", - message='No parameter named "n"', + message='Argument missing for parameter "name"', line=11, - column=6, + column=1, ), Result( type="error", - message='Argument missing for parameter "name"', + message='No parameter named "n"', line=11, - column=1, + column=6, ), Result( type="information", diff --git a/tests/pyright/test_fields_resolver.py b/tests/pyright/test_fields_resolver.py index c3a4add78a..0888b7cc2a 100644 --- a/tests/pyright/test_fields_resolver.py +++ b/tests/pyright/test_fields_resolver.py @@ -29,15 +29,15 @@ def test_pyright(): assert results == [ Result( type="error", - message='No parameter named "n"', + message='Argument missing for parameter "name"', line=15, - column=6, + column=1, ), Result( type="error", - message='Argument missing for parameter "name"', + message='No parameter named "n"', line=15, - column=1, + column=6, ), Result( type="information", diff --git a/tests/pyright/test_fields_resolver_async.py b/tests/pyright/test_fields_resolver_async.py index bdbb432491..6127d1384d 100644 --- a/tests/pyright/test_fields_resolver_async.py +++ b/tests/pyright/test_fields_resolver_async.py @@ -29,15 +29,15 @@ def test_pyright(): assert results == [ Result( type="error", - message='No parameter named "n"', + message='Argument missing for parameter "name"', line=15, - column=6, + column=1, ), Result( type="error", - message='Argument missing for parameter "name"', + message='No parameter named "n"', line=15, - column=1, + column=6, ), Result( type="information", diff --git a/tests/pyright/test_params.py b/tests/pyright/test_params.py index 171875bdde..ec7e4aa495 100644 --- a/tests/pyright/test_params.py +++ b/tests/pyright/test_params.py @@ -29,12 +29,6 @@ def test_pyright(): results = run_pyright(CODE) assert results == [ - Result( - type="error", - message='No parameter named "n"', - line=16, - column=11, - ), Result( type="error", message='Argument missing for parameter "name"', @@ -44,7 +38,7 @@ def test_pyright(): Result( type="error", message='No parameter named "n"', - line=19, + line=16, column=11, ), Result( @@ -53,4 +47,10 @@ def test_pyright(): line=19, column=1, ), + Result( + type="error", + message='No parameter named "n"', + line=19, + column=11, + ), ] diff --git a/tests/pyright/test_private.py b/tests/pyright/test_private.py index 8b83f6de81..ae228aadce 100644 --- a/tests/pyright/test_private.py +++ b/tests/pyright/test_private.py @@ -27,15 +27,15 @@ def test_pyright(): assert results == [ Result( type="error", - message='No parameter named "n"', + message='Arguments missing for parameters "name", "age"', line=12, - column=6, + column=1, ), Result( type="error", - message='Arguments missing for parameters "name", "age"', + message='No parameter named "n"', line=12, - column=1, + column=6, ), Result( type="information", diff --git a/tests/pyright/utils.py b/tests/pyright/utils.py index 3f44cc2e8c..29ce41ba4e 100644 --- a/tests/pyright/utils.py +++ b/tests/pyright/utils.py @@ -70,7 +70,7 @@ def run_pyright(code: str, strict: bool = True) -> List[Result]: pyright_result: PyrightCLIResult = json.loads(process_result.stdout.decode("utf-8")) - return [ + result = [ Result( type=cast(ResultType, diagnostic["severity"].strip()), message=diagnostic["message"].strip(), @@ -80,6 +80,11 @@ def run_pyright(code: str, strict: bool = True) -> List[Result]: for diagnostic in pyright_result["generalDiagnostics"] ] + # make sure that results are sorted by line and column and then message + result.sort(key=lambda x: (x.line, x.column, x.message)) + + return result + def pyright_exist() -> bool: return shutil.which("pyright") is not None diff --git a/tests/schema/test_schema_hooks.py b/tests/schema/test_schema_hooks.py new file mode 100644 index 0000000000..51038e2f49 --- /dev/null +++ b/tests/schema/test_schema_hooks.py @@ -0,0 +1,43 @@ +import textwrap +from typing import List + +import strawberry +from strawberry.field import StrawberryField +from strawberry.types.types import StrawberryObjectDefinition + + +def test_can_change_which_fields_are_exposed(): + @strawberry.type + class User: + name: str + email: str = strawberry.field(metadata={"tags": ["internal"]}) + + @strawberry.type + class Query: + user: User + + def public_field_filter(field: StrawberryField) -> bool: + return "internal" not in field.metadata.get("tags", []) + + class PublicSchema(strawberry.Schema): + def get_fields( + self, type_definition: StrawberryObjectDefinition + ) -> List[StrawberryField]: + fields = super().get_fields(type_definition) + return list(filter(public_field_filter, fields)) + + schema = PublicSchema(query=Query) + + expected_schema = textwrap.dedent( + """ + type Query { + user: User! + } + + type User { + name: String! + } + """ + ).strip() + + assert schema.as_str() == expected_schema