diff --git a/rodi/__init__.py b/rodi/__init__.py index 7f27091..b240411 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -509,7 +509,7 @@ def _get_resolvers_for_parameters( # but at least Optional could be supported in the future raise UnsupportedUnionTypeException(param_name, concrete_type) - if param_type is _empty: + if param_type is _empty or param_type not in services._map: if services.strict: raise CannotResolveParameterException(param_name, concrete_type) @@ -521,6 +521,14 @@ def _get_resolvers_for_parameters( else: aliases = services._aliases[param_name] + if not aliases: + cls_name = class_name(param_type) + aliases = ( + services._aliases[cls_name] + or services._aliases[cls_name.lower()] + or services._aliases[to_standard_param_name(cls_name)] + ) + if aliases: assert ( len(aliases) == 1 @@ -736,6 +744,13 @@ def get( scope = ActivationScope(self) resolver = self._map.get(desired_type) + if not resolver: + cls_name = class_name(desired_type) + resolver = ( + self._map.get(cls_name) + or self._map.get(cls_name.lower()) + or self._map.get(to_standard_param_name(cls_name)) + ) scoped_service = scope.scoped_services.get(desired_type) if scope else None if not resolver and not scoped_service: diff --git a/tests/test_fn_exec.py b/tests/test_fn_exec.py index 541e800..91ff62e 100644 --- a/tests/test_fn_exec.py +++ b/tests/test_fn_exec.py @@ -2,6 +2,7 @@ Functions exec tests. exec functions are designed to enable executing any function injecting parameters. """ + import pytest from rodi import Container, inject diff --git a/tests/test_services.py b/tests/test_services.py index 1015049..942b5b2 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -687,6 +687,36 @@ def __init__(self, cats_controller, service_settings): assert isinstance(u.cats_controller.cat_request_handler, GetCatRequestHandler) +def test_alias_dep_resolving(): + container = arrange_cats_example() + + class BaseClass: + pass + + class DerivedClass(BaseClass): + pass + + class UsingAliasByType: + def __init__(self, example: BaseClass): + self.example = example + + def resolve_derived_class(_) -> DerivedClass: + return DerivedClass() + + container.add_scoped_by_factory(resolve_derived_class, DerivedClass) + container.add_alias("BaseClass", DerivedClass) + container.add_scoped(UsingAliasByType) + + provider = container.build_provider() + u = provider.get(UsingAliasByType) + + assert isinstance(u, UsingAliasByType) + assert isinstance(u.example, DerivedClass) + + b = provider.get(BaseClass) + assert isinstance(b, DerivedClass) + + def test_get_service_by_name_or_alias(): container = arrange_cats_example() container.add_alias("k", CatsController) @@ -2323,7 +2353,7 @@ def factory() -> annotation: def test_factory_without_locals_raises(): def factory_without_context() -> None: - ... + pass with pytest.raises(FactoryMissingContextException): _get_factory_annotations_or_throw(factory_without_context) @@ -2332,7 +2362,7 @@ def factory_without_context() -> None: def test_factory_with_locals_get_annotations(): @inject() def factory_without_context() -> "Cat": - ... + pass annotations = _get_factory_annotations_or_throw(factory_without_context) @@ -2350,20 +2380,20 @@ def test_deps_github_scenario(): """ class HTTPClient: - ... + pass class CommentsService: - ... + pass class ChecksService: - ... + pass class CLAHandler: comments_service: CommentsService checks_service: ChecksService class GitHubSettings: - ... + pass class GitHubAuthHandler: settings: GitHubSettings @@ -2478,7 +2508,7 @@ class B: def test_provide_protocol_with_attribute_dependency() -> None: class P(Protocol): def foo(self) -> Any: - ... + pass class Dependency: pass @@ -2506,7 +2536,7 @@ def foo(self) -> Any: def test_provide_protocol_with_init_dependency() -> None: class P(Protocol): def foo(self) -> Any: - ... + pass class Dependency: pass @@ -2536,10 +2566,10 @@ def test_provide_protocol_generic() -> None: class P(Protocol[T]): def foo(self, t: T) -> T: - ... + pass class A: - ... + pass class Impl(P[A]): def foo(self, t: A) -> A: @@ -2562,10 +2592,10 @@ def test_provide_protocol_generic_with_inner_dependency() -> None: class P(Protocol[T]): def foo(self, t: T) -> T: - ... + pass class A: - ... + pass class Dependency: pass