From c7f1629e0782dca93dbb0af04f7da72dd7ef0035 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 26 Dec 2024 09:55:23 +0100 Subject: [PATCH 1/5] Raise exception from background task on BaseHTTPMiddleware --- starlette/middleware/base.py | 7 +++---- tests/middleware/test_base.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index f51b13f73..f2eeb4571 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -104,9 +104,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = _CachedRequest(scope, receive) wrapped_receive = request.wrapped_receive response_sent = anyio.Event() + app_exc: Exception | None = None async def call_next(request: Request) -> Response: - app_exc: Exception | None = None send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]] recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] send_stream, recv_stream = anyio.create_memory_object_stream() @@ -175,9 +175,6 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: if not message.get("more_body", False): break - if app_exc is not None: - raise app_exc - response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info) response.raw_headers = message["headers"] return response @@ -187,6 +184,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: response = await self.dispatch_func(request, call_next) await response(scope, wrapped_receive, send) response_sent.set() + if app_exc is not None: + raise app_exc async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: raise NotImplementedError() # pragma: no cover diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 449624e2f..f26f3ad76 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -296,6 +296,29 @@ async def send(message: Message) -> None: assert background_task_run.is_set() +def test_run_background_tasks_raise_exceptions() -> None: + # test for https://github.com/encode/starlette/issues/2625 + + async def sleep_and_set() -> None: + await anyio.sleep(0.1) + raise ValueError("TEST") + + async def endpoint_with_background_task(_: Request) -> PlainTextResponse: + return PlainTextResponse(background=BackgroundTask(sleep_and_set)) + + async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response: + return await call_next(request) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], + routes=[Route("/", endpoint_with_background_task)], + ) + + client = TestClient(app) + with pytest.raises(ValueError, match="TEST"): + client.get("/") + + @pytest.mark.anyio async def test_do_not_block_on_background_tasks() -> None: response_complete = anyio.Event() From 7ef64c9a62e540809b54681215ea35332f8f7fd7 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 26 Dec 2024 09:58:19 +0100 Subject: [PATCH 2/5] Use test client factory --- tests/middleware/test_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index f26f3ad76..ef9a09daa 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -296,7 +296,7 @@ async def send(message: Message) -> None: assert background_task_run.is_set() -def test_run_background_tasks_raise_exceptions() -> None: +def test_run_background_tasks_raise_exceptions(test_client_factory: TestClientFactory) -> None: # test for https://github.com/encode/starlette/issues/2625 async def sleep_and_set() -> None: @@ -314,7 +314,7 @@ async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> R routes=[Route("/", endpoint_with_background_task)], ) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(ValueError, match="TEST"): client.get("/") From a48bcda0af1ce0b7683f1b33ddb61340de1c2e69 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 12:36:51 +0000 Subject: [PATCH 3/5] Update starlette/middleware/base.py --- starlette/middleware/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 8325763c7..b813c0c8d 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -106,7 +106,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: app_exc: Exception | None = None async def call_next(request: Request) -> Response: - app_exc: Exception | None = None async def receive_or_disconnect() -> Message: if response_sent.is_set(): From 069ad2e1ed9cb1800850f981d161534cd5ea140f Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 12:39:14 +0000 Subject: [PATCH 4/5] lint --- starlette/middleware/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index b813c0c8d..36a514f21 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -106,7 +106,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: app_exc: Exception | None = None async def call_next(request: Request) -> Response: - async def receive_or_disconnect() -> Message: if response_sent.is_set(): return {"type": "http.disconnect"} From 558b389fc962cc9789c682ecdb7f21ce84a75ea7 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 12:46:15 +0000 Subject: [PATCH 5/5] collapse exceptions groups from streaming response --- starlette/responses.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index 31874f655..c522e7f23 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -18,6 +18,7 @@ import anyio import anyio.to_thread +from starlette._utils import collapse_excgroups from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders @@ -258,14 +259,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: except OSError: raise ClientDisconnect() else: - async with anyio.create_task_group() as task_group: + with collapse_excgroups(): + async with anyio.create_task_group() as task_group: - async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: - await func() - task_group.cancel_scope.cancel() + async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() - task_group.start_soon(wrap, partial(self.stream_response, send)) - await wrap(partial(self.listen_for_disconnect, receive)) + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background()