Skip to content

Commit

Permalink
inject factories (#36)
Browse files Browse the repository at this point in the history
* inject factories

* docs on injecting factories
  • Loading branch information
lesnik512 authored Jun 15, 2024
1 parent 634fee4 commit fa7044e
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 79 deletions.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
introduction/fastapi
introduction/litestar
introduction/multiple-containers
introduction/inject-factories
.. toctree::
:maxdepth: 1
Expand Down
56 changes: 56 additions & 0 deletions docs/introduction/inject-factories.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Injecting factories

When you need to inject the factory itself, but not the result of its call, use:
1. `.provider` attribute for async resolver
2. `.sync_provider` attribute for sync resolver

Let's first define providers with container:
```python
import dataclasses
import datetime
import typing

from that_depends import BaseContainer, providers


async def create_async_resource() -> typing.AsyncIterator[datetime.datetime]:
yield datetime.datetime.now(tz=datetime.timezone.utc)


@dataclasses.dataclass(kw_only=True, slots=True)
class SomeFactory:
start_at: datetime.datetime


@dataclasses.dataclass(kw_only=True, slots=True)
class FactoryWithFactories:
sync_factory: typing.Callable[..., SomeFactory]
async_factory: typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, SomeFactory]]


class DIContainer(BaseContainer):
async_resource = providers.AsyncResource(create_async_resource)
dependent_factory = providers.Factory(SomeFactory, start_at=async_resource.cast)
factory_with_factories = providers.Factory(
FactoryWithFactories,
sync_factory=dependent_factory.sync_provider,
async_factory=dependent_factory.provider,
)
```

Async factory from `.provider` attribute can be used like this:
```python
factory_with_factories = await DIContainer.factory_with_factories()
instance1 = await factory_with_factories.async_factory()
instance2 = await factory_with_factories.async_factory()
assert instance1 is not instance2
```

Sync factory from `.sync_provider` attribute can be used like this:
```python
await DIContainer.init_async_resources()
factory_with_factories = await DIContainer.factory_with_factories()
instance1 = factory_with_factories.sync_factory()
instance2 = factory_with_factories.sync_factory()
assert instance1 is not instance2
```
150 changes: 75 additions & 75 deletions poetry.lock

Large diffs are not rendered by default.

Empty file added tests/integrations/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
59 changes: 59 additions & 0 deletions tests/test_inject_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import dataclasses
import typing

import pytest

from tests import container
from that_depends import BaseContainer, providers


@dataclasses.dataclass(kw_only=True, slots=True)
class InjectedFactories:
sync_factory: typing.Callable[..., container.DependentFactory]
async_factory: typing.Callable[..., typing.Coroutine[typing.Any, typing.Any, container.DependentFactory]]


class DIContainer(BaseContainer):
sync_resource = providers.Resource(container.create_sync_resource)
async_resource = providers.AsyncResource(container.create_async_resource)

simple_factory = providers.Factory(container.SimpleFactory, dep1="text", dep2=123)
dependent_factory = providers.Factory(
container.DependentFactory,
simple_factory=simple_factory.cast,
sync_resource=sync_resource.cast,
async_resource=async_resource.cast,
)
injected_factories = providers.Factory(
InjectedFactories,
sync_factory=dependent_factory.sync_provider,
async_factory=dependent_factory.provider,
)


async def test_async_provider() -> None:
injected_factories = await DIContainer.injected_factories()
instance1 = await injected_factories.async_factory()
instance2 = await injected_factories.async_factory()

assert isinstance(instance1, container.DependentFactory)
assert isinstance(instance2, container.DependentFactory)
assert instance1 is not instance2

await DIContainer.tear_down()


async def test_sync_provider() -> None:
injected_factories = await DIContainer.injected_factories()
with pytest.raises(RuntimeError, match="AsyncResource cannot be resolved synchronously"):
injected_factories.sync_factory()

await DIContainer.init_async_resources()
instance1 = injected_factories.sync_factory()
instance2 = injected_factories.sync_factory()

assert isinstance(instance1, container.DependentFactory)
assert isinstance(instance2, container.DependentFactory)
assert instance1 is not instance2

await DIContainer.tear_down()
12 changes: 12 additions & 0 deletions that_depends/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,15 @@ class AbstractResource(AbstractProvider[T], abc.ABC):
@abc.abstractmethod
async def tear_down(self) -> None:
"""Tear down dependency."""


class AbstractFactory(AbstractProvider[T], abc.ABC):
"""Abstract Factory Class."""

@property
def provider(self) -> typing.Callable[[], typing.Coroutine[typing.Any, typing.Any, T]]:
return self.async_resolve

@property
def sync_provider(self) -> typing.Callable[[], T]:
return self.sync_resolve
8 changes: 4 additions & 4 deletions that_depends/providers/factories.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import typing

from that_depends.providers.base import AbstractProvider
from that_depends.providers.base import AbstractFactory, AbstractProvider


T = typing.TypeVar("T")
P = typing.ParamSpec("P")


class Factory(AbstractProvider[T]):
class Factory(AbstractFactory[T]):
def __init__(self, factory: type[T] | typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None:
self._factory = factory
self._args = args
Expand All @@ -33,7 +33,7 @@ def sync_resolve(self) -> T:
)


class AsyncFactory(AbstractProvider[T]):
class AsyncFactory(AbstractFactory[T]):
def __init__(self, factory: typing.Callable[P, typing.Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> None:
self._factory = factory
self._args = args
Expand All @@ -49,6 +49,6 @@ async def async_resolve(self) -> T:
**{k: await v.async_resolve() if isinstance(v, AbstractProvider) else v for k, v in self._kwargs.items()},
)

def sync_resolve(self) -> T:
def sync_resolve(self) -> typing.NoReturn:
msg = "AsyncFactory cannot be resolved synchronously"
raise RuntimeError(msg)

0 comments on commit fa7044e

Please sign in to comment.