From 62258da6e1720384c266ef22023c995d5ea3b5e3 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 20 Dec 2024 09:59:25 +0100 Subject: [PATCH] Add default values to Context and RootValue type vars --- strawberry/http/typevars.py | 10 +++++----- tests/fastapi/test_context.py | 16 ++++++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/strawberry/http/typevars.py b/strawberry/http/typevars.py index 53a5d5ac33..a1f6020e83 100644 --- a/strawberry/http/typevars.py +++ b/strawberry/http/typevars.py @@ -1,20 +1,20 @@ -from typing import TypeVar +from typing_extensions import TypeVar Request = TypeVar("Request", contravariant=True) Response = TypeVar("Response") SubResponse = TypeVar("SubResponse") WebSocketRequest = TypeVar("WebSocketRequest") WebSocketResponse = TypeVar("WebSocketResponse") -Context = TypeVar("Context") -RootValue = TypeVar("RootValue") +Context = TypeVar("Context", default=None) +RootValue = TypeVar("RootValue", default=None) __all__ = [ + "Context", "Request", "Response", + "RootValue", "SubResponse", "WebSocketRequest", "WebSocketResponse", - "Context", - "RootValue", ] diff --git a/tests/fastapi/test_context.py b/tests/fastapi/test_context.py index 48eebc9550..16530b4e09 100644 --- a/tests/fastapi/test_context.py +++ b/tests/fastapi/test_context.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, AsyncGenerator, Dict +from typing import AsyncGenerator, Dict import pytest @@ -47,7 +47,7 @@ def get_context(custom_context: CustomContext = Depends(custom_context_dependenc app = FastAPI() schema = strawberry.Schema(query=Query) - graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context) + graphql_app = GraphQLRouter(schema=schema, context_getter=get_context) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) @@ -81,7 +81,7 @@ def get_context(custom_context: CustomContext = Depends()): app = FastAPI() schema = strawberry.Schema(query=Query) - graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context) + graphql_app = GraphQLRouter(schema=schema, context_getter=get_context) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) @@ -113,7 +113,7 @@ def get_context(value: str = Depends(custom_context_dependency)) -> Dict[str, st app = FastAPI() schema = strawberry.Schema(query=Query) - graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context) + graphql_app = GraphQLRouter(schema=schema, context_getter=get_context) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) @@ -138,7 +138,7 @@ def abc(self, info: strawberry.Info) -> str: app = FastAPI() schema = strawberry.Schema(query=Query) - graphql_app = GraphQLRouter[None, None](schema, context_getter=None) + graphql_app = GraphQLRouter(schema, context_getter=None) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) @@ -169,7 +169,7 @@ def get_context(value: str = Depends(custom_context_dependency)) -> str: app = FastAPI() schema = strawberry.Schema(query=Query) - graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context) + graphql_app = GraphQLRouter(schema=schema, context_getter=get_context) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) @@ -213,7 +213,7 @@ def get_context(context: Context = Depends()) -> Context: app = FastAPI() schema = strawberry.Schema(query=Query, subscription=Subscription) - graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context) + graphql_app = GraphQLRouter(schema=schema, context_getter=get_context) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) @@ -287,7 +287,7 @@ def get_context(context: Context = Depends()) -> Context: app = FastAPI() schema = strawberry.Schema(query=Query, subscription=Subscription) - graphql_app = GraphQLRouter[Any, None](schema=schema, context_getter=get_context) + graphql_app = GraphQLRouter(schema=schema, context_getter=get_context) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app)