diff --git a/rodi/__init__.py b/rodi/__init__.py index 21f4f0e..ffead96 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -475,7 +475,12 @@ def _get_resolver(self, desired_type, context: ResolutionContext): assert ( reg is not None ), f"A resolver for type {class_name(desired_type)} is not configured" - return reg(context) + resolver = reg(context) + + # add the resolver to the context, so we can find it + # next time we need it + context.resolved[desired_type] = resolver + return resolver def _get_resolvers_for_parameters( self, @@ -567,11 +572,12 @@ def _ignore_class_attribute(self, key: str, value) -> bool: """ Returns a value indicating whether a class attribute should be ignored for dependency resolution, by name and value. + It's ignored if it's a ClassVar or if it's already initialized explicitly. """ - try: - return value.__origin__ is ClassVar - except AttributeError: - return False + is_classvar = getattr(value, "__origin__", None) is ClassVar + is_initialized = getattr(self.concrete_type, key, None) is not None + + return is_classvar or is_initialized def _resolve_by_annotations( self, context: ResolutionContext, annotations: Dict[str, Type] @@ -1146,14 +1152,24 @@ def build_provider(self) -> Services: _map: Dict[Union[str, Type], Type] = {} for _type, resolver in self._map.items(): - # NB: do not call resolver if one was already prepared for the type - assert _type not in context.resolved, "_map keys must be unique" - if isinstance(resolver, DynamicResolver): context.dynamic_chain.clear() - _map[_type] = resolver(context) - context.resolved[_type] = _map[_type] + if _type in context.resolved: + # assert _type not in context.resolved, "_map keys must be unique" + # check if its in the map + if _type in _map: + # NB: do not call resolver if one was already prepared for the + # type + raise OverridingServiceException(_type, resolver) + else: + resolved = context.resolved[_type] + else: + # add to context so that we don't repeat operations + resolved = resolver(context) + context.resolved[_type] = resolved + + _map[_type] = resolved type_name = class_name(_type) if "." not in type_name: diff --git a/tests/test_examples.py b/tests/test_examples.py index acaae27..0a7b3c6 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -12,6 +12,11 @@ @pytest.mark.parametrize("file_path", examples) def test_example(file_path: str): - module_name = file_path.replace("./examples/", "").replace(".py", "") + module_name = ( + # Windows + file_path.replace("./examples\\", "") + # Unix + .replace("./examples/", "").replace(".py", "") + ) # assertions are in imported modules importlib.import_module(module_name) diff --git a/tests/test_services.py b/tests/test_services.py index f93fc69..9430813 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -2539,3 +2539,92 @@ class A: a = container.resolve(A) assert a.foo == "foo" + + +def test_singleton_register_order_last(): + """ + The registration order of singletons should not matter. + Check that singletons are not registered twice when they are registered + after their dependents. + """ + + class Bar: + foo: Foo + + class Bar2: + foo: Foo + + container = Container() + container.register(Bar) + container.register(Bar2) + container._add_exact_singleton(Foo) + + bar = container.resolve(Bar) + bar2 = container.resolve(Bar2) + foo = container.resolve(Foo) + + # check that singletons are always the same instance + assert bar.foo is bar2.foo is foo + + +def test_singleton_register_order_first(): + """ + The registration order of singletons should not matter. + Check that singletons are not registered twice when they are registered + before their dependents. + """ + + class Bar: + foo: Foo + + class Bar2: + foo: Foo + + container = Container() + container._add_exact_singleton(Foo) + container.register(Bar) + container.register(Bar2) + + bar = container.resolve(Bar) + bar2 = container.resolve(Bar2) + foo = container.resolve(Foo) + + # check that singletons are always the same instance + assert bar.foo is bar2.foo is foo + + +def test_ignore_class_variable_if_already_initialized(): + """ + if a class variable is already initialized, it should not be overridden by + resolving a new instance nor fail if rodi can't resolve it. + """ + + foo_instance = Foo() + + class A: + foo: Foo = foo_instance + + class B: + example: ClassVar[str] = "example" + dependency: A + + container = Container() + + container.register(A) + container.register(B) + container._add_exact_singleton(Foo) + + b = container.resolve(B) + a = container.resolve(A) + foo = container.resolve(Foo) + + assert isinstance(a, A) + assert isinstance(a.foo, Foo) + assert foo_instance is a.foo + + assert isinstance(b, B) + assert b.example == "example" + assert b.dependency.foo is foo_instance + + # check that is not being overridden by resolving a new instance + assert foo is not a.foo