Skip to content

Commit

Permalink
Implemented scoping for DIContextMiddleware.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexanderlazarev0 committed Jan 29, 2025
1 parent 8c7dcd7 commit a370bbf
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
15 changes: 14 additions & 1 deletion tests/providers/test_context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
import typing
import uuid
from contextlib import AsyncExitStack, ExitStack
from unittest.mock import Mock

import pytest

from that_depends import BaseContainer, Provide, fetch_context_item, inject, providers
from that_depends.entities.resource_context import ResourceContext
from that_depends.meta import DefaultScopeNotDefinedError
from that_depends.providers import container_context
from that_depends.providers.context_resources import ContextScope, _enter_named_scope, get_current_scope
from that_depends.providers.context_resources import (
ContextScope,
DIContextMiddleware,
_enter_named_scope,
get_current_scope,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -797,3 +803,10 @@ def test_container_context_does_not_support_scope_any() -> None:
pytest.raises(ValueError, match=f"{ContextScope.ANY} cannot be entered!"),
):
container_context(scope=ContextScope.ANY)


def test_di_middleware_does_not_support_scope_any() -> None:
with (
pytest.raises(ValueError, match=f"{ContextScope.ANY} cannot be entered!"),
):
DIContextMiddleware(Mock(), scope=ContextScope.ANY)
9 changes: 7 additions & 2 deletions that_depends/providers/context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,9 @@ def __init__(
self._context_items: set[SupportsContext[typing.Any]] = set(context_items)
self._global_context: dict[str, typing.Any] | None = global_context
self._reset_all_containers: bool = reset_all_containers
if scope == ContextScope.ANY:
msg = f"{scope} cannot be entered!"
raise ValueError(msg)
self._scope = scope

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand All @@ -608,8 +611,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
async with (
container_context(*self._context_items, global_context=self._global_context)
container_context(*self._context_items, global_context=self._global_context, scope=self._scope)
if self._context_items
else container_context(global_context=self._global_context, reset_all_containers=self._reset_all_containers)
else container_context(
global_context=self._global_context, reset_all_containers=self._reset_all_containers, scope=self._scope
)
):
return await self.app(scope, receive, send)

0 comments on commit a370bbf

Please sign in to comment.