Skip to content

Commit

Permalink
String provider resolution (#150)
Browse files Browse the repository at this point in the history
* Setup scheme for string based provider definition

* Implemented string resolution functionality.

* Added tests.

* Deferred provider resolution to inject wrapper.

* Updated migration guide to mention that dynamic ontainer was removed.

* Wrote string provider documentation.
  • Loading branch information
alexanderlazarev0 authored Feb 5, 2025
1 parent 884653f commit b429870
Show file tree
Hide file tree
Showing 19 changed files with 325 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
introduction/faststream
introduction/inject-factories
introduction/scopes
introduction/string-injection
introduction/multiple-containers
introduction/application-settings
Expand Down
79 changes: 79 additions & 0 deletions docs/introduction/string-injection.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
## Provider resolution by name

The `@inject` decorator can be used to inject a provider by its name. This is useful when you want to inject a provider that is not directly imported in the current module.


This serves two primary purposes:

- A higher level of decoupling between container and your code.
- Avoiding circular imports.

---
## Usage

To inject a provider by name, use the `Provide` marker with a string argument that has the following format:

```
Container.Provider[.attribute.attribute...]
```

The string will be validated when it is passed to `Provide[]`, thus will raise an exception
immediately.

**For example**:

```python
from that_depends import BaseContainer, inject, Provide

class Config(BaseSettings):
name: str = "Damian"

class A(BaseContainer):
b = providers.Factory(Config)

@inject
def read(val = Provide["A.b.name"]):
return val

assert read() == "Damian"
```

### Container alias

Containers support aliases:

```python

class A(BaseContainer):
alias = "C" # replaces the container name.
b = providers.Factory(Config)

@inject
def read(val = Provide["C.b.name"]): # `A` can no longer be used.
return val

assert read() == "Damian"
```
---
## Considerations

This feature is primarily intended as a fallback when other options are not optimal or
simply not available, thus is recommended to be used sparingly.

If you do decide to use injection by name, consider the following:

- In order for this type of injection to work, your container must be in scope when the injected function is called:
```python
from that_depends import BaseContainer, inject, Provide

@inject
def injected(f = Provide["MyContainer.my_provider"]): ...

injected() # will raise an Exception

class MyContainer(BaseContainer):
my_provider = providers.Factory(some_creator)

injected() # will resolve
```
- Validation of whether you have provided a correct container name and provider name will only happen when the function is called.
22 changes: 21 additions & 1 deletion docs/migration/v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ If you want to learn more about the new features introduced in `2.*`, please ref

---

## Deprecated Features
## Deprecated or Removed Features

1. **`BaseContainer.init_async_resources()` removed**
The method `BaseContainer.init_async_resources()` has been removed. Use `BaseContainer.init_resources()` instead.
Expand Down Expand Up @@ -48,6 +48,26 @@ If you want to learn more about the new features introduced in `2.*`, please ref
from that_depends.providers import Resource
my_provider = providers.Resource(some_async_function)
```

3. **`BaseContainer` and its subclasses are no longer dynamic.**

In `1.*`, you could define a container class and add providers to it dynamically. This feature has been removed in `2.*`. You must now define all providers in the container class itself.

**Example:**
In `1.*`, you could define a container and then dynamically set providers:
```python
from that_depends import BaseContainer

class MyContainer(BaseContainer):
pass

MyContainer.my_provider = providers.Resource(some_function)
```
In `2.*`, this will raise an `AttributeError`. Instead, define the provider directly in the container class:
```python
class MyContainer(BaseContainer):
my_provider = providers.Resource(some_function)
```

---

Expand Down
1 change: 1 addition & 0 deletions tests/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class SingletonFactory:

class DIContainer(BaseContainer):
default_scope = None
alias = "test_container"
sync_resource = providers.Resource(create_sync_resource)
async_resource = providers.Resource(create_async_resource)

Expand Down
1 change: 1 addition & 0 deletions tests/integrations/fastapi/test_fastapi_di_pass_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ async def init_di_context(request: fastapi.Request) -> typing.AsyncIterator[None


class DIContainer(BaseContainer):
alias = "fastapi_container"
context_request = providers.Factory(
lambda: fetch_context_item("request"),
)
Expand Down
1 change: 1 addition & 0 deletions tests/integrations/litestar/test_litestar_di.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SomeService:


class DIContainer(BaseContainer):
alias = "litestar_container"
bool_fn = providers.Factory(bool_fn, value=False)
str_fn = providers.Factory(str_fn)
list_fn = providers.Factory(list_fn)
Expand Down
1 change: 1 addition & 0 deletions tests/providers/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


class DIContainer(BaseContainer):
alias = "collection_container"
sync_resource = providers.Resource(create_sync_resource)
async_resource = providers.Resource(create_async_resource)
sequence = providers.List(sync_resource, async_resource)
Expand Down
1 change: 1 addition & 0 deletions tests/providers/test_context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ async def create_async_context_resource() -> typing.AsyncIterator[str]:


class DIContainer(BaseContainer):
alias = "context_resource_container"
default_scope = ContextScopes.ANY
sync_context_resource = providers.ContextResource(create_sync_context_resource)
async_context_resource = providers.ContextResource(create_async_context_resource)
Expand Down
1 change: 1 addition & 0 deletions tests/providers/test_inject_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class InjectedFactories:


class DIContainer(BaseContainer):
alias = "inject_factories_container"
sync_resource = providers.Resource(container.create_sync_resource)
async_resource = providers.Resource(container.create_async_resource)

Expand Down
1 change: 1 addition & 0 deletions tests/providers/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


class DIContainer(BaseContainer):
alias = "object_container"
instance = providers.Object(instance)


Expand Down
1 change: 1 addition & 0 deletions tests/providers/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


class DIContainer(BaseContainer):
alias = "resources_container"
async_resource = providers.Resource(create_async_resource)
sync_resource = providers.Resource(create_sync_resource)
async_resource_from_class = providers.Resource(AsyncContextManagerResource)
Expand Down
1 change: 1 addition & 0 deletions tests/providers/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_selector_state(self) -> typing.Literal["sync_resource", "async_resource"


class DIContainer(BaseContainer):
alias = "selector_container"
sync_resource = providers.Resource(create_sync_resource)
async_resource = providers.Resource(create_async_resource)
selector: providers.Selector[datetime.datetime] = providers.Selector(
Expand Down
1 change: 1 addition & 0 deletions tests/providers/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def _sync_creator_with_dependency(dep: int) -> str:


class DIContainer(BaseContainer):
alias = "singleton_container"
factory: providers.AsyncFactory[int] = providers.AsyncFactory(_async_creator)
settings: Settings = providers.Singleton(Settings).cast
singleton = providers.Singleton(SingletonFactory, dep1=settings.some_setting)
Expand Down
1 change: 1 addition & 0 deletions tests/test_dynamic_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


class DIContainer(BaseContainer):
alias = "dynamic_container"
sync_resource: providers.Resource[datetime.datetime]


Expand Down
128 changes: 128 additions & 0 deletions tests/test_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import datetime
import typing
import warnings
from unittest.mock import Mock

import pytest

from tests import container
from that_depends import BaseContainer, Provide, inject, providers
from that_depends.injection import StringProviderDefinition
from that_depends.providers.context_resources import ContextScopes


Expand Down Expand Up @@ -140,3 +142,129 @@ def _injected(val: int = Provide[_Container.p_inject]) -> int:
def test_inject_decorator_should_not_allow_any_scope() -> None:
with pytest.raises(ValueError, match=f"{ContextScopes.ANY} is not allowed in inject decorator."):
inject(scope=ContextScopes.ANY)


@pytest.mark.parametrize(
("definition", "expected"),
[
("container.provider", ("container", "provider", [])),
("container.provider.attr", ("container", "provider", ["attr"])),
("container.provider.attr1.attr2", ("container", "provider", ["attr1", "attr2"])),
("some.long.container.provider", ("some", "long", ["container", "provider"])),
],
)
def test_validate_and_extract_provider_definition_valid(definition: str, expected: tuple[str, str, list[str]]) -> None:
"""Test valid definitions and ensure the function returns the correct tuple."""
parsed_definition = StringProviderDefinition(definition)
result = parsed_definition._container_name, parsed_definition._provider_name, parsed_definition._attrs
assert result == expected


@pytest.mark.parametrize(
"definition",
[
"",
"container",
".provider",
"container.",
"container..provider",
"container.provider.",
],
)
def test_validate_and_extract_provider_definition_invalid(definition: str) -> None:
"""Test invalid definitions and ensure the function raises ValueError."""
with pytest.raises(ValueError, match=f"Invalid provider definition: {definition}"):
StringProviderDefinition(definition)


async def test_async_injection_with_string_provider_definition() -> None:
return_value = 321321

class _Container(BaseContainer):
async_resource = providers.Factory(lambda: return_value)

@inject
async def _injected(val: int = Provide["_Container.async_resource"]) -> int:
return val

assert await _injected() == return_value


def test_sync_injection_with_string_provider_definition() -> None:
return_value = 312312421

class _Container(BaseContainer):
sync_resource = providers.Factory(lambda: return_value)

@inject
def _injected(val: int = Provide["_Container.sync_resource"]) -> int:
return val

assert _injected() == return_value


def test_provider_string_definition_with_alias() -> None:
return_value = 321

class _Container(BaseContainer):
alias = "ALIAS"
sync_resource = providers.Factory(lambda: return_value)

@inject
def _injected(val: int = Provide["ALIAS.sync_resource"]) -> int:
return val

assert _injected() == return_value


def test_provider_string_definition_with_attr_getter() -> None:
expected_value = 123123
return_value = Mock()
return_value.a = expected_value

class _Container(BaseContainer):
sync_resource = providers.Factory(lambda: return_value)

@inject
def _injected(val: int = Provide["_Container.sync_resource.a"]) -> int:
return val

assert _injected() == expected_value


def test_inject_with_non_existing_container() -> None:
provider_name = "DOESNOTEXIST"

@inject
def _injected(val: int = Provide[f"{provider_name}.provider"]) -> None: ...

with pytest.raises(ValueError, match=f"Container {provider_name} not found in scope!"):
_injected()


def test_inject_with_non_existing_provider() -> None:
container_alias = "EXIST"

class _Container(BaseContainer):
alias = container_alias

provider_name = "DOESNOTEXIST"

@inject
def _injected(val: int = Provide[f"EXIST.{provider_name}"]) -> None: ...

with pytest.raises(ValueError, match=f"Provider {provider_name} not found in container {container_alias}"):
_injected()


def test_provider_resolution_with_string_definition_happens_at_runtime() -> None:
return_value = 321

@inject
def _injected(val: int = Provide["_Container.sync_resource"]) -> int:
return val

class _Container(BaseContainer):
sync_resource = providers.Factory(lambda: return_value)

assert _injected() == return_value
1 change: 1 addition & 0 deletions that_depends/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
class BaseContainer(SupportsContext[None], metaclass=BaseContainerMeta):
"""Base container class."""

alias: str | None = None
providers: dict[str, AbstractProvider[typing.Any]]
containers: list[type["BaseContainer"]]
default_scope: ContextScope | None = ContextScopes.ANY
Expand Down
Loading

0 comments on commit b429870

Please sign in to comment.