Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way to filter fields #3274

Merged
merged 5 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
Release type: minor

This release adds a new method `get_fields` on the `Schema` class.
You can use `get_fields` to hide certain field based on some conditions,
for example:

```python
@strawberry.type
class User:
name: str
email: str = strawberry.field(metadata={"tags": ["internal"]})


@strawberry.type
class Query:
user: User


def public_field_filter(field: StrawberryField) -> bool:
return "internal" not in field.metadata.get("tags", [])


class PublicSchema(strawberry.Schema):
def get_fields(
self, type_definition: StrawberryObjectDefinition
) -> List[StrawberryField]:
return list(filter(public_field_filter, type_definition.fields))


schema = PublicSchema(query=Query)
```

The schema here would only have the `name` field on the `User` type.
40 changes: 40 additions & 0 deletions docs/types/schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,43 @@ class StrawberryLogger:

cls.logger.error(error, exc_info=error.original_error, **logger_kwargs)
```

## Filtering/customising fields

You can customise the fields that are exposed on a schema by subclassing the
`Schema` class and overriding the `get_fields` method, for example you can use
this to create different GraphQL APIs, such as a public and an internal API.
Here's an example of this:

```python
@strawberry.type
class User:
name: str
email: str = strawberry.field(metadata={"tags": ["internal"]})


@strawberry.type
class Query:
user: User


def public_field_filter(field: StrawberryField) -> bool:
return "internal" not in field.metadata.get("tags", [])


class PublicSchema(strawberry.Schema):
def get_fields(
self, type_definition: StrawberryObjectDefinition
) -> List[StrawberryField]:
return list(filter(public_field_filter, type_definition.fields))


