Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collect errors more reliably from websocket test client #2814

Merged
merged 17 commits into from
Dec 29, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import enum
import inspect
import io
import json
Expand All @@ -16,10 +17,9 @@
import anyio
import anyio.abc
import anyio.from_thread
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.streams.stapled import StapledObjectStream
graingert marked this conversation as resolved.
Show resolved Hide resolved

from starlette._utils import is_async_callable
from starlette._utils import collapse_excgroups, is_async_callable
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

Expand Down Expand Up @@ -85,6 +85,14 @@ class WebSocketDenialResponse( # type: ignore[misc]
"""


class _Eof(enum.Enum):
EOF = enum.auto()


EOF: typing.Final = _Eof.EOF
Eof = typing.Literal[_Eof.EOF]


class WebSocketTestSession:
def __init__(
self,
Expand All @@ -97,24 +105,24 @@ def __init__(
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: queue.Queue[Message] = queue.Queue()
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
self.extra_headers = None

def __enter__(self) -> WebSocketTestSession:
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(self.portal_factory())
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())

try:
_: Future[None] = self.portal.start_task_soon(self._run)
fut: Future[None] = self.portal.start_task_soon(self._run)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
except Exception:
self.exit_stack.close()
raise
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
return self
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(fut.result)
stack.callback(portal.call, self._notify_close)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self

@cached_property
def should_close(self) -> anyio.Event:
Expand All @@ -124,15 +132,14 @@ async def _notify_close(self) -> None:
self.should_close.set()

def __exit__(self, *args: typing.Any) -> None:
try:
self.close(1000)
finally:
self.portal.start_task_soon(self._notify_close)
self.exit_stack.close()
while not self._send_queue.empty():
self.exit_stack.close()

while True:
message = self._send_queue.get()
if message is EOF:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an analogous to EOF from the standard library on 3.13?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It raises an exception

break
if isinstance(message, BaseException):
raise message
raise message # pragma: no cover (defensive, should be impossible)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it should be impossible?

The except BaseException as exc below doesn't have a pragma: no cover, so I assume it's being hit?

Copy link
Member Author

@graingert graingert Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is impossible because the exit stack will raise the exception out of fut.result() and so the queue won't be consumed.

This is only possible to be hit if ws.receive() is interrupted (eg with a KI) while waiting for an exception or message to be placed on the queue.

I'm currently sketching out another slight refactor that uses MemoryObjectStreams here instead that should clean this up a bit


async def _run(self) -> None:
"""
Expand All @@ -143,17 +150,21 @@ async def run_app(tg: anyio.abc.TaskGroup) -> None:
try:
await self.app(self.scope, self._asgi_receive, self._asgi_send)
except anyio.get_cancelled_exc_class():
...
raise
except BaseException as exc:
self._send_queue.put(exc)
raise
finally:
tg.cancel_scope.cancel()

async with anyio.create_task_group() as tg:
tg.start_soon(run_app, tg)
await self.should_close.wait()
tg.cancel_scope.cancel()
try:
with collapse_excgroups():
async with anyio.create_task_group() as tg:
tg.start_soon(run_app, tg)
await self.should_close.wait()
tg.cancel_scope.cancel()
finally:
self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use the if sys.version_info here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to stick to the EOF approach until someone puts up a Queue.shutdown backport, or we can use a MemoryObjectStream with portal


async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
Expand Down Expand Up @@ -202,6 +213,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None:

def receive(self) -> Message:
message = self._send_queue.get()
assert message is not EOF
if isinstance(message, BaseException):
raise message
return message
Expand Down
Loading