diff --git a/docs/index.md b/docs/index.md index c31d7fed..20fdf6d4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -34,10 +34,20 @@ testing/fixture testing/provider-overriding +.. toctree:: + :maxdepth: 1 + :caption: Migration + + migration/v2 + + .. toctree:: :maxdepth: 1 :caption: For developers dev/main-decisions dev/contributing + + + ``` diff --git a/docs/migration/v2.md b/docs/migration/v2.md new file mode 100644 index 00000000..b90976a4 --- /dev/null +++ b/docs/migration/v2.md @@ -0,0 +1,150 @@ +# Migrating from 1.* to 2.* + +## How to Read This Guide + +This guide is intended to help you migrate existing functionality from `that-depends` version `1.*` to `2.*`. +The goal is to enable you to migrate as quickly as possible while making only the minimal necessary changes to your codebase. + +If you want to learn more about the new features introduced in `2.*`, please refer to the [documentation](https://that-depends.readthedocs.io/) and the [release notes](https://github.com/modern-python/that-depends/releases). + +--- + +## Deprecated Features + +1. **`BaseContainer.init_async_resources()` removed** + The method `BaseContainer.init_async_resources()` has been removed. Use `BaseContainer.init_resources()` instead. + + **Example:** + If you are using containers, your setup might look like this: + + ```python + from that_depends import BaseContainer + + class MyContainer(BaseContainer): + # Define your providers here + ... + ``` + Replace all instances of: + ```python + await MyContainer.init_async_resources() + ``` + With: + ```python + await MyContainer.init_resources() + ``` + +2. **`that_depends.providers.AsyncResource` removed** + The `AsyncResource` class has been removed. Use `providers.Resource` instead. + + **Example:** + Replace all instances of: + ```python + from that_depends.providers import AsyncResource + my_provider = providers.AsyncResource(some_async_function) + ``` + With: + ```python + from that_depends.providers import Resource + my_provider = providers.Resource(some_async_function) + ``` + +--- + +## Changes in the API + +1. **`container_context()` now requires a keyword argument for initial Context** + Previously, a global context could be initialized by passing a dictionary to the `container_context()` context manager: + + ```python + my_global_context = {"some_key": "some_value"} + async with container_context(my_global_context): + assert fetch_context_item("some_key") == "some_value" + ``` + + In `2.*`, use the `global_context` keyword argument instead: + + ```python + my_global_context = {"some_key": "some_value"} + async with container_context(global_context=my_global_context): + assert fetch_context_item("some_key") == "some_value" + ``` + +2. **Context reset behavior changed in `container_context()`** + Previously, calling `container_context(my_global_context)` would: + - Set the global context to `my_global_context`, allowing values to be resolved using `fetch_context_item()`. This behavior remains the same. + - Reset the context for all `providers.ContextResource` instances globally. This behavior has changed. + + In `2.*`, if you want to reset the context for all resources in addition to setting a global context, you need to use the `reset_all_containers=True` argument: + + ```python + async with container_context(global_context=my_global_context, reset_all_containers=True): + assert fetch_context_item("some_key") == "some_value" + ``` + + > **Note:** `reset_all_containers=True` only reinitializes the context for `ContextResource` instances defined within containers (i.e., classes inheriting from `BaseContainer`). If you also need to reset contexts for resources defined outside containers, you must handle these explicitly. See the [ContextResource documentation](../providers/context-resources.md) for more details. + +--- + +## Potential Issues with `container_context()` + +If you have migrated the functionality as described above but still experience issues managing context resources, it might be due to improperly initializing resources when entering `container_context()`. + +Here’s an example of an incompatibility with `1.*`: + +```python +from that_depends import container_context + +async def some_async_function(): + # Enter a new context but import `MyContainer` later + async with container_context(): + from some_other_module import MyContainer + # Attempt to resolve a `ContextResource` resource + my_resource = await MyContainer.my_context_resource.async_resolve() # ❌ Error! +``` + +To resolve such issues in `2.*`, consider the following suggestions: + +1. **Pass explicit arguments to `DIContextMiddleware`** + If you are using `DIContextMiddleware` with your ASGI application, you can now pass additional arguments. + + **Example with `FastAPI`:** + + ```python + import fastapi + from that_depends.providers import DIContextMiddleware, ContextResource + from that_depends import BaseContainer + + MyContainer: BaseContainer + my_context_resource_provider: ContextResource + my_app: fastapi.FastAPI + + my_app.add_middleware(DIContextMiddleware, MyContainer, my_context_resource_provider) + ``` + + This middleware will automatically initialize the context for the provided resources when an endpoint is called. + +2. **Avoid entering `container_context()` without arguments** + Pass all resources supporting context initialization (e.g., `providers.ContextResource` instances and `BaseContainer` subclasses) explicitly. + + **Example:** + + ```python + from that_depends import container_context + + MyContainer: BaseContainer + my_context_resource_provider: ContextResource + + async with container_context(MyContainer, my_context_resource_provider): + # Resolve resources + my_container_instance = MyContainer.my_context_resource.sync_resolve() + my_provider_instance = my_context_resource_provider.sync_resolve() + ``` + + Explicit initialization of container context is recommended to prevent unexpected behavior and improve performance. + +--- + +## Further Help + +If you continue to experience issues during migration, consider creating a [discussion](https://github.com/modern-python/that-depends/discussions) or opening an [issue](https://github.com/modern-python/that-depends/issues). +``` diff --git a/docs/providers/context-resources.md b/docs/providers/context-resources.md index bac0b8b5..64ff900f 100644 --- a/docs/providers/context-resources.md +++ b/docs/providers/context-resources.md @@ -1,10 +1,26 @@ -# ContextResource -Instances injected with the `ContextResource` provider have a managed lifecycle. +# Context-Dependent Resources +`that_depends` provides a way to manage two types of contexts: + +- A *global context* (a dictionary) where you can store objects for later retrieval. +- *Resource-specific contexts*, which are managed by the `ContextResource` provider. + +To interact with both types of contexts, there are two separate interfaces: + +1. Use the `container_context()` context manager to interact with the global context and manage `ContextResource` providers. +2. Directly manage a `ContextResource` context by using the `SupportsContext` interface, which both containers + and `ContextResource` providers implement. + +--- +## Quick Start + +You must initialize a context before you can resolve a `ContextResource`. + +**Setup:** ```python import typing -from that_depends import BaseContainer, providers +from that_depends import BaseContainer, providers, inject, Provide async def my_async_resource() -> typing.AsyncIterator[str]: @@ -14,7 +30,6 @@ async def my_async_resource() -> typing.AsyncIterator[str]: finally: print("Teardown of async resource") - def my_sync_resource() -> typing.Iterator[str]: print("Initializing sync resource") try: @@ -22,34 +37,105 @@ def my_sync_resource() -> typing.Iterator[str]: finally: print("Teardown of sync resource") - class MyContainer(BaseContainer): async_resource = providers.ContextResource(my_async_resource) sync_resource = providers.ContextResource(my_sync_resource) ``` -To be able to resolve `ContextResource` one must first enter `container_context`: +Then, you can resolve the resource by initializing its context: ```python -async with container_context(): - await MyContainer.async_resource.async_resolve() # "async resource" - MyContainer.sync_resource.sync_resolve() # "sync resource" +@MyContainer.async_resource.context +@inject +async def func(dep: str = Provide[MyContainer.async_resource]): + return dep + +await func() # returns "async resource" +``` +This will initialize a new context for `async_resource` each time `func` is called. + +--- +## Global Context + +A global context can be initialized by using the `container_context` context manager. + +```python +from that_depends import container_context, fetch_context_item + +async with container_context(global_context={"key": "value"}): + # run some code + fetch_context_item("key") # returns 'value' +``` + +You can also use `container_context` as a decorator: +```python +@container_context(global_context={"key": "value"}) +async def func(): + # run some code + fetch_context_item("key") +``` + +The values stored in the `global_context` can be resolved as long as: +1. You are still within the scope of the context manager. +2. You have not initialized a new context: +```python +async with container_context(global_context={"key": "value"}): + # run some code + fetch_context_item("key") + async with container_context(): # this will reset all contexts, including the global context. + fetch_context_item("key") # Error! key not found ``` - Trying to resolve `ContextResource` without first entering `container_context` will yield `RuntimeError`: +If you want to maintain the global context, you can initialize a new context with the `preserve_global_context` argument: +```python +async with container_context(global_context={"key": "value"}): + # run some code + fetch_context_item("key") + async with container_context(preserve_global_context=True): # preserves the global context + fetch_context_item("key") # returns 'value' +``` + +Additionally, you can use the `global_context` argument in combination with `preserve_global_context` to +extend the global context. This merges the two contexts together by key, with the new `global_context` taking precedence: +```python +async with container_context(global_context={"key_1": "value_1", "key_2": "value_2"}): + # run some code + fetch_context_item("key_1") # returns 'value_1' + async with container_context( + global_context={"key_2": "new_value", "key_3": "value_3"}, + preserve_global_context=True + ): + fetch_context_item("key_1") # returns 'value_1' + fetch_context_item("key_2") # returns 'new_value' + fetch_context_item("key_3") # returns 'value_3' +``` + +--- + +## Context Resources + +To resolve a `ContextResource`, you must first initialize a new context for that resource. The simplest way to do this is by entering `container_context()` without passing any arguments: +```python +async with container_context(): # this will make all containers initialize a new context + await MyContainer.async_resource.async_resolve() # "async resource" + MyContainer.sync_resource.sync_resolve() # "sync resource" +``` + +Trying to resolve a `ContextResource` without first entering `container_context` will yield a `RuntimeError`: ```python value = MyContainer.sync_resource.sync_resolve() > RuntimeError: Context is not set. Use container_context ``` -### Resolving async and sync dependencies: -``container_context`` implements both ``AsyncContextManager`` and ``ContextManager``. -This means that you can enter an async context with: +### Resolving async and sync dependencies + +``container_context`` implements both ``AsyncContextManager`` and ``ContextManager``. +This means you can enter an async context with: ```python async with container_context(): ... ``` -An async context will allow resolution of both sync and async dependencies. +An async context allows resolution of both sync and async dependencies. A sync context can be entered using: ```python @@ -59,34 +145,91 @@ with container_context(): A sync context will only allow resolution of sync dependencies: ```python async def my_func(): - with container_context(): # enter sync context - # try to resolve async dependency. + with container_context(): # enter sync context + # trying to resolve async dependency await MyContainer.async_resource.async_resolve() -> RuntimeError: AsyncResource cannot be resolved in an sync context. +> RuntimeError: AsyncResource cannot be resolved in a sync context. ``` +### More granular context initialization + +If you do not wish to simply reinitialize the context for all containers, you can initialize a context for a specific container: +```python +# this will init a new context for all ContextResources in MyContainer and any connected containers. +async with container_context(MyContainer): + ... +``` +Or for a specific resource: +```python +# this will init a new context for the specific resource only. +async with container_context(MyContainer.async_resource): + ... +``` + +It is not necessary to use `container_context()` to do this. Instead, you can use the `SupportsContext` interface described +[here](#quick-reference). + ### Context Hierarchy -Each time you enter `container_context` a new context is created in the background. -Resources are cached in the context after first resolution. -Resources created in a context are torn down again when `container_context` exits. + +Resources are cached in the context after their first resolution. +They are torn down when `container_context` exits: ```python async with container_context(): value_outer = await MyContainer.resource.async_resolve() async with container_context(): - # new context -> resource will be resolved a new + # new context -> resource will be resolved anew value_inner = await MyContainer.resource.async_resolve() assert value_inner != value_outer - # previously resolved value is cached in context. + # previously resolved value is cached in the outer context assert value_outer == await MyContainer.resource.async_resolve() ``` -### Resolving resources whenever function is called -`container_context` can be used as decorator: +### Resolving resources whenever a function is called + +`container_context` can be used as a decorator: ```python -@container_context() +@MyContainer.session.context # wrap with a session-specific context @inject -async def insert_into_database(session = Provide[MyContainer.session]): +async def insert_into_database(session=Provide[MyContainer.session]): ... ``` -Each time ``await insert_into_database()`` is called new instance of ``session`` will be injected. +Each time you call `await insert_into_database()`, a new instance of `session` will be injected. + +### Quick reference + +| Intention | Using `container_context()` | Using `SupportsContext` explicitly | Using `SupportsContext` decorator | +|-------------------------------------------------------|-----------------------------------------------|--------------------------------------------|-----------------------------------| +| Reset context for all containers in scope | `async with container_context():` | Not supported. | Not supported. | +| Reset only sync contexts for all containers in scope. | `with container_context():` | Not supported. | Not supported. | +| Reset a `provider.ContextResource` context | `async with container_context(my_provider):` | `async with my_provider.async_context():` | `@my_provider.context` | +| Reset a sync `provider.ContextResource` context | `with container_context(my_provider):` | `with my_provider.sync_context():` | `@my_provider.context` | +| Reset all resources in a container | `async with container_context(my_container):` | `async with my_container.async_context():` | `@my_container.context` | +| Reset all sync resources in a container | `with container_context(my_container):` | `with my_container.sync_context():` | `@my_container.context` | + +--- + +## Middleware + +For `ASGI` applications, `that_depends` provides the `DIContextMiddleware` to manage context resources. + +The `DIContextMiddleware` accepts containers and resources as arguments and automatically initializes the context for the provided resources when an endpoint is called. + +**Example with `FastAPI`:** +```python +import fastapi +from that_depends.providers import DIContextMiddleware, ContextResource +from that_depends import BaseContainer + +MyContainer: BaseContainer +my_context_resource_provider: ContextResource +my_app: fastapi.FastAPI + +# This will initialize the context for `my_context_resource_provider` and `MyContainer` whenever an endpoint is called. +my_app.add_middleware(DIContextMiddleware, MyContainer, my_context_resource_provider) + +# This will initialize the context for all containers when an endpoint is called. +my_app.add_middleware(DIContextMiddleware) +``` + +> `DIContextMiddleware` also supports the `global_context` and `preserve_global_context` arguments. diff --git a/tests/container.py b/tests/container.py index 69b1277e..a2f53507 100644 --- a/tests/container.py +++ b/tests/container.py @@ -67,3 +67,4 @@ class DIContainer(BaseContainer): ) singleton = providers.Singleton(SingletonFactory, dep1=True) object = providers.Object(object()) + context_resource = providers.ContextResource(create_async_resource) diff --git a/tests/integrations/fastapi/test_fastapi_di.py b/tests/integrations/fastapi/test_fastapi_di.py index 2bcabbc8..08c78961 100644 --- a/tests/integrations/fastapi/test_fastapi_di.py +++ b/tests/integrations/fastapi/test_fastapi_di.py @@ -2,45 +2,65 @@ import typing import fastapi +import pytest from starlette import status from starlette.testclient import TestClient from tests import container +from that_depends import fetch_context_item from that_depends.providers import DIContextMiddleware -app = fastapi.FastAPI() -app.add_middleware(DIContextMiddleware) +_GLOBAL_CONTEXT: typing.Final[dict[str, str]] = {"test2": "value2", "test1": "value1"} -@app.get("/") -async def read_root( - dependency: typing.Annotated[ - container.DependentFactory, - fastapi.Depends(container.DIContainer.dependent_factory), - ], - free_dependency: typing.Annotated[ - container.FreeFactory, - fastapi.Depends(container.DIContainer.resolver(container.FreeFactory)), - ], - singleton: typing.Annotated[ - container.SingletonFactory, - fastapi.Depends(container.DIContainer.singleton), - ], - singleton_attribute: typing.Annotated[bool, fastapi.Depends(container.DIContainer.singleton.dep1)], -) -> datetime.datetime: - assert dependency.sync_resource == free_dependency.dependent_factory.sync_resource - assert dependency.async_resource == free_dependency.dependent_factory.async_resource - assert singleton.dep1 is True - assert singleton_attribute is True - return dependency.async_resource +@pytest.fixture(params=[None, container.DIContainer]) +def fastapi_app(request: pytest.FixtureRequest) -> fastapi.FastAPI: + app = fastapi.FastAPI() + if request.param: + app.add_middleware(DIContextMiddleware, request.param, global_context=_GLOBAL_CONTEXT) + else: + app.add_middleware( + DIContextMiddleware, + global_context=_GLOBAL_CONTEXT, + ) + @app.get("/") + async def read_root( + dependency: typing.Annotated[ + container.DependentFactory, + fastapi.Depends(container.DIContainer.dependent_factory), + ], + free_dependency: typing.Annotated[ + container.FreeFactory, + fastapi.Depends(container.DIContainer.resolver(container.FreeFactory)), + ], + singleton: typing.Annotated[ + container.SingletonFactory, + fastapi.Depends(container.DIContainer.singleton), + ], + singleton_attribute: typing.Annotated[bool, fastapi.Depends(container.DIContainer.singleton.dep1)], + context_resource: typing.Annotated[datetime.datetime, fastapi.Depends(container.DIContainer.context_resource)], + ) -> datetime.datetime: + assert dependency.sync_resource == free_dependency.dependent_factory.sync_resource + assert dependency.async_resource == free_dependency.dependent_factory.async_resource + assert singleton.dep1 is True + assert singleton_attribute is True + assert context_resource == await container.DIContainer.context_resource.async_resolve() + for key, value in _GLOBAL_CONTEXT.items(): + assert fetch_context_item(key) == value + return dependency.async_resource -client = TestClient(app) + return app -async def test_read_main() -> None: - response = client.get("/") +@pytest.fixture +def fastapi_client(fastapi_app: fastapi.FastAPI) -> TestClient: + return TestClient(fastapi_app) + + +async def test_read_main(fastapi_client: TestClient) -> None: + response = fastapi_client.get("/") assert response.status_code == status.HTTP_200_OK assert ( datetime.datetime.fromisoformat(response.json().replace("Z", "+00:00")) diff --git a/tests/integrations/fastapi/test_fastapi_di_pass_request.py b/tests/integrations/fastapi/test_fastapi_di_pass_request.py index 133bfc27..1ac273de 100644 --- a/tests/integrations/fastapi/test_fastapi_di_pass_request.py +++ b/tests/integrations/fastapi/test_fastapi_di_pass_request.py @@ -9,7 +9,7 @@ async def init_di_context(request: fastapi.Request) -> typing.AsyncIterator[None]: - async with container_context({"request": request}): + async with container_context(global_context={"request": request}): yield diff --git a/tests/integrations/faststream/test_faststream_di_pass_message.py b/tests/integrations/faststream/test_faststream_di_pass_message.py index fc07835e..cc84c69b 100644 --- a/tests/integrations/faststream/test_faststream_di_pass_message.py +++ b/tests/integrations/faststream/test_faststream_di_pass_message.py @@ -14,7 +14,7 @@ async def consume_scope( call_next: typing.Callable[..., typing.Awaitable[typing.Any]], msg: StreamMessage[typing.Any], ) -> typing.Any: # noqa: ANN401 - async with container_context({"request": msg}): + async with container_context(global_context={"request": msg}): return await call_next(msg) @@ -30,7 +30,7 @@ class DIContainer(BaseContainer): @broker.subscriber(TEST_SUBJECT) -async def index_subscruber( +async def index_subscriber( context_request: typing.Annotated[ NatsMessage, Depends(DIContainer.context_request, cast=False), diff --git a/tests/providers/test_attr_getter.py b/tests/providers/test_attr_getter.py index b09deda2..5105d3c4 100644 --- a/tests/providers/test_attr_getter.py +++ b/tests/providers/test_attr_getter.py @@ -68,36 +68,50 @@ def some_async_settings_provider(request: pytest.FixtureRequest) -> providers.Ab return typing.cast(providers.AbstractProvider[Settings], request.param) -@container_context() def test_attr_getter_with_zero_attribute_depth_sync( some_sync_settings_provider: providers.AbstractProvider[Settings], ) -> None: attr_getter = some_sync_settings_provider.some_str_value - assert attr_getter.sync_resolve() == Settings().some_str_value + if isinstance(some_sync_settings_provider, providers.ContextResource): + with container_context(some_sync_settings_provider): + assert attr_getter.sync_resolve() == Settings().some_str_value + else: + assert attr_getter.sync_resolve() == Settings().some_str_value -@container_context() async def test_attr_getter_with_zero_attribute_depth_async( some_async_settings_provider: providers.AbstractProvider[Settings], ) -> None: attr_getter = some_async_settings_provider.some_str_value - assert await attr_getter.async_resolve() == Settings().some_str_value + if isinstance(some_async_settings_provider, providers.ContextResource): + async with container_context(some_async_settings_provider): + assert await attr_getter.async_resolve() == Settings().some_str_value + else: + assert await attr_getter.async_resolve() == Settings().some_str_value -@container_context() def test_attr_getter_with_more_than_zero_attribute_depth_sync( some_sync_settings_provider: providers.AbstractProvider[Settings], ) -> None: - attr_getter = some_sync_settings_provider.nested1_attr.nested2_attr.some_const - assert attr_getter.sync_resolve() == Nested2().some_const + with ( + container_context(some_sync_settings_provider) + if isinstance(some_sync_settings_provider, providers.ContextResource) + else container_context() + ): + attr_getter = some_sync_settings_provider.nested1_attr.nested2_attr.some_const + assert attr_getter.sync_resolve() == Nested2().some_const -@container_context() async def test_attr_getter_with_more_than_zero_attribute_depth_async( some_async_settings_provider: providers.AbstractProvider[Settings], ) -> None: - attr_getter = some_async_settings_provider.nested1_attr.nested2_attr.some_const - assert await attr_getter.async_resolve() == Nested2().some_const + async with ( + container_context(some_async_settings_provider) + if isinstance(some_async_settings_provider, providers.ContextResource) + else container_context() + ): + attr_getter = some_async_settings_provider.nested1_attr.nested2_attr.some_const + assert await attr_getter.async_resolve() == Nested2().some_const @pytest.mark.parametrize( diff --git a/tests/providers/test_context_resources.py b/tests/providers/test_context_resources.py index 2a9562bb..c17241e1 100644 --- a/tests/providers/test_context_resources.py +++ b/tests/providers/test_context_resources.py @@ -3,7 +3,7 @@ import logging import typing import uuid -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, ExitStack import pytest @@ -37,6 +37,14 @@ class DIContainer(BaseContainer): ) +class DependentDiContainer(BaseContainer): + dependent_sync_context_resource = providers.ContextResource(create_sync_context_resource) + dependent_async_context_resource = providers.ContextResource(create_async_context_resource) + + +DIContainer.connect_containers(DependentDiContainer) + + @pytest.fixture(autouse=True) async def _clear_di_container() -> typing.AsyncIterator[None]: try: @@ -85,7 +93,7 @@ def test_sync_context_resource(sync_context_resource: providers.ContextResource[ async def test_async_context_resource_in_sync_context(async_context_resource: providers.ContextResource[str]) -> None: - with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved in an sync context."), container_context(): + with pytest.raises(RuntimeError, match="Context is not set. Use container_context"), container_context(): await async_context_resource() @@ -139,10 +147,10 @@ def test_context_resources_wrong_providers_init() -> None: async def test_context_resource_with_dynamic_resource() -> None: - async with container_context({"resource_type": "sync"}): + async with container_context(global_context={"resource_type": "sync"}, reset_all_containers=True): assert (await DIContainer.dynamic_context_resource()).startswith("sync") - async with container_context({"resource_type": "async_"}): + async with container_context(global_context={"resource_type": "async_"}, reset_all_containers=True): assert (await DIContainer.dynamic_context_resource()).startswith("async") async with container_context(): @@ -150,9 +158,9 @@ async def test_context_resource_with_dynamic_resource() -> None: async def test_early_exit_of_container_context() -> None: - with pytest.raises(RuntimeError, match="Context is not set, call ``__aenter__`` first"): + with pytest.raises(RuntimeError, match="No context token set for global vars, use __enter__ or __aenter__ first."): await container_context().__aexit__(None, None, None) - with pytest.raises(RuntimeError, match="Context is not set, call ``__enter__`` first"): + with pytest.raises(RuntimeError, match="No context token set for global vars, use __enter__ or __aenter__ first."): container_context().__exit__(None, None, None) @@ -192,6 +200,47 @@ async def some_injected(depth: int, val: str = Provide[DIContainer.async_context await some_injected(1) +async def test_async_injection_when_resetting_resource_specific_context( + async_context_resource: providers.ContextResource[str], +) -> None: + """Async context resources should be able to reset the context for themselves.""" + + @async_context_resource.context + @inject + async def _async_injected(val: str = Provide[async_context_resource]) -> str: + assert isinstance(async_context_resource._fetch_context().context_stack, AsyncExitStack) # noqa: SLF001 + return val + + async_result = await _async_injected() + assert async_result != await _async_injected() + assert isinstance(async_result, str) + + +async def test_sync_injection_when_resetting_resource_specific_context( + sync_context_resource: providers.ContextResource[str], +) -> None: + """Sync context resources should be able to reset the context for themselves.""" + + @sync_context_resource.context + @inject + async def _async_injected(val: str = Provide[sync_context_resource]) -> str: + assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001 + return val + + @sync_context_resource.context + @inject + def _sync_injected(val: str = Provide[sync_context_resource]) -> str: + assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001 + return val + + async_result = await _async_injected() + assert async_result != await _async_injected() + assert isinstance(async_result, str) + sync_result = _sync_injected() + assert sync_result != _sync_injected() + assert isinstance(sync_result, str) + + @pytest.mark.repeat(10) async def test_async_context_resource_asyncio_concurrency() -> None: calls: int = 0 @@ -207,7 +256,355 @@ async def create_client() -> typing.AsyncIterator[str]: async def resolve_resource() -> str: return await resource.async_resolve() - async with container_context(): + async with resource.async_context(): await asyncio.gather(resolve_resource(), resolve_resource()) assert calls == 1 + + +@pytest.mark.repeat(10) +async def test_sync_context_resource_asyncio_concurrency() -> None: + calls: int = 0 + + def create_client() -> typing.Iterator[str]: + nonlocal calls + calls += 1 + yield "" + + resource = providers.ContextResource(create_client) + + async def resolve_resource() -> str: + return resource.sync_resolve() + + with resource.sync_context(): + await asyncio.gather(resolve_resource(), resolve_resource()) + + assert calls == 1 + + +async def test_async_injection_when_explicitly_resetting_resource_specific_context( + async_context_resource: providers.ContextResource[str], +) -> None: + """Async context resources should be able to reset the context for themselves explicitly.""" + + @async_context_resource.async_context() + @inject + async def _async_injected(val: str = Provide[async_context_resource]) -> str: + assert isinstance(async_context_resource._fetch_context().context_stack, AsyncExitStack) # noqa: SLF001 + return val + + async_result = await _async_injected() + assert async_result != await _async_injected() + assert isinstance(async_result, str) + + +async def test_sync_injection_when_explicitly_resetting_resource_specific_context( + sync_context_resource: providers.ContextResource[str], +) -> None: + """Sync context resources should be able to reset the context for themselves explicitly.""" + + @sync_context_resource.async_context() + @inject + async def _async_injected(val: str = Provide[sync_context_resource]) -> str: + assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001 + return val + + @sync_context_resource.sync_context() + @inject + def _sync_injected(val: str = Provide[sync_context_resource]) -> str: + assert isinstance(sync_context_resource._fetch_context().context_stack, ExitStack) # noqa: SLF001 + return val + + async_result = await _async_injected() + assert async_result != await _async_injected() + assert isinstance(async_result, str) + sync_result = _sync_injected() + assert sync_result != _sync_injected() + assert isinstance(sync_result, str) + + +async def test_async_resolution_when_explicitly_resolving( + async_context_resource: providers.ContextResource[str], +) -> None: + """Async context should cache resources until a new one is created.""" + async with async_context_resource.async_context(): + val_1 = await async_context_resource.async_resolve() + val_2 = await async_context_resource.async_resolve() + assert val_1 == val_2 + async with async_context_resource.async_context(): + val_3 = await async_context_resource.async_resolve() + assert val_1 != val_3 + async with async_context_resource.async_context(): + val_4 = await async_context_resource.async_resolve() + assert val_1 != val_4 != val_3 + val_5 = await async_context_resource.async_resolve() + assert val_5 == val_3 + val_6 = await async_context_resource.async_resolve() + assert val_6 == val_1 + + +def test_sync_resolution_when_explicitly_resolving( + sync_context_resource: providers.ContextResource[str], +) -> None: + """Sync context should cache resources until a new one is created.""" + with sync_context_resource.sync_context(): + val_1 = sync_context_resource.sync_resolve() + val_2 = sync_context_resource.sync_resolve() + assert val_1 == val_2 + with sync_context_resource.sync_context(): + val_3 = sync_context_resource.sync_resolve() + assert val_1 != val_3 + with sync_context_resource.sync_context(): + val_4 = sync_context_resource.sync_resolve() + assert val_1 != val_4 != val_3 + val_5 = sync_context_resource.sync_resolve() + assert val_5 == val_3 + val_6 = sync_context_resource.sync_resolve() + assert val_6 == val_1 + + +def test_sync_container_context_resolution( + sync_context_resource: providers.ContextResource[str], +) -> None: + """container_context should reset context for sync provider.""" + with container_context(sync_context_resource): + val_1 = sync_context_resource.sync_resolve() + val_2 = sync_context_resource.sync_resolve() + assert val_1 == val_2 + with container_context(sync_context_resource): + val_3 = sync_context_resource.sync_resolve() + assert val_3 != val_1 + val_4 = sync_context_resource.sync_resolve() + assert val_4 == val_1 + with pytest.raises(RuntimeError): + sync_context_resource.sync_resolve() + + +async def test_async_container_context_resolution( + async_context_resource: providers.ContextResource[str], +) -> None: + """container_context should reset context for async provider.""" + async with container_context(async_context_resource): + val_1 = await async_context_resource.async_resolve() + val_2 = await async_context_resource.async_resolve() + assert val_1 == val_2 + async with container_context(async_context_resource): + val_3 = await async_context_resource.async_resolve() + assert val_3 != val_1 + val_4 = await async_context_resource.async_resolve() + assert val_4 == val_1 + with pytest.raises(RuntimeError): + await async_context_resource.async_resolve() + + +async def test_async_global_context_resolution() -> None: + with pytest.raises(RuntimeError): + async with AsyncExitStack() as stack: + await stack.enter_async_context(container_context(preserve_global_context=True)) + my_global_resources = {"test_1": "test_1", "test_2": "test_2"} + + async with container_context(global_context=my_global_resources): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + async with container_context(preserve_global_context=True): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + async with container_context(preserve_global_context=False): + for key in my_global_resources: + assert fetch_context_item(key) is None + + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + with pytest.raises(RuntimeError): + fetch_context_item("test_1") + + +def test_sync_global_context_resolution() -> None: + with pytest.raises(RuntimeError), ExitStack() as stack: + stack.enter_context(container_context(preserve_global_context=True)) + my_global_resources = {"test_1": "test_1", "test_2": "test_2"} + with container_context(global_context=my_global_resources): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + with container_context(preserve_global_context=True): + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + with container_context(preserve_global_context=False): + for key in my_global_resources: + assert fetch_context_item(key) is None + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + for key, item in my_global_resources.items(): + assert fetch_context_item(key) == item + + with pytest.raises(RuntimeError): + fetch_context_item("test_1") + + +async def test_async_global_context_reset(async_context_resource: providers.ContextResource[str]) -> None: + """container_context should reset async providers.""" + async with container_context(): + val_1 = await async_context_resource.async_resolve() + val_2 = await async_context_resource.async_resolve() + assert val_1 == val_2 + async with container_context(): + val_3 = await async_context_resource.async_resolve() + assert val_3 != val_1 + val_4 = await async_context_resource.async_resolve() + assert val_4 == val_1 + + +def test_sync_global_context_reset(sync_context_resource: providers.ContextResource[str]) -> None: + """container_context should reset sync providers.""" + with container_context(): + val_1 = sync_context_resource.sync_resolve() + val_2 = sync_context_resource.sync_resolve() + assert val_1 == val_2 + with container_context(): + val_3 = sync_context_resource.sync_resolve() + assert val_3 != val_1 + val_4 = sync_context_resource.sync_resolve() + assert val_4 == val_1 + + +async def test_async_context_with_container( + async_context_resource: providers.ContextResource[str], + sync_context_resource: providers.ContextResource[str], +) -> None: + """Containers should enter async context for all its providers.""" + async with DIContainer.async_context(): + val_1 = await async_context_resource.async_resolve() + val_2 = await async_context_resource.async_resolve() + assert val_1 == val_2 + val_1_sync = sync_context_resource.sync_resolve() + val_2_sync = sync_context_resource.sync_resolve() + assert val_1_sync == val_2_sync + async with DIContainer.async_context(): + val_3 = await async_context_resource.async_resolve() + val_3_sync = sync_context_resource.sync_resolve() + assert val_3 != val_1 + assert val_3_sync != val_1_sync + val_4 = await async_context_resource.async_resolve() + val_4_sync = sync_context_resource.sync_resolve() + assert val_4 == val_1 + assert val_4_sync == val_1_sync + + +def test_sync_context_with_container( + sync_context_resource: providers.ContextResource[str], +) -> None: + """Containers should enter sync context for all its providers.""" + with DIContainer.sync_context(): + val_1 = sync_context_resource.sync_resolve() + val_2 = sync_context_resource.sync_resolve() + assert val_1 == val_2 + with DIContainer.sync_context(): + val_3 = sync_context_resource.sync_resolve() + assert val_3 != val_1 + val_4 = sync_context_resource.sync_resolve() + assert val_4 == val_1 + + +async def test_async_container_context_wrapper(async_context_resource: providers.ContextResource[str]) -> None: + """Container context wrapper should correctly enter async context for wrapped function.""" + + @DIContainer.context + @inject + async def _injected(val: str = Provide[async_context_resource]) -> str: + return val + + assert await _injected() != await _injected() + + @DIContainer.async_context() + @inject + async def _explicit_injected(val: str = Provide[async_context_resource]) -> str: + return val + + assert await _explicit_injected() != await _explicit_injected() + + +def test_sync_container_context_wrapper(sync_context_resource: providers.ContextResource[str]) -> None: + """Container context wrapper should correctly enter sync context for wrapped function.""" + + @DIContainer.context + @inject + def _injected(val: str = Provide[sync_context_resource]) -> str: + return val + + assert _injected() != _injected() + + @DIContainer.sync_context() + @inject + def _explicit_injected(val: str = Provide[sync_context_resource]) -> str: + return val + + assert _explicit_injected() != _explicit_injected() + + +async def test_async_context_resource_with_dependent_container() -> None: + """Container should initialize async context resource for dependent containers.""" + async with DIContainer.async_context(): + val_1 = await DependentDiContainer.dependent_async_context_resource.async_resolve() + val_2 = await DependentDiContainer.dependent_async_context_resource.async_resolve() + assert val_1 == val_2 + + +def test_sync_context_resource_with_dependent_container() -> None: + """Container should initialize sync context resource for dependent containers.""" + with DIContainer.sync_context(): + val_1 = DependentDiContainer.dependent_sync_context_resource.sync_resolve() + val_2 = DependentDiContainer.dependent_sync_context_resource.sync_resolve() + assert val_1 == val_2 + + +def test_containers_support_sync_context() -> None: + assert DIContainer.supports_sync_context() + + +def test_enter_sync_context_for_async_resource_should_throw( + async_context_resource: providers.ContextResource[str], +) -> None: + with pytest.raises(RuntimeError): + async_context_resource.__enter__() + + +def test_exit_sync_context_before_enter_should_throw(sync_context_resource: providers.ContextResource[str]) -> None: + with pytest.raises(RuntimeError): + sync_context_resource.__exit__(None, None, None) + + +async def test_exit_async_context_before_enter_should_throw( + async_context_resource: providers.ContextResource[str], +) -> None: + with pytest.raises(RuntimeError): + await async_context_resource.__aexit__(None, None, None) + + +def test_enter_sync_context_from_async_resource_should_throw( + async_context_resource: providers.ContextResource[str], +) -> None: + with pytest.raises(RuntimeError), ExitStack() as stack: + stack.enter_context(async_context_resource.sync_context()) + + +async def test_preserve_globals_and_initial_context() -> None: + initial_context = {"test_1": "test_1", "test_2": "test_2"} + + async with container_context(global_context=initial_context): + for key, item in initial_context.items(): + assert fetch_context_item(key) == item + new_context = {"test_3": "test_3"} + async with container_context(global_context=new_context, preserve_global_context=True): + for key, item in new_context.items(): + assert fetch_context_item(key) == item + for key, item in initial_context.items(): + assert fetch_context_item(key) == item + for key, item in initial_context.items(): + assert fetch_context_item(key) == item + for key in new_context: + assert fetch_context_item(key) is None diff --git a/that_depends/container.py b/that_depends/container.py index 45b3d16b..7fe71c65 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -1,8 +1,10 @@ import inspect import typing -from contextlib import contextmanager +from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager +from that_depends.meta import BaseContainerMeta from that_depends.providers import AbstractProvider, Resource, Singleton +from that_depends.providers.context_resources import ContextResource, SupportsContext if typing.TYPE_CHECKING: @@ -13,7 +15,7 @@ P = typing.ParamSpec("P") -class BaseContainer: +class BaseContainer(SupportsContext[None], metaclass=BaseContainerMeta): providers: dict[str, AbstractProvider[typing.Any]] containers: list[type["BaseContainer"]] @@ -21,6 +23,48 @@ def __new__(cls, *_: typing.Any, **__: typing.Any) -> "typing_extensions.Self": msg = f"{cls.__name__} should not be instantiated" raise RuntimeError(msg) + @classmethod + def supports_sync_context(cls) -> bool: + return True + + @classmethod + @contextmanager + def sync_context(cls) -> typing.Iterator[None]: + with ExitStack() as stack: + for container in cls.get_containers(): + stack.enter_context(container.sync_context()) + for provider in cls.get_providers().values(): + if isinstance(provider, ContextResource) and not provider.is_async: + stack.enter_context(provider.sync_context()) + yield + + @classmethod + @asynccontextmanager + async def async_context(cls) -> typing.AsyncIterator[None]: + async with AsyncExitStack() as stack: + for container in cls.get_containers(): + await stack.enter_async_context(container.async_context()) + for provider in cls.get_providers().values(): + if isinstance(provider, ContextResource): + await stack.enter_async_context(provider.async_context()) + yield + + @classmethod + def context(cls, func: typing.Callable[P, T]) -> typing.Callable[P, T]: + if inspect.iscoroutinefunction(func): + + async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + async with cls.async_context(): + return await func(*args, **kwargs) # type: ignore[no-any-return] + + return typing.cast(typing.Callable[P, T], _async_wrapper) + + def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + with cls.sync_context(): + return func(*args, **kwargs) + + return _sync_wrapper + @classmethod def connect_containers(cls, *containers: type["BaseContainer"]) -> None: """Connect containers. @@ -37,7 +81,6 @@ def connect_containers(cls, *containers: type["BaseContainer"]) -> None: def get_providers(cls) -> dict[str, AbstractProvider[typing.Any]]: if not hasattr(cls, "providers"): cls.providers = {k: v for k, v in cls.__dict__.items() if isinstance(v, AbstractProvider)} - return cls.providers @classmethod diff --git a/that_depends/meta.py b/that_depends/meta.py new file mode 100644 index 00000000..7ac08b8f --- /dev/null +++ b/that_depends/meta.py @@ -0,0 +1,24 @@ +import abc +import typing +from threading import Lock + + +if typing.TYPE_CHECKING: + from that_depends.container import BaseContainer + + +class BaseContainerMeta(abc.ABCMeta): + _instances: typing.ClassVar[list[type["BaseContainer"]]] = [] + + _lock: Lock = Lock() + + def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, typing.Any]) -> type: + new_cls = super().__new__(cls, name, bases, namespace) + with cls._lock: + if name != "BaseContainer": + cls._instances.append(new_cls) # type: ignore[arg-type] + return new_cls + + @classmethod + def get_instances(cls) -> list[type["BaseContainer"]]: + return cls._instances diff --git a/that_depends/providers/__init__.py b/that_depends/providers/__init__.py index 78d88523..dd6ed8dc 100644 --- a/that_depends/providers/__init__.py +++ b/that_depends/providers/__init__.py @@ -1,5 +1,5 @@ from that_depends.providers.base import AbstractProvider, AttrGetter -from that_depends.providers.collections import Dict, List +from that_depends.providers.collection import Dict, List from that_depends.providers.context_resources import ( ContextResource, DIContextMiddleware, diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index 142510d7..4c66d7c7 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -91,16 +91,16 @@ def __init__( super().__init__() self._creator: typing.Any if inspect.isasyncgenfunction(creator): - self._is_async = True + self.is_async = True self._creator = contextlib.asynccontextmanager(creator) elif inspect.isgeneratorfunction(creator): - self._is_async = False + self.is_async = False self._creator = contextlib.contextmanager(creator) elif isinstance(creator, type) and issubclass(creator, typing.AsyncContextManager): - self._is_async = True + self.is_async = True self._creator = creator elif isinstance(creator, type) and issubclass(creator, typing.ContextManager): - self._is_async = False + self.is_async = False self._creator = creator else: msg = "Unsupported resource type" @@ -121,10 +121,6 @@ async def async_resolve(self) -> T_co: if context.instance is not None: return context.instance - if not context.is_async and self._is_async: - msg = "AsyncResource cannot be resolved in an sync context." - raise RuntimeError(msg) - # lock to prevent race condition while resolving async with context.asyncio_lock: if context.instance is not None: @@ -159,7 +155,7 @@ def sync_resolve(self) -> T_co: if context.instance is not None: return context.instance - if self._is_async: + if self.is_async: msg = "AsyncResource cannot be resolved synchronously" raise RuntimeError(msg) diff --git a/that_depends/providers/collections.py b/that_depends/providers/collection.py similarity index 98% rename from that_depends/providers/collections.py rename to that_depends/providers/collection.py index cb9d4dde..c976a561 100644 --- a/that_depends/providers/collections.py +++ b/that_depends/providers/collection.py @@ -1,4 +1,4 @@ -import typing # noqa: A005 +import typing from that_depends.providers.base import AbstractProvider diff --git a/that_depends/providers/context_resources.py b/that_depends/providers/context_resources.py index 6d865b62..7d479039 100644 --- a/that_depends/providers/context_resources.py +++ b/that_depends/providers/context_resources.py @@ -1,16 +1,25 @@ +import abc +import contextlib import inspect import logging import typing -import uuid +from abc import abstractmethod from contextlib import AbstractAsyncContextManager, AbstractContextManager from contextvars import ContextVar, Token from functools import wraps from types import TracebackType +from typing_extensions import TypeIs + from that_depends.entities.resource_context import ResourceContext +from that_depends.meta import BaseContainerMeta from that_depends.providers.base import AbstractResource +if typing.TYPE_CHECKING: + from that_depends.container import BaseContainer + + logger: typing.Final = logging.getLogger(__name__) T_co = typing.TypeVar("T_co", covariant=True) P = typing.ParamSpec("P") @@ -26,142 +35,313 @@ ContextType = dict[str, typing.Any] -class container_context( # noqa: N801 - AbstractAsyncContextManager[ContextType], AbstractContextManager[ContextType] -): - """Manage the context of ContextResources. +def _get_container_context() -> dict[str, typing.Any]: + try: + return _CONTAINER_CONTEXT.get() + except LookupError as exc: + msg = "Context is not set. Use container_context" + raise RuntimeError(msg) from exc - Can be entered using ``async with container_context()`` or with ``with container_context()`` - as async-context-manager or context-manager respectively. - When used as async-context-manager, it will allow setup & teardown of both sync and async resources. - When used as sync-context-manager, it will only allow setup & teardown of sync resources. - """ - __slots__ = "_context_token", "_initial_context" +def fetch_context_item(key: str, default: typing.Any = None) -> typing.Any: # noqa: ANN401 + return _get_container_context().get(key, default) - def __init__(self, initial_context: ContextType | None = None) -> None: - self._initial_context: ContextType = initial_context or {} - self._context_token: Token[ContextType] | None = None - def __enter__(self) -> ContextType: - self._initial_context[_ASYNC_CONTEXT_KEY] = False +T = typing.TypeVar("T") +CT = typing.TypeVar("CT") + + +class SupportsContext(typing.Generic[CT], abc.ABC): + @abstractmethod + def context(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]: + """Initialize context for the given function. + + :param func: function to wrap. + :return: wrapped function with context. + """ + + @abstractmethod + def async_context(self) -> typing.AsyncContextManager[CT]: + """Initialize async context.""" + + @abstractmethod + def sync_context(self) -> typing.ContextManager[CT]: + """Initialize sync context.""" + + @abstractmethod + def supports_sync_context(self) -> bool: + """Check if the resource supports sync context.""" + + +class ContextResource( + AbstractResource[T_co], + AbstractAsyncContextManager[ResourceContext[T_co]], + AbstractContextManager[ResourceContext[T_co]], + SupportsContext[ResourceContext[T_co]], +): + __slots__ = ( + "_args", + "_context", + "_creator", + "_internal_name", + "_kwargs", + "_override", + "_token", + "is_async", + ) + + def __init__( + self, + creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + super().__init__(creator, *args, **kwargs) + self._context: ContextVar[ResourceContext[T_co]] = ContextVar(f"{self._creator.__name__}-context") + self._token: Token[ResourceContext[T_co]] | None = None + + def supports_sync_context(self) -> bool: + return not self.is_async + + def __enter__(self) -> ResourceContext[T_co]: + if self.is_async: + msg = "You must enter async context for async creators." + raise RuntimeError(msg) return self._enter() - async def __aenter__(self) -> ContextType: - self._initial_context[_ASYNC_CONTEXT_KEY] = True + async def __aenter__(self) -> ResourceContext[T_co]: return self._enter() - def _enter(self) -> ContextType: - self._context_token = _CONTAINER_CONTEXT.set({**self._initial_context}) - return _CONTAINER_CONTEXT.get() + def _enter(self) -> ResourceContext[T_co]: + self._token = self._context.set(ResourceContext(is_async=self.is_async)) + return self._context.get() def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None ) -> None: - if self._context_token is None: + if not self._token: msg = "Context is not set, call ``__enter__`` first" raise RuntimeError(msg) try: - for context_item in reversed(_CONTAINER_CONTEXT.get().values()): - if isinstance(context_item, ResourceContext): - # we don't need to handle the case where the ResourceContext is async - context_item.sync_tear_down() + context_item = self._context.get() + context_item.sync_tear_down() finally: - _CONTAINER_CONTEXT.reset(self._context_token) + self._context.reset(self._token) async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, traceback: TracebackType | None ) -> None: - if self._context_token is None: + if self._token is None: msg = "Context is not set, call ``__aenter__`` first" raise RuntimeError(msg) try: - for context_item in reversed(_CONTAINER_CONTEXT.get().values()): - if not isinstance(context_item, ResourceContext): - continue - - if context_item.is_context_stack_async(context_item.context_stack): - await context_item.tear_down() - else: - context_item.sync_tear_down() + context_item = self._context.get() + if context_item.is_context_stack_async(context_item.context_stack): + await context_item.tear_down() + else: + context_item.sync_tear_down() finally: - _CONTAINER_CONTEXT.reset(self._context_token) + self._context.reset(self._token) - def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]: + @contextlib.contextmanager + def sync_context(self) -> typing.Iterator[ResourceContext[T_co]]: + if self.is_async: + msg = "Please use async context instead." + raise RuntimeError(msg) + token = self._token + with self as val: + yield val + self._token = token + + @contextlib.asynccontextmanager + async def async_context(self) -> typing.AsyncIterator[ResourceContext[T_co]]: + token = self._token + async with self as val: + yield val + self._token = token + + def context(self, func: typing.Callable[P, T]) -> typing.Callable[P, T]: + """Create a new context manager for the resource, the context manager will be async if the resource is async. + + :return: A context manager for the resource. + :rtype: typing.ContextManager[ResourceContext[T_co]] | typing.AsyncContextManager[ResourceContext[T_co]] + """ if inspect.iscoroutinefunction(func): @wraps(func) - async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: - async with container_context(self._initial_context): + async def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + async with self.async_context(): return await func(*args, **kwargs) # type: ignore[no-any-return] - return typing.cast(typing.Callable[P, T_co], _async_inner) + return typing.cast(typing.Callable[P, T], _async_wrapper) + # wrapped function is sync @wraps(func) - def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: - with container_context(self._initial_context): + def _sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + with self.sync_context(): return func(*args, **kwargs) - return _sync_inner + return typing.cast(typing.Callable[P, T], _sync_wrapper) + + def _fetch_context(self) -> ResourceContext[T_co]: + try: + return self._context.get() + except LookupError as e: + msg = "Context is not set. Use container_context" + raise RuntimeError(msg) from e -class DIContextMiddleware: - def __init__(self, app: ASGIApp) -> None: - self.app: typing.Final = app +ContainerType = typing.TypeVar("ContainerType", bound="type[BaseContainer]") - @container_context() - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - return await self.app(scope, receive, send) +class container_context(AbstractContextManager[ContextType], AbstractAsyncContextManager[ContextType]): # noqa: N801 + ___slots__ = ( + "_providers", + "_context_stack", + "_containers", + "_initial_context", + "_context_token", + "_reset_resource_context", + ) -def _get_container_context() -> dict[str, typing.Any]: - try: + def __init__( + self, + *context_items: SupportsContext[typing.Any], + global_context: ContextType | None = None, + preserve_global_context: bool = False, + reset_all_containers: bool = False, + ) -> None: + """Initialize a new container context. + + :param context_items: context items to initialize new context for. + :param global_context: existing context to use + :param preserve_global_context: whether to preserve old global context. + Will merge old context with the new context if this option is set to True. + :param reset_all_containers: Create a new context for all containers. + """ + if preserve_global_context and global_context: + self._initial_context = {**_get_container_context(), **global_context} + else: + self._initial_context: ContextType = ( # type: ignore[no-redef] + _get_container_context() if preserve_global_context else global_context or {} + ) + self._context_token: Token[ContextType] | None = None + self._context_items: set[SupportsContext[typing.Any]] = set(context_items) + self._reset_resource_context: typing.Final[bool] = ( + not context_items and not global_context + ) or reset_all_containers + if self._reset_resource_context: + self._add_providers_from_containers(BaseContainerMeta.get_instances()) + + self._context_stack: contextlib.AsyncExitStack | contextlib.ExitStack | None = None + + def _add_providers_from_containers(self, containers: list[ContainerType]) -> None: + for container in containers: + for container_provider in container.get_providers().values(): + if isinstance(container_provider, ContextResource): + self._context_items.add(container_provider) + + def __enter__(self) -> ContextType: + self._context_stack = contextlib.ExitStack() + for item in self._context_items: + if item.supports_sync_context(): + self._context_stack.enter_context(item.sync_context()) + return self._enter_globals() + + async def __aenter__(self) -> ContextType: + self._context_stack = contextlib.AsyncExitStack() + for item in self._context_items: + await self._context_stack.enter_async_context(item.async_context()) + return self._enter_globals() + + def _enter_globals(self) -> ContextType: + self._context_token = _CONTAINER_CONTEXT.set(self._initial_context) return _CONTAINER_CONTEXT.get() - except LookupError as exc: - msg = "Context is not set. Use container_context" - raise RuntimeError(msg) from exc + def _is_context_token(self, _: Token[ContextType] | None) -> TypeIs[Token[ContextType]]: + return isinstance(_, Token) -def _is_container_context_async() -> bool: - """Check if the current container context is async. + def _exit_globals(self) -> None: + if self._is_context_token(self._context_token): + return _CONTAINER_CONTEXT.reset(self._context_token) + msg = "No context token set for global vars, use __enter__ or __aenter__ first." + raise RuntimeError(msg) - :return: Whether the current container context is async. - :rtype: bool - """ - return typing.cast(bool, _get_container_context().get(_ASYNC_CONTEXT_KEY, False)) + def _has_async_exit_stack( + self, + _: contextlib.AsyncExitStack | contextlib.ExitStack | None, + ) -> typing.TypeGuard[contextlib.AsyncExitStack]: + return isinstance(_, contextlib.AsyncExitStack) + def _has_sync_exit_stack( + self, _: contextlib.AsyncExitStack | contextlib.ExitStack | None + ) -> typing.TypeGuard[contextlib.ExitStack]: + return isinstance(_, contextlib.ExitStack) -def fetch_context_item(key: str, default: typing.Any = None) -> typing.Any: # noqa: ANN401 - return _get_container_context().get(key, default) + def __exit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> None: + try: + if self._has_sync_exit_stack(self._context_stack): + self._context_stack.close() + else: + msg = "Context is not set, call ``__enter__`` first" + raise RuntimeError(msg) + finally: + self._exit_globals() + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, traceback: TracebackType | None + ) -> None: + try: + if self._has_async_exit_stack(self._context_stack): + await self._context_stack.aclose() + else: + msg = "Context is not set, call ``__aenter__`` first" + raise RuntimeError(msg) + finally: + self._exit_globals() -class ContextResource(AbstractResource[T_co]): - __slots__ = ( - "_args", - "_creator", - "_internal_name", - "_is_async", - "_kwargs", - "_override", - ) + def __call__(self, func: typing.Callable[P, T_co]) -> typing.Callable[P, T_co]: + if inspect.iscoroutinefunction(func): + @wraps(func) + async def _async_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: + async with container_context(*self._context_items, reset_all_containers=self._reset_resource_context): + return await func(*args, **kwargs) # type: ignore[no-any-return] + + return typing.cast(typing.Callable[P, T_co], _async_inner) + + @wraps(func) + def _sync_inner(*args: P.args, **kwargs: P.kwargs) -> T_co: + with container_context(*self._context_items, reset_all_containers=self._reset_resource_context): + return func(*args, **kwargs) + + return _sync_inner + + +class DIContextMiddleware: def __init__( self, - creator: typing.Callable[P, typing.Iterator[T_co] | typing.AsyncIterator[T_co]], - *args: P.args, - **kwargs: P.kwargs, + app: ASGIApp, + *context_items: SupportsContext[typing.Any], + global_context: dict[str, typing.Any] | None = None, + reset_all_containers: bool = True, ) -> None: - super().__init__(creator, *args, **kwargs) - self._internal_name: typing.Final = f"{creator.__name__}-{uuid.uuid4()}" - - def _fetch_context(self) -> ResourceContext[T_co]: - container_context = _get_container_context() - if resource_context := container_context.get(self._internal_name): - return typing.cast(ResourceContext[T_co], resource_context) + self.app: typing.Final = app + self._context_items: set[SupportsContext[typing.Any]] = set(context_items) + self._global_context: dict[str, typing.Any] | None = global_context + self._reset_all_containers: bool = reset_all_containers - resource_context = ResourceContext(is_async=_is_container_context_async()) - container_context.setdefault(self._internal_name, resource_context) - return resource_context + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self._context_items: + pass + async with ( + container_context(*self._context_items, global_context=self._global_context) + if self._context_items + else container_context(global_context=self._global_context, reset_all_containers=self._reset_all_containers) + ): + return await self.app(scope, receive, send) diff --git a/that_depends/providers/resources.py b/that_depends/providers/resources.py index 02f3aea6..238ab5b7 100644 --- a/that_depends/providers/resources.py +++ b/that_depends/providers/resources.py @@ -13,9 +13,11 @@ class Resource(AbstractResource[T_co]): "_args", "_context", "_creator", + "_creator", "_is_async", "_kwargs", "_override", + "is_async", ) def __init__( @@ -25,7 +27,7 @@ def __init__( **kwargs: P.kwargs, ) -> None: super().__init__(creator, *args, **kwargs) - self._context: typing.Final[ResourceContext[T_co]] = ResourceContext(is_async=self._is_async) + self._context: typing.Final[ResourceContext[T_co]] = ResourceContext(is_async=self.is_async) def _fetch_context(self) -> ResourceContext[T_co]: return self._context