Skip to content

Commit

Permalink
Add support for optional connections (#3707)
Browse files Browse the repository at this point in the history
* Support nullable Connection types in relay field decorator

Enable nullable Connection types in the connection field decorator by updating type checking logic and adding validation for inner types. Update documentation and add tests to ensure compatibility with permission extensions and different nullable syntax.

New Features:
- Support nullable Connection types in the connection field decorator in strawberry.relay.fields.

Enhancements:
- Update type checking logic to handle Optional[Connection[T]] and Connection[T] | None annotations.

Documentation:
- Update documentation to reflect that connection fields can now be nullable.

Tests:
- Add tests to verify nullable connection fields work correctly with permission extensions and both Optional[Connection[T]] and Connection[T] | None syntax.

Resolves #3703

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Initial working version

* Fix typos and types

* Pre-commit

* Fix missing `@`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add release file

* Update tests/relay/test_connection.py

Co-authored-by: Thiago Bellini Ribeiro <thiago@bellini.dev>

* Add check

---------

Co-authored-by: sourcery-ai[bot] <sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Thiago Bellini Ribeiro <thiago@bellini.dev>
  • Loading branch information
4 people authored Dec 20, 2024
1 parent c9dac1d commit 8a8e3aa
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 43 deletions.
51 changes: 51 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
Release type: minor

This release adds support for making Relay connection optional, this is useful
when you want to add permission classes to the connection and not fail the whole
query if the user doesn't have permission to access the connection.

Example:

```python
import strawberry
from strawberry import relay
from strawberry.permission import BasePermission


class IsAuthenticated(BasePermission):
message = "User is not authenticated"

# This method can also be async!
def has_permission(
self, source: typing.Any, info: strawberry.Info, **kwargs
) -> bool:
return False


@strawberry.type
class Fruit(relay.Node):
code: relay.NodeID[int]
name: str
weight: float

@classmethod
def resolve_nodes(
cls,
*,
info: strawberry.Info,
node_ids: Iterable[str],
):
return []


@strawberry.type
class Query:
node: relay.Node = relay.node()

@relay.connection(
relay.ListConnection[Fruit] | None, permission_classes=[IsAuthenticated()]
)
def fruits(self) -> Iterable[Fruit]:
# This can be a database query, a generator, an async generator, etc
return all_fruits.values()
```
2 changes: 1 addition & 1 deletion strawberry/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def create_optional(self, evaled_type: Any) -> StrawberryOptional:
)

# Note that passing a single type to `Union` is equivalent to not using `Union`
# at all. This allows us to not di any checks for how many types have been
# at all. This allows us to not do any checks for how many types have been
# passed as we can safely use `Union` for both optional types
# (e.g. `Optional[str]`) and optional unions (e.g.
# `Optional[Union[TypeA, TypeB]]`)
Expand Down
64 changes: 23 additions & 41 deletions strawberry/relay/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
Type,
Union,
cast,
overload,
)
from typing_extensions import Annotated, get_origin
from typing_extensions import Annotated, get_args, get_origin

from strawberry.annotation import StrawberryAnnotation
from strawberry.extensions.field_extension import (
Expand All @@ -44,9 +43,9 @@
from strawberry.types.fields.resolver import StrawberryResolver
from strawberry.types.lazy_type import LazyType
from strawberry.utils.aio import asyncgen_to_list
from strawberry.utils.typing import eval_type, is_generic_alias
from strawberry.utils.typing import eval_type, is_generic_alias, is_optional, is_union

from .types import Connection, GlobalID, Node, NodeIterableType, NodeType
from .types import Connection, GlobalID, Node

if TYPE_CHECKING:
from typing_extensions import Literal
Expand Down Expand Up @@ -233,7 +232,11 @@ def apply(self, field: StrawberryField) -> None:
f_type = f_type.resolve_type()
field.type = f_type

if isinstance(f_type, StrawberryOptional):
f_type = f_type.of_type

type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type

if not isinstance(type_origin, type) or not issubclass(type_origin, Connection):
raise RelayWrongAnnotationError(field.name, cast(type, field.origin))

Expand All @@ -253,13 +256,19 @@ def apply(self, field: StrawberryField) -> None:
None,
)

if is_union(resolver_type):
assert is_optional(resolver_type)

resolver_type = get_args(resolver_type)[0]

origin = get_origin(resolver_type)

if origin is None or not issubclass(
origin, (Iterator, Iterable, AsyncIterator, AsyncIterable)
):
raise RelayWrongResolverAnnotationError(field.name, field.base_resolver)

self.connection_type = cast(Type[Connection[Node]], field.type)
self.connection_type = cast(Type[Connection[Node]], f_type)

