diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 8a8c7406..3d523ff6 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -6,6 +6,7 @@ import typing import uuid from contextlib import AsyncExitStack, ExitStack +from unittest.mock import Mock import pytest @@ -13,7 +14,12 @@ from that_depends.entities.resource_context import ResourceContext from that_depends.meta import DefaultScopeNotDefinedError from that_depends.providers import container_context -from that_depends.providers.context_resources import ContextScope, _enter_named_scope, get_current_scope +from that_depends.providers.context_resources import ( + ContextScope, + DIContextMiddleware, + _enter_named_scope, + get_current_scope, +) logger = logging.getLogger(__name__) @@ -797,3 +803,10 @@ def test_container_context_does_not_support_scope_any() -> None: pytest.raises(ValueError, match=f"{ContextScope.ANY} cannot be entered!"), ): container_context(scope=ContextScope.ANY) + + +def test_di_middleware_does_not_support_scope_any() -> None: + with ( + pytest.raises(ValueError, match=f"{ContextScope.ANY} cannot be entered!"), + ): + DIContextMiddleware(Mock(), scope=ContextScope.ANY) diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index b73180d1..33c7f393 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -590,6 +590,9 @@ def __init__( self._context_items: set[SupportsContext[typing.Any]] = set(context_items) self._global_context: dict[str, typing.Any] | None = global_context self._reset_all_containers: bool = reset_all_containers + if scope == ContextScope.ANY: + msg = f"{scope} cannot be entered!" + raise ValueError(msg) self._scope = scope async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -608,8 +611,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ async with ( - container_context(*self._context_items, global_context=self._global_context) + container_context(*self._context_items, global_context=self._global_context, scope=self._scope) if self._context_items - else container_context(global_context=self._global_context, reset_all_containers=self._reset_all_containers) + else container_context( + global_context=self._global_context, reset_all_containers=self._reset_all_containers, scope=self._scope + ) ): return await self.app(scope, receive, send)