Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default values to Context and RootValue type vars #3732

Merged
merged 13 commits into from
Dec 20, 2024
32 changes: 11 additions & 21 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,21 @@ jobs:

steps:
- uses: actions/checkout@v4
- run: pipx install poetry
- run: pipx install coverage
- uses: actions/setup-python@v5
id: setup-python
with:
python-version: |
3.8
3.9
3.10
3.11
3.12
python-version: "3.12"
cache: "poetry"

- name: Pip and nox cache
id: cache
uses: actions/cache@v4
with:
path: |
~/.cache
~/.nox
.nox
key:
${{ runner.os }}-nox-lint-${{ env.pythonLocation }}-${{
hashFiles('**/poetry.lock') }}-${{ hashFiles('**/noxfile.py') }}
restore-keys: |
${{ runner.os }}-nox-lint-${{ env.pythonLocation }}
- run: poetry install --with integrations
if: steps.setup-python.outputs.cache-hit != 'true'

- run: pip install poetry nox nox-poetry uv
- run: nox -r -t lint
- run: |
mkdir .mypy_cache

poetry run mypy --install-types --non-interactive --cache-dir=.mypy_cache/ --config-file mypy.ini

unit-tests-on-windows:
name: 🪟 Tests on Windows
Expand Down
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Release type: patch

This release updates the Context and RootValue vars to have
a default value of `None`, this makes it easier to use the views
without having to pass in a value for these vars.
10 changes: 1 addition & 9 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def tests_typecheckers(session: Session) -> None:

session.install("pyright")
session.install("pydantic")
session.install("git+https://github.com/python/mypy.git#master")
session.install("mypy")

session.run(
"pytest",
Expand All @@ -181,11 +181,3 @@ def tests_cli(session: Session) -> None:
"tests/cli",
"-vv",
)


@session(name="Mypy", tags=["lint"])
def mypy(session: Session) -> None:
session.run_always("poetry", "install", "--with", "integrations", external=True)
session.install("mypy")

session.run("mypy", "--config-file", "mypy.ini")
405 changes: 246 additions & 159 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ poetry-plugin-export = "^1.6.0"
urllib3 = "<2"
graphlib_backport = {version = "*", python = "<3.9", optional = false}
inline-snapshot = "^0.10.1"
types-deprecated = "^1.2.15.20241117"
types-six = "^1.17.0.20241205"
types-pyyaml = "^6.0.12.20240917"
mypy = "^1.13.0"

[tool.poetry.group.integrations]
optional = true
Expand Down
23 changes: 16 additions & 7 deletions strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
overload,
Expand Down Expand Up @@ -116,11 +117,19 @@ class AsyncBaseHTTPView(
connection_init_wait_timeout: timedelta = timedelta(minutes=1)
request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter]
websocket_adapter_class: Callable[
["AsyncBaseHTTPView", WebSocketRequest, WebSocketResponse],
[
"AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue]",
WebSocketRequest,
WebSocketResponse,
],
AsyncWebSocketAdapter,
]
graphql_transport_ws_handler_class = BaseGraphQLTransportWSHandler
graphql_ws_handler_class = BaseGraphQLWSHandler
graphql_transport_ws_handler_class: Type[
BaseGraphQLTransportWSHandler[Context, RootValue]
] = BaseGraphQLTransportWSHandler[Context, RootValue]
graphql_ws_handler_class: Type[BaseGraphQLWSHandler[Context, RootValue]] = (
BaseGraphQLWSHandler[Context, RootValue]
)

