Skip to content

Commit

Permalink
Handle websocket disconnects more gracefully (#3685)
Browse files Browse the repository at this point in the history
  • Loading branch information
DoctorJohn authored Nov 1, 2024
1 parent 3e2b9bf commit c3b8135
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 40 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

Starting with this release, both websocket-based protocols will handle unexpected socket disconnections more gracefully.
6 changes: 5 additions & 1 deletion strawberry/aiohttp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
WebSocketDisconnected,
)
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import (
Expand Down Expand Up @@ -105,7 +106,10 @@ async def iter_json(
raise NonTextMessageReceived()

async def send_json(self, message: Mapping[str, object]) -> None:
await self.ws.send_json(message)
try:
await self.ws.send_json(message)
except RuntimeError as exc:
raise WebSocketDisconnected from exc

async def close(self, code: int, reason: str) -> None:
await self.ws.close(code=code, message=reason.encode())
Expand Down
6 changes: 5 additions & 1 deletion strawberry/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
WebSocketDisconnected,
)
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import (
Expand Down Expand Up @@ -105,7 +106,10 @@ async def iter_json(
pass

async def send_json(self, message: Mapping[str, object]) -> None:
await self.ws.send_json(message)
try:
await self.ws.send_json(message)
except WebSocketDisconnect as exc:
raise WebSocketDisconnected from exc

async def close(self, code: int, reason: str) -> None:
await self.ws.close(code=code, reason=reason)
Expand Down
4 changes: 4 additions & 0 deletions strawberry/http/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@ class NonJsonMessageReceived(Exception):
pass


class WebSocketDisconnected(Exception):
pass


__all__ = ["HTTPException"]
6 changes: 5 additions & 1 deletion strawberry/litestar/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
WebSocketDisconnected,
)
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import Context, RootValue
Expand Down Expand Up @@ -216,7 +217,10 @@ async def iter_json(
pass

async def send_json(self, message: Mapping[str, object]) -> None:
await self.ws.send_json(message)
try:
await self.ws.send_json(message)
except WebSocketDisconnect as exc:
raise WebSocketDisconnected from exc

async def close(self, code: int, reason: str) -> None:
await self.ws.close(code=code, reason=reason)
Expand Down
57 changes: 28 additions & 29 deletions strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
)

from graphql import GraphQLError, GraphQLSyntaxError, parse

from strawberry.http.exceptions import NonJsonMessageReceived, NonTextMessageReceived
from strawberry.http.exceptions import (
NonJsonMessageReceived,
NonTextMessageReceived,
WebSocketDisconnected,
)
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
CompleteMessage,
ConnectionAckMessage,
Expand Down Expand Up @@ -76,12 +79,17 @@ async def handle(self) -> Any:
self.on_request_accepted()

try:
async for message in self.websocket.iter_json():
await self.handle_message(message)
except NonTextMessageReceived:
await self.handle_invalid_message("WebSocket message type must be text")
except NonJsonMessageReceived:
await self.handle_invalid_message("WebSocket message must be valid JSON")
try:
async for message in self.websocket.iter_json():
await self.handle_message(message)
except NonTextMessageReceived:
await self.handle_invalid_message("WebSocket message type must be text")
except NonJsonMessageReceived:
await self.handle_invalid_message(
"WebSocket message must be valid JSON"
)
except WebSocketDisconnected:
pass
finally:
await self.shutdown()

Expand Down Expand Up @@ -127,50 +135,41 @@ async def handle_task_exception(self, error: Exception) -> None: # pragma: no c
self.task_logger.exception("Exception in worker task", exc_info=error)

async def handle_message(self, message: dict) -> None:
handler: Callable
handler_arg: Any
try:
message_type = message.pop("type")

if message_type == ConnectionInitMessage.type:
handler = self.handle_connection_init
handler_arg = ConnectionInitMessage(**message)
await self.handle_connection_init(ConnectionInitMessage(**message))

elif message_type == PingMessage.type:
handler = self.handle_ping
handler_arg = PingMessage(**message)
await self.handle_ping(PingMessage(**message))

elif message_type == PongMessage.type:
handler = self.handle_pong
handler_arg = PongMessage(**message)
await self.handle_pong(PongMessage(**message))

elif message_type == SubscribeMessage.type:
handler = self.handle_subscribe

payload_args = message.pop("payload")

payload = SubscribeMessagePayload(
query=payload_args["query"],
operationName=payload_args.get("operationName"),
variables=payload_args.get("variables"),
extensions=payload_args.get("extensions"),
)
handler_arg = SubscribeMessage(payload=payload, **message)
await self.handle_subscribe(
SubscribeMessage(payload=payload, **message)
)

elif message_type == CompleteMessage.type:
handler = self.handle_complete
handler_arg = CompleteMessage(**message)
await self.handle_complete(CompleteMessage(**message))

else:
handler = self.handle_invalid_message
handler_arg = f"Unknown message type: {message_type}"
error_message = f"Unknown message type: {message_type}"
await self.handle_invalid_message(error_message)

except (KeyError, TypeError):
handler = self.handle_invalid_message
handler_arg = "Failed to parse message"

await handler(handler_arg)
await self.reap_completed_tasks()
await self.handle_invalid_message("Failed to parse message")
finally:
await self.reap_completed_tasks()

async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
if self.connection_timed_out:
Expand Down
19 changes: 12 additions & 7 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
cast,
)

from strawberry.http.exceptions import NonTextMessageReceived
from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected
from strawberry.subscriptions.protocols.graphql_ws import (
GQL_COMPLETE,
GQL_CONNECTION_ACK,
Expand Down Expand Up @@ -65,12 +65,17 @@ def __init__(

async def handle(self) -> None:
try:
async for message in self.websocket.iter_json(ignore_parsing_errors=True):
await self.handle_message(cast(OperationMessage, message))
except NonTextMessageReceived:
await self.websocket.close(
code=1002, reason="WebSocket message type must be text"
)
try:
async for message in self.websocket.iter_json(
ignore_parsing_errors=True
):
await self.handle_message(cast(OperationMessage, message))
except NonTextMessageReceived:
await self.websocket.close(
code=1002, reason="WebSocket message type must be text"
)
except WebSocketDisconnected:
pass
finally:
if self.keep_alive_task:
self.keep_alive_task.cancel()
Expand Down
20 changes: 20 additions & 0 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,23 @@ async def test_no_extensions_results_wont_send_extensions_in_payload(
mock.assert_called_once()
assert_next(response, "sub1", {"echo": "Hi"})
assert "extensions" not in response["payload"]


async def test_unexpected_client_disconnects_are_gracefully_handled(
ws: WebSocketClient,
):
process_errors = Mock()

with patch.object(Schema, "process_errors", process_errors):
await ws.send_json(
SubscribeMessage(
id="sub1",
payload=SubscribeMessagePayload(
query='subscription { echo(message: "Hi", delay: 0.5) }'
),
).as_dict()
)

await ws.close()
await asyncio.sleep(1)
assert not process_errors.called
28 changes: 27 additions & 1 deletion tests/websockets/test_graphql_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
GQL_START,
GQL_STOP,
)
from tests.views.schema import MyExtension
from tests.views.schema import MyExtension, Schema

if TYPE_CHECKING:
from ..http.clients.aiohttp import HttpClient, WebSocketClient
Expand Down Expand Up @@ -630,3 +630,29 @@ async def test_no_extensions_results_wont_send_extensions_in_payload(

await ws.send_json({"type": GQL_STOP, "id": "demo"})
response = await ws.receive_json()


async def test_unexpected_client_disconnects_are_gracefully_handled(
ws_raw: WebSocketClient,
):
ws = ws_raw
process_errors = mock.Mock()

with mock.patch.object(Schema, "process_errors", process_errors):
await ws.send_json({"type": GQL_CONNECTION_INIT})
response = await ws.receive_json()
assert response["type"] == GQL_CONNECTION_ACK

await ws.send_json(
{
"type": GQL_START,
"id": "sub1",
"payload": {
"query": 'subscription { echo(message: "Hi", delay: 0.5) }',
},
}
)

await ws.close()
await asyncio.sleep(1)
assert not process_errors.called

0 comments on commit c3b8135

Please sign in to comment.