From 15044cd4c4e94035c08eb40c0a5394cc1bd74df5 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Fri, 3 Jan 2025 19:10:53 +0100 Subject: [PATCH] Type internal test clients stricter (#3745) * Type internal test clients stricter * Remove unneeded exception handling --- .../protocols/graphql_ws/handlers.py | 3 - tests/http/clients/aiohttp.py | 14 ++--- tests/http/clients/asgi.py | 24 +++----- tests/http/clients/async_django.py | 15 +++-- tests/http/clients/async_flask.py | 4 +- tests/http/clients/base.py | 12 ++-- tests/http/clients/chalice.py | 7 ++- tests/http/clients/channels.py | 57 +++++++++++-------- tests/http/clients/django.py | 43 +++++++------- tests/http/clients/fastapi.py | 19 ++----- tests/http/clients/flask.py | 6 +- tests/http/clients/litestar.py | 24 +++----- tests/http/clients/quart.py | 4 +- tests/websockets/views.py | 19 ++++++- 14 files changed, 129 insertions(+), 122 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 352d5c5f08..8722618fed 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -34,9 +34,6 @@ class BaseGraphQLWSHandler(Generic[Context, RootValue]): - context: Context - root_value: RootValue - def __init__( self, view: AsyncBaseHTTPView[Any, Any, Any, Any, Any, Context, RootValue], diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index 915559a5f0..421fe961a1 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -2,9 +2,9 @@ import contextlib import json -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Mapping from io import BytesIO -from typing import Any, Optional +from typing import Any, Optional, Union from typing_extensions import Literal from aiohttp import web @@ -37,7 +37,7 @@ class GraphQLView(OnWSConnectMixin, BaseGraphQLView[dict[str, object], object]): graphql_ws_handler_class = DebuggableGraphQLWSHandler async def get_context( - self, request: web.Request, response: web.StreamResponse + self, request: web.Request, response: Union[web.Response, web.WebSocketResponse] ) -> dict[str, object]: context = await super().get_context(request, response) @@ -95,7 +95,7 @@ def create_app(self, **kwargs: Any) -> None: async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, + query: str, variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, headers: Optional[dict[str, str]] = None, @@ -163,7 +163,7 @@ async def post( return Response( status_code=response.status, data=(await response.text()).encode(), - headers=response.headers, + headers=dict(response.headers), ) @contextlib.asynccontextmanager @@ -186,7 +186,7 @@ def __init__(self, ws: ClientWebSocketResponse): async def send_text(self, payload: str) -> None: await self.ws.send_str(payload) - async def send_json(self, payload: dict[str, Any]) -> None: + async def send_json(self, payload: Mapping[str, object]) -> None: await self.ws.send_json(payload) async def send_bytes(self, payload: bytes) -> None: @@ -197,7 +197,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message: self._reason = m.extra return Message(type=m.type, data=m.data, extra=m.extra) - async def receive_json(self, timeout: Optional[float] = None) -> Any: + async def receive_json(self, timeout: Optional[float] = None) -> object: m = await self.ws.receive(timeout) assert m.type == WSMsgType.TEXT return json.loads(m.data) diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index b836dcce66..9a8036d688 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -2,7 +2,7 @@ import contextlib import json -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Mapping from io import BytesIO from typing import Any, Optional, Union from typing_extensions import Literal @@ -10,7 +10,7 @@ from starlette.requests import Request from starlette.responses import Response as StarletteResponse from starlette.testclient import TestClient, WebSocketTestSession -from starlette.websockets import WebSocket, WebSocketDisconnect +from starlette.websockets import WebSocket from strawberry.asgi import GraphQL as BaseGraphQLView from strawberry.http import GraphQLHTTPResponse @@ -86,7 +86,7 @@ def create_app(self, **kwargs: Any) -> None: async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, + query: str, variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, headers: Optional[dict[str, str]] = None, @@ -152,7 +152,7 @@ async def post( return Response( status_code=response.status_code, data=response.content, - headers=response.headers, + headers=dict(response.headers), ) @contextlib.asynccontextmanager @@ -162,13 +162,8 @@ async def ws_connect( *, protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: - try: - with self.client.websocket_connect(url, protocols) as ws: - yield AsgiWebSocketClient(ws) - except WebSocketDisconnect as error: - ws = AsgiWebSocketClient(None) - ws.handle_disconnect(error) - yield ws + with self.client.websocket_connect(url, protocols) as ws: + yield AsgiWebSocketClient(ws) class AsgiWebSocketClient(WebSocketClient): @@ -178,15 +173,10 @@ def __init__(self, ws: WebSocketTestSession): self._close_code: Optional[int] = None self._close_reason: Optional[str] = None - def handle_disconnect(self, exc: WebSocketDisconnect) -> None: - self._closed = True - self._close_code = exc.code - self._close_reason = exc.reason - async def send_text(self, payload: str) -> None: self.ws.send_text(payload) - async def send_json(self, payload: dict[str, Any]) -> None: + async def send_json(self, payload: Mapping[str, object]) -> None: self.ws.send_json(payload) async def send_bytes(self, payload: bytes) -> None: diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py index 5f5caf4dd7..fe97a43c8b 100644 --- a/tests/http/clients/async_django.py +++ b/tests/http/clients/async_django.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import AsyncIterable + from django.core.exceptions import BadRequest, SuspiciousOperation from django.http import Http404, HttpRequest, HttpResponse, StreamingHttpResponse -from django.test.client import RequestFactory from strawberry.django.views import AsyncGraphQLView as BaseAsyncGraphQLView from strawberry.http import GraphQLHTTPResponse @@ -14,14 +15,16 @@ from .django import DjangoHttpClient -class AsyncGraphQLView(BaseAsyncGraphQLView): +class AsyncGraphQLView(BaseAsyncGraphQLView[dict[str, object], object]): result_override: ResultOverrideFunction = None async def get_root_value(self, request: HttpRequest) -> Query: await super().get_root_value(request) # for coverage return Query() - async def get_context(self, request: HttpRequest, response: HttpResponse) -> object: + async def get_context( + self, request: HttpRequest, response: HttpResponse + ) -> dict[str, object]: context = {"request": request, "response": response} return get_context(context) @@ -36,7 +39,7 @@ async def process_result( class AsyncDjangoHttpClient(DjangoHttpClient): - async def _do_request(self, request: RequestFactory) -> Response: + async def _do_request(self, request: HttpRequest) -> Response: view = AsyncGraphQLView.as_view( schema=schema, graphiql=self.graphiql, @@ -56,14 +59,16 @@ async def _do_request(self, request: RequestFactory) -> Response: data=e.args[0].encode(), headers={}, ) + data = ( response.streaming_content if isinstance(response, StreamingHttpResponse) + and isinstance(response.streaming_content, AsyncIterable) else response.content ) return Response( status_code=response.status_code, data=data, - headers=response.headers, + headers=dict(response.headers), ) diff --git a/tests/http/clients/async_flask.py b/tests/http/clients/async_flask.py index 3b8e45b755..d828f7d929 100644 --- a/tests/http/clients/async_flask.py +++ b/tests/http/clients/async_flask.py @@ -16,12 +16,12 @@ from .flask import FlaskHttpClient -class GraphQLView(BaseAsyncGraphQLView): +class GraphQLView(BaseAsyncGraphQLView[dict[str, object], object]): methods = ["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"] result_override: ResultOverrideFunction = None - def __init__(self, *args: str, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any): self.result_override = kwargs.pop("result_override") super().__init__(*args, **kwargs) diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 426aa0e19b..795dc60094 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -108,7 +108,7 @@ def __init__( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, + query: str, variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, headers: Optional[dict[str, str]] = None, @@ -141,7 +141,7 @@ async def post( async def query( self, - query: Optional[str] = None, + query: str, method: Literal["get", "post"] = "post", variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, @@ -302,7 +302,9 @@ async def send_legacy_message(self, message: OperationMessage) -> None: await self.send_json(message) -class DebuggableGraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): +class DebuggableGraphQLTransportWSHandler( + BaseGraphQLTransportWSHandler[dict[str, object], object] +): def on_init(self) -> None: """This method can be patched by unit tests to get the instance of the transport handler when it is initialized. @@ -330,10 +332,10 @@ def context(self, value): self.original_context = value -class DebuggableGraphQLWSHandler(BaseGraphQLWSHandler): +class DebuggableGraphQLWSHandler(BaseGraphQLWSHandler[dict[str, object], object]): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) - self.original_context = self.context + self.original_context = kwargs.get("context", {}) def get_tasks(self) -> list: return list(self.tasks.values()) diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index 8fbb31ff46..e57062bb3b 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -20,7 +20,7 @@ from .base import JSON, HttpClient, Response, ResultOverrideFunction -class GraphQLView(BaseGraphQLView): +class GraphQLView(BaseGraphQLView[dict[str, object], object]): result_override: ResultOverrideFunction = None def get_root_value(self, request: ChaliceRequest) -> Query: @@ -29,7 +29,7 @@ def get_root_value(self, request: ChaliceRequest) -> Query: def get_context( self, request: ChaliceRequest, response: TemporalResponse - ) -> object: + ) -> dict[str, object]: context = super().get_context(request, response) return get_context(context) @@ -66,12 +66,13 @@ def __init__( "/graphql", methods=["GET", "POST"], content_types=["application/json"] ) def handle_graphql(): + assert self.app.current_request is not None return view.execute_request(self.app.current_request) async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, + query: str, variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, headers: Optional[dict[str, str]] = None, diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index bfec59a1f4..6fe6a135e1 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -2,7 +2,7 @@ import contextlib import json as json_module -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Mapping from io import BytesIO from typing import Any, Optional from typing_extensions import Literal @@ -15,10 +15,11 @@ GraphQLWSConsumer, SyncGraphQLHTTPConsumer, ) -from strawberry.channels.handlers.base import ChannelsConsumer +from strawberry.channels.handlers.http_handler import ChannelsRequest from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE -from strawberry.http.typevars import Context, RootValue +from strawberry.http.temporal_response import TemporalResponse +from strawberry.types import ExecutionResult from tests.views.schema import Query, schema from tests.websockets.views import OnWSConnectMixin @@ -67,23 +68,23 @@ def create_multipart_request_body( return headers, request_body -class DebuggableGraphQLHTTPConsumer(GraphQLHTTPConsumer): +class DebuggableGraphQLHTTPConsumer(GraphQLHTTPConsumer[dict[str, object], object]): result_override: ResultOverrideFunction = None def __init__(self, *args: Any, **kwargs: Any): self.result_override = kwargs.pop("result_override") super().__init__(*args, **kwargs) - async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue]: + async def get_root_value(self, request: ChannelsRequest): return Query() - async def get_context(self, request: ChannelsConsumer, response: Any) -> Context: + async def get_context(self, request: ChannelsRequest, response: TemporalResponse): context = await super().get_context(request, response) return get_context(context) async def process_result( - self, request: ChannelsConsumer, result: Any + self, request: ChannelsRequest, result: ExecutionResult ) -> GraphQLHTTPResponse: if self.result_override: return self.result_override(result) @@ -91,23 +92,25 @@ async def process_result( return await super().process_result(request, result) -class DebuggableSyncGraphQLHTTPConsumer(SyncGraphQLHTTPConsumer): +class DebuggableSyncGraphQLHTTPConsumer( + SyncGraphQLHTTPConsumer[dict[str, object], object] +): result_override: ResultOverrideFunction = None def __init__(self, *args: Any, **kwargs: Any): self.result_override = kwargs.pop("result_override") super().__init__(*args, **kwargs) - def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue]: + def get_root_value(self, request: ChannelsRequest): return Query() - def get_context(self, request: ChannelsConsumer, response: Any) -> Context: + def get_context(self, request: ChannelsRequest, response: TemporalResponse): context = super().get_context(request, response) return get_context(context) def process_result( - self, request: ChannelsConsumer, result: Any + self, request: ChannelsRequest, result: ExecutionResult ) -> GraphQLHTTPResponse: if self.result_override: return self.result_override(result) @@ -115,11 +118,15 @@ def process_result( return super().process_result(request, result) -class DebuggableGraphQLWSConsumer(OnWSConnectMixin, GraphQLWSConsumer): +class DebuggableGraphQLWSConsumer( + OnWSConnectMixin, GraphQLWSConsumer[dict[str, object], object] +): graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler graphql_ws_handler_class = DebuggableGraphQLWSHandler - async def get_context(self, request, response): + async def get_context( + self, request: GraphQLWSConsumer, response: GraphQLWSConsumer + ): context = await super().get_context(request, response) return get_context(context) @@ -156,7 +163,7 @@ def create_app(self, **kwargs: Any) -> None: async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, + query: str, variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, headers: Optional[dict[str, str]] = None, @@ -169,10 +176,9 @@ async def _graphql_request( headers = self._get_headers(method=method, headers=headers, files=files) if method == "post": - if files: - new_headers, body = create_multipart_request_body(body, files) - for k, v in new_headers: - headers[k] = v + if body and files: + header_pairs, body = create_multipart_request_body(body, files) + headers = dict(header_pairs) else: body = json_module.dumps(body).encode() endpoint_url = "/graphql" @@ -188,19 +194,20 @@ async def request( self, url: str, method: Literal["get", "post", "patch", "put", "delete"], - body: bytes = b"", headers: Optional[dict[str, str]] = None, + body: bytes = b"", ) -> Response: # HttpCommunicator expects tuples of bytestrings - if headers: - headers = [(k.encode(), v.encode()) for k, v in headers.items()] + header_tuples = ( + [(k.encode(), v.encode()) for k, v in headers.items()] if headers else [] + ) communicator = HttpCommunicator( self.http_app, method.upper(), url, body=body, - headers=headers, + headers=header_tuples, ) response = await communicator.get_response() @@ -286,14 +293,14 @@ def name(self) -> str: async def send_text(self, payload: str) -> None: await self.ws.send_to(text_data=payload) - async def send_json(self, payload: dict[str, Any]) -> None: + async def send_json(self, payload: Mapping[str, object]) -> None: await self.ws.send_json_to(payload) async def send_bytes(self, payload: bytes) -> None: await self.ws.send_to(bytes_data=payload) async def receive(self, timeout: Optional[float] = None) -> Message: - m = await self.ws.receive_output(timeout=timeout) + m = await self.ws.receive_output(timeout=timeout) # type: ignore if m["type"] == "websocket.close": self._closed = True self._close_code = m["code"] @@ -304,7 +311,7 @@ async def receive(self, timeout: Optional[float] = None) -> Message: return Message(type=m["type"], data=m["data"], extra=m["extra"]) async def receive_json(self, timeout: Optional[float] = None) -> Any: - m = await self.ws.receive_output(timeout=timeout) + m = await self.ws.receive_output(timeout=timeout) # type: ignore assert m["type"] == "websocket.send" assert "text" in m return json_module.loads(m["text"]) diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index 10b19893c3..fae0823f0c 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -20,14 +20,16 @@ from .base import JSON, HttpClient, Response, ResultOverrideFunction -class GraphQLView(BaseGraphQLView): +class GraphQLView(BaseGraphQLView[dict[str, object], object]): result_override: ResultOverrideFunction = None def get_root_value(self, request) -> Query: super().get_root_value(request) # for coverage return Query() - def get_context(self, request: HttpRequest, response: HttpResponse) -> object: + def get_context( + self, request: HttpRequest, response: HttpResponse + ) -> dict[str, object]: context = {"request": request, "response": response} return get_context(context) @@ -70,7 +72,7 @@ def _get_headers( return super()._get_headers(method=method, headers=headers, files=files) - async def _do_request(self, request: RequestFactory) -> Response: + async def _do_request(self, request: HttpRequest) -> Response: view = GraphQLView.as_view( schema=schema, graphiql=self.graphiql, @@ -83,24 +85,20 @@ async def _do_request(self, request: RequestFactory) -> Response: try: response = view(request) except Http404: - return Response( - status_code=404, data=b"Not found", headers=response.headers - ) + return Response(status_code=404, data=b"Not found") except (BadRequest, SuspiciousOperation) as e: - return Response( - status_code=400, data=e.args[0].encode(), headers=response.headers - ) + return Response(status_code=400, data=e.args[0].encode()) else: return Response( status_code=response.status_code, data=response.content, - headers=response.headers, + headers=dict(response.headers), ) async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, + query: str, variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, headers: Optional[dict[str, str]] = None, @@ -116,11 +114,12 @@ async def _graphql_request( data: Union[dict[str, object], str, None] = None if body and files: - files = { - name: SimpleUploadedFile(name, file.read()) - for name, file in files.items() - } - body.update(files) + body.update( + { + name: SimpleUploadedFile(name, file.read()) + for name, file in files.items() + } + ) else: additional_arguments["content_type"] = "application/json" @@ -142,11 +141,7 @@ async def request( method: Literal["get", "post", "patch", "put", "delete"], headers: Optional[dict[str, str]] = None, ) -> Response: - headers = self._get_headers( - method=method, # type: ignore - headers=headers, - files=None, - ) + headers = headers or {} factory = RequestFactory() request = getattr(factory, method)(url, **headers) @@ -158,6 +153,12 @@ async def get( url: str, headers: Optional[dict[str, str]] = None, ) -> Response: + headers = self._get_headers( + method="get", + headers=headers, + files=None, + ) + return await self.request(url, "get", headers=headers) async def post( diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 0f0ed88398..21cf010d54 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -7,8 +7,6 @@ from typing import Any, Optional from typing_extensions import Literal -from starlette.websockets import WebSocketDisconnect - from fastapi import BackgroundTasks, Depends, FastAPI, Request, WebSocket from fastapi.testclient import TestClient from strawberry.fastapi import GraphQLRouter as BaseGraphQLRouter @@ -35,7 +33,7 @@ def custom_context_dependency() -> str: return "Hi!" -async def fastapi_get_context( +def fastapi_get_context( background_tasks: BackgroundTasks, request: Request = None, # type: ignore ws: WebSocket = None, # type: ignore @@ -49,14 +47,14 @@ async def fastapi_get_context( ) -async def get_root_value( +def get_root_value( request: Request = None, # type: ignore - FastAPI ws: WebSocket = None, # type: ignore - FastAPI ) -> Query: return Query() -class GraphQLRouter(OnWSConnectMixin, BaseGraphQLRouter[Any, Any]): +class GraphQLRouter(OnWSConnectMixin, BaseGraphQLRouter[dict[str, object], object]): result_override: ResultOverrideFunction = None graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler graphql_ws_handler_class = DebuggableGraphQLWSHandler @@ -114,7 +112,7 @@ async def _handle_response(self, response: Any) -> Response: async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, + query: str, variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, headers: Optional[dict[str, str]] = None, @@ -179,10 +177,5 @@ async def ws_connect( *, protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: - try: - with self.client.websocket_connect(url, protocols) as ws: - yield AsgiWebSocketClient(ws) - except WebSocketDisconnect as error: - ws = AsgiWebSocketClient(None) - ws.handle_disconnect(error) - yield ws + with self.client.websocket_connect(url, protocols) as ws: + yield AsgiWebSocketClient(ws) diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index 7f5ec5cb8f..7509d0f911 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -22,7 +22,7 @@ from .base import JSON, HttpClient, Response, ResultOverrideFunction -class GraphQLView(BaseGraphQLView): +class GraphQLView(BaseGraphQLView[dict[str, object], object]): # this allows to test our code path for checking the request type # TODO: we might want to remove our check since it is done by flask # already @@ -30,7 +30,7 @@ class GraphQLView(BaseGraphQLView): result_override: ResultOverrideFunction = None - def __init__(self, *args: str, **kwargs: Any): + def __init__(self, *args: Any, **kwargs: Any): self.result_override = kwargs.pop("result_override") super().__init__(*args, **kwargs) @@ -84,7 +84,7 @@ def __init__( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, + query: str, variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, headers: Optional[dict[str, str]] = None, diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index f357bf75f8..931cc98297 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -2,7 +2,7 @@ import contextlib import json -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Mapping from io import BytesIO from typing import Any, Optional from typing_extensions import Literal @@ -87,7 +87,7 @@ async def process_result( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, + query: str, variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, headers: Optional[dict[str, str]] = None, @@ -151,7 +151,7 @@ async def post( return Response( status_code=response.status_code, data=response.content, - headers=response.headers, + headers=dict(response.headers), ) @contextlib.asynccontextmanager @@ -161,13 +161,8 @@ async def ws_connect( *, protocols: list[str], ) -> AsyncGenerator[WebSocketClient, None]: - try: - with self.client.websocket_connect(url, protocols) as ws: - yield LitestarWebSocketClient(ws) - except WebSocketDisconnect as error: - ws = LitestarWebSocketClient(None) - ws.handle_disconnect(error) - yield ws + with self.client.websocket_connect(url, protocols) as ws: + yield LitestarWebSocketClient(ws) class LitestarWebSocketClient(WebSocketClient): @@ -177,14 +172,10 @@ def __init__(self, ws: WebSocketTestSession): self._close_code: Optional[int] = None self._close_reason: Optional[str] = None - def handle_disconnect(self, exc: WebSocketDisconnect) -> None: - self._closed = True - self._close_code = exc.code - async def send_text(self, payload: str) -> None: self.ws.send_text(payload) - async def send_json(self, payload: dict[str, Any]) -> None: + async def send_json(self, payload: Mapping[str, object]) -> None: self.ws.send_json(payload) async def send_bytes(self, payload: bytes) -> None: @@ -211,12 +202,15 @@ async def receive(self, timeout: Optional[float] = None) -> Message: return Message(type=m["type"], data=m["code"], extra=m["reason"]) elif m["type"] == "websocket.send": return Message(type=m["type"], data=m["text"]) + + assert "data" in m return Message(type=m["type"], data=m["data"], extra=m["extra"]) async def receive_json(self, timeout: Optional[float] = None) -> Any: m = self.ws.receive() assert m["type"] == "websocket.send" assert "text" in m + assert m["text"] is not None return json.loads(m["text"]) async def close(self) -> None: diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 43d21ee28f..a562aa0f5a 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -18,7 +18,7 @@ from .base import JSON, HttpClient, Response, ResultOverrideFunction -class GraphQLView(BaseGraphQLView): +class GraphQLView(BaseGraphQLView[dict[str, object], object]): methods = ["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"] result_override: ResultOverrideFunction = None @@ -77,7 +77,7 @@ def __init__( async def _graphql_request( self, method: Literal["get", "post"], - query: Optional[str] = None, + query: str, variables: Optional[dict[str, object]] = None, files: Optional[dict[str, BytesIO]] = None, headers: Optional[dict[str, str]] = None, diff --git a/tests/websockets/views.py b/tests/websockets/views.py index 1f7b2eaee2..981ad96354 100644 --- a/tests/websockets/views.py +++ b/tests/websockets/views.py @@ -3,10 +3,27 @@ from strawberry import UNSET from strawberry.exceptions import ConnectionRejectionError from strawberry.http.async_base_view import AsyncBaseHTTPView +from strawberry.http.typevars import ( + Request, + Response, + SubResponse, + WebSocketRequest, + WebSocketResponse, +) from strawberry.types.unset import UnsetType -class OnWSConnectMixin(AsyncBaseHTTPView): +class OnWSConnectMixin( + AsyncBaseHTTPView[ + Request, + Response, + SubResponse, + WebSocketRequest, + WebSocketResponse, + dict[str, object], + object, + ] +): async def on_ws_connect( self, context: dict[str, object] ) -> Union[UnsetType, None, dict[str, object]]: