Skip to content

Commit

Permalink
fix: config loading after module reload
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Mar 3, 2025
1 parent 75ef7e4 commit 6820700
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,4 @@

- Updated documentation to include new custom exceptions.
- Improved the use of Pydantic for input data validation for retriever objects.
- Fixed config loading after module reload (usage in jupyter notebooks)
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/experimental/pipeline/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pydantic import BaseModel

from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.utils.validation import issubclass_safe


class DataModel(BaseModel):
Expand Down Expand Up @@ -52,7 +53,7 @@ def __new__(
f"The run method return type must be annotated in {name}"
)
# the type hint must be a subclass of DataModel
if not issubclass(return_model, DataModel):
if not issubclass_safe(return_model, DataModel):
raise PipelineDefinitionError(
f"The run method must return a subclass of DataModel in {name}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
ParamConfig,
)
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.utils.validation import issubclass_safe


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -131,9 +133,9 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> T:
self._global_data = resolved_data or {}
logger.debug(f"OBJECT_CONFIG: parsing {self} using {resolved_data}")
if self.class_ is None:
raise ValueError(f"`class_` is not required to parse object {self}")
raise ValueError(f"`class_` is required to parse object {self}")
klass = self._get_class(self.class_, self.get_module())
if not issubclass(klass, self.get_interface()):
if not issubclass_safe(klass, self.get_interface()):
raise ValueError(
f"Invalid class '{klass}'. Expected a subclass of '{self.get_interface()}'"
)
Expand Down
20 changes: 19 additions & 1 deletion src/neo4j_graphrag/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,29 @@
# limitations under the License.
from __future__ import annotations

from typing import Optional
import importlib
from typing import Optional, Tuple, Type, Union


def validate_search_query_input(
query_text: Optional[str] = None, query_vector: Optional[list[float]] = None
) -> None:
if not (bool(query_vector) ^ bool(query_text)):
raise ValueError("You must provide exactly one of query_vector or query_text.")


def issubclass_safe(cls: Type, class_or_tuple: Union[Type, Tuple[Type]]) -> bool:
if isinstance(class_or_tuple, tuple):
return any(issubclass_safe(cls, base) for base in class_or_tuple)

if issubclass(cls, class_or_tuple):
return True

# Handle case where module was reloaded
cls_module = importlib.import_module(cls.__module__)
# Get the latest version of the base class from the module
latest_base = getattr(cls_module, class_or_tuple.__name__, None)
if issubclass(cls, latest_base):
return True

return False
30 changes: 30 additions & 0 deletions tests/unit/experimental/pipeline/config/test_object_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import sys
from abc import ABC
from typing import ClassVar
from unittest.mock import patch

import neo4j
Expand Down Expand Up @@ -58,6 +62,32 @@ def test_get_class_wrong_path() -> None:
c._get_class("MyClass")


class _MyClass:
def __init__(self, param: str) -> None:
self.param = param


class _MyInterface(ABC): ...


def test_parse_after_module_reload() -> None:
class MyClassConfig(ObjectConfig[_MyClass]):
DEFAULT_MODULE: ClassVar[str] = __name__
INTERFACE: ClassVar[type] = _MyClass

param_value = "value"
config = {
"class_": f"{__name__}.{_MyClass.__name__}",
"params_": {"param": param_value},
}
config = MyClassConfig(**config)
importlib.reload(sys.modules[__name__])

my_obj = config.parse()
assert isinstance(my_obj, _MyClass)
assert my_obj.param == param_value


def test_neo4j_driver_config() -> None:
config = Neo4jDriverConfig.model_validate(
{
Expand Down

0 comments on commit 6820700

Please sign in to comment.