schema = PublicSchema(query=Query)
```

<Note>

The `get_fields` method is only called once when creating the schema, this is
not intended to be used to dynamically customise the schema.

</Note>
9 changes: 8 additions & 1 deletion strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def __init__(
# TODO: check that the overrides are valid
scalar_registry.update(cast(SCALAR_OVERRIDES_DICT_TYPE, scalar_overrides))

self.schema_converter = GraphQLCoreConverter(self.config, scalar_registry)
self.schema_converter = GraphQLCoreConverter(
self.config, scalar_registry, self.get_fields
)
self.directives = directives
self.schema_directives = list(schema_directives)

Expand Down Expand Up @@ -231,6 +233,11 @@ def get_directive_by_name(self, graphql_name: str) -> Optional[StrawberryDirecti
None,
)

def get_fields(
self, type_definition: StrawberryObjectDefinition
) -> List[StrawberryField]:
return type_definition.fields

async def execute(
self,
query: Optional[str],
Expand Down
9 changes: 8 additions & 1 deletion strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _get_thunk_mapping(
type_definition: StrawberryObjectDefinition,
name_converter: Callable[[StrawberryField], str],
field_converter: FieldConverterProtocol[FieldType],
get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]],
) -> Dict[str, FieldType]:
"""Create a GraphQL core `ThunkMapping` mapping of field names to field types.

Expand All @@ -123,7 +124,9 @@ def _get_thunk_mapping(
"""
thunk_mapping: Dict[str, FieldType] = {}

for field in type_definition.fields:
fields = get_fields(type_definition)

for field in fields:
field_type = field.type

if field_type is UNRESOLVED:
Expand Down Expand Up @@ -178,10 +181,12 @@ def __init__(
self,
config: StrawberryConfig,
scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]],
get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]],
):
self.type_map: Dict[str, ConcreteType] = {}
self.config = config
self.scalar_registry = scalar_registry
self.get_fields = get_fields

def from_argument(self, argument: StrawberryArgument) -> GraphQLArgument:
argument_type = cast(
Expand Down Expand Up @@ -374,6 +379,7 @@ def get_graphql_fields(
type_definition=type_definition,
name_converter=self.config.name_converter.from_field,
field_converter=self.from_field,
get_fields=self.get_fields,
)

def get_graphql_input_fields(
Expand All @@ -383,6 +389,7 @@ def get_graphql_input_fields(
type_definition=type_definition,
name_converter=self.config.name_converter.from_field,
field_converter=self.from_input_field,
get_fields=self.get_fields,
)

def from_input_object(self, object_type: type) -> GraphQLInputObjectType:
Expand Down
18 changes: 9 additions & 9 deletions tests/pyright/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def test_federation_type():
results = run_pyright(CODE)

assert results == [
Result(type="error", message='No parameter named "n"', line=16, column=6),
Result(
type="error",
message='Argument missing for parameter "name"',
line=16,
column=1,
),
Result(type="error", message='No parameter named "n"', line=16, column=6),
Result(
type="information",
message='Type of "User" is "type[User]"',
Expand Down Expand Up @@ -75,15 +75,15 @@ def test_federation_interface():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=12,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=12,
column=1,
column=6,
),
Result(
type="information",
Expand Down Expand Up @@ -122,15 +122,15 @@ def test_federation_input():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=10,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=10,
column=1,
column=6,
),
Result(
type="information",
Expand Down
14 changes: 7 additions & 7 deletions tests/pyright/test_federation_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def test_pyright():
results = run_pyright(CODE)

assert results == [
Result(
type="error",
message='Argument missing for parameter "name"',
line=24,
column=1,
),
Result(
type="error",
message='No parameter named "n"',
Expand All @@ -51,7 +57,7 @@ def test_pyright():
Result(
type="error",
message='Argument missing for parameter "name"',
line=24,
line=27,
column=1,
),
Result(
Expand All @@ -60,12 +66,6 @@ def test_pyright():
line=27,
column=11,
),
Result(
type="error",
message='Argument missing for parameter "name"',
line=27,
column=1,
),
Result(
type="information",
message='Type of "User" is "type[User]"',
Expand Down
8 changes: 4 additions & 4 deletions tests/pyright/test_federation_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=11,
column=11,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=11,
column=1,
column=11,
),
]
8 changes: 4 additions & 4 deletions tests/pyright/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=11,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=11,
column=1,
column=6,
),
Result(
type="information",
Expand Down
8 changes: 4 additions & 4 deletions tests/pyright/test_fields_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=11,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=11,
column=1,
column=6,
),
Result(
type="information",
Expand Down
8 changes: 4 additions & 4 deletions tests/pyright/test_fields_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=15,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=15,
column=1,
column=6,
),
Result(
type="information",
Expand Down
8 changes: 4 additions & 4 deletions tests/pyright/test_fields_resolver_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Argument missing for parameter "name"',
line=15,
column=6,
column=1,
),
Result(
type="error",
message='Argument missing for parameter "name"',
message='No parameter named "n"',
line=15,
column=1,
column=6,
),
Result(
type="information",
Expand Down
14 changes: 7 additions & 7 deletions tests/pyright/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ def test_pyright():
results = run_pyright(CODE)

assert results == [
Result(
type="error",
message='No parameter named "n"',
line=16,
column=11,
),
Result(
type="error",
message='Argument missing for parameter "name"',
Expand All @@ -44,7 +38,7 @@ def test_pyright():
Result(
type="error",
message='No parameter named "n"',
line=19,
line=16,
column=11,
),
Result(
Expand All @@ -53,4 +47,10 @@ def test_pyright():
line=19,
column=1,
),
Result(
type="error",
message='No parameter named "n"',
line=19,
column=11,
),
]
8 changes: 4 additions & 4 deletions tests/pyright/test_private.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def test_pyright():
assert results == [
Result(
type="error",
message='No parameter named "n"',
message='Arguments missing for parameters "name", "age"',
line=12,
column=6,
column=1,
),
Result(
type="error",
message='Arguments missing for parameters "name", "age"',
message='No parameter named "n"',
line=12,
column=1,
column=6,
),
Result(
type="information",
Expand Down
Loading
Loading