From cdde09c00b4d54218aa9892d391c6d000d6a7e21 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 25 Oct 2024 23:01:44 +0200 Subject: [PATCH 01/24] Fixed some context-dependent tests --- test.py | 195 +++++++++++ tests/providers/test_context_resources.py | 4 +- that_depends/container.py | 4 +- that_depends/meta.py | 19 ++ that_depends/providers/__init__.py | 2 +- .../{collections.py => collection.py} | 0 that_depends/providers/context_resources.py | 314 +++++++++++++----- 7 files changed, 441 insertions(+), 97 deletions(-) create mode 100644 test.py create mode 100644 that_depends/meta.py rename that_depends/providers/{collections.py => collection.py} (100%) diff --git a/test.py b/test.py new file mode 100644 index 00000000..6c668c25 --- /dev/null +++ b/test.py @@ -0,0 +1,195 @@ +import asyncio +import random +import typing + +import pytest + +from that_depends.container import BaseContainer +from that_depends.injection import Provide, inject +from that_depends.providers.context_resources import ContextResource, container_context, fetch_context_item + + +random.seed(1) + + +async def async_yields_string() -> typing.AsyncIterator[str]: + yield str(random.random()) # noqa: S311 + + +def sync_yields_string() -> typing.Iterator[str]: + yield str(random.random()) # noqa: S311 + + +class MyContainer(BaseContainer): + async_resource: ContextResource[str] = ContextResource(async_yields_string) + sync_resource: ContextResource[str] = ContextResource(sync_yields_string) + + +@MyContainer.sync_resource.sync_context() +@inject +def sync_injected(val: str = Provide[MyContainer.sync_resource]) -> str: + return val + + +@MyContainer.async_resource.async_context() +@inject +async def async_injected(val: str = Provide[MyContainer.async_resource]) -> str: + return val + + +@MyContainer.async_resource.context() +@inject +async def async_injected_implicit(val: str = Provide[MyContainer.async_resource]) -> str: + return val + + +@MyContainer.sync_resource.context() +@inject +def sync_injected_implicit(val: str = Provide[MyContainer.sync_resource]) -> str: + return val + + +async def test_injected() -> None: + async_result = await async_injected() + sync_result = sync_injected() + async_result_implicit = await async_injected_implicit() + sync_result_implicit = sync_injected_implicit() + assert isinstance(async_result, str) + assert isinstance(sync_result, str) + assert isinstance(async_result_implicit, str) + assert isinstance(sync_result_implicit, str) + + +async def async_main() -> None: + """Test async resolution.""" + async with MyContainer.async_resource.async_context(): + val_1 = await MyContainer.async_resource.async_resolve() + val_2 = await MyContainer.async_resource.async_resolve() + assert val_1 == val_2 + async with MyContainer.async_resource.async_context(): + val_3 = await MyContainer.async_resource.async_resolve() + assert val_1 != val_3 + async with MyContainer.async_resource.async_context(): + val_4 = await MyContainer.async_resource.async_resolve() + assert val_1 != val_4 != val_3 + val_5 = await MyContainer.async_resource.async_resolve() + assert val_5 == val_3 + val_6 = await MyContainer.async_resource.async_resolve() + assert val_6 == val_1 + + +def sync_main() -> None: + """Test sync resolution.""" + with MyContainer.sync_resource.sync_context(): + val_1 = MyContainer.sync_resource.sync_resolve() + val_2 = MyContainer.sync_resource.sync_resolve() + assert val_1 == val_2 + with MyContainer.sync_resource.sync_context(): + val_3 = MyContainer.sync_resource.sync_resolve() + assert val_1 != val_3 + with MyContainer.sync_resource.sync_context(): + val_4 = MyContainer.sync_resource.sync_resolve() + assert val_1 != val_4 != val_3 + val_5 = MyContainer.sync_resource.sync_resolve() + assert val_5 == val_1 + + +def check_sync_container_context() -> None: + """Test sync provider resolution container_context.""" + with container_context(providers=[MyContainer.sync_resource]): + val_1 = MyContainer.sync_resource.sync_resolve() + val_2 = MyContainer.sync_resource.sync_resolve() + assert val_1 == val_2 + with container_context(providers=[MyContainer.sync_resource]): + val_3 = MyContainer.sync_resource.sync_resolve() + assert val_3 != val_1 + + val_4 = MyContainer.sync_resource.sync_resolve() + assert val_4 == val_1 + with pytest.raises(RuntimeError): + MyContainer.sync_resource.sync_resolve() + + +async def check_async_container_context() -> None: + """Test async provider resolution container_context.""" + async with container_context(providers=[MyContainer.async_resource]): + val_1 = await MyContainer.async_resource.async_resolve() + val_2 = await MyContainer.async_resource.async_resolve() + assert val_1 == val_2 + async with container_context(providers=[MyContainer.async_resource]): + val_3 = await MyContainer.async_resource.async_resolve() + assert val_3 != val_1 + + val_4 = await MyContainer.async_resource.async_resolve() + assert val_4 == val_1 + with pytest.raises(RuntimeError): + await MyContainer.async_resource.async_resolve() + + +async def check_async_global_passing() -> None: + with pytest.raises(RuntimeError): + # TODO: This perhaps should not throw an exception because to just set an empty dict if no context was entered before. + async with container_context(preserve_globals=True) as gs: + assert gs + my_global_resources = {"test_1": "test_1", "test_2": "test_2"} + + async with container_context(initial_context=my_global_resources): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + async with container_context(preserve_globals=True): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + async with container_context(preserve_globals=False): + for key in my_global_resources: + assert fetch_context_item(key) is None + + +def check_sync_global_passing() -> None: + with pytest.raises(RuntimeError): + # TODO: This perhaps should not throw an exception because to just set an empty dict if no context was entered before. + with container_context(preserve_globals=True) as gs: + assert gs + my_global_resources = {"test_1": "test_1", "test_2": "test_2"} + + with container_context(initial_context=my_global_resources): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + with container_context(preserve_globals=True): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + with container_context(preserve_globals=False): + for key in my_global_resources: + assert fetch_context_item(key) is None + + +async def test_reset_context_async() -> None: + async with container_context(): + val_1 = await MyContainer.async_resource.async_resolve() + + async with container_context(): + val_2 = await MyContainer.async_resource.async_resolve() + assert val_1 != val_2 + + +def test_reset_context_sync() -> None: + with container_context(): + val_1 = MyContainer.sync_resource.sync_resolve() + with container_context(): + val_2 = MyContainer.sync_resource.sync_resolve() + assert val_1 != val_2 + + +if __name__ == "__main__": + asyncio.run(async_main()) + sync_main() + check_sync_container_context() + asyncio.run(check_async_container_context()) + asyncio.run(test_injected()) + asyncio.run(check_async_global_passing()) + check_sync_global_passing() + asyncio.run(test_reset_context_async()) + test_reset_context_sync() diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 42dba8a0..cb8b0750 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -149,9 +149,9 @@ async def test_context_resource_with_dynamic_resource() -> None: async def test_early_exit_of_container_context() -> None: - with pytest.raises(RuntimeError, match="Context is not set, call ``__aenter__`` first"): + with pytest.raises(RuntimeError, match="No context token set for global vars, use __enter__ or __aenter__ first."): await container_context().__aexit__(None, None, None) - with pytest.raises(RuntimeError, match="Context is not set, call ``__enter__`` first"): + with pytest.raises(RuntimeError, match="No context token set for global vars, use __enter__ or __aenter__ first."): container_context().__exit__(None, None, None) diff --git a/that_depends/container.py b/that_depends/container.py index 30e183e4..450199b1 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -3,6 +3,7 @@ import warnings from contextlib import contextmanager +from that_depends.meta import BaseContainerMeta from that_depends.providers import AbstractProvider, Resource, Singleton @@ -14,7 +15,7 @@ P = typing.ParamSpec("P") -class BaseContainer: +class BaseContainer(metaclass=BaseContainerMeta): providers: dict[str, AbstractProvider[typing.Any]] containers: list[type["BaseContainer"]] @@ -38,7 +39,6 @@ def connect_containers(cls, *containers: type["BaseContainer"]) -> None: def get_providers(cls) -> dict[str, AbstractProvider[typing.Any]]: if not hasattr(cls, "providers"): cls.providers = {k: v for k, v in cls.__dict__.items() if isinstance(v, AbstractProvider)} - return cls.providers @classmethod diff --git a/that_depends/meta.py b/that_depends/meta.py new file mode 100644 index 00000000..9a4672ce --- /dev/null +++ b/that_depends/meta.py @@ -0,0 +1,19 @@ +import typing + + +if typing.TYPE_CHECKING: + from that_depends.container import BaseContainer + + +class BaseContainerMeta(type): + instances: typing.ClassVar[list[type["BaseContainer"]]] = [] + + def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]) -> type: + new_cls = super().__new__(cls, name, bases, namespace) + if name != "BaseContainer": + cls.instances.append(new_cls) # type: ignore[arg-type] + return new_cls + + @classmethod + def get_instances(cls) -> list[type["BaseContainer"]]: + return cls.instances diff --git a/that_depends/providers/__init__.py b/that_depends/providers/__init__.py index c9f7b866..97733b54 100644 --- a/that_depends/providers/__init__.py +++ b/that_depends/providers/__init__.py @@ -1,5 +1,5 @@ from that_depends.providers.base import AbstractProvider, AttrGetter -from that_depends.providers.collections import Dict, List +from that_depends.providers.collection import Dict, List from that_depends.providers.context_resources import ( AsyncContextResource, ContextResource, diff --git a/that_depends/providers/collections.py b/that_depends/providers/collection.py similarity index 100% rename from that_depends/providers/collections.py rename to that_depends/providers/collection.py diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 3ec061ef..ee272d9e 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -1,16 +1,23 @@ +import contextlib import inspect import logging import typing -import uuid import warnings from contextlib import AbstractAsyncContextManager, AbstractContextManager from contextvars import ContextVar, Token from functools import wraps from types import TracebackType +from typing_extensions import TypeIs + +from that_depends.meta import BaseContainerMeta from that_depends.providers.base import AbstractResource, ResourceContext +if typing.TYPE_CHECKING: + from that_depends.container import BaseContainer + + logger: typing.Final = logging.getLogger(__name__) T_co = typing.TypeVar("T_co", covariant=True) P = typing.ParamSpec("P") @@ -26,145 +33,259 @@ ContextType = dict[str, typing.Any] -class container_context( # noqa: N801 - AbstractAsyncContextManager[ContextType], AbstractContextManager[ContextType] -): - """Manage the context of ContextResources. +def _get_container_context() -> dict[str, typing.Any]: + try: + return _CONTAINER_CONTEXT.get() + except LookupError as exc: + msg = "Context is not set. Use container_context" + raise RuntimeError(msg) from exc + + +def fetch_context_item(key: str, default: typing.Any = None) -> typing.Any: # noqa: ANN401 + return _get_container_context().get(key, default) - Can be entered using ``async with container_context()`` or with ``with container_context()`` - as async-context-manager or context-manager respectively. - When used as async-context-manager, it will allow setup & teardown of both sync and async resources. - When used as sync-context-manager, it will only allow setup & teardown of sync resources. - """ - __slots__ = "_initial_context", "_context_token" +T = typing.TypeVar("T") - def __init__(self, initial_context: ContextType | None = None) -> None: - self._initial_context: ContextType = initial_context or {} - self._context_token: Token[ContextType] | None = None - def __enter__(self) -> ContextType: - self._initial_context[_ASYNC_CONTEXT_KEY] = False +class ContextResource( + AbstractResource[T_co], + AbstractAsyncContextManager[ResourceContext[T_co]], + AbstractContextManager[ResourceContext[T_co]], +): + __slots__ = ( + "_is_async", + "_creator", + "_args", + "_kwargs", + "_override", + "_internal_name", + "_context", + "_token", + ) + + def __init__( + self, + creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + super().__init__(creator, *args, **kwargs) + self._context: ContextVar[ResourceContext[T_co]] = ContextVar(f"{self._creator.__name__}-context") + self._token: Token[ResourceContext[T_co]] | None = None + + def __enter__(self) -> ResourceContext[T_co]: + if self._is_creator_async(self._creator): + msg = "You must enter async context for async creators." + raise RuntimeError(msg) return self._enter() - async def __aenter__(self) -> ContextType: - self._initial_context[_ASYNC_CONTEXT_KEY] = True + async def __aenter__(self) -> ResourceContext[T_co]: return self._enter() - def _enter(self) -> ContextType: - self._context_token = _CONTAINER_CONTEXT.set({**self._initial_context}) - return _CONTAINER_CONTEXT.get() + def _enter(self) -> ResourceContext[T_co]: + self._token = self._context.set(ResourceContext(is_async=self._is_creator_async(self._creator))) + return self._context.get() def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None ) -> None: - if self._context_token is None: + if not self._token: msg = "Context is not set, call ``__enter__`` first" raise RuntimeError(msg) try: - for context_item in reversed(_CONTAINER_CONTEXT.get().values()): - if isinstance(context_item, ResourceContext): - # we don't need to handle the case where the ResourceContext is async - context_item.sync_tear_down() + context_item = self._context.get() + context_item.sync_tear_down() finally: - _CONTAINER_CONTEXT.reset(self._context_token) + self._context.reset(self._token) async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, traceback: TracebackType | None ) -> None: - if self._context_token is None: + if self._token is None: msg = "Context is not set, call ``__aenter__`` first" raise RuntimeError(msg) try: - for context_item in reversed(_CONTAINER_CONTEXT.get().values()): - if not isinstance(context_item, ResourceContext): - continue - - if context_item.is_context_stack_async(context_item.context_stack): - await context_item.tear_down() - else: - context_item.sync_tear_down() + context_item = self._context.get() + if context_item.is_context_stack_async(context_item.context_stack): + await context_item.tear_down() + else: + context_item.sync_tear_down() finally: - _CONTAINER_CONTEXT.reset(self._context_token) + self._context.reset(self._token) - def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]: - if inspect.iscoroutinefunction(func): + @contextlib.contextmanager + def sync_context(self) -> typing.Iterator[ResourceContext[T_co]]: + if self._is_creator_async(self._creator): + msg = "Please use async context instead." + raise RuntimeError(msg) + token = self._token + with self as val: + yield val + self._token = token + + @contextlib.asynccontextmanager + async def async_context(self) -> typing.AsyncIterator[ResourceContext[T_co]]: + token = self._token + async with self as val: + yield val + self._token = token + + def context( + self, + ) -> typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: + """Create a new context manager for the resource, the context manager will be async if the resource is async. - @wraps(func) - async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: - async with container_context(self._initial_context): - return await func(*args, **kwargs) # type: ignore[no-any-return] + :return: A context manager for the resource. + :rtype: typing.ContextManager[ResourceContext[T_co]] | typing.AsyncContextManager[ResourceContext[T_co]] + """ + if self._is_creator_async(self._creator): + return typing.cast(typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]], self.async_context()) + return typing.cast(typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]], self.sync_context()) - return typing.cast(typing.Callable[P, T_co], _async_inner) + @property + def is_async(self) -> bool: + return self._is_creator_async(self._creator) - @wraps(func) - def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: - with container_context(self._initial_context): - return func(*args, **kwargs) + def _fetch_context(self) -> ResourceContext[T_co]: + try: + return self._context.get() + except LookupError as e: + msg = "Context is not set. Use container_context" + raise RuntimeError(msg) from e - return _sync_inner +ContainerType = typing.TypeVar("ContainerType", bound="type[BaseContainer]") -class DIContextMiddleware: - def __init__(self, app: ASGIApp) -> None: - self.app: typing.Final = app - - @container_context() - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - return await self.app(scope, receive, send) +class container_context(AbstractContextManager[ContextType], AbstractAsyncContextManager[ContextType]): # noqa: N801 + ___slots__ = ( + "_providers", + "_context_stack", + "_containers", + "_initial_context", + "_context_token", + "_reset_resource_context", + ) -def _get_container_context() -> dict[str, typing.Any]: - try: - return _CONTAINER_CONTEXT.get() - except LookupError as exc: - msg = "Context is not set. Use container_context" - raise RuntimeError(msg) from exc + def __init__( + self, + initial_context: ContextType | None = None, + providers: list[ContextResource[typing.Any]] | None = None, + containers: list[ContainerType] | None = None, + preserve_globals: bool = False, + reset_resource_context: bool = False + ) -> None: + if preserve_globals and initial_context: + self._initial_context = {**_get_container_context(), **initial_context} + else: + self._initial_context: ContextType = _get_container_context() if preserve_globals else initial_context or {} # type: ignore[no-redef] + self._context_token: Token[ContextType] | None = None + self._providers: set[ContextResource[typing.Any]] = set() + self._reset_resource_context: typing.Final[bool] = (not containers and not providers) or reset_resource_context + if providers: + for provider in providers: + if isinstance(provider, ContextResource): + self._providers.add(provider) + else: + msg = "Provider is not a ContextResource" + raise TypeError(msg) + if containers: + self._add_providers_from_containers(containers) + if self._reset_resource_context: + self._add_providers_from_containers(BaseContainerMeta.get_instances()) + self._context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None -def _is_container_context_async() -> bool: - """Check if the current container context is async. + def _add_providers_from_containers(self, containers: list[ContainerType]) -> None: + for container in containers: + for container_provider in container.get_providers().values(): + if isinstance(container_provider, ContextResource): + self._providers.add(container_provider) - :return: Whether the current container context is async. - :rtype: bool - """ - return typing.cast(bool, _get_container_context().get(_ASYNC_CONTEXT_KEY, False)) + def __enter__(self) -> ContextType: + self._context_stack = contextlib.ExitStack() + for provider in self._providers: + if self._reset_resource_context: + if not provider.is_async: + self._context_stack.enter_context(provider.sync_context()) + else: + self._context_stack.enter_context(provider.sync_context()) + return self._enter_globals() + async def __aenter__(self) -> ContextType: + self._context_stack = contextlib.AsyncExitStack() + for provider in self._providers: + await self._context_stack.enter_async_context(provider.async_context()) + return self._enter_globals() -def fetch_context_item(key: str, default: typing.Any = None) -> typing.Any: # noqa: ANN401 - return _get_container_context().get(key, default) + def _enter_globals(self) -> ContextType: + self._context_token = _CONTAINER_CONTEXT.set(self._initial_context) + return _CONTAINER_CONTEXT.get() + def _is_context_token(self, _: Token[ContextType] | None) -> TypeIs[Token[ContextType]]: + return isinstance(_, Token) -class ContextResource(AbstractResource[T_co]): - __slots__ = ( - "_is_async", - "_creator", - "_args", - "_kwargs", - "_override", - "_internal_name", - ) + def _exit_globals(self) -> None: + if self._is_context_token(self._context_token): + return _CONTAINER_CONTEXT.reset(self._context_token) + msg = "No context token set for global vars, use __enter__ or __aenter__ first." + raise RuntimeError(msg) - def __init__( + def _has_async_exit_stack( self, - creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]], - *args: P.args, - **kwargs: P.kwargs, + _: contextlib.AsyncExitStack | contextlib.ExitStack | None, + ) -> typing.TypeGuard[contextlib.AsyncExitStack]: + return isinstance(_, contextlib.AsyncExitStack) + + def _has_sync_exit_stack( + self, _: contextlib.AsyncExitStack | contextlib.ExitStack | None + ) -> typing.TypeGuard[contextlib.ExitStack]: + return isinstance(_, contextlib.ExitStack) + + def __exit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None ) -> None: - super().__init__(creator, *args, **kwargs) - self._internal_name: typing.Final = f"{creator.__name__}-{uuid.uuid4()}" + try: + if self._has_sync_exit_stack(self._context_stack): + self._context_stack.close() + else: + msg = "Context is not set, call ``__enter__`` first" + raise RuntimeError(msg) + finally: + self._exit_globals() - def _fetch_context(self) -> ResourceContext[T_co]: - container_context = _get_container_context() - if resource_context := container_context.get(self._internal_name): - return typing.cast(ResourceContext[T_co], resource_context) + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, traceback: TracebackType | None + ) -> None: + try: + if self._has_async_exit_stack(self._context_stack): + await self._context_stack.aclose() + else: + msg = "Context is not set, call ``__aenter__`` first" + raise RuntimeError(msg) + finally: + self._exit_globals() - resource_context = ResourceContext(is_async=_is_container_context_async()) - container_context[self._internal_name] = resource_context - return resource_context + def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]: + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: + async with container_context(providers=list(self._providers), reset_resource_context=self._reset_resource_context): + return await func(*args, **kwargs) # type: ignore[no-any-return] + + return typing.cast(typing.Callable[P, T_co], _async_inner) + + @wraps(func) + def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: + with container_context(providers=list(self._providers), reset_resource_context=self._reset_resource_context): + return func(*args, **kwargs) + + return _sync_inner class AsyncContextResource(ContextResource[T_co]): @@ -176,3 +297,12 @@ def __init__( ) -> None: warnings.warn("AsyncContextResource is deprecated, use ContextResource instead", RuntimeWarning, stacklevel=1) super().__init__(creator, *args, **kwargs) + + +class DIContextMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app: typing.Final = app + + @container_context() + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + return await self.app(scope, receive, send) From c66aa9bc14c2bb350353c77f61b8aa25c20dfe7e Mon Sep 17 00:00:00 2001 From: Alexander Date: Sun, 10 Nov 2024 00:51:50 +0100 Subject: [PATCH 02/24] Fixed existing pytests. --- test.py | 16 +++++++++ tests/providers/test_attr_getter.py | 40 ++++++++++++++------- tests/providers/test_context_resources.py | 2 +- that_depends/providers/base.py | 12 +++---- that_depends/providers/context_resources.py | 17 ++++++--- that_depends/providers/resources.py | 4 +-- 6 files changed, 65 insertions(+), 26 deletions(-) diff --git a/test.py b/test.py index f90825de..ff491481 100644 --- a/test.py +++ b/test.py @@ -3,6 +3,7 @@ import typing import pytest +from pydantic import BaseModel from that_depends.container import BaseContainer from that_depends.injection import Provide, inject @@ -12,6 +13,10 @@ random.seed(1) +class Config(BaseModel): + some_str_value: str = "some_string_value" + + async def async_yields_string() -> typing.AsyncIterator[str]: yield str(random.random()) # noqa: S311 @@ -20,9 +25,14 @@ def sync_yields_string() -> typing.Iterator[str]: yield str(random.random()) # noqa: S311 +def sync_yields_config() -> typing.Iterator[Config]: + yield Config() + + class MyContainer(BaseContainer): async_resource: ContextResource[str] = ContextResource(async_yields_string) sync_resource: ContextResource[str] = ContextResource(sync_yields_string) + sync_config: ContextResource[Config] = ContextResource(sync_yields_config) @MyContainer.sync_resource.sync_context() @@ -180,6 +190,11 @@ def test_reset_context_sync() -> None: assert val_1 != val_2 +@container_context() +def test_attr_getter_sync() -> None: + assert MyContainer.sync_config.sync_resolve().some_str_value + + if __name__ == "__main__": asyncio.run(async_main()) sync_main() @@ -190,3 +205,4 @@ def test_reset_context_sync() -> None: check_sync_global_passing() asyncio.run(test_reset_context_async()) test_reset_context_sync() + test_attr_getter_sync() diff --git a/tests/providers/test_attr_getter.py b/tests/providers/test_attr_getter.py index b09deda2..1af8991d 100644 --- a/tests/providers/test_attr_getter.py +++ b/tests/providers/test_attr_getter.py @@ -68,36 +68,52 @@ def some_async_settings_provider(request: pytest.FixtureRequest) -> providers.Ab return typing.cast(providers.AbstractProvider[Settings], request.param) -@container_context() def test_attr_getter_with_zero_attribute_depth_sync( some_sync_settings_provider: providers.AbstractProvider[Settings], ) -> None: - attr_getter = some_sync_settings_provider.some_str_value - assert attr_getter.sync_resolve() == Settings().some_str_value + with container_context( + providers=[some_sync_settings_provider] + if isinstance(some_sync_settings_provider, providers.ContextResource) + else [] + ): + attr_getter = some_sync_settings_provider.some_str_value + assert attr_getter.sync_resolve() == Settings().some_str_value -@container_context() async def test_attr_getter_with_zero_attribute_depth_async( some_async_settings_provider: providers.AbstractProvider[Settings], ) -> None: - attr_getter = some_async_settings_provider.some_str_value - assert await attr_getter.async_resolve() == Settings().some_str_value + async with container_context( + providers=[some_async_settings_provider] + if isinstance(some_async_settings_provider, providers.ContextResource) + else [] + ): + attr_getter = some_async_settings_provider.some_str_value + assert await attr_getter.async_resolve() == Settings().some_str_value -@container_context() def test_attr_getter_with_more_than_zero_attribute_depth_sync( some_sync_settings_provider: providers.AbstractProvider[Settings], ) -> None: - attr_getter = some_sync_settings_provider.nested1_attr.nested2_attr.some_const - assert attr_getter.sync_resolve() == Nested2().some_const + with container_context( + providers=[some_sync_settings_provider] + if isinstance(some_sync_settings_provider, providers.ContextResource) + else [] + ): + attr_getter = some_sync_settings_provider.nested1_attr.nested2_attr.some_const + assert attr_getter.sync_resolve() == Nested2().some_const -@container_context() async def test_attr_getter_with_more_than_zero_attribute_depth_async( some_async_settings_provider: providers.AbstractProvider[Settings], ) -> None: - attr_getter = some_async_settings_provider.nested1_attr.nested2_attr.some_const - assert await attr_getter.async_resolve() == Nested2().some_const + async with container_context( + providers=[some_async_settings_provider] + if isinstance(some_async_settings_provider, providers.ContextResource) + else [] + ): + attr_getter = some_async_settings_provider.nested1_attr.nested2_attr.some_const + assert await attr_getter.async_resolve() == Nested2().some_const @pytest.mark.parametrize( diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 8629db1c..cfd6e9be 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -84,7 +84,7 @@ def test_sync_context_resource(sync_context_resource: providers.ContextResource[ async def test_async_context_resource_in_sync_context(async_context_resource: providers.ContextResource[str]) -> None: - with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved in an sync context."), container_context(): + with pytest.raises(RuntimeError, match="Context is not set. Use container_context"), container_context(): await async_context_resource() diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index aafe2046..216b0d03 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -82,16 +82,16 @@ def __init__( super().__init__() self._creator: typing.Any if inspect.isasyncgenfunction(creator): - self._is_async = True + self.is_async = True self._creator = contextlib.asynccontextmanager(creator) elif inspect.isgeneratorfunction(creator): - self._is_async = False + self.is_async = False self._creator = contextlib.contextmanager(creator) elif isinstance(creator, type) and issubclass(creator, typing.AsyncContextManager): - self._is_async = True + self.is_async = True self._creator = creator elif isinstance(creator, type) and issubclass(creator, typing.ContextManager): - self._is_async = False + self.is_async = False self._creator = creator else: msg = "Unsupported resource type" @@ -112,7 +112,7 @@ async def async_resolve(self) -> T_co: if context.instance is not None: return context.instance - if not context.is_async and self._is_async: + if not context.is_async and self.is_async: msg = "AsyncResource cannot be resolved in an sync context." raise RuntimeError(msg) @@ -150,7 +150,7 @@ def sync_resolve(self) -> T_co: if context.instance is not None: return context.instance - if self._is_async: + if self.is_async: msg = "AsyncResource cannot be resolved synchronously" raise RuntimeError(msg) diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 46ea61e1..e7eccd96 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -55,7 +55,7 @@ class ContextResource( AbstractContextManager[ResourceContext[T_co]], ): __slots__ = ( - "_is_async", + "is_async", "_creator", "_args", "_kwargs", @@ -65,6 +65,9 @@ class ContextResource( "_token", ) + def __repr__(self) -> str: + return f"ContextResource({self._creator.__name__})" + def __init__( self, creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]], @@ -147,10 +150,6 @@ def context( return typing.cast(typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]], self.async_context()) return typing.cast(typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]], self.sync_context()) - @property - def is_async(self) -> bool: - return inspect.iscoroutinefunction(self._creator) - def _fetch_context(self) -> ResourceContext[T_co]: try: return self._context.get() @@ -180,6 +179,14 @@ def __init__( preserve_globals: bool = False, reset_resource_context: bool = False, ) -> None: + """Initialize a container context. + + :param initial_context: existing context to use + :param providers: providers to reset context of. + :param containers: containers to reset context of. + :param preserve_globals: whether to preserve global context vars. + :param reset_resource_context: whether to reset resource context. + """ if preserve_globals and initial_context: self._initial_context = {**_get_container_context(), **initial_context} else: diff --git a/that_depends/providers/resources.py b/that_depends/providers/resources.py index 5d7393bc..70b38aa8 100644 --- a/that_depends/providers/resources.py +++ b/that_depends/providers/resources.py @@ -11,7 +11,7 @@ class Resource(AbstractResource[T_co]): __slots__ = ( - "_is_async", + "is_async", "_creator", "_args", "_kwargs", @@ -26,7 +26,7 @@ def __init__( **kwargs: P.kwargs, ) -> None: super().__init__(creator, *args, **kwargs) - self._context: typing.Final[ResourceContext[T_co]] = ResourceContext(is_async=self._is_async) + self._context: typing.Final[ResourceContext[T_co]] = ResourceContext(is_async=self.is_async) def _fetch_context(self) -> ResourceContext[T_co]: return self._context From 32477fee5c95430af6486d5bdccd43ec00eb8ad4 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 6 Jan 2025 18:15:25 +0100 Subject: [PATCH 03/24] Updated ContextResource.context to only be usable as a wrapper. --- test.py | 19 +++++- tests/providers/test_context_resources.py | 65 ++++++++++++++++++++- that_depends/providers/context_resources.py | 23 ++++++-- 3 files changed, 97 insertions(+), 10 deletions(-) diff --git a/test.py b/test.py index ff491481..353535af 100644 --- a/test.py +++ b/test.py @@ -47,13 +47,13 @@ async def async_injected(val: str = Provide[MyContainer.async_resource]) -> str: return val -@MyContainer.async_resource.context() +@MyContainer.async_resource.context @inject async def async_injected_implicit(val: str = Provide[MyContainer.async_resource]) -> str: return val -@MyContainer.sync_resource.context() +@MyContainer.sync_resource.context @inject def sync_injected_implicit(val: str = Provide[MyContainer.sync_resource]) -> str: return val @@ -195,6 +195,20 @@ def test_attr_getter_sync() -> None: assert MyContainer.sync_config.sync_resolve().some_str_value +async def test_inject_sync_into_async() -> None: + @MyContainer.async_resource.context + async def _inner() -> str: + return await MyContainer.async_resource.async_resolve() + + @MyContainer.sync_resource.context + async def _sync_injected() -> str: + return MyContainer.sync_resource.sync_resolve() + + value_1 = await _inner() + value_2 = await _sync_injected() + assert value_1 != value_2 + + if __name__ == "__main__": asyncio.run(async_main()) sync_main() @@ -206,3 +220,4 @@ def test_attr_getter_sync() -> None: asyncio.run(test_reset_context_async()) test_reset_context_sync() test_attr_getter_sync() + asyncio.run(test_inject_sync_into_async()) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index c1e1b508..ea035431 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -3,7 +3,7 @@ import logging import typing import uuid -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, ExitStack import pytest @@ -192,6 +192,47 @@ async def some_injected(depth: int, val: str = Provide[DIContainer.async_context await some_injected(1) +async def test_async_injection_when_resetting_resource_specific_context( + async_context_resource: providers.ContextResource[str], +) -> None: + """Async context resources should be able to reset the context for themselves.""" + + @async_context_resource.context + @inject + async def _async_injected(val: str = Provide[async_context_resource]) -> str: + assert isinstance(async_context_resource._fetch_context().context_stack, AsyncExitStack) # noqa: SLF001 + return val + + async_result = await _async_injected() + assert async_result != await _async_injected() + assert isinstance(async_result, str) + + +async def test_sync_injection_when_resetting_resource_specific_context( + sync_context_resource: providers.ContextResource[str], +) -> None: + """Sync context resources should be able to reset the context for themselves.""" + + @sync_context_resource.context + @inject + async def _async_injected(val: str = Provide[sync_context_resource]) -> str: + assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001 + return val + + @sync_context_resource.context + @inject + def _sync_injected(val: str = Provide[sync_context_resource]) -> str: + assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001 + return val + + async_result = await _async_injected() + assert async_result != await _async_injected() + assert isinstance(async_result, str) + sync_result = _sync_injected() + assert sync_result != _sync_injected() + assert isinstance(sync_result, str) + + @pytest.mark.repeat(10) async def test_async_context_resource_asyncio_concurrency() -> None: calls: int = 0 @@ -207,7 +248,27 @@ async def create_client() -> typing.AsyncIterator[str]: async def resolve_resource() -> str: return await resource.async_resolve() - async with container_context(): + async with resource.async_context(): + await asyncio.gather(resolve_resource(), resolve_resource()) + + assert calls == 1 + + +@pytest.mark.repeat(10) +async def test_sync_context_resource_asyncio_concurrency() -> None: + calls: int = 0 + + def create_client() -> typing.Iterator[str]: + nonlocal calls + calls += 1 + yield "" + + resource = providers.ContextResource(create_client) + + async def resolve_resource() -> str: + return resource.sync_resolve() + + with resource.sync_context(): await asyncio.gather(resolve_resource(), resolve_resource()) assert calls == 1 diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index e7eccd96..805a9dfc 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -138,17 +138,28 @@ async def async_context(self) -> typing.AsyncIterator[ResourceContext[T_co]]: yield val self._token = token - def context( - self, - ) -> typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]]: + def context(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]: """Create a new context manager for the resource, the context manager will be async if the resource is async. :return: A context manager for the resource. :rtype: typing.ContextManager[ResourceContext[T_co]] | typing.AsyncContextManager[ResourceContext[T_co]] """ - if self.is_async: - return typing.cast(typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]], self.async_context()) - return typing.cast(typing.Callable[[typing.Callable[P, T]], typing.Callable[P, T]], self.sync_context()) + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + async with self.async_context(): + return await func(*args, **kwargs) # type: ignore[no-any-return] + + return typing.cast(typing.Callable[P, T], _async_wrapper) + + # wrapped function is sync + @wraps(func) + def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + with self.sync_context(): + return func(*args, **kwargs) + + return typing.cast(typing.Callable[P, T], _sync_wrapper) def _fetch_context(self) -> ResourceContext[T_co]: try: From 9720f0fa70f5470c9d9b486dd391723fd9d31051 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 6 Jan 2025 21:53:21 +0100 Subject: [PATCH 04/24] Made BaseContainerMeta thread safe. --- that_depends/meta.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/that_depends/meta.py b/that_depends/meta.py index 9a4672ce..db5e8aa7 100644 --- a/that_depends/meta.py +++ b/that_depends/meta.py @@ -1,4 +1,5 @@ import typing +from threading import Lock if typing.TYPE_CHECKING: @@ -6,14 +7,17 @@ class BaseContainerMeta(type): - instances: typing.ClassVar[list[type["BaseContainer"]]] = [] + _instances: typing.ClassVar[list[type["BaseContainer"]]] = [] + + _lock: Lock = Lock() def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]) -> type: new_cls = super().__new__(cls, name, bases, namespace) - if name != "BaseContainer": - cls.instances.append(new_cls) # type: ignore[arg-type] + with cls._lock: + if name != "BaseContainer": + cls._instances.append(new_cls) # type: ignore[arg-type] return new_cls @classmethod def get_instances(cls) -> list[type["BaseContainer"]]: - return cls.instances + return cls._instances From 0913b96c097696c3686900900bb61a459b452445 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 8 Jan 2025 09:52:38 +0100 Subject: [PATCH 05/24] Added tests and enabled context for containers. --- test.py | 223 ------------------- tests/providers/test_context_resources.py | 228 ++++++++++++++++++++ that_depends/container.py | 41 +++- that_depends/providers/context_resources.py | 8 +- that_depends/providers/resources.py | 4 +- 5 files changed, 274 insertions(+), 230 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index 353535af..00000000 --- a/test.py +++ /dev/null @@ -1,223 +0,0 @@ -import asyncio -import random -import typing - -import pytest -from pydantic import BaseModel - -from that_depends.container import BaseContainer -from that_depends.injection import Provide, inject -from that_depends.providers.context_resources import ContextResource, container_context, fetch_context_item - - -random.seed(1) - - -class Config(BaseModel): - some_str_value: str = "some_string_value" - - -async def async_yields_string() -> typing.AsyncIterator[str]: - yield str(random.random()) # noqa: S311 - - -def sync_yields_string() -> typing.Iterator[str]: - yield str(random.random()) # noqa: S311 - - -def sync_yields_config() -> typing.Iterator[Config]: - yield Config() - - -class MyContainer(BaseContainer): - async_resource: ContextResource[str] = ContextResource(async_yields_string) - sync_resource: ContextResource[str] = ContextResource(sync_yields_string) - sync_config: ContextResource[Config] = ContextResource(sync_yields_config) - - -@MyContainer.sync_resource.sync_context() -@inject -def sync_injected(val: str = Provide[MyContainer.sync_resource]) -> str: - return val - - -@MyContainer.async_resource.async_context() -@inject -async def async_injected(val: str = Provide[MyContainer.async_resource]) -> str: - return val - - -@MyContainer.async_resource.context -@inject -async def async_injected_implicit(val: str = Provide[MyContainer.async_resource]) -> str: - return val - - -@MyContainer.sync_resource.context -@inject -def sync_injected_implicit(val: str = Provide[MyContainer.sync_resource]) -> str: - return val - - -async def test_injected() -> None: - async_result = await async_injected() - sync_result = sync_injected() - async_result_implicit = await async_injected_implicit() - sync_result_implicit = sync_injected_implicit() - assert isinstance(async_result, str) - assert isinstance(sync_result, str) - assert isinstance(async_result_implicit, str) - assert isinstance(sync_result_implicit, str) - - -async def async_main() -> None: - """Test async resolution.""" - async with MyContainer.async_resource.async_context(): - val_1 = await MyContainer.async_resource.async_resolve() - val_2 = await MyContainer.async_resource.async_resolve() - assert val_1 == val_2 - async with MyContainer.async_resource.async_context(): - val_3 = await MyContainer.async_resource.async_resolve() - assert val_1 != val_3 - async with MyContainer.async_resource.async_context(): - val_4 = await MyContainer.async_resource.async_resolve() - assert val_1 != val_4 != val_3 - val_5 = await MyContainer.async_resource.async_resolve() - assert val_5 == val_3 - val_6 = await MyContainer.async_resource.async_resolve() - assert val_6 == val_1 - - -def sync_main() -> None: - """Test sync resolution.""" - with MyContainer.sync_resource.sync_context(): - val_1 = MyContainer.sync_resource.sync_resolve() - val_2 = MyContainer.sync_resource.sync_resolve() - assert val_1 == val_2 - with MyContainer.sync_resource.sync_context(): - val_3 = MyContainer.sync_resource.sync_resolve() - assert val_1 != val_3 - with MyContainer.sync_resource.sync_context(): - val_4 = MyContainer.sync_resource.sync_resolve() - assert val_1 != val_4 != val_3 - val_5 = MyContainer.sync_resource.sync_resolve() - assert val_5 == val_1 - - -def check_sync_container_context() -> None: - """Test sync provider resolution container_context.""" - with container_context(providers=[MyContainer.sync_resource]): - val_1 = MyContainer.sync_resource.sync_resolve() - val_2 = MyContainer.sync_resource.sync_resolve() - assert val_1 == val_2 - with container_context(providers=[MyContainer.sync_resource]): - val_3 = MyContainer.sync_resource.sync_resolve() - assert val_3 != val_1 - - val_4 = MyContainer.sync_resource.sync_resolve() - assert val_4 == val_1 - with pytest.raises(RuntimeError): - MyContainer.sync_resource.sync_resolve() - - -async def check_async_container_context() -> None: - """Test async provider resolution container_context.""" - async with container_context(providers=[MyContainer.async_resource]): - val_1 = await MyContainer.async_resource.async_resolve() - val_2 = await MyContainer.async_resource.async_resolve() - assert val_1 == val_2 - async with container_context(providers=[MyContainer.async_resource]): - val_3 = await MyContainer.async_resource.async_resolve() - assert val_3 != val_1 - - val_4 = await MyContainer.async_resource.async_resolve() - assert val_4 == val_1 - with pytest.raises(RuntimeError): - await MyContainer.async_resource.async_resolve() - - -async def check_async_global_passing() -> None: - with pytest.raises(RuntimeError): - async with container_context(preserve_globals=True) as gs: - assert gs - my_global_resources = {"test_1": "test_1", "test_2": "test_2"} - - async with container_context(initial_context=my_global_resources): - for key, item in my_global_resources.items(): - assert fetch_context_item(key) == item - - async with container_context(preserve_globals=True): - for key, item in my_global_resources.items(): - assert fetch_context_item(key) == item - - async with container_context(preserve_globals=False): - for key in my_global_resources: - assert fetch_context_item(key) is None - - -def check_sync_global_passing() -> None: - with pytest.raises(RuntimeError), container_context(preserve_globals=True) as gs: - assert gs - my_global_resources = {"test_1": "test_1", "test_2": "test_2"} - - with container_context(initial_context=my_global_resources): - for key, item in my_global_resources.items(): - assert fetch_context_item(key) == item - - with container_context(preserve_globals=True): - for key, item in my_global_resources.items(): - assert fetch_context_item(key) == item - - with container_context(preserve_globals=False): - for key in my_global_resources: - assert fetch_context_item(key) is None - - -async def test_reset_context_async() -> None: - async with container_context(): - val_1 = await MyContainer.async_resource.async_resolve() - - async with container_context(): - val_2 = await MyContainer.async_resource.async_resolve() - assert val_1 != val_2 - - -def test_reset_context_sync() -> None: - with container_context(): - val_1 = MyContainer.sync_resource.sync_resolve() - with container_context(): - val_2 = MyContainer.sync_resource.sync_resolve() - assert val_1 != val_2 - - -@container_context() -def test_attr_getter_sync() -> None: - assert MyContainer.sync_config.sync_resolve().some_str_value - - -async def test_inject_sync_into_async() -> None: - @MyContainer.async_resource.context - async def _inner() -> str: - return await MyContainer.async_resource.async_resolve() - - @MyContainer.sync_resource.context - async def _sync_injected() -> str: - return MyContainer.sync_resource.sync_resolve() - - value_1 = await _inner() - value_2 = await _sync_injected() - assert value_1 != value_2 - - -if __name__ == "__main__": - asyncio.run(async_main()) - sync_main() - check_sync_container_context() - asyncio.run(check_async_container_context()) - asyncio.run(test_injected()) - asyncio.run(check_async_global_passing()) - check_sync_global_passing() - asyncio.run(test_reset_context_async()) - test_reset_context_sync() - test_attr_getter_sync() - asyncio.run(test_inject_sync_into_async()) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index ea035431..4b1d6508 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -272,3 +272,231 @@ async def resolve_resource() -> str: await asyncio.gather(resolve_resource(), resolve_resource()) assert calls == 1 + + +async def test_async_injection_when_explicitly_resetting_resource_specific_context( + async_context_resource: providers.ContextResource[str], +) -> None: + """Async context resources should be able to reset the context for themselves explicitly.""" + + @async_context_resource.async_context() + @inject + async def _async_injected(val: str = Provide[async_context_resource]) -> str: + assert isinstance(async_context_resource._fetch_context().context_stack, AsyncExitStack) # noqa: SLF001 + return val + + async_result = await _async_injected() + assert async_result != await _async_injected() + assert isinstance(async_result, str) + + +async def test_sync_injection_when_explicitly_resetting_resource_specific_context( + sync_context_resource: providers.ContextResource[str], +) -> None: + """Sync context resources should be able to reset the context for themselves explicitly.""" + + @sync_context_resource.async_context() + @inject + async def _async_injected(val: str = Provide[sync_context_resource]) -> str: + assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001 + return val + + @sync_context_resource.sync_context() + @inject + def _sync_injected(val: str = Provide[sync_context_resource]) -> str: + assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001 + return val + + async_result = await _async_injected() + assert async_result != await _async_injected() + assert isinstance(async_result, str) + sync_result = _sync_injected() + assert sync_result != _sync_injected() + assert isinstance(sync_result, str) + + +async def test_async_resolution_when_explicitly_resolving( + async_context_resource: providers.ContextResource[str], +) -> None: + """Async context should cache resources until a new one is created.""" + async with async_context_resource.async_context(): + val_1 = await async_context_resource.async_resolve() + val_2 = await async_context_resource.async_resolve() + assert val_1 == val_2 + async with async_context_resource.async_context(): + val_3 = await async_context_resource.async_resolve() + assert val_1 != val_3 + async with async_context_resource.async_context(): + val_4 = await async_context_resource.async_resolve() + assert val_1 != val_4 != val_3 + val_5 = await async_context_resource.async_resolve() + assert val_5 == val_3 + val_6 = await async_context_resource.async_resolve() + assert val_6 == val_1 + + +def test_sync_resolution_when_explicitly_resolving( + sync_context_resource: providers.ContextResource[str], +) -> None: + """Sync context should cache resources until a new one is created.""" + with sync_context_resource.sync_context(): + val_1 = sync_context_resource.sync_resolve() + val_2 = sync_context_resource.sync_resolve() + assert val_1 == val_2 + with sync_context_resource.sync_context(): + val_3 = sync_context_resource.sync_resolve() + assert val_1 != val_3 + with sync_context_resource.sync_context(): + val_4 = sync_context_resource.sync_resolve() + assert val_1 != val_4 != val_3 + val_5 = sync_context_resource.sync_resolve() + assert val_5 == val_3 + val_6 = sync_context_resource.sync_resolve() + assert val_6 == val_1 + + +def test_sync_container_context_resolution( + sync_context_resource: providers.ContextResource[str], +) -> None: + """container_context should reset context for sync provider.""" + with container_context(providers=[sync_context_resource]): + val_1 = sync_context_resource.sync_resolve() + val_2 = sync_context_resource.sync_resolve() + assert val_1 == val_2 + with container_context(providers=[sync_context_resource]): + val_3 = sync_context_resource.sync_resolve() + assert val_3 != val_1 + val_4 = sync_context_resource.sync_resolve() + assert val_4 == val_1 + with pytest.raises(RuntimeError): + sync_context_resource.sync_resolve() + + +async def test_async_container_context_resolution( + async_context_resource: providers.ContextResource[str], +) -> None: + """container_context should reset context for async provider.""" + async with container_context(providers=[async_context_resource]): + val_1 = await async_context_resource.async_resolve() + val_2 = await async_context_resource.async_resolve() + assert val_1 == val_2 + async with container_context(providers=[async_context_resource]): + val_3 = await async_context_resource.async_resolve() + assert val_3 != val_1 + val_4 = await async_context_resource.async_resolve() + assert val_4 == val_1 + with pytest.raises(RuntimeError): + await async_context_resource.async_resolve() + + +async def test_async_global_context_resolution() -> None: + with pytest.raises(RuntimeError): + async with container_context(preserve_globals=True) as gs: + assert gs + my_global_resources = {"test_1": "test_1", "test_2": "test_2"} + + async with container_context(initial_context=my_global_resources): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + async with container_context(preserve_globals=True): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + async with container_context(preserve_globals=False): + for key in my_global_resources: + assert fetch_context_item(key) is None + + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + with pytest.raises(RuntimeError): + fetch_context_item("test_1") + + +def test_sync_global_context_resolution() -> None: + with pytest.raises(RuntimeError), container_context(preserve_globals=True) as gs: + assert gs + my_global_resources = {"test_1": "test_1", "test_2": "test_2"} + with container_context(initial_context=my_global_resources): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + with container_context(preserve_globals=True): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + with container_context(preserve_globals=False): + for key in my_global_resources: + assert fetch_context_item(key) is None + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + with pytest.raises(RuntimeError): + fetch_context_item("test_1") + + +async def test_async_global_context_reset(async_context_resource: providers.ContextResource[str]) -> None: + """container_context should reset async providers.""" + async with container_context(): + val_1 = await async_context_resource.async_resolve() + val_2 = await async_context_resource.async_resolve() + assert val_1 == val_2 + async with container_context(): + val_3 = await async_context_resource.async_resolve() + assert val_3 != val_1 + val_4 = await async_context_resource.async_resolve() + assert val_4 == val_1 + + +def test_sync_global_context_reset(sync_context_resource: providers.ContextResource[str]) -> None: + """container_context should reset sync providers.""" + with container_context(): + val_1 = sync_context_resource.sync_resolve() + val_2 = sync_context_resource.sync_resolve() + assert val_1 == val_2 + with container_context(): + val_3 = sync_context_resource.sync_resolve() + assert val_3 != val_1 + val_4 = sync_context_resource.sync_resolve() + assert val_4 == val_1 + + +async def test_async_context_with_container( + async_context_resource: providers.ContextResource[str], + sync_context_resource: providers.ContextResource[str], +) -> None: + """Containers should enter async context for all its providers.""" + async with DIContainer.async_context(): + val_1 = await async_context_resource.async_resolve() + val_2 = await async_context_resource.async_resolve() + assert val_1 == val_2 + val_1_sync = sync_context_resource.sync_resolve() + val_2_sync = sync_context_resource.sync_resolve() + assert val_1_sync == val_2_sync + async with DIContainer.async_context(): + val_3 = await async_context_resource.async_resolve() + val_3_sync = sync_context_resource.sync_resolve() + assert val_3 != val_1 + assert val_3_sync != val_1_sync + val_4 = await async_context_resource.async_resolve() + val_4_sync = sync_context_resource.sync_resolve() + assert val_4 == val_1 + assert val_4_sync == val_1_sync + + +def test_sync_context_with_container( + sync_context_resource: providers.ContextResource[str], +) -> None: + """Containers should enter sync context for all its providers.""" + with DIContainer.sync_context(): + val_1 = sync_context_resource.sync_resolve() + val_2 = sync_context_resource.sync_resolve() + assert val_1 == val_2 + with DIContainer.sync_context(): + val_3 = sync_context_resource.sync_resolve() + assert val_3 != val_1 + val_4 = sync_context_resource.sync_resolve() + assert val_4 == val_1 diff --git a/that_depends/container.py b/that_depends/container.py index 5349c13f..5d47f853 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -1,10 +1,11 @@ import inspect import typing import warnings -from contextlib import contextmanager +from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from that_depends.meta import BaseContainerMeta from that_depends.providers import AbstractProvider, Resource, Singleton +from that_depends.providers.context_resources import ContextResource if typing.TYPE_CHECKING: @@ -23,6 +24,44 @@ def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": msg = f"{cls.__name__} should not be instantiated" raise RuntimeError(msg) + @classmethod + @contextmanager + def sync_context(cls) -> typing.Iterator[None]: + with ExitStack() as stack: + for container in cls.get_containers(): + stack.enter_context(container.sync_context()) + for provider in cls.get_providers().values(): + if isinstance(provider, ContextResource) and not provider.is_async: + stack.enter_context(provider.sync_context()) + yield + + @classmethod + @asynccontextmanager + async def async_context(cls) -> typing.AsyncIterator[None]: + async with AsyncExitStack() as stack: + for container in cls.get_containers(): + await stack.enter_async_context(container.async_context()) + for provider in cls.get_providers().values(): + if isinstance(provider, ContextResource): + await stack.enter_async_context(provider.async_context()) + yield + + @classmethod + def context(cls, func: typing.Callable[P, T]) -> typing.Callable[P, T]: + if inspect.iscoroutinefunction(func): + + async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + async with cls.async_context(): + return await func(*args, **kwargs) # type: ignore[no-any-return] + + return typing.cast(typing.Callable[P, T], _async_wrapper) + + def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + with cls.sync_context(): + return func(*args, **kwargs) + + return _sync_wrapper + @classmethod def connect_containers(cls, *containers: type["BaseContainer"]) -> None: """Connect containers. diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 805a9dfc..fef9097d 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -55,14 +55,14 @@ class ContextResource( AbstractContextManager[ResourceContext[T_co]], ): __slots__ = ( - "is_async", - "_creator", "_args", + "_context", + "_creator", + "_internal_name", "_kwargs", "_override", - "_internal_name", - "_context", "_token", + "is_async", ) def __repr__(self) -> str: diff --git a/that_depends/providers/resources.py b/that_depends/providers/resources.py index f049913a..45b0e40e 100644 --- a/that_depends/providers/resources.py +++ b/that_depends/providers/resources.py @@ -11,14 +11,14 @@ class Resource(AbstractResource[T_co]): __slots__ = ( - "is_async", - "_creator", "_args", "_context", "_creator", + "_creator", "_is_async", "_kwargs", "_override", + "is_async", ) def __init__( From a7d0dcb91efaa9fd8e93103d85ad8290108d54f5 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 8 Jan 2025 10:00:15 +0100 Subject: [PATCH 06/24] Added additional tests for container context wrappers. --- tests/providers/test_context_resources.py | 36 +++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 4b1d6508..01dadf7e 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -500,3 +500,39 @@ def test_sync_context_with_container( assert val_3 != val_1 val_4 = sync_context_resource.sync_resolve() assert val_4 == val_1 + + +async def test_async_container_context_wrapper(async_context_resource: providers.ContextResource[str]) -> None: + """Container context wrapper should correctly enter async context for wrapped function.""" + + @DIContainer.context + @inject + async def _injected(val: str = Provide[async_context_resource]) -> str: + return val + + assert await _injected() != await _injected() + + @DIContainer.async_context() + @inject + async def _explicit_injected(val: str = Provide[async_context_resource]) -> str: + return val + + assert await _explicit_injected() != await _explicit_injected() + + +def test_sync_container_context_wrapper(sync_context_resource: providers.ContextResource[str]) -> None: + """Container context wrapper should correctly enter sync context for wrapped function.""" + + @DIContainer.context + @inject + def _injected(val: str = Provide[sync_context_resource]) -> str: + return val + + assert _injected() != _injected() + + @DIContainer.sync_context() + @inject + def _explicit_injected(val: str = Provide[sync_context_resource]) -> str: + return val + + assert _explicit_injected() != _explicit_injected() From 511c7dcc6dd42eb5b7e34c9c54684dc7e8007cc2 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 8 Jan 2025 10:46:02 +0100 Subject: [PATCH 07/24] Added tests for dependent containers. --- tests/providers/test_context_resources.py | 32 ++++++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 01dadf7e..35906f9e 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -37,6 +37,14 @@ class DIContainer(BaseContainer): ) +class DependentDiContainer(BaseContainer): + dependent_sync_context_resource = providers.ContextResource(create_sync_context_resource) + dependent_async_context_resource = providers.ContextResource(create_async_context_resource) + + +DIContainer.connect_containers(DependentDiContainer) + + @pytest.fixture(autouse=True) async def _clear_di_container() -> typing.AsyncIterator[None]: try: @@ -391,8 +399,8 @@ async def test_async_container_context_resolution( async def test_async_global_context_resolution() -> None: with pytest.raises(RuntimeError): - async with container_context(preserve_globals=True) as gs: - assert gs + async with AsyncExitStack() as stack: + await stack.enter_async_context(container_context(preserve_globals=True)) my_global_resources = {"test_1": "test_1", "test_2": "test_2"} async with container_context(initial_context=my_global_resources): @@ -417,8 +425,8 @@ async def test_async_global_context_resolution() -> None: def test_sync_global_context_resolution() -> None: - with pytest.raises(RuntimeError), container_context(preserve_globals=True) as gs: - assert gs + with pytest.raises(RuntimeError), ExitStack() as stack: + stack.enter_context(container_context(preserve_globals=True)) my_global_resources = {"test_1": "test_1", "test_2": "test_2"} with container_context(initial_context=my_global_resources): for key, item in my_global_resources.items(): @@ -536,3 +544,19 @@ def _explicit_injected(val: str = Provide[sync_context_resource]) -> str: return val assert _explicit_injected() != _explicit_injected() + + +async def test_async_context_resource_with_dependent_container() -> None: + """Container should init async context resource for depedent containers.""" + async with DIContainer.async_context(): + val_1 = await DependentDiContainer.dependent_async_context_resource.async_resolve() + val_2 = await DependentDiContainer.dependent_async_context_resource.async_resolve() + assert val_1 == val_2 + + +def test_sync_context_resource_with_dependent_container() -> None: + """Container should init sync context resource for depedent containers.""" + with DIContainer.sync_context(): + val_1 = DependentDiContainer.dependent_sync_context_resource.sync_resolve() + val_2 = DependentDiContainer.dependent_sync_context_resource.sync_resolve() + assert val_1 == val_2 From 57a642a5ccc03bf0b600ff14582b5f22ede0194d Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 8 Jan 2025 13:23:35 +0100 Subject: [PATCH 08/24] Added SupportsContext interface. --- that_depends/container.py | 6 +++--- that_depends/meta.py | 3 ++- that_depends/providers/context_resources.py | 24 ++++++++++++++++++++- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/that_depends/container.py b/that_depends/container.py index 5d47f853..f0035c4e 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -5,7 +5,7 @@ from that_depends.meta import BaseContainerMeta from that_depends.providers import AbstractProvider, Resource, Singleton -from that_depends.providers.context_resources import ContextResource +from that_depends.providers.context_resources import ContextResource, SupportsContext if typing.TYPE_CHECKING: @@ -16,7 +16,7 @@ P = typing.ParamSpec("P") -class BaseContainer(metaclass=BaseContainerMeta): +class BaseContainer(SupportsContext[None], metaclass=BaseContainerMeta): providers: dict[str, AbstractProvider[typing.Any]] containers: list[type["BaseContainer"]] @@ -37,7 +37,7 @@ def sync_context(cls) -> typing.Iterator[None]: @classmethod @asynccontextmanager - async def async_context(cls) -> typing.AsyncIterator[None]: + async def async_context(cls) -> typing.AsyncIterator[None]: # type: ignore[override] async with AsyncExitStack() as stack: for container in cls.get_containers(): await stack.enter_async_context(container.async_context()) diff --git a/that_depends/meta.py b/that_depends/meta.py index db5e8aa7..7ac08b8f 100644 --- a/that_depends/meta.py +++ b/that_depends/meta.py @@ -1,3 +1,4 @@ +import abc import typing from threading import Lock @@ -6,7 +7,7 @@ from that_depends.container import BaseContainer -class BaseContainerMeta(type): +class BaseContainerMeta(abc.ABCMeta): _instances: typing.ClassVar[list[type["BaseContainer"]]] = [] _lock: Lock = Lock() diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index fef9097d..2d0d3bf0 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -1,8 +1,10 @@ +import abc import contextlib import inspect import logging import typing import warnings +from abc import abstractmethod from contextlib import AbstractAsyncContextManager, AbstractContextManager from contextvars import ContextVar, Token from functools import wraps @@ -47,12 +49,32 @@ def fetch_context_item(key: str, default: typing.Any = None) -> typing.Any: # n T = typing.TypeVar("T") +CT = typing.TypeVar("CT") + + +class SupportsContext(typing.Generic[CT], abc.ABC): + @abstractmethod + def context(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]: + """Initialize context for the given function. + + :param func: function to wrap. + :return: wrapped function with context. + """ + + @abstractmethod + async def async_context(self) -> typing.AsyncContextManager[CT]: + """Initialize async context.""" + + @abstractmethod + def sync_context(self) -> typing.ContextManager[CT]: + """Initialize sync context.""" class ContextResource( AbstractResource[T_co], AbstractAsyncContextManager[ResourceContext[T_co]], AbstractContextManager[ResourceContext[T_co]], + SupportsContext[ResourceContext[T_co]], ): __slots__ = ( "_args", @@ -132,7 +154,7 @@ def sync_context(self) -> typing.Iterator[ResourceContext[T_co]]: self._token = token @contextlib.asynccontextmanager - async def async_context(self) -> typing.AsyncIterator[ResourceContext[T_co]]: + async def async_context(self) -> typing.AsyncIterator[ResourceContext[T_co]]: # type: ignore[override] token = self._token async with self as val: yield val From df5ed5ff28a73c211853c8a343eb8e8b45884159 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 8 Jan 2025 14:30:02 +0100 Subject: [PATCH 09/24] Changed container context api to support SupportContext items. --- .../fastapi/test_fastapi_di_pass_request.py | 2 +- .../test_faststream_di_pass_message.py | 4 +- tests/providers/test_attr_getter.py | 24 +++++----- tests/providers/test_context_resources.py | 12 ++--- that_depends/container.py | 3 ++ that_depends/providers/context_resources.py | 46 ++++++++----------- 6 files changed, 42 insertions(+), 49 deletions(-) diff --git a/tests/integrations/fastapi/test_fastapi_di_pass_request.py b/tests/integrations/fastapi/test_fastapi_di_pass_request.py index 133bfc27..9c8176f1 100644 --- a/tests/integrations/fastapi/test_fastapi_di_pass_request.py +++ b/tests/integrations/fastapi/test_fastapi_di_pass_request.py @@ -9,7 +9,7 @@ async def init_di_context(request: fastapi.Request) -> typing.AsyncIterator[None]: - async with container_context({"request": request}): + async with container_context(initial_context={"request": request}): yield diff --git a/tests/integrations/faststream/test_faststream_di_pass_message.py b/tests/integrations/faststream/test_faststream_di_pass_message.py index fc07835e..b54ee551 100644 --- a/tests/integrations/faststream/test_faststream_di_pass_message.py +++ b/tests/integrations/faststream/test_faststream_di_pass_message.py @@ -14,7 +14,7 @@ async def consume_scope( call_next: typing.Callable[..., typing.Awaitable[typing.Any]], msg: StreamMessage[typing.Any], ) -> typing.Any: # noqa: ANN401 - async with container_context({"request": msg}): + async with container_context(initial_context={"request": msg}): return await call_next(msg) @@ -30,7 +30,7 @@ class DIContainer(BaseContainer): @broker.subscriber(TEST_SUBJECT) -async def index_subscruber( +async def index_subscriber( context_request: typing.Annotated[ NatsMessage, Depends(DIContainer.context_request, cast=False), diff --git a/tests/providers/test_attr_getter.py b/tests/providers/test_attr_getter.py index 1af8991d..11185227 100644 --- a/tests/providers/test_attr_getter.py +++ b/tests/providers/test_attr_getter.py @@ -71,10 +71,10 @@ def some_async_settings_provider(request: pytest.FixtureRequest) -> providers.Ab def test_attr_getter_with_zero_attribute_depth_sync( some_sync_settings_provider: providers.AbstractProvider[Settings], ) -> None: - with container_context( - providers=[some_sync_settings_provider] + with ( + container_context(some_sync_settings_provider) if isinstance(some_sync_settings_provider, providers.ContextResource) - else [] + else container_context() ): attr_getter = some_sync_settings_provider.some_str_value assert attr_getter.sync_resolve() == Settings().some_str_value @@ -83,10 +83,10 @@ def test_attr_getter_with_zero_attribute_depth_sync( async def test_attr_getter_with_zero_attribute_depth_async( some_async_settings_provider: providers.AbstractProvider[Settings], ) -> None: - async with container_context( - providers=[some_async_settings_provider] + async with ( + container_context(some_async_settings_provider) if isinstance(some_async_settings_provider, providers.ContextResource) - else [] + else container_context() ): attr_getter = some_async_settings_provider.some_str_value assert await attr_getter.async_resolve() == Settings().some_str_value @@ -95,10 +95,10 @@ async def test_attr_getter_with_zero_attribute_depth_async( def test_attr_getter_with_more_than_zero_attribute_depth_sync( some_sync_settings_provider: providers.AbstractProvider[Settings], ) -> None: - with container_context( - providers=[some_sync_settings_provider] + with ( + container_context(some_sync_settings_provider) if isinstance(some_sync_settings_provider, providers.ContextResource) - else [] + else container_context() ): attr_getter = some_sync_settings_provider.nested1_attr.nested2_attr.some_const assert attr_getter.sync_resolve() == Nested2().some_const @@ -107,10 +107,10 @@ def test_attr_getter_with_more_than_zero_attribute_depth_sync( async def test_attr_getter_with_more_than_zero_attribute_depth_async( some_async_settings_provider: providers.AbstractProvider[Settings], ) -> None: - async with container_context( - providers=[some_async_settings_provider] + async with ( + container_context(some_async_settings_provider) if isinstance(some_async_settings_provider, providers.ContextResource) - else [] + else container_context() ): attr_getter = some_async_settings_provider.nested1_attr.nested2_attr.some_const assert await attr_getter.async_resolve() == Nested2().some_const diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 35906f9e..90ced3f9 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -147,10 +147,10 @@ def test_context_resources_wrong_providers_init() -> None: async def test_context_resource_with_dynamic_resource() -> None: - async with container_context({"resource_type": "sync"}): + async with container_context(initial_context={"resource_type": "sync"}): assert (await DIContainer.dynamic_context_resource()).startswith("sync") - async with container_context({"resource_type": "async_"}): + async with container_context(initial_context={"resource_type": "async_"}): assert (await DIContainer.dynamic_context_resource()).startswith("async") async with container_context(): @@ -367,11 +367,11 @@ def test_sync_container_context_resolution( sync_context_resource: providers.ContextResource[str], ) -> None: """container_context should reset context for sync provider.""" - with container_context(providers=[sync_context_resource]): + with container_context(sync_context_resource): val_1 = sync_context_resource.sync_resolve() val_2 = sync_context_resource.sync_resolve() assert val_1 == val_2 - with container_context(providers=[sync_context_resource]): + with container_context(sync_context_resource): val_3 = sync_context_resource.sync_resolve() assert val_3 != val_1 val_4 = sync_context_resource.sync_resolve() @@ -384,11 +384,11 @@ async def test_async_container_context_resolution( async_context_resource: providers.ContextResource[str], ) -> None: """container_context should reset context for async provider.""" - async with container_context(providers=[async_context_resource]): + async with container_context(async_context_resource): val_1 = await async_context_resource.async_resolve() val_2 = await async_context_resource.async_resolve() assert val_1 == val_2 - async with container_context(providers=[async_context_resource]): + async with container_context(async_context_resource): val_3 = await async_context_resource.async_resolve() assert val_3 != val_1 val_4 = await async_context_resource.async_resolve() diff --git a/that_depends/container.py b/that_depends/container.py index f0035c4e..3d4baecd 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -24,6 +24,9 @@ def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": msg = f"{cls.__name__} should not be instantiated" raise RuntimeError(msg) + def supports_sync_context(self) -> bool: + return True + @classmethod @contextmanager def sync_context(cls) -> typing.Iterator[None]: diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 2d0d3bf0..79138f0c 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -69,6 +69,10 @@ async def async_context(self) -> typing.AsyncContextManager[CT]: def sync_context(self) -> typing.ContextManager[CT]: """Initialize sync context.""" + @abstractmethod + def supports_sync_context(self) -> bool: + """Check if the resource supports sync context.""" + class ContextResource( AbstractResource[T_co], @@ -100,6 +104,9 @@ def __init__( self._context: ContextVar[ResourceContext[T_co]] = ContextVar(f"{self._creator.__name__}-context") self._token: Token[ResourceContext[T_co]] | None = None + def supports_sync_context(self) -> bool: + return not self.is_async + def __enter__(self) -> ResourceContext[T_co]: if self.is_async: msg = "You must enter async context for async creators." @@ -206,9 +213,8 @@ class container_context(AbstractContextManager[ContextType], AbstractAsyncContex def __init__( self, + *args: SupportsContext[typing.Any], initial_context: ContextType | None = None, - providers: list[ContextResource[typing.Any]] | None = None, - containers: list[ContainerType] | None = None, preserve_globals: bool = False, reset_resource_context: bool = False, ) -> None: @@ -225,17 +231,8 @@ def __init__( else: self._initial_context: ContextType = _get_container_context() if preserve_globals else initial_context or {} # type: ignore[no-redef] self._context_token: Token[ContextType] | None = None - self._providers: set[ContextResource[typing.Any]] = set() - self._reset_resource_context: typing.Final[bool] = (not containers and not providers) or reset_resource_context - if providers: - for provider in providers: - if isinstance(provider, ContextResource): - self._providers.add(provider) - else: - msg = "Provider is not a ContextResource" - raise TypeError(msg) - if containers: - self._add_providers_from_containers(containers) + self._context_items: set[SupportsContext[typing.Any]] = set(args) + self._reset_resource_context: typing.Final[bool] = (not args) or reset_resource_context if self._reset_resource_context: self._add_providers_from_containers(BaseContainerMeta.get_instances()) @@ -245,22 +242,19 @@ def _add_providers_from_containers(self, containers: list[ContainerType]) -> Non for container in containers: for container_provider in container.get_providers().values(): if isinstance(container_provider, ContextResource): - self._providers.add(container_provider) + self._context_items.add(container_provider) def __enter__(self) -> ContextType: self._context_stack = contextlib.ExitStack() - for provider in self._providers: - if self._reset_resource_context: - if not provider.is_async: - self._context_stack.enter_context(provider.sync_context()) - else: - self._context_stack.enter_context(provider.sync_context()) + for item in self._context_items: + if item.supports_sync_context(): + self._context_stack.enter_context(item.sync_context()) return self._enter_globals() async def __aenter__(self) -> ContextType: self._context_stack = contextlib.AsyncExitStack() - for provider in self._providers: - await self._context_stack.enter_async_context(provider.async_context()) + for item in self._context_items: + await self._context_stack.enter_async_context(item.async_context()) # type: ignore[arg-type] return self._enter_globals() def _enter_globals(self) -> ContextType: @@ -316,18 +310,14 @@ def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]: @wraps(func) async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: - async with container_context( - providers=list(self._providers), reset_resource_context=self._reset_resource_context - ): + async with container_context(*self._context_items, reset_resource_context=self._reset_resource_context): return await func(*args, **kwargs) # type: ignore[no-any-return] return typing.cast(typing.Callable[P, T_co], _async_inner) @wraps(func) def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: - with container_context( - providers=list(self._providers), reset_resource_context=self._reset_resource_context - ): + with container_context(*self._context_items, reset_resource_context=self._reset_resource_context): return func(*args, **kwargs) return _sync_inner From 11d24bdaf91a874d0c61ca0e7f84691df704ec80 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 8 Jan 2025 14:49:27 +0100 Subject: [PATCH 10/24] Added tests for coverage. --- tests/providers/test_context_resources.py | 52 ++++++++++++++++++++- that_depends/container.py | 3 +- that_depends/providers/base.py | 4 -- that_depends/providers/context_resources.py | 3 -- 4 files changed, 52 insertions(+), 10 deletions(-) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 90ced3f9..4e141e9c 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -547,7 +547,7 @@ def _explicit_injected(val: str = Provide[sync_context_resource]) -> str: async def test_async_context_resource_with_dependent_container() -> None: - """Container should init async context resource for depedent containers.""" + """Container should initialize async context resource for dependent containers.""" async with DIContainer.async_context(): val_1 = await DependentDiContainer.dependent_async_context_resource.async_resolve() val_2 = await DependentDiContainer.dependent_async_context_resource.async_resolve() @@ -555,8 +555,56 @@ async def test_async_context_resource_with_dependent_container() -> None: def test_sync_context_resource_with_dependent_container() -> None: - """Container should init sync context resource for depedent containers.""" + """Container should initialize sync context resource for dependent containers.""" with DIContainer.sync_context(): val_1 = DependentDiContainer.dependent_sync_context_resource.sync_resolve() val_2 = DependentDiContainer.dependent_sync_context_resource.sync_resolve() assert val_1 == val_2 + + +def test_containers_support_sync_context() -> None: + assert DIContainer.supports_sync_context() + + +def test_enter_sync_context_for_async_resource_should_throw( + async_context_resource: providers.ContextResource[str], +) -> None: + with pytest.raises(RuntimeError): + async_context_resource.__enter__() + + +def test_exit_sync_context_before_enter_should_throw(sync_context_resource: providers.ContextResource[str]) -> None: + with pytest.raises(RuntimeError): + sync_context_resource.__exit__(None, None, None) + + +async def test_exit_async_context_before_enter_should_throw( + async_context_resource: providers.ContextResource[str], +) -> None: + with pytest.raises(RuntimeError): + await async_context_resource.__aexit__(None, None, None) + + +def test_enter_sync_context_from_async_resource_should_throw( + async_context_resource: providers.ContextResource[str], +) -> None: + with pytest.raises(RuntimeError), ExitStack() as stack: + stack.enter_context(async_context_resource.sync_context()) + + +async def test_preserve_globals_and_initial_context() -> None: + initial_context = {"test_1": "test_1", "test_2": "test_2"} + + async with container_context(initial_context=initial_context): + for key, item in initial_context.items(): + assert fetch_context_item(key) == item + new_context = {"test_3": "test_3"} + async with container_context(initial_context=new_context, preserve_globals=True): + for key, item in new_context.items(): + assert fetch_context_item(key) == item + for key, item in initial_context.items(): + assert fetch_context_item(key) == item + for key, item in initial_context.items(): + assert fetch_context_item(key) == item + for key in new_context: + assert fetch_context_item(key) is None diff --git a/that_depends/container.py b/that_depends/container.py index 3d4baecd..c584f1a0 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -24,7 +24,8 @@ def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": msg = f"{cls.__name__} should not be instantiated" raise RuntimeError(msg) - def supports_sync_context(self) -> bool: + @classmethod + def supports_sync_context(cls) -> bool: return True @classmethod diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index 05d4a662..4c66d7c7 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -121,10 +121,6 @@ async def async_resolve(self) -> T_co: if context.instance is not None: return context.instance - if not context.is_async and self.is_async: - msg = "AsyncResource cannot be resolved in an sync context." - raise RuntimeError(msg) - # lock to prevent race condition while resolving async with context.asyncio_lock: if context.instance is not None: diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 79138f0c..86daf30f 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -91,9 +91,6 @@ class ContextResource( "is_async", ) - def __repr__(self) -> str: - return f"ContextResource({self._creator.__name__})" - def __init__( self, creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]], From 579d04f043b645ab6bca96cb20eb7ef1865d2d28 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Jan 2025 10:53:49 +0100 Subject: [PATCH 11/24] Fixed a method signature & removed so redundant mypy ignores. --- that_depends/container.py | 2 +- that_depends/providers/context_resources.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/that_depends/container.py b/that_depends/container.py index c584f1a0..3935bc91 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -41,7 +41,7 @@ def sync_context(cls) -> typing.Iterator[None]: @classmethod @asynccontextmanager - async def async_context(cls) -> typing.AsyncIterator[None]: # type: ignore[override] + async def async_context(cls) -> typing.AsyncIterator[None]: async with AsyncExitStack() as stack: for container in cls.get_containers(): await stack.enter_async_context(container.async_context()) diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 86daf30f..155f1fad 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -62,7 +62,7 @@ def context(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]: """ @abstractmethod - async def async_context(self) -> typing.AsyncContextManager[CT]: + def async_context(self) -> typing.AsyncContextManager[CT]: """Initialize async context.""" @abstractmethod @@ -158,7 +158,7 @@ def sync_context(self) -> typing.Iterator[ResourceContext[T_co]]: self._token = token @contextlib.asynccontextmanager - async def async_context(self) -> typing.AsyncIterator[ResourceContext[T_co]]: # type: ignore[override] + async def async_context(self) -> typing.AsyncIterator[ResourceContext[T_co]]: token = self._token async with self as val: yield val @@ -251,7 +251,7 @@ def __enter__(self) -> ContextType: async def __aenter__(self) -> ContextType: self._context_stack = contextlib.AsyncExitStack() for item in self._context_items: - await self._context_stack.enter_async_context(item.async_context()) # type: ignore[arg-type] + await self._context_stack.enter_async_context(item.async_context()) return self._enter_globals() def _enter_globals(self) -> ContextType: From 7b3d24e3bd879867c4deb6790a405a3afcd1341d Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Jan 2025 11:20:39 +0100 Subject: [PATCH 12/24] Refactored container_context api. --- .../fastapi/test_fastapi_di_pass_request.py | 2 +- .../test_faststream_di_pass_message.py | 2 +- tests/providers/test_context_resources.py | 24 ++++++------ that_depends/providers/context_resources.py | 38 ++++++++++--------- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/tests/integrations/fastapi/test_fastapi_di_pass_request.py b/tests/integrations/fastapi/test_fastapi_di_pass_request.py index 9c8176f1..1ac273de 100644 --- a/tests/integrations/fastapi/test_fastapi_di_pass_request.py +++ b/tests/integrations/fastapi/test_fastapi_di_pass_request.py @@ -9,7 +9,7 @@ async def init_di_context(request: fastapi.Request) -> typing.AsyncIterator[None]: - async with container_context(initial_context={"request": request}): + async with container_context(global_context={"request": request}): yield diff --git a/tests/integrations/faststream/test_faststream_di_pass_message.py b/tests/integrations/faststream/test_faststream_di_pass_message.py index b54ee551..cc84c69b 100644 --- a/tests/integrations/faststream/test_faststream_di_pass_message.py +++ b/tests/integrations/faststream/test_faststream_di_pass_message.py @@ -14,7 +14,7 @@ async def consume_scope( call_next: typing.Callable[..., typing.Awaitable[typing.Any]], msg: StreamMessage[typing.Any], ) -> typing.Any: # noqa: ANN401 - async with container_context(initial_context={"request": msg}): + async with container_context(global_context={"request": msg}): return await call_next(msg) diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 4e141e9c..c17241e1 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -147,10 +147,10 @@ def test_context_resources_wrong_providers_init() -> None: async def test_context_resource_with_dynamic_resource() -> None: - async with container_context(initial_context={"resource_type": "sync"}): + async with container_context(global_context={"resource_type": "sync"}, reset_all_containers=True): assert (await DIContainer.dynamic_context_resource()).startswith("sync") - async with container_context(initial_context={"resource_type": "async_"}): + async with container_context(global_context={"resource_type": "async_"}, reset_all_containers=True): assert (await DIContainer.dynamic_context_resource()).startswith("async") async with container_context(): @@ -400,18 +400,18 @@ async def test_async_container_context_resolution( async def test_async_global_context_resolution() -> None: with pytest.raises(RuntimeError): async with AsyncExitStack() as stack: - await stack.enter_async_context(container_context(preserve_globals=True)) + await stack.enter_async_context(container_context(preserve_global_context=True)) my_global_resources = {"test_1": "test_1", "test_2": "test_2"} - async with container_context(initial_context=my_global_resources): + async with container_context(global_context=my_global_resources): for key, item in my_global_resources.items(): assert fetch_context_item(key) == item - async with container_context(preserve_globals=True): + async with container_context(preserve_global_context=True): for key, item in my_global_resources.items(): assert fetch_context_item(key) == item - async with container_context(preserve_globals=False): + async with container_context(preserve_global_context=False): for key in my_global_resources: assert fetch_context_item(key) is None @@ -426,15 +426,15 @@ async def test_async_global_context_resolution() -> None: def test_sync_global_context_resolution() -> None: with pytest.raises(RuntimeError), ExitStack() as stack: - stack.enter_context(container_context(preserve_globals=True)) + stack.enter_context(container_context(preserve_global_context=True)) my_global_resources = {"test_1": "test_1", "test_2": "test_2"} - with container_context(initial_context=my_global_resources): + with container_context(global_context=my_global_resources): for key, item in my_global_resources.items(): assert fetch_context_item(key) == item - with container_context(preserve_globals=True): + with container_context(preserve_global_context=True): for key, item in my_global_resources.items(): assert fetch_context_item(key) == item - with container_context(preserve_globals=False): + with container_context(preserve_global_context=False): for key in my_global_resources: assert fetch_context_item(key) is None for key, item in my_global_resources.items(): @@ -595,11 +595,11 @@ def test_enter_sync_context_from_async_resource_should_throw( async def test_preserve_globals_and_initial_context() -> None: initial_context = {"test_1": "test_1", "test_2": "test_2"} - async with container_context(initial_context=initial_context): + async with container_context(global_context=initial_context): for key, item in initial_context.items(): assert fetch_context_item(key) == item new_context = {"test_3": "test_3"} - async with container_context(initial_context=new_context, preserve_globals=True): + async with container_context(global_context=new_context, preserve_global_context=True): for key, item in new_context.items(): assert fetch_context_item(key) == item for key, item in initial_context.items(): diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 155f1fad..466c5001 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -210,26 +210,30 @@ class container_context(AbstractContextManager[ContextType], AbstractAsyncContex def __init__( self, - *args: SupportsContext[typing.Any], - initial_context: ContextType | None = None, - preserve_globals: bool = False, - reset_resource_context: bool = False, + *context_items: SupportsContext[typing.Any], + global_context: ContextType | None = None, + preserve_global_context: bool = False, + reset_all_containers: bool = False, ) -> None: - """Initialize a container context. + """Initialize a new container context. - :param initial_context: existing context to use - :param providers: providers to reset context of. - :param containers: containers to reset context of. - :param preserve_globals: whether to preserve global context vars. - :param reset_resource_context: whether to reset resource context. + :param context_items: context items to initialize new context for. + :param global_context: existing context to use + :param preserve_global_context: whether to preserve old global context. + Will merge old context with the new context if this option is set to True. + :param reset_all_containers: Create a new context for all containers. """ - if preserve_globals and initial_context: - self._initial_context = {**_get_container_context(), **initial_context} + if preserve_global_context and global_context: + self._initial_context = {**_get_container_context(), **global_context} else: - self._initial_context: ContextType = _get_container_context() if preserve_globals else initial_context or {} # type: ignore[no-redef] + self._initial_context: ContextType = ( # type: ignore[no-redef] + _get_container_context() if preserve_global_context else global_context or {} + ) self._context_token: Token[ContextType] | None = None - self._context_items: set[SupportsContext[typing.Any]] = set(args) - self._reset_resource_context: typing.Final[bool] = (not args) or reset_resource_context + self._context_items: set[SupportsContext[typing.Any]] = set(context_items) + self._reset_resource_context: typing.Final[bool] = ( + not context_items and not global_context + ) or reset_all_containers if self._reset_resource_context: self._add_providers_from_containers(BaseContainerMeta.get_instances()) @@ -307,14 +311,14 @@ def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]: @wraps(func) async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: - async with container_context(*self._context_items, reset_resource_context=self._reset_resource_context): + async with container_context(*self._context_items, reset_all_containers=self._reset_resource_context): return await func(*args, **kwargs) # type: ignore[no-any-return] return typing.cast(typing.Callable[P, T_co], _async_inner) @wraps(func) def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: - with container_context(*self._context_items, reset_resource_context=self._reset_resource_context): + with container_context(*self._context_items, reset_all_containers=self._reset_resource_context): return func(*args, **kwargs) return _sync_inner From 0f6c1a76eed6fcf51bd579a970fbaec45f76ee23 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Jan 2025 11:29:12 +0100 Subject: [PATCH 13/24] Reworked attr_getter tests. --- tests/providers/test_attr_getter.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/providers/test_attr_getter.py b/tests/providers/test_attr_getter.py index 11185227..5105d3c4 100644 --- a/tests/providers/test_attr_getter.py +++ b/tests/providers/test_attr_getter.py @@ -71,24 +71,22 @@ def some_async_settings_provider(request: pytest.FixtureRequest) -> providers.Ab def test_attr_getter_with_zero_attribute_depth_sync( some_sync_settings_provider: providers.AbstractProvider[Settings], ) -> None: - with ( - container_context(some_sync_settings_provider) - if isinstance(some_sync_settings_provider, providers.ContextResource) - else container_context() - ): - attr_getter = some_sync_settings_provider.some_str_value + attr_getter = some_sync_settings_provider.some_str_value + if isinstance(some_sync_settings_provider, providers.ContextResource): + with container_context(some_sync_settings_provider): + assert attr_getter.sync_resolve() == Settings().some_str_value + else: assert attr_getter.sync_resolve() == Settings().some_str_value async def test_attr_getter_with_zero_attribute_depth_async( some_async_settings_provider: providers.AbstractProvider[Settings], ) -> None: - async with ( - container_context(some_async_settings_provider) - if isinstance(some_async_settings_provider, providers.ContextResource) - else container_context() - ): - attr_getter = some_async_settings_provider.some_str_value + attr_getter = some_async_settings_provider.some_str_value + if isinstance(some_async_settings_provider, providers.ContextResource): + async with container_context(some_async_settings_provider): + assert await attr_getter.async_resolve() == Settings().some_str_value + else: assert await attr_getter.async_resolve() == Settings().some_str_value From 011807170c65bca5f442cbc534a9cd40c4fe8ebb Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Jan 2025 16:13:49 +0100 Subject: [PATCH 14/24] Extended DIContextMiddleware to support new Context features. --- tests/integrations/fastapi/test_fastapi_di.py | 79 +++++++++++-------- that_depends/providers/context_resources.py | 19 ++++- 2 files changed, 63 insertions(+), 35 deletions(-) diff --git a/tests/integrations/fastapi/test_fastapi_di.py b/tests/integrations/fastapi/test_fastapi_di.py index 2bcabbc8..2e3dcc30 100644 --- a/tests/integrations/fastapi/test_fastapi_di.py +++ b/tests/integrations/fastapi/test_fastapi_di.py @@ -2,45 +2,60 @@ import typing import fastapi +import pytest from starlette import status from starlette.testclient import TestClient from tests import container +from that_depends import fetch_context_item from that_depends.providers import DIContextMiddleware -app = fastapi.FastAPI() -app.add_middleware(DIContextMiddleware) - - -@app.get("/") -async def read_root( - dependency: typing.Annotated[ - container.DependentFactory, - fastapi.Depends(container.DIContainer.dependent_factory), - ], - free_dependency: typing.Annotated[ - container.FreeFactory, - fastapi.Depends(container.DIContainer.resolver(container.FreeFactory)), - ], - singleton: typing.Annotated[ - container.SingletonFactory, - fastapi.Depends(container.DIContainer.singleton), - ], - singleton_attribute: typing.Annotated[bool, fastapi.Depends(container.DIContainer.singleton.dep1)], -) -> datetime.datetime: - assert dependency.sync_resource == free_dependency.dependent_factory.sync_resource - assert dependency.async_resource == free_dependency.dependent_factory.async_resource - assert singleton.dep1 is True - assert singleton_attribute is True - return dependency.async_resource - - -client = TestClient(app) - - -async def test_read_main() -> None: - response = client.get("/") +_GLOBAL_CONTEXT: typing.Final[dict[str, str]] = {"test2": "value2", "test1": "value1"} + + +@pytest.fixture(params=[None, container.DIContainer]) +def fastapi_app(request: pytest.FixtureRequest) -> fastapi.FastAPI: + app = fastapi.FastAPI() + if request.param: + app.add_middleware(DIContextMiddleware, request.param, global_context=_GLOBAL_CONTEXT) + else: + app.add_middleware(DIContextMiddleware, global_context=_GLOBAL_CONTEXT) + + @app.get("/") + async def read_root( + dependency: typing.Annotated[ + container.DependentFactory, + fastapi.Depends(container.DIContainer.dependent_factory), + ], + free_dependency: typing.Annotated[ + container.FreeFactory, + fastapi.Depends(container.DIContainer.resolver(container.FreeFactory)), + ], + singleton: typing.Annotated[ + container.SingletonFactory, + fastapi.Depends(container.DIContainer.singleton), + ], + singleton_attribute: typing.Annotated[bool, fastapi.Depends(container.DIContainer.singleton.dep1)], + ) -> datetime.datetime: + assert dependency.sync_resource == free_dependency.dependent_factory.sync_resource + assert dependency.async_resource == free_dependency.dependent_factory.async_resource + assert singleton.dep1 is True + assert singleton_attribute is True + for key, value in _GLOBAL_CONTEXT.items(): + assert fetch_context_item(key) == value + return dependency.async_resource + + return app + + +@pytest.fixture +def fastapi_client(fastapi_app: fastapi.FastAPI) -> TestClient: + return TestClient(fastapi_app) + + +async def test_read_main(fastapi_client: TestClient) -> None: + response = fastapi_client.get("/") assert response.status_code == status.HTTP_200_OK assert ( datetime.datetime.fromisoformat(response.json().replace("Z", "+00:00")) diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 00a9dc19..e6ee9464 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -324,9 +324,22 @@ def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: class DIContextMiddleware: - def __init__(self, app: ASGIApp) -> None: + def __init__( + self, + *context_items: SupportsContext[typing.Any], + app: ASGIApp, + global_context: dict[str, typing.Any] | None = None, + ) -> None: self.app: typing.Final = app + self._context_items: set[SupportsContext[typing.Any]] = set(context_items) + self._global_context: dict[str, typing.Any] | None = global_context - @container_context() async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - return await self.app(scope, receive, send) + if self._context_items: + pass + async with ( + container_context(*self._context_items, global_context=self._global_context) + if self._context_items + else container_context(global_context=self._global_context) + ): + return await self.app(scope, receive, send) From c78878d8947646530046875a80e1099942045d47 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Jan 2025 17:01:29 +0100 Subject: [PATCH 15/24] Added migrations guide. --- docs/index.md | 10 +++ docs/migration/v2.md | 171 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 docs/migration/v2.md diff --git a/docs/index.md b/docs/index.md index c31d7fed..20fdf6d4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -34,10 +34,20 @@ testing/fixture testing/provider-overriding +.. toctree:: + :maxdepth: 1 + :caption: Migration + + migration/v2 + + .. toctree:: :maxdepth: 1 :caption: For developers dev/main-decisions dev/contributing + + + ``` diff --git a/docs/migration/v2.md b/docs/migration/v2.md new file mode 100644 index 00000000..08afe3e1 --- /dev/null +++ b/docs/migration/v2.md @@ -0,0 +1,171 @@ +from fastapi import FastAPI + +# Migrating from 1.* to 2.* + + +## How to read this guide: + +This guide is intended to help your migrate existing functionality from `that-depends` version `1.*` to `2.*`. +The goal of this guide is to allow you to migrate as fast as possible, making only the minimal necessary changes +to your codebase. + +If you want to know more about new features introduced in `2.*`, please refer to the [documentation](https://that-depends.readthedocs.io/en/latest/) and +[release notes](https://github.com/modern-python/that-depends/releases). + + + +## Deprecated features + + + +1. `BaseContainer.init_async_resources()` has been removed. Use `BaseContainer.init_resources()` instead. + +**Example:** + +If you are using containers, you likely have a similar setup to this: +```python +from that_depends import BaseContainer + +class MyContainer(BaseContainer): + # here you have defined your providers + ... +``` +Replace all instances of: +```python +await MyContainer.init_async_resources() +``` +With the following: +```python +await MyContainer.init_resources() +``` +2. `that_depends.providers.AsyncResource` has been removed. Use `providers.Resource` instead. + +**Example:** + +Replace all instances of: +```python +from that_depends.providers import AsyncResource +my_provider = providers.AsyncResource(some_async_function) +``` + +With the following: +```python +from that_depends.providers import Resource +my_provider = providers.Resource(some_async_function) +``` + +## Changes in the API + +1. `container_context()` now requires a keyword argument to set initial context. + +Previously, one could initialize a global context by passing a dict to the `container_context()` context manager: + + +```python +my_global_context = {"some_key": "some_value"} +async with container_context(my_global_context): + assert fetch_context_item("some_key") == "some_value" +``` + +In `2.*` instead use the `global_context` keyword argument: + +```python +my_global_context = {"some_key": "some_value"} +async with container_context(global_context=my_global_context): + assert fetch_context_item("some_key") == "some_value" +``` + +2. `container_context(global_context=my_global_context)` no longer resets context for all resources. + +Previously, when calling `container_context(my_global_context)` two things would happen: +- The global context would be set to `my_global_context`, so values from this context could be resolved using `fetch_context_item()`. This behaviour remains the same. +- A new context would be initialized for all `providers.ContextResource` instances (wherever they were defined). This behaviour no longer applies. + +In `2.*` if you wish to both set a `global_context` and reset context for all resources, you also need to set `reset_all_containers=True`: + + +```python +async with container_context(global_context=my_global_context, reset_all_containers=True): + assert fetch_context_item("some_key") == "some_value" +``` + +> Notice that `reset_all_containers=True` re-initializes the context for all `providers.ContextResource` instances defined with containers (classes that inherit from `that_depends.BaseContainer`). Unlike previously, where it would also +> reset context for `ContextResource` providers defined anywhere. If you wish to reproduce this behaviour, you will need to explicitly re-initialize the context for +> providers outside containers. For more details, please read the documentation and the [potential issues section](#potential-issues-when-using-container_context). + +For further details on handling context with `2.*` please refer to the [ContextResource documentation](../providers/context-resources.md). + +## Potential issues when using `container_context()` + +> Please make sure that you have consulted the previous sections of this page before reading this. + +If you have migrated all the previous functionality and are still experiencing issues with managing context +resources, this might be due to the fact that entering `container_context()` does not correctly initialize context for your resources. + +Here is an example of how you might have used `container_context()` in `1.*`: + +```python +from that_depends import container_context + +async def some_async_function(): + # you enter a new context but because `MyContainer` is imported later, + # resources in that container are not re-initialized + async with container_context(): + # now you import your container + from some_other_module import MyContainer + # and attempt to resolve a resource `providers`.ContextResource` + my_resource = await MyContainer.my_context_resource.async_resolve() # ❌ Error! +``` + +In situations like this one and perhaps other similar situations where your container is outside the current scope when +entering `container_context()`, you might have to handle things differently in order to get your resources initialized. + +Here are some potential suggestions: + +- Pass explicit arguments to the `DIContextMiddleware` + +If you are using `DIContextMiddleware` with your ASGI-application, this now accepts additional arguments. + +**Example with `fastapi`:** + +```python +import fastapi +from that_depends.providers import DIContextMiddleware, ContextResource +from that_depends import BaseContainer + +MyContainer: BaseContainer +my_context_resource_provider: ContextResource +my_app: fastapi.FastAPI + +my_app.add_middleware(DIContextMiddleware, MyContainer, my_context_resource_provider) +``` + +The middleware will automatically initialize context for the provided resources when an endpoint is called. + + +- Avoid entering `container_context()` with no arguments passed. + +You can now pass all resources that support context initialization (i.e. `providers.ContextResource` instances and `BaseContainer` subclasses) to `container_context()` explicitly. + +**Example:** + +```python +from that_depends import container_context + +MyContainer: BaseContainer +my_context_resource_provider: ContextResource + +with container_context(MyContainer, my_context_resource_provider): + # now you can resolve resources from `MyContainer` and `my_context_resource_provider` + my_container_instance = MyContainer.my_context_resource.sync_resolve() + my_provider_instance = my_context_resource_provider.sync_resolve() +``` + +In general, it is recommended to explicitly initialize container context, since this can prevent unexpected behaviour +and speed up your code. + + + +## Further help + +If you are still having issues with migration you can either create a [discussion](https://github.com/modern-python/that-depends/discussions) or open an [issue](https://github.com/modern-python/that-depends/issues). From 08c42bfd90d3016280f9a648dbf376abe38ee3dd Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Jan 2025 17:10:46 +0100 Subject: [PATCH 16/24] Improved migration guide. --- docs/migration/v2.md | 218 ++++++++++++++++++++----------------------- 1 file changed, 100 insertions(+), 118 deletions(-) diff --git a/docs/migration/v2.md b/docs/migration/v2.md index 08afe3e1..877f1292 100644 --- a/docs/migration/v2.md +++ b/docs/migration/v2.md @@ -1,171 +1,153 @@ -from fastapi import FastAPI -# Migrating from 1.* to 2.* -## How to read this guide: +# Migrating from 1.* to 2.* -This guide is intended to help your migrate existing functionality from `that-depends` version `1.*` to `2.*`. -The goal of this guide is to allow you to migrate as fast as possible, making only the minimal necessary changes -to your codebase. +## How to Read This Guide -If you want to know more about new features introduced in `2.*`, please refer to the [documentation](https://that-depends.readthedocs.io/en/latest/) and -[release notes](https://github.com/modern-python/that-depends/releases). +This guide is intended to help you migrate existing functionality from `that-depends` version `1.*` to `2.*`. +The goal is to enable you to migrate as quickly as possible while making only the minimal necessary changes to your codebase. +If you want to learn more about the new features introduced in `2.*`, please refer to the [documentation](https://that-depends.readthedocs.io/) and the [release notes](https://github.com/modern-python/that-depends/releases). +--- -## Deprecated features +## Deprecated Features +1. **`BaseContainer.init_async_resources()` removed** + The method `BaseContainer.init_async_resources()` has been removed. Use `BaseContainer.init_resources()` instead. + **Example:** + If you are using containers, your setup might look like this: -1. `BaseContainer.init_async_resources()` has been removed. Use `BaseContainer.init_resources()` instead. + ```python + from that_depends import BaseContainer -**Example:** + class MyContainer(BaseContainer): + # Define your providers here + ... + ``` + Replace all instances of: + ```python + await MyContainer.init_async_resources() + ``` + With: + ```python + await MyContainer.init_resources() + ``` -If you are using containers, you likely have a similar setup to this: -```python -from that_depends import BaseContainer +2. **`that_depends.providers.AsyncResource` removed** + The `AsyncResource` class has been removed. Use `providers.Resource` instead. -class MyContainer(BaseContainer): - # here you have defined your providers - ... -``` -Replace all instances of: -```python -await MyContainer.init_async_resources() -``` -With the following: -```python -await MyContainer.init_resources() -``` -2. `that_depends.providers.AsyncResource` has been removed. Use `providers.Resource` instead. + **Example:** + Replace all instances of: + ```python + from that_depends.providers import AsyncResource + my_provider = providers.AsyncResource(some_async_function) + ``` + With: + ```python + from that_depends.providers import Resource + my_provider = providers.Resource(some_async_function) + ``` -**Example:** - -Replace all instances of: -```python -from that_depends.providers import AsyncResource -my_provider = providers.AsyncResource(some_async_function) -``` - -With the following: -```python -from that_depends.providers import Resource -my_provider = providers.Resource(some_async_function) -``` +--- ## Changes in the API -1. `container_context()` now requires a keyword argument to set initial context. - -Previously, one could initialize a global context by passing a dict to the `container_context()` context manager: - - -```python -my_global_context = {"some_key": "some_value"} -async with container_context(my_global_context): - assert fetch_context_item("some_key") == "some_value" -``` +1. **`container_context()` now requires a keyword argument for initial Context** + Previously, a global context could be initialized by passing a dictionary to the `container_context()` context manager: -In `2.*` instead use the `global_context` keyword argument: + ```python + my_global_context = {"some_key": "some_value"} + async with container_context(my_global_context): + assert fetch_context_item("some_key") == "some_value" + ``` -```python -my_global_context = {"some_key": "some_value"} -async with container_context(global_context=my_global_context): - assert fetch_context_item("some_key") == "some_value" -``` + In `2.*`, use the `global_context` keyword argument instead: -2. `container_context(global_context=my_global_context)` no longer resets context for all resources. + ```python + my_global_context = {"some_key": "some_value"} + async with container_context(global_context=my_global_context): + assert fetch_context_item("some_key") == "some_value" + ``` -Previously, when calling `container_context(my_global_context)` two things would happen: -- The global context would be set to `my_global_context`, so values from this context could be resolved using `fetch_context_item()`. This behaviour remains the same. -- A new context would be initialized for all `providers.ContextResource` instances (wherever they were defined). This behaviour no longer applies. +2. **Context reset behavior changed in `container_context()`** + Previously, calling `container_context(my_global_context)` would: + - Set the global context to `my_global_context`, allowing values to be resolved using `fetch_context_item()`. This behavior remains the same. + - Reset the context for all `providers.ContextResource` instances globally. This behavior has changed. -In `2.*` if you wish to both set a `global_context` and reset context for all resources, you also need to set `reset_all_containers=True`: + In `2.*`, if you want to reset the context for all resources in addition to setting a global context, you need to use the `reset_all_containers=True` argument: + ```python + async with container_context(global_context=my_global_context, reset_all_containers=True): + assert fetch_context_item("some_key") == "some_value" + ``` -```python -async with container_context(global_context=my_global_context, reset_all_containers=True): - assert fetch_context_item("some_key") == "some_value" -``` - -> Notice that `reset_all_containers=True` re-initializes the context for all `providers.ContextResource` instances defined with containers (classes that inherit from `that_depends.BaseContainer`). Unlike previously, where it would also -> reset context for `ContextResource` providers defined anywhere. If you wish to reproduce this behaviour, you will need to explicitly re-initialize the context for -> providers outside containers. For more details, please read the documentation and the [potential issues section](#potential-issues-when-using-container_context). + > **Note:** `reset_all_containers=True` only reinitializes the context for `ContextResource` instances defined within containers (i.e., classes inheriting from `BaseContainer`). If you also need to reset contexts for resources defined outside containers, you must handle these explicitly. See the [ContextResource documentation](../providers/context-resources.md) for more details. -For further details on handling context with `2.*` please refer to the [ContextResource documentation](../providers/context-resources.md). +--- -## Potential issues when using `container_context()` +## Potential Issues with `container_context()` -> Please make sure that you have consulted the previous sections of this page before reading this. +If you have migrated the functionality as described above but still experience issues managing context resources, it might be due to improperly initializing resources when entering `container_context()`. -If you have migrated all the previous functionality and are still experiencing issues with managing context -resources, this might be due to the fact that entering `container_context()` does not correctly initialize context for your resources. - -Here is an example of how you might have used `container_context()` in `1.*`: +Here’s an example of an incompatibility with `1.*`: ```python from that_depends import container_context async def some_async_function(): - # you enter a new context but because `MyContainer` is imported later, - # resources in that container are not re-initialized - async with container_context(): - # now you import your container + # Enter a new context but import `MyContainer` later + async with container_context(): from some_other_module import MyContainer - # and attempt to resolve a resource `providers`.ContextResource` - my_resource = await MyContainer.my_context_resource.async_resolve() # ❌ Error! + # Attempt to resolve a `ContextResource` resource + my_resource = await MyContainer.my_context_resource.async_resolve() # ❌ Error! ``` -In situations like this one and perhaps other similar situations where your container is outside the current scope when -entering `container_context()`, you might have to handle things differently in order to get your resources initialized. +To resolve such issues in `2.*`, consider the following suggestions: -Here are some potential suggestions: +1. **Pass explicit arguments to `DIContextMiddleware`** + If you are using `DIContextMiddleware` with your ASGI application, you can now pass additional arguments. -- Pass explicit arguments to the `DIContextMiddleware` + **Example with `FastAPI`:** -If you are using `DIContextMiddleware` with your ASGI-application, this now accepts additional arguments. + ```python + import fastapi + from that_depends.providers import DIContextMiddleware, ContextResource + from that_depends import BaseContainer -**Example with `fastapi`:** + MyContainer: BaseContainer + my_context_resource_provider: ContextResource + my_app: fastapi.FastAPI -```python -import fastapi -from that_depends.providers import DIContextMiddleware, ContextResource -from that_depends import BaseContainer + my_app.add_middleware(DIContextMiddleware, MyContainer, my_context_resource_provider) + ``` -MyContainer: BaseContainer -my_context_resource_provider: ContextResource -my_app: fastapi.FastAPI + This middleware will automatically initialize the context for the provided resources when an endpoint is called. -my_app.add_middleware(DIContextMiddleware, MyContainer, my_context_resource_provider) -``` +2. **Avoid entering `container_context()` without arguments** + Pass all resources supporting context initialization (e.g., `providers.ContextResource` instances and `BaseContainer` subclasses) explicitly. -The middleware will automatically initialize context for the provided resources when an endpoint is called. + **Example:** + ```python + from that_depends import container_context -- Avoid entering `container_context()` with no arguments passed. + MyContainer: BaseContainer + my_context_resource_provider: ContextResource -You can now pass all resources that support context initialization (i.e. `providers.ContextResource` instances and `BaseContainer` subclasses) to `container_context()` explicitly. + async with container_context(MyContainer, my_context_resource_provider): + # Resolve resources + my_container_instance = MyContainer.my_context_resource.sync_resolve() + my_provider_instance = my_context_resource_provider.sync_resolve() + ``` -**Example:** + Explicit initialization of container context is recommended to prevent unexpected behavior and improve performance. -```python -from that_depends import container_context +--- -MyContainer: BaseContainer -my_context_resource_provider: ContextResource +## Further Help -with container_context(MyContainer, my_context_resource_provider): - # now you can resolve resources from `MyContainer` and `my_context_resource_provider` - my_container_instance = MyContainer.my_context_resource.sync_resolve() - my_provider_instance = my_context_resource_provider.sync_resolve() +If you continue to experience issues during migration, consider creating a [discussion](https://github.com/modern-python/that-depends/discussions) or opening an [issue](https://github.com/modern-python/that-depends/issues). ``` - -In general, it is recommended to explicitly initialize container context, since this can prevent unexpected behaviour -and speed up your code. - - - -## Further help - -If you are still having issues with migration you can either create a [discussion](https://github.com/modern-python/that-depends/discussions) or open an [issue](https://github.com/modern-python/that-depends/issues). From 3db834530c2a0eda1f3f0b6b7b5e9eea47f61c0c Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Jan 2025 17:17:28 +0100 Subject: [PATCH 17/24] Fixed argument order in middleware. --- that_depends/providers/context_resources.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index e6ee9464..dcb03efd 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -326,8 +326,8 @@ def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: class DIContextMiddleware: def __init__( self, - *context_items: SupportsContext[typing.Any], app: ASGIApp, + *context_items: SupportsContext[typing.Any], global_context: dict[str, typing.Any] | None = None, ) -> None: self.app: typing.Final = app From 8701417eb61d88c84384954f0d55924ad4f75294 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Jan 2025 17:24:43 +0100 Subject: [PATCH 18/24] Made middleware tests actually test context resources. --- tests/container.py | 1 + tests/integrations/fastapi/test_fastapi_di.py | 7 ++++++- that_depends/providers/context_resources.py | 4 +++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/container.py b/tests/container.py index 69b1277e..a2f53507 100644 --- a/tests/container.py +++ b/tests/container.py @@ -67,3 +67,4 @@ class DIContainer(BaseContainer): ) singleton = providers.Singleton(SingletonFactory, dep1=True) object = providers.Object(object()) + context_resource = providers.ContextResource(create_async_resource) diff --git a/tests/integrations/fastapi/test_fastapi_di.py b/tests/integrations/fastapi/test_fastapi_di.py index 2e3dcc30..08c78961 100644 --- a/tests/integrations/fastapi/test_fastapi_di.py +++ b/tests/integrations/fastapi/test_fastapi_di.py @@ -20,7 +20,10 @@ def fastapi_app(request: pytest.FixtureRequest) -> fastapi.FastAPI: if request.param: app.add_middleware(DIContextMiddleware, request.param, global_context=_GLOBAL_CONTEXT) else: - app.add_middleware(DIContextMiddleware, global_context=_GLOBAL_CONTEXT) + app.add_middleware( + DIContextMiddleware, + global_context=_GLOBAL_CONTEXT, + ) @app.get("/") async def read_root( @@ -37,11 +40,13 @@ async def read_root( fastapi.Depends(container.DIContainer.singleton), ], singleton_attribute: typing.Annotated[bool, fastapi.Depends(container.DIContainer.singleton.dep1)], + context_resource: typing.Annotated[datetime.datetime, fastapi.Depends(container.DIContainer.context_resource)], ) -> datetime.datetime: assert dependency.sync_resource == free_dependency.dependent_factory.sync_resource assert dependency.async_resource == free_dependency.dependent_factory.async_resource assert singleton.dep1 is True assert singleton_attribute is True + assert context_resource == await container.DIContainer.context_resource.async_resolve() for key, value in _GLOBAL_CONTEXT.items(): assert fetch_context_item(key) == value return dependency.async_resource diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index dcb03efd..7d479039 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -329,10 +329,12 @@ def __init__( app: ASGIApp, *context_items: SupportsContext[typing.Any], global_context: dict[str, typing.Any] | None = None, + reset_all_containers: bool = True, ) -> None: self.app: typing.Final = app self._context_items: set[SupportsContext[typing.Any]] = set(context_items) self._global_context: dict[str, typing.Any] | None = global_context + self._reset_all_containers: bool = reset_all_containers async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self._context_items: @@ -340,6 +342,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with ( container_context(*self._context_items, global_context=self._global_context) if self._context_items - else container_context(global_context=self._global_context) + else container_context(global_context=self._global_context, reset_all_containers=self._reset_all_containers) ): return await self.app(scope, receive, send) From 935ca9128396c05e27e43ba0e7bb4cc648570f79 Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 9 Jan 2025 21:42:16 +0100 Subject: [PATCH 19/24] Rewrote introductions to context-resources.md --- docs/providers/context-resources.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/providers/context-resources.md b/docs/providers/context-resources.md index bac0b8b5..d16bb98f 100644 --- a/docs/providers/context-resources.md +++ b/docs/providers/context-resources.md @@ -1,4 +1,20 @@ -# ContextResource +# Context Dependent Resources + +`that_depends` provides a way to manage two types of contexts: + +- a *global context* dict where you can store objects to retrieve later. +- *resource specific contexts* which are managed by the `ContextResource` provider. + + +To interact with both types of contexts there are two separate interfaces: + +1. Use the `container_context()` context manager to interact with the global context and manage `ContextResource` providers. +2. Directly manage a `ContextResource` context by using use `SupportsContext` interface, which both containers +and `ContextResource` providers implement. + +## Global Context + + Instances injected with the `ContextResource` provider have a managed lifecycle. ```python From 26c66b5e73f6be1d11febf09925c35c2195a86a4 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 10 Jan 2025 13:18:42 +0100 Subject: [PATCH 20/24] Wrote the global context section. --- docs/providers/context-resources.md | 75 +++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/docs/providers/context-resources.md b/docs/providers/context-resources.md index d16bb98f..3ffc298f 100644 --- a/docs/providers/context-resources.md +++ b/docs/providers/context-resources.md @@ -1,3 +1,5 @@ +from that_depends import fetch_context_itemfrom that_depends import fetch_context_itemfrom that_depends import Providefrom that_depends import inject + # Context Dependent Resources `that_depends` provides a way to manage two types of contexts: @@ -12,15 +14,16 @@ To interact with both types of contexts there are two separate interfaces: 2. Directly manage a `ContextResource` context by using use `SupportsContext` interface, which both containers and `ContextResource` providers implement. -## Global Context +## Quick Start -Instances injected with the `ContextResource` provider have a managed lifecycle. +You have to initialize a context before being able to resolve a `ContextResource`. +**Setup:** ```python import typing -from that_depends import BaseContainer, providers +from that_depends import BaseContainer, providers, inject, Provide async def my_async_resource() -> typing.AsyncIterator[str]: @@ -44,6 +47,72 @@ class MyContainer(BaseContainer): sync_resource = providers.ContextResource(my_sync_resource) ``` +Then you can resolve the resource by initializing its context: +```python +@MyContainer.async_resource.context +@inject +async func(dep: str = Provide[MyContainer.async_resource]): + return dep + +await func() # returns `async resource` +``` +This will initialize a new context for `async_resource` on each call of our `func`. + +## Global Context + +A global context can be initialized by using `container_context` context manager. + +```python +from that_depends import container_context, fetch_context_item + +async with container_context(global_context={"key": "value"}): + # run some code + fetch_context_item("key") # returns 'value' +``` + +You can also use `container_context` as a decorator: +```python +@container_context(global_context={"key": "value"}) +async def func(): + # run some code + fetch_context_item("key") +``` + +The values stored in the `global_context` can be resolved as long as: +1. You are still within the scope of the context manager. +2. You have not initialized a new context: +```python +async with container_context(global_context={"key": "value"}): + # run some code + fetch_context_item("key") + async with container_context(): # this will reset all contexts including the global context. + fetch_context_item("key") # Error! key not found +``` +If you want to maintain the global context, you can initialize a new context with the `preserve_global_context` argument: +```python + +async with container_context(global_context={"key": "value"}): + # run some code + fetch_context_item("key") + async with container_context(preserve_global_context=True): # this will preserve the global context. + fetch_context_item("key") # returns 'value` +``` +Additionally, you can use the `global_context` arguemnt in combination with `preserve_global_context` to +extend the global context, this will merge the two contexts together by key with the new global_context taking precedence: +```python +async with container_context(global_context={"key_1": "value_1", "key_2": "value_2"}): + # run some code + fetch_context_item("key_1") # returns `value_1` + async with container_context( + global_context={"key_2": "new_value", "key_3": "value_3"}, + preserve_global_context=True): # this will preserve the global context. + + fetch_context_item("key_1") # returns 'value_1` + fetch_context_item("key_2") # returns 'new_value` + fetch_context_item("key_3") # returns 'value_3` +``` +## Context Resources + To be able to resolve `ContextResource` one must first enter `container_context`: ```python async with container_context(): From 0de03622b465df09cbf7ef8133079e8460b9fc08 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 10 Jan 2025 20:31:42 +0100 Subject: [PATCH 21/24] Finished the context-resources documentation. --- docs/providers/context-resources.md | 67 ++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/docs/providers/context-resources.md b/docs/providers/context-resources.md index 3ffc298f..c92d3e96 100644 --- a/docs/providers/context-resources.md +++ b/docs/providers/context-resources.md @@ -1,4 +1,3 @@ -from that_depends import fetch_context_itemfrom that_depends import fetch_context_itemfrom that_depends import Providefrom that_depends import inject # Context Dependent Resources @@ -33,7 +32,6 @@ async def my_async_resource() -> typing.AsyncIterator[str]: finally: print("Teardown of async resource") - def my_sync_resource() -> typing.Iterator[str]: print("Initializing sync resource") try: @@ -41,7 +39,6 @@ def my_sync_resource() -> typing.Iterator[str]: finally: print("Teardown of sync resource") - class MyContainer(BaseContainer): async_resource = providers.ContextResource(my_async_resource) sync_resource = providers.ContextResource(my_sync_resource) @@ -113,9 +110,10 @@ async with container_context(global_context={"key_1": "value_1", "key_2": "value ``` ## Context Resources -To be able to resolve `ContextResource` one must first enter `container_context`: +To be able to resolve `ContextResource` one must first enter intialize a new context for that resource. +The most simple way to achieve this is by entering `container_context()` without passing any arguments: ```python -async with container_context(): +async with container_context(): # this will make all containers initialize a new context. await MyContainer.async_resource.async_resolve() # "async resource" MyContainer.sync_resource.sync_resolve() # "sync resource" ``` @@ -151,8 +149,24 @@ async def my_func(): > RuntimeError: AsyncResource cannot be resolved in an sync context. ``` +### More granular context initialization + +If you do not wish to simply re-initialize the context for all containers you can either initialize a context for a container: +```python +# this will init a new context for all ContextResources in the container and connected containers. +async with container_context(MyContainer): ... +``` +Or a specific resource: +```python +# this will init a new context for the specific resource. +async with container_context(MyContainer.async_resource): ... +``` + +One does not need to use `container_context()` to achieve this, and instead can use the `SupportsContext` interface described +[here](#quick-reference) + + ### Context Hierarchy -Each time you enter `container_context` a new context is created in the background. Resources are cached in the context after first resolution. Resources created in a context are torn down again when `container_context` exits. ```python @@ -169,9 +183,48 @@ async with container_context(): ### Resolving resources whenever function is called `container_context` can be used as decorator: ```python -@container_context() +@MyContainer.session.context # wrap with session specific context. @inject async def insert_into_database(session = Provide[MyContainer.session]): ... ``` Each time ``await insert_into_database()`` is called new instance of ``session`` will be injected. + + +### Quick reference + +| Intention | Using `container_context()` | Using `SupportsContext` explicit | Using `SupportsContext` decorator | +|-------------------------------------------------------|-----------------------------------------------|--------------------------------------------|-----------------------------------| +| Reset context for all containers in scope | `async with container_context():` | Not supported. | Not supported. | +| Reset only sync contexts for all containers in scope. | `with container_context():` | Not supported. | Not supported. | +| Reset a `provider.ContextResource` context. | `async with container_context(my_provider):` | `async with my_provider.async_context():` | `@my_provider.context` | +| Reset a sync `provider.ContextResource` context. | `with container_context(my_provider):` | `with my_provider.sync_context():` | `@my_provider.context` | +| Reset all resources in a container. | `async with container_context(my_container):` | `async with my_container.async_context():` | `@my_container.context` | +| Reset all sync resources in a container. | `with container_context(my_container):` | `with my_container.sync_context():` | `@my_container.context` | + + +## Middleware + +For `ASGI` applications, `that_depends` provides the `DIContextMiddleware` to manage context resources. + +The `DIContextMiddleware` accepts containers and resources as arguments and will automatically initialize the context for the provided resources when an endpoint is called. + +**Example with `FastAPI`:** + +```python +import fastapi +from that_depends.providers import DIContextMiddleware, ContextResource +from that_depends import BaseContainer + +MyContainer: BaseContainer +my_context_resource_provider: ContextResource +my_app: fastapi.FastAPI + +# this will initialize the context for `my_context_resource_provider` and `MyContainer` when an endpoint is called. +my_app.add_middleware(DIContextMiddleware, MyContainer, my_context_resource_provider) + +# this will initialize the context for all containers when an endpoint is called. +my_app.add_middleware(DIContextMiddleware) +``` + +> `DIContextMiddleware` also supports the `global_context` and `preserve_global_context` arguments. From a195079e6786c2a6b602c6f64df36930a8aa0d39 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 10 Jan 2025 20:34:43 +0100 Subject: [PATCH 22/24] Minor documentation improvements. --- docs/providers/context-resources.md | 148 +++++++++++++++------------- 1 file changed, 77 insertions(+), 71 deletions(-) diff --git a/docs/providers/context-resources.md b/docs/providers/context-resources.md index c92d3e96..3c29a347 100644 --- a/docs/providers/context-resources.md +++ b/docs/providers/context-resources.md @@ -1,22 +1,21 @@ -# Context Dependent Resources +# Context-Dependent Resources `that_depends` provides a way to manage two types of contexts: -- a *global context* dict where you can store objects to retrieve later. -- *resource specific contexts* which are managed by the `ContextResource` provider. +- A *global context* (a dictionary) where you can store objects for later retrieval. +- *Resource-specific contexts*, which are managed by the `ContextResource` provider. - -To interact with both types of contexts there are two separate interfaces: +To interact with both types of contexts, there are two separate interfaces: 1. Use the `container_context()` context manager to interact with the global context and manage `ContextResource` providers. -2. Directly manage a `ContextResource` context by using use `SupportsContext` interface, which both containers -and `ContextResource` providers implement. - +2. Directly manage a `ContextResource` context by using the `SupportsContext` interface, which both containers + and `ContextResource` providers implement. +--- ## Quick Start -You have to initialize a context before being able to resolve a `ContextResource`. +You must initialize a context before you can resolve a `ContextResource`. **Setup:** ```python @@ -44,27 +43,28 @@ class MyContainer(BaseContainer): sync_resource = providers.ContextResource(my_sync_resource) ``` -Then you can resolve the resource by initializing its context: +Then, you can resolve the resource by initializing its context: ```python @MyContainer.async_resource.context @inject -async func(dep: str = Provide[MyContainer.async_resource]): +async def func(dep: str = Provide[MyContainer.async_resource]): return dep - -await func() # returns `async resource` + +await func() # returns "async resource" ``` -This will initialize a new context for `async_resource` on each call of our `func`. +This will initialize a new context for `async_resource` each time `func` is called. +--- ## Global Context -A global context can be initialized by using `container_context` context manager. +A global context can be initialized by using the `container_context` context manager. ```python from that_depends import container_context, fetch_context_item async with container_context(global_context={"key": "value"}): # run some code - fetch_context_item("key") # returns 'value' + fetch_context_item("key") # returns 'value' ``` You can also use `container_context` as a decorator: @@ -82,57 +82,61 @@ The values stored in the `global_context` can be resolved as long as: async with container_context(global_context={"key": "value"}): # run some code fetch_context_item("key") - async with container_context(): # this will reset all contexts including the global context. - fetch_context_item("key") # Error! key not found + async with container_context(): # this will reset all contexts, including the global context. + fetch_context_item("key") # Error! key not found ``` + If you want to maintain the global context, you can initialize a new context with the `preserve_global_context` argument: ```python - async with container_context(global_context={"key": "value"}): # run some code fetch_context_item("key") - async with container_context(preserve_global_context=True): # this will preserve the global context. - fetch_context_item("key") # returns 'value` + async with container_context(preserve_global_context=True): # preserves the global context + fetch_context_item("key") # returns 'value' ``` -Additionally, you can use the `global_context` arguemnt in combination with `preserve_global_context` to -extend the global context, this will merge the two contexts together by key with the new global_context taking precedence: + +Additionally, you can use the `global_context` argument in combination with `preserve_global_context` to +extend the global context. This merges the two contexts together by key, with the new `global_context` taking precedence: ```python async with container_context(global_context={"key_1": "value_1", "key_2": "value_2"}): # run some code - fetch_context_item("key_1") # returns `value_1` + fetch_context_item("key_1") # returns 'value_1' async with container_context( - global_context={"key_2": "new_value", "key_3": "value_3"}, - preserve_global_context=True): # this will preserve the global context. - - fetch_context_item("key_1") # returns 'value_1` - fetch_context_item("key_2") # returns 'new_value` - fetch_context_item("key_3") # returns 'value_3` + global_context={"key_2": "new_value", "key_3": "value_3"}, + preserve_global_context=True + ): + fetch_context_item("key_1") # returns 'value_1' + fetch_context_item("key_2") # returns 'new_value' + fetch_context_item("key_3") # returns 'value_3' ``` + +--- + ## Context Resources -To be able to resolve `ContextResource` one must first enter intialize a new context for that resource. -The most simple way to achieve this is by entering `container_context()` without passing any arguments: +To resolve a `ContextResource`, you must first initialize a new context for that resource. The simplest way to do this is by entering `container_context()` without passing any arguments: ```python -async with container_context(): # this will make all containers initialize a new context. - await MyContainer.async_resource.async_resolve() # "async resource" - MyContainer.sync_resource.sync_resolve() # "sync resource" +async with container_context(): # this will make all containers initialize a new context + await MyContainer.async_resource.async_resolve() # "async resource" + MyContainer.sync_resource.sync_resolve() # "sync resource" ``` - Trying to resolve `ContextResource` without first entering `container_context` will yield `RuntimeError`: +Trying to resolve a `ContextResource` without first entering `container_context` will yield a `RuntimeError`: ```python value = MyContainer.sync_resource.sync_resolve() > RuntimeError: Context is not set. Use container_context ``` -### Resolving async and sync dependencies: -``container_context`` implements both ``AsyncContextManager`` and ``ContextManager``. -This means that you can enter an async context with: +### Resolving async and sync dependencies + +``container_context`` implements both ``AsyncContextManager`` and ``ContextManager``. +This means you can enter an async context with: ```python async with container_context(): ... ``` -An async context will allow resolution of both sync and async dependencies. +An async context allows resolution of both sync and async dependencies. A sync context can be entered using: ```python @@ -142,75 +146,77 @@ with container_context(): A sync context will only allow resolution of sync dependencies: ```python async def my_func(): - with container_context(): # enter sync context - # try to resolve async dependency. + with container_context(): # enter sync context + # trying to resolve async dependency await MyContainer.async_resource.async_resolve() -> RuntimeError: AsyncResource cannot be resolved in an sync context. +> RuntimeError: AsyncResource cannot be resolved in a sync context. ``` ### More granular context initialization -If you do not wish to simply re-initialize the context for all containers you can either initialize a context for a container: +If you do not wish to simply reinitialize the context for all containers, you can initialize a context for a specific container: ```python -# this will init a new context for all ContextResources in the container and connected containers. -async with container_context(MyContainer): ... +# this will init a new context for all ContextResources in MyContainer and any connected containers. +async with container_context(MyContainer): + ... ``` -Or a specific resource: +Or for a specific resource: ```python -# this will init a new context for the specific resource. -async with container_context(MyContainer.async_resource): ... +# this will init a new context for the specific resource only. +async with container_context(MyContainer.async_resource): + ... ``` -One does not need to use `container_context()` to achieve this, and instead can use the `SupportsContext` interface described -[here](#quick-reference) - +It is not necessary to use `container_context()` to do this. Instead, you can use the `SupportsContext` interface described +[here](#quick-reference). ### Context Hierarchy -Resources are cached in the context after first resolution. -Resources created in a context are torn down again when `container_context` exits. + +Resources are cached in the context after their first resolution. +They are torn down when `container_context` exits: ```python async with container_context(): value_outer = await MyContainer.resource.async_resolve() async with container_context(): - # new context -> resource will be resolved a new + # new context -> resource will be resolved anew value_inner = await MyContainer.resource.async_resolve() assert value_inner != value_outer - # previously resolved value is cached in context. + # previously resolved value is cached in the outer context assert value_outer == await MyContainer.resource.async_resolve() ``` -### Resolving resources whenever function is called -`container_context` can be used as decorator: +### Resolving resources whenever a function is called + +`container_context` can be used as a decorator: ```python -@MyContainer.session.context # wrap with session specific context. +@MyContainer.session.context # wrap with a session-specific context @inject -async def insert_into_database(session = Provide[MyContainer.session]): +async def insert_into_database(session=Provide[MyContainer.session]): ... ``` -Each time ``await insert_into_database()`` is called new instance of ``session`` will be injected. - +Each time you call `await insert_into_database()`, a new instance of `session` will be injected. ### Quick reference -| Intention | Using `container_context()` | Using `SupportsContext` explicit | Using `SupportsContext` decorator | +| Intention | Using `container_context()` | Using `SupportsContext` explicitly | Using `SupportsContext` decorator | |-------------------------------------------------------|-----------------------------------------------|--------------------------------------------|-----------------------------------| | Reset context for all containers in scope | `async with container_context():` | Not supported. | Not supported. | | Reset only sync contexts for all containers in scope. | `with container_context():` | Not supported. | Not supported. | -| Reset a `provider.ContextResource` context. | `async with container_context(my_provider):` | `async with my_provider.async_context():` | `@my_provider.context` | -| Reset a sync `provider.ContextResource` context. | `with container_context(my_provider):` | `with my_provider.sync_context():` | `@my_provider.context` | -| Reset all resources in a container. | `async with container_context(my_container):` | `async with my_container.async_context():` | `@my_container.context` | -| Reset all sync resources in a container. | `with container_context(my_container):` | `with my_container.sync_context():` | `@my_container.context` | +| Reset a `provider.ContextResource` context | `async with container_context(my_provider):` | `async with my_provider.async_context():` | `@my_provider.context` | +| Reset a sync `provider.ContextResource` context | `with container_context(my_provider):` | `with my_provider.sync_context():` | `@my_provider.context` | +| Reset all resources in a container | `async with container_context(my_container):` | `async with my_container.async_context():` | `@my_container.context` | +| Reset all sync resources in a container | `with container_context(my_container):` | `with my_container.sync_context():` | `@my_container.context` | +--- ## Middleware For `ASGI` applications, `that_depends` provides the `DIContextMiddleware` to manage context resources. -The `DIContextMiddleware` accepts containers and resources as arguments and will automatically initialize the context for the provided resources when an endpoint is called. +The `DIContextMiddleware` accepts containers and resources as arguments and automatically initializes the context for the provided resources when an endpoint is called. **Example with `FastAPI`:** - ```python import fastapi from that_depends.providers import DIContextMiddleware, ContextResource @@ -220,10 +226,10 @@ MyContainer: BaseContainer my_context_resource_provider: ContextResource my_app: fastapi.FastAPI -# this will initialize the context for `my_context_resource_provider` and `MyContainer` when an endpoint is called. +# This will initialize the context for `my_context_resource_provider` and `MyContainer` whenever an endpoint is called. my_app.add_middleware(DIContextMiddleware, MyContainer, my_context_resource_provider) -# this will initialize the context for all containers when an endpoint is called. +# This will initialize the context for all containers when an endpoint is called. my_app.add_middleware(DIContextMiddleware) ``` From d542116a1ea6feb5517d64cc709a1ca7f15b2c88 Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 10 Jan 2025 20:37:45 +0100 Subject: [PATCH 23/24] Removed redundant comment. --- that_depends/providers/collection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/that_depends/providers/collection.py b/that_depends/providers/collection.py index cb9d4dde..c976a561 100644 --- a/that_depends/providers/collection.py +++ b/that_depends/providers/collection.py @@ -1,4 +1,4 @@ -import typing # noqa: A005 +import typing from that_depends.providers.base import AbstractProvider From bb96e23ae00d1e00c4cf1e859aad8f6dc82febfc Mon Sep 17 00:00:00 2001 From: Alexander Date: Fri, 10 Jan 2025 20:42:31 +0100 Subject: [PATCH 24/24] Removed white-space at the top of files. --- docs/migration/v2.md | 3 --- docs/providers/context-resources.md | 1 - 2 files changed, 4 deletions(-) diff --git a/docs/migration/v2.md b/docs/migration/v2.md index 877f1292..b90976a4 100644 --- a/docs/migration/v2.md +++ b/docs/migration/v2.md @@ -1,6 +1,3 @@ - - - # Migrating from 1.* to 2.* ## How to Read This Guide diff --git a/docs/providers/context-resources.md b/docs/providers/context-resources.md index 3c29a347..64ff900f 100644 --- a/docs/providers/context-resources.md +++ b/docs/providers/context-resources.md @@ -1,4 +1,3 @@ - # Context-Dependent Resources `that_depends` provides a way to manage two types of contexts: