diff --git a/rodi/__init__.py b/rodi/__init__.py index ffead96..d367dde 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -20,6 +20,11 @@ get_type_hints, ) +if sys.version_info >= (3, 9): # pragma: no cover + from typing import _no_init_or_replace_init as _no_init +elif sys.version_info >= (3, 8): # pragma: no cover + from typing import _no_init + try: from typing import Protocol except ImportError: # pragma: no cover @@ -579,6 +584,17 @@ def _ignore_class_attribute(self, key: str, value) -> bool: return is_classvar or is_initialized + def _has_default_init(self): + init = getattr(self.concrete_type, "__init__", None) + + if init is object.__init__: + return True + + if sys.version_info >= (3, 8): # pragma: no cover + if init is _no_init: + return True + return False + def _resolve_by_annotations( self, context: ResolutionContext, annotations: Dict[str, Type] ): @@ -605,7 +621,7 @@ def __call__(self, context: ResolutionContext): chain = context.dynamic_chain chain.append(concrete_type) - if getattr(concrete_type, "__init__") is object.__init__: + if self._has_default_init(): annotations = get_type_hints( concrete_type, vars(sys.modules[concrete_type.__module__]), diff --git a/tests/test_services.py b/tests/test_services.py index 9430813..1015049 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -2,6 +2,7 @@ from abc import ABC from dataclasses import dataclass from typing import ( + Any, ClassVar, Dict, Generic, @@ -2474,6 +2475,62 @@ class B: assert isinstance(value, DynamicResolver) +def test_provide_protocol_with_attribute_dependency() -> None: + class P(Protocol): + def foo(self) -> Any: + ... + + class Dependency: + pass + + class Impl(P): + # attribute dependency + dependency: Dependency + + def foo(self) -> Any: + pass + + container = Container() + container.register(Dependency) + container.register(Impl) + + try: + resolved = container.resolve(Impl) + except CannotResolveParameterException as e: + pytest.fail(str(e)) + + assert isinstance(resolved, Impl) + assert isinstance(resolved.dependency, Dependency) + + +def test_provide_protocol_with_init_dependency() -> None: + class P(Protocol): + def foo(self) -> Any: + ... + + class Dependency: + pass + + class Impl(P): + def __init__(self, dependency: Dependency) -> None: + self.dependency = dependency + + def foo(self) -> Any: + pass + + container = Container() + container.register(Dependency) + container.register(Impl) + + try: + resolved = container.resolve(Impl) + except CannotResolveParameterException as e: + pytest.fail(str(e)) + + assert isinstance(resolved, Impl) + assert isinstance(resolved.dependency, Dependency) + + def test_provide_protocol_generic() -> None: T = TypeVar("T") @@ -2500,6 +2557,39 @@ def foo(self, t: A) -> A: assert isinstance(resolved, Impl) +def test_provide_protocol_generic_with_inner_dependency() -> None: + T = TypeVar("T") + + class P(Protocol[T]): + def foo(self, t: T) -> T: + ... + + class A: + ... + + class Dependency: + pass + + class Impl(P[A]): + dependency: Dependency + + def foo(self, t: A) -> A: + return t + + container = Container() + + container.register(Impl) + container.register(Dependency) + + try: + resolved = container.resolve(Impl) + except CannotResolveParameterException as e: + pytest.fail(str(e)) + + assert isinstance(resolved, Impl) + assert isinstance(resolved.dependency, Dependency) + + def test_ignore_class_var(): """ ClassVar attributes must be ignored, because they are not instance attributes.