Skip to content

Commit

Permalink
Implement model equality (#2057)
Browse files Browse the repository at this point in the history
Fixes #928
Partially solves Deltares/Ribasim-NL#229

Differences, and such will follow in a separate PR.

---------

Co-authored-by: Martijn Visser <mgvisser@gmail.com>
  • Loading branch information
evetion and visr authored Feb 13, 2025
1 parent 7f083a2 commit eb81f0e
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 3 deletions.
70 changes: 68 additions & 2 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import operator
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
Expand All @@ -15,6 +16,7 @@
import geopandas as gpd
import numpy as np
import pandas as pd
import pydantic
from pandera.typing import DataFrame
from pandera.typing.geopandas import GeoDataFrame
from pydantic import BaseModel as PydanticBaseModel
Expand All @@ -23,6 +25,7 @@
DirectoryPath,
Field,
PrivateAttr,
SerializationInfo,
ValidationInfo,
field_validator,
model_serializer,
Expand Down Expand Up @@ -90,6 +93,54 @@ def _fields(cls) -> list[str]:
"""Return the names of the fields contained in the Model."""
return list(cls.model_fields.keys())

def model_dump(self, **kwargs) -> dict[str, Any]:
return super().model_dump(serialize_as_any=True, **kwargs)

# __eq__ from Pydantic BaseModel itself, edited to remove the comparison of private attrs
# https://github.com/pydantic/pydantic/blob/ff3789d4cc06ee024b7253b919d3e36748a72829/pydantic/main.py#L1069
# The MIT License (MIT) | Copyright (c) 2017 to present Pydantic Services Inc. and individual contributors.
def __eq__(self, other: Any) -> bool:
if isinstance(other, BaseModel):
self_type = self.__pydantic_generic_metadata__["origin"] or self.__class__
other_type = (
other.__pydantic_generic_metadata__["origin"] or other.__class__
)

if not (
self_type == other_type
# This comparison has been removed, otherwise we recurse because
# we store the parent of the model in a private attribute
# and getattr(self, "__pydantic_private__", None)
# == getattr(other, "__pydantic_private__", None)
and self.__pydantic_extra__ == other.__pydantic_extra__
):
return False

if self.__dict__ == other.__dict__:
return True

model_fields = type(self).__pydantic_fields__.keys()
if (
self.__dict__.keys() <= model_fields
and other.__dict__.keys() <= model_fields
):
return False

getter = (
operator.itemgetter(*model_fields)
if model_fields
else lambda _: pydantic._utils._SENTINEL # type: ignore
)
try:
return getter(self.__dict__) == getter(other.__dict__)
except KeyError:
self_fields_proxy = pydantic._utils.SafeGetItemProxy(self.__dict__) # type: ignore
other_fields_proxy = pydantic._utils.SafeGetItemProxy(other.__dict__) # type: ignore
return getter(self_fields_proxy) == getter(other_fields_proxy)

else:
return NotImplemented


class FileModel(BaseModel, ABC):
"""Base class to represent models with a file representation.
Expand Down Expand Up @@ -165,6 +216,17 @@ class TableModel(FileModel, Generic[TableT]):
df: DataFrame[TableT] | None = Field(default=None, exclude=True, repr=False)
_sort_keys: list[str] = PrivateAttr(default=[])

def __eq__(self, other: Any) -> bool:
if isinstance(other, TableModel):
if self.df is None and other.df is None:
return True
if self.df is None or other.df is None:
return False
else:
return self.df.equals(other.df)

return NotImplemented

@field_validator("df")
@classmethod
def _check_schema(cls, v: DataFrame[TableT]):
Expand All @@ -184,8 +246,12 @@ def _check_schema(cls, v: DataFrame[TableT]):
return v

@model_serializer
def _set_model(self) -> str | None:
return str(self.filepath.name) if self.filepath is not None else None
def _set_model(self, info: SerializationInfo) -> "str | TableModel[TableT] | None":
# When writing, only return the filename.
if info.context == "write":
return str(self.filepath.name) if self.filepath is not None else None
else:
return self

@classmethod
def tablename(cls) -> str:
Expand Down
4 changes: 3 additions & 1 deletion python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def _write_toml(self, fn: Path) -> Path:
Path
The file path of the written TOML file.
"""
content = self.model_dump(exclude_unset=True, exclude_none=True, by_alias=True)
content = self.model_dump(
exclude_unset=True, exclude_none=True, by_alias=True, context="write"
)
# Filter empty dicts (default Nodes)
content = dict(filter(lambda x: x[1], content.items()))
content["ribasim_version"] = ribasim.__version__
Expand Down
26 changes: 26 additions & 0 deletions python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ribasim.geometry.link import NodeData
from ribasim.input_base import esc_id
from ribasim.model import Model
from ribasim.nodes import basin
from ribasim_testmodels import (
basic_model,
outlet_model,
Expand Down Expand Up @@ -245,3 +246,28 @@ def test_non_existent_files(tmp_path):

with pytest.raises(FileNotFoundError, match=r"Database file .* does not exist\."):
Model.read(toml_path)


def test_model_equals(basic):
nbasic = basic.model_copy(deep=True)

assert nbasic.basin.static == basic.basin.static
assert nbasic.basin == basic.basin
assert nbasic == basic

nbasic.solver.saveat = 0
assert nbasic.solver.saveat != basic.solver.saveat
assert nbasic.solver != basic.solver
assert nbasic.basin == basic.basin
assert nbasic != basic

nbasic.solver.saveat = basic.solver.saveat
nbasic.basin.add(
Node(None, Point(-1.5, -1), name="confluence"),
[
basin.Static(precipitation=[4]),
],
)
assert nbasic.basin.static != basic.basin.static
assert nbasic.basin != basic.basin
assert nbasic != basic

0 comments on commit eb81f0e

Please sign in to comment.