Skip to content

Commit

Permalink
30 configuration provider (#43)
Browse files Browse the repository at this point in the history
* refactor tests

* add attr getter for singleton

* add __slots__ to providers

* add docs on application settings
  • Loading branch information
lesnik512 authored Jun 22, 2024
1 parent c0683ea commit 28da9b3
Show file tree
Hide file tree
Showing 20 changed files with 234 additions and 79 deletions.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
introduction/inject-factories
introduction/multiple-containers
introduction/dynamic-container
introduction/application-settings
.. toctree::
:maxdepth: 1
Expand Down
24 changes: 24 additions & 0 deletions docs/introduction/application-settings.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Application settings
For example, you have application settings in `pydantic_settings`
```python
import pydantic_settings


class Settings(pydantic_settings.BaseSettings):
service_name: str = "FastAPI template"
debug: bool = False
...
```

You can register settings as `Singleton` in DI container

```python
from that_depends import BaseContainer, providers


class DIContainer(BaseContainer):
settings: Settings = providers.Singleton(Settings).cast
some_factory = providers.Factory(SomeFactory, service_name=settings.service_name)
```

And when `some_factory` is resolved it will receive `service_name` attribute from `Settings`
2 changes: 1 addition & 1 deletion docs/providers/singleton.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ class SingletonFactory:

class DIContainer(BaseContainer):
singleton = providers.Singleton(SingletonFactory, dep1=True)
```
```
9 changes: 0 additions & 9 deletions tests/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

logger = logging.getLogger(__name__)

global_state_for_selector: typing.Literal["sync_resource", "async_resource", "missing"] = "sync_resource"


def create_sync_resource() -> typing.Iterator[datetime.datetime]:
logger.debug("Resource initiated")
Expand Down Expand Up @@ -58,13 +56,6 @@ class SingletonFactory:
class DIContainer(BaseContainer):
sync_resource = providers.Resource(create_sync_resource)
async_resource = providers.AsyncResource(create_async_resource)
sequence = providers.List(sync_resource, async_resource)
mapping = providers.Dict(sync_resource=sync_resource, async_resource=async_resource)
selector: providers.Selector[datetime.datetime] = providers.Selector(
lambda: global_state_for_selector,
sync_resource=sync_resource,
async_resource=async_resource,
)

simple_factory = providers.Factory(SimpleFactory, dep1="text", dep2=123)
async_factory = providers.AsyncFactory(async_factory, async_resource.cast)
Expand Down
48 changes: 48 additions & 0 deletions tests/providers/test_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import typing

import pytest

from tests.container import create_async_resource, create_sync_resource
from that_depends import BaseContainer, providers


class DIContainer(BaseContainer):
sync_resource = providers.Resource(create_sync_resource)
async_resource = providers.AsyncResource(create_async_resource)
sequence = providers.List(sync_resource, async_resource)
mapping = providers.Dict(sync_resource=sync_resource, async_resource=async_resource)


@pytest.fixture(autouse=True)
async def _clear_di_container() -> typing.AsyncIterator[None]:
try:
yield
finally:
await DIContainer.tear_down()


async def test_list_provider() -> None:
sequence = await DIContainer.sequence()
sync_resource = await DIContainer.sync_resource()
async_resource = await DIContainer.async_resource()

assert sequence == [sync_resource, async_resource]


def test_list_failed_sync_resolve() -> None:
with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
DIContainer.sequence.sync_resolve()


async def test_list_sync_resolve_after_init() -> None:
await DIContainer.init_async_resources()
DIContainer.sequence.sync_resolve()


async def test_dict_provider() -> None:
mapping = await DIContainer.mapping()
sync_resource = await DIContainer.sync_resource()
async_resource = await DIContainer.async_resource()

assert mapping == {"sync_resource": sync_resource, "async_resource": async_resource}
assert mapping == DIContainer.mapping.sync_resolve()
18 changes: 0 additions & 18 deletions tests/providers/test_collections_providers.py

This file was deleted.

8 changes: 8 additions & 0 deletions tests/providers/test_context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ class DIContainer(BaseContainer):
async_context_resource = providers.AsyncContextResource(create_async_context_resource)


@pytest.fixture(autouse=True)
async def _clear_di_container() -> typing.AsyncIterator[None]:
try:
yield
finally:
await DIContainer.tear_down()


@pytest.fixture(params=[DIContainer.sync_context_resource, DIContainer.async_context_resource])
def context_resource(request: pytest.FixtureRequest) -> providers.AbstractResource[typing.Any]:
return typing.cast(providers.AbstractResource[typing.Any], request.param)
Expand Down
19 changes: 0 additions & 19 deletions tests/providers/test_main_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,6 @@ def test_failed_sync_resolve() -> None:
with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
DIContainer.async_resource.sync_resolve()

with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
DIContainer.sequence.sync_resolve()


async def test_sync_resolve_after_init() -> None:
await DIContainer.init_async_resources()
DIContainer.sequence.sync_resolve()


async def test_singleton_provider() -> None:
singleton1 = await DIContainer.singleton()
singleton2 = await DIContainer.singleton()
singleton3 = DIContainer.singleton.sync_resolve()
await DIContainer.singleton.tear_down()
singleton4 = DIContainer.singleton.sync_resolve()

assert singleton1 is singleton2 is singleton3
assert singleton4 is not singleton1


def test_wrong_providers_init() -> None:
with pytest.raises(RuntimeError, match="Resource must be generator function"):
Expand Down
62 changes: 62 additions & 0 deletions tests/providers/test_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import datetime
import logging
import typing

import pytest

from tests.container import create_async_resource, create_sync_resource
from that_depends import BaseContainer, providers


logger = logging.getLogger(__name__)

global_state_for_selector: typing.Literal["sync_resource", "async_resource", "missing"] = "sync_resource"


class SelectorState:
def __init__(self) -> None:
self.selector_state: typing.Literal["sync_resource", "async_resource", "missing"] = "sync_resource"

def get_selector_state(self) -> typing.Literal["sync_resource", "async_resource", "missing"]:
return self.selector_state


selector_state = SelectorState()


class DIContainer(BaseContainer):
sync_resource = providers.Resource(create_sync_resource)
async_resource = providers.AsyncResource(create_async_resource)
selector: providers.Selector[datetime.datetime] = providers.Selector(
selector_state.get_selector_state,
sync_resource=sync_resource,
async_resource=async_resource,
)


async def test_selector_provider_async() -> None:
selector_state.selector_state = "async_resource"
selected = await DIContainer.selector()
async_resource = await DIContainer.async_resource()

assert selected == async_resource


async def test_selector_provider_async_missing() -> None:
selector_state.selector_state = "missing"
with pytest.raises(RuntimeError, match="No provider matches"):
await DIContainer.selector()


async def test_selector_provider_sync() -> None:
selector_state.selector_state = "sync_resource"
selected = DIContainer.selector.sync_resolve()
sync_resource = DIContainer.sync_resource.sync_resolve()

assert selected == sync_resource


async def test_selector_provider_sync_missing() -> None:
selector_state.selector_state = "missing"
with pytest.raises(RuntimeError, match="No provider matches"):
DIContainer.selector.sync_resolve()
32 changes: 0 additions & 32 deletions tests/providers/test_selector_providers.py

This file was deleted.

41 changes: 41 additions & 0 deletions tests/providers/test_singleton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import dataclasses

import pydantic

from that_depends import BaseContainer, providers


@dataclasses.dataclass(kw_only=True, slots=True)
class SingletonFactory:
dep1: str


class Settings(pydantic.BaseModel):
some_setting: str = "some_value"
other_setting: str = "other_value"


class DIContainer(BaseContainer):
settings: Settings = providers.Singleton(Settings).cast
singleton = providers.Singleton(SingletonFactory, dep1=settings.some_setting)


async def test_singleton_provider() -> None:
singleton1 = await DIContainer.singleton()
singleton2 = await DIContainer.singleton()
singleton3 = DIContainer.singleton.sync_resolve()
await DIContainer.singleton.tear_down()
singleton4 = DIContainer.singleton.sync_resolve()

assert singleton1 is singleton2 is singleton3
assert singleton4 is not singleton1

await DIContainer.tear_down()


async def test_singleton_attr_getter() -> None:
singleton1 = await DIContainer.singleton()

assert singleton1.dep1 == Settings().some_setting

await DIContainer.tear_down()
2 changes: 2 additions & 0 deletions tests/test_dynamic_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ async def test_dynamic_container() -> None:

assert isinstance(sync_resource, datetime.datetime)
assert isinstance(async_resource, datetime.datetime)

await DIContainer.tear_down()
2 changes: 2 additions & 0 deletions that_depends/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from that_depends.providers.attr_getter import AttrGetter
from that_depends.providers.base import AbstractProvider, AbstractResource
from that_depends.providers.collections import Dict, List
from that_depends.providers.context_resources import (
Expand All @@ -18,6 +19,7 @@
"AsyncContextResource",
"AsyncFactory",
"AsyncResource",
"AttrGetter",
"ContextResource",
"DIContextMiddleware",
"Factory",
Expand Down
21 changes: 21 additions & 0 deletions that_depends/providers/attr_getter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import typing

from that_depends.providers.base import AbstractProvider


T = typing.TypeVar("T")
P = typing.ParamSpec("P")


class AttrGetter(AbstractProvider[T]):
__slots__ = "_provider", "_attr_name"

def __init__(self, provider: AbstractProvider[T], attr_name: str) -> None:
self._provider = provider
self._attr_name = attr_name

async def async_resolve(self) -> typing.Any: # noqa: ANN401
return getattr(await self._provider.async_resolve(), self._attr_name)

def sync_resolve(self) -> typing.Any: # noqa: ANN401
return getattr(self._provider.sync_resolve(), self._attr_name)
4 changes: 4 additions & 0 deletions that_depends/providers/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@


class List(AbstractProvider[list[T]]):
__slots__ = ("_providers",)

def __init__(self, *providers: AbstractProvider[T]) -> None:
self._providers = providers

Expand All @@ -21,6 +23,8 @@ async def __call__(self) -> list[T]:


class Dict(AbstractProvider[dict[str, T]]):
__slots__ = ("_providers",)

def __init__(self, **providers: AbstractProvider[T]) -> None:
self._providers = providers

Expand Down
4 changes: 4 additions & 0 deletions that_depends/providers/context_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def _get_context() -> dict[str, AbstractResource[typing.Any]]:


class ContextResource(AbstractProvider[T]):
__slots__ = "_creator", "_args", "_kwargs", "_override", "_internal_name"

def __init__(
self,
creator: typing.Callable[P, typing.Iterator[T]],
Expand Down Expand Up @@ -94,6 +96,8 @@ def sync_resolve(self) -> T:


class AsyncContextResource(AbstractProvider[T]):
__slots__ = "_creator", "_args", "_kwargs", "_override", "_internal_name"

def __init__(
self,
creator: typing.Callable[P, typing.AsyncIterator[T]],
Expand Down
Loading

0 comments on commit 28da9b3

Please sign in to comment.