@property
@abc.abstractmethod
Expand Down Expand Up @@ -281,8 +290,8 @@ async def run(
await self.graphql_transport_ws_handler_class(
view=self,
websocket=websocket,
context=context,
root_value=root_value,
context=context, # type: ignore
root_value=root_value, # type: ignore
schema=self.schema,
debug=self.debug,
connection_init_wait_timeout=self.connection_init_wait_timeout,
Expand All @@ -291,8 +300,8 @@ async def run(
await self.graphql_ws_handler_class(
view=self,
websocket=websocket,
context=context,
root_value=root_value,
context=context, # type: ignore
root_value=root_value, # type: ignore
schema=self.schema,
debug=self.debug,
keep_alive=self.keep_alive,
Expand Down
10 changes: 5 additions & 5 deletions strawberry/http/typevars.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import TypeVar
from typing_extensions import TypeVar

Request = TypeVar("Request", contravariant=True)
Response = TypeVar("Response")
SubResponse = TypeVar("SubResponse")
WebSocketRequest = TypeVar("WebSocketRequest")
WebSocketResponse = TypeVar("WebSocketResponse")
Context = TypeVar("Context")
RootValue = TypeVar("RootValue")
Context = TypeVar("Context", default=None)
RootValue = TypeVar("RootValue", default=None)


__all__ = [
"Context",
"Request",
"Response",
"RootValue",
"SubResponse",
"WebSocketRequest",
"WebSocketResponse",
"Context",
"RootValue",
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from contextlib import suppress
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Dict,
Generic,
List,
Optional,
cast,
Expand All @@ -20,6 +22,7 @@
NonTextMessageReceived,
WebSocketDisconnected,
)
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
CompleteMessage,
ConnectionInitMessage,
Expand All @@ -44,15 +47,15 @@
from strawberry.schema.subscribe import SubscriptionResult


class BaseGraphQLTransportWSHandler:
class BaseGraphQLTransportWSHandler(Generic[Context, RootValue]):
task_logger: logging.Logger = logging.getLogger("strawberry.ws.task")

def __init__(
self,
view: AsyncBaseHTTPView,
view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue],
websocket: AsyncWebSocketAdapter,
context: object,
root_value: object,
context: Context,
root_value: RootValue,
schema: BaseSchema,
debug: bool,
connection_init_wait_timeout: timedelta,
Expand All @@ -68,7 +71,7 @@ def __init__(
self.connection_init_received = False
self.connection_acknowledged = False
self.connection_timed_out = False
self.operations: Dict[str, Operation] = {}
self.operations: Dict[str, Operation[Context, RootValue]] = {}
self.completed_tasks: List[asyncio.Task] = []

async def handle(self) -> None:
Expand Down Expand Up @@ -184,6 +187,8 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
elif hasattr(self.context, "connection_params"):
self.context.connection_params = payload

self.context = cast(Context, self.context)

try:
connection_ack_payload = await self.view.on_ws_connect(self.context)
except ConnectionRejectionError:
Expand Down Expand Up @@ -250,7 +255,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
operation.task = asyncio.create_task(self.run_operation(operation))
self.operations[message["id"]] = operation

async def run_operation(self, operation: Operation) -> None:
async def run_operation(self, operation: Operation[Context, RootValue]) -> None:
"""The operation task's top level method. Cleans-up and de-registers the operation once it is done."""
# TODO: Handle errors in this method using self.handle_task_exception()

Expand Down Expand Up @@ -334,7 +339,7 @@ async def reap_completed_tasks(self) -> None:
await task


class Operation:
class Operation(Generic[Context, RootValue]):
"""A class encapsulating a single operation with its id. Helps enforce protocol state transition."""

__slots__ = [
Expand All @@ -350,7 +355,7 @@ class Operation:

def __init__(
self,
handler: BaseGraphQLTransportWSHandler,
handler: BaseGraphQLTransportWSHandler[Context, RootValue],
id: str,
operation_type: OperationType,
query: str,
Expand Down
16 changes: 12 additions & 4 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from contextlib import suppress
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
Generic,
Optional,
cast,
)

from strawberry.exceptions import ConnectionRejectionError
from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionInitMessage,
ConnectionTerminateMessage,
Expand All @@ -29,13 +32,16 @@
from strawberry.schema import BaseSchema


class BaseGraphQLWSHandler:
class BaseGraphQLWSHandler(Generic[Context, RootValue]):
context: Context
root_value: RootValue

def __init__(
self,
view: AsyncBaseHTTPView,
view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue],
websocket: AsyncWebSocketAdapter,
context: object,
root_value: object,
context: Context,
root_value: RootValue,
schema: BaseSchema,
debug: bool,
keep_alive: bool,
Expand Down Expand Up @@ -100,6 +106,8 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
elif hasattr(self.context, "connection_params"):
self.context.connection_params = payload

self.context = cast(Context, self.context)

try:
connection_ack_payload = await self.view.on_ws_connect(self.context)
except ConnectionRejectionError as e:
Expand Down
16 changes: 8 additions & 8 deletions tests/fastapi/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, AsyncGenerator, Dict
from typing import AsyncGenerator, Dict

import pytest

Expand Down Expand Up @@ -47,7 +47,7 @@ def get_context(custom_context: CustomContext = Depends(custom_context_dependenc

app = FastAPI()
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")

test_client = TestClient(app)
Expand Down Expand Up @@ -81,7 +81,7 @@ def get_context(custom_context: CustomContext = Depends()):

app = FastAPI()
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")

test_client = TestClient(app)
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_context(value: str = Depends(custom_context_dependency)) -> Dict[str, st

app = FastAPI()
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")

test_client = TestClient(app)
Expand All @@ -138,7 +138,7 @@ def abc(self, info: strawberry.Info) -> str:

app = FastAPI()
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter[None, None](schema, context_getter=None)
graphql_app = GraphQLRouter(schema, context_getter=None)
app.include_router(graphql_app, prefix="/graphql")

test_client = TestClient(app)
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_context(value: str = Depends(custom_context_dependency)) -> str:

app = FastAPI()
schema = strawberry.Schema(query=Query)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")

test_client = TestClient(app)
Expand Down Expand Up @@ -213,7 +213,7 @@ def get_context(context: Context = Depends()) -> Context:

app = FastAPI()
schema = strawberry.Schema(query=Query, subscription=Subscription)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")
test_client = TestClient(app)

Expand Down Expand Up @@ -287,7 +287,7 @@ def get_context(context: Context = Depends()) -> Context:

app = FastAPI()
schema = strawberry.Schema(query=Query, subscription=Subscription)
graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema=schema, context_getter=get_context)
app.include_router(graphql_app, prefix="/graphql")
test_client = TestClient(app)

Expand Down
Loading