def resolve(
self,
Expand Down Expand Up @@ -327,44 +336,17 @@ def node(*args: Any, **kwargs: Any) -> StrawberryField:
return field(*args, **kwargs)


@overload
def connection(
graphql_type: Optional[Type[Connection[NodeType]]] = None,
*,
resolver: Optional[_RESOLVER_TYPE[NodeIterableType[Any]]] = None,
name: Optional[str] = None,
is_subscription: bool = False,
description: Optional[str] = None,
init: Literal[True] = True,
permission_classes: Optional[List[Type[BasePermission]]] = None,
deprecation_reason: Optional[str] = None,
default: Any = dataclasses.MISSING,
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: List[FieldExtension] = (), # type: ignore
) -> Any: ...


@overload
def connection(
graphql_type: Optional[Type[Connection[NodeType]]] = None,
*,
name: Optional[str] = None,
is_subscription: bool = False,
description: Optional[str] = None,
permission_classes: Optional[List[Type[BasePermission]]] = None,
deprecation_reason: Optional[str] = None,
default: Any = dataclasses.MISSING,
default_factory: Union[Callable[..., object], object] = dataclasses.MISSING,
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: List[FieldExtension] = (), # type: ignore
) -> StrawberryField: ...
# we used to have `Type[Connection[NodeType]]` here, but that when we added
# support for making the Connection type optional, we had to change it to
# `Any` because otherwise it wouldn't be type check since `Optional[Connection[Something]]`
# is not a `Type`, but a special form, see https://discuss.python.org/t/is-annotated-compatible-with-type-t/43898/46
# for more information, and also https://peps.python.org/pep-0747/, which is currently
# in draft status (and no type checker supports it yet)
ConnectionGraphQLType = Any


def connection(
graphql_type: Optional[Type[Connection[NodeType]]] = None,
graphql_type: Optional[ConnectionGraphQLType] = None,
*,
resolver: Optional[_RESOLVER_TYPE[Any]] = None,
name: Optional[str] = None,
Expand All @@ -379,7 +361,7 @@ def connection(
extensions: List[FieldExtension] = (), # type: ignore
# This init parameter is used by pyright to determine whether this field
# is added in the constructor or not. It is not used to change
# any behavior at the moment.
# any behaviour at the moment.
init: Literal[True, False, None] = None,
) -> Any:
"""Annotate a property or a method to create a relay connection field.
Expand Down
2 changes: 1 addition & 1 deletion strawberry/schema/subscribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def subscribe(
middleware_manager,
execution_context_class,
)
# GrapQL-core might return an initial error result instead of an async iterator.
# GraphQL-core might return an initial error result instead of an async iterator.
# This happens when "there was an immediate error" i.e resolver is not an async iterator.
# To overcome this while maintaining the extension contexts we do this trick.
first = await asyncgen.__anext__()
Expand Down
125 changes: 125 additions & 0 deletions tests/relay/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import sys
from typing import Any, Iterable, List, Optional
from typing_extensions import Self

import pytest

import strawberry
from strawberry.permission import BasePermission
from strawberry.relay import Connection, Node


@strawberry.type
class User(Node):
id: strawberry.relay.NodeID
name: str = "John"

@classmethod
def resolve_nodes(
cls, *, info: strawberry.Info, node_ids: List[Any], required: bool
) -> List[Self]:
return [cls() for _ in node_ids]


@strawberry.type
class UserConnection(Connection[User]):
@classmethod
def resolve_connection(
cls,
nodes: Iterable[User],
*,
info: Any,
after: Optional[str] = None,
before: Optional[str] = None,
first: Optional[int] = None,
last: Optional[int] = None,
) -> Optional[Self]:
return None


class TestPermission(BasePermission):
message = "Not allowed"

def has_permission(self, source, info, **kwargs: Any):
return False


def test_nullable_connection_with_optional():
@strawberry.type
class Query:
@strawberry.relay.connection(Optional[UserConnection])
def users(self) -> Optional[List[User]]:
return None

schema = strawberry.Schema(query=Query)
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": None}
assert not result.errors


@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="pipe syntax for union is only available on python 3.10+",
)
def test_nullable_connection_with_pipe():
@strawberry.type
class Query:
@strawberry.relay.connection(UserConnection | None)
def users(self) -> List[User] | None:
return None

schema = strawberry.Schema(query=Query)
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": None}
assert not result.errors


def test_nullable_connection_with_permission():
@strawberry.type
class Query:
@strawberry.relay.connection(
Optional[UserConnection], permission_classes=[TestPermission]
)
def users(self) -> Optional[List[User]]: # pragma: no cover
pytest.fail("Should not have been called...")

schema = strawberry.Schema(query=Query)
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": None}
assert result.errors[0].message == "Not allowed"

0 comments on commit 8a8e3aa

Please sign in to comment.