Skip to content

Commit

Permalink
Resolve protocols without needing to define init (#46)
Browse files Browse the repository at this point in the history
* fix: resolve protocols without needing to define init

* fix: support python 3.8

* fix: support _no_init on python 3.8
  • Loading branch information
lucas-labs authored Nov 25, 2023
1 parent 6043b5b commit 41037c7
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 1 deletion.
18 changes: 17 additions & 1 deletion rodi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
get_type_hints,
)

if sys.version_info >= (3, 9): # pragma: no cover
from typing import _no_init_or_replace_init as _no_init
elif sys.version_info >= (3, 8): # pragma: no cover
from typing import _no_init

try:
from typing import Protocol
except ImportError: # pragma: no cover
Expand Down Expand Up @@ -579,6 +584,17 @@ def _ignore_class_attribute(self, key: str, value) -> bool:

return is_classvar or is_initialized

def _has_default_init(self):
init = getattr(self.concrete_type, "__init__", None)

if init is object.__init__:
return True

if sys.version_info >= (3, 8): # pragma: no cover
if init is _no_init:
return True
return False

def _resolve_by_annotations(
self, context: ResolutionContext, annotations: Dict[str, Type]
):
Expand All @@ -605,7 +621,7 @@ def __call__(self, context: ResolutionContext):
chain = context.dynamic_chain
chain.append(concrete_type)

if getattr(concrete_type, "__init__") is object.__init__:
if self._has_default_init():
annotations = get_type_hints(
concrete_type,
vars(sys.modules[concrete_type.__module__]),
Expand Down
90 changes: 90 additions & 0 deletions tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC
from dataclasses import dataclass
from typing import (
Any,
ClassVar,
Dict,
Generic,
Expand Down Expand Up @@ -2474,6 +2475,62 @@ class B:
assert isinstance(value, DynamicResolver)


def test_provide_protocol_with_attribute_dependency() -> None:
class P(Protocol):
def foo(self) -> Any:
...

class Dependency:
pass

class Impl(P):
# attribute dependency
dependency: Dependency

def foo(self) -> Any:
pass

container = Container()
container.register(Dependency)
container.register(Impl)

try:
resolved = container.resolve(Impl)
except CannotResolveParameterException as e:
pytest.fail(str(e))

assert isinstance(resolved, Impl)
assert isinstance(resolved.dependency, Dependency)


def test_provide_protocol_with_init_dependency() -> None:
class P(Protocol):
def foo(self) -> Any:
...

class Dependency:
pass

class Impl(P):
def __init__(self, dependency: Dependency) -> None:
self.dependency = dependency

def foo(self) -> Any:
pass

container = Container()
container.register(Dependency)
container.register(Impl)

try:
resolved = container.resolve(Impl)
except CannotResolveParameterException as e:
pytest.fail(str(e))

assert isinstance(resolved, Impl)
assert isinstance(resolved.dependency, Dependency)


def test_provide_protocol_generic() -> None:
T = TypeVar("T")

Expand All @@ -2500,6 +2557,39 @@ def foo(self, t: A) -> A:
assert isinstance(resolved, Impl)


def test_provide_protocol_generic_with_inner_dependency() -> None:
T = TypeVar("T")

class P(Protocol[T]):
def foo(self, t: T) -> T:
...

class A:
...

class Dependency:
pass

class Impl(P[A]):
dependency: Dependency

def foo(self, t: A) -> A:
return t

container = Container()

container.register(Impl)
container.register(Dependency)

try:
resolved = container.resolve(Impl)
except CannotResolveParameterException as e:
pytest.fail(str(e))

assert isinstance(resolved, Impl)
assert isinstance(resolved.dependency, Dependency)


def test_ignore_class_var():
"""
ClassVar attributes must be ignored, because they are not instance attributes.
Expand Down

0 comments on commit 41037c7

Please sign in to comment.