From 565904a4f7911463c4141095010e7b6b43b1285a Mon Sep 17 00:00:00 2001 From: Suvayu Ali Date: Wed, 10 Jul 2024 16:33:51 +0200 Subject: [PATCH] parser.py: fix merge_assets, bug fixes, refactor, add tests - silence SyntaxWarning for match string that's actually a regex in test suite --- pyproject.toml | 1 + src/esdl4tulipa/mapping.py | 8 +- src/esdl4tulipa/parser.py | 168 +++++++++++++++++++++------------- tests/test_parser.py | 180 ++++++++++++++++++++++++++++++------- 4 files changed, 259 insertions(+), 98 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd86b8f..06129a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ version_scheme = "release-branch-semver" [tool.pytest.ini_options] testpaths = ["tests"] addopts = ["--import-mode=importlib", "--cov=src", "-q"] +filterwarnings = ['ignore:invalid escape sequence:SyntaxWarning'] [tool.black] include = '\.pyi?$' diff --git a/src/esdl4tulipa/mapping.py b/src/esdl4tulipa/mapping.py index fe98302..a51b667 100644 --- a/src/esdl4tulipa/mapping.py +++ b/src/esdl4tulipa/mapping.py @@ -4,6 +4,7 @@ from dataclasses import fields from dataclasses import is_dataclass from typing import Type +from typing import TypeAlias from typing import TypeVar from typing import Union @@ -41,7 +42,6 @@ def fields(cls) -> list[str]: def __post_init__(self): # noqa: D105 for key, field_t in self.__annotations__.items(): value = getattr(self, key) - print(key, field_t, unguarded_is_dataclass(field_t), value) if unguarded_is_dataclass(field_t) and isinstance(value, dict): setattr(self, key, field_t(**value)) @@ -151,8 +151,10 @@ def esdl_key(cls, key: str) -> str: # noqa: D102 return _esdl_key.get(key, "") -ESDLAssets = Union[hub_t, consumer_t, producer_t, conversion_t, storage_t, flow_t] -asset_types: dict[str, type[ESDLAssets]] = { +TAssets: TypeAlias = Union[ + hub_t, consumer_t, producer_t, conversion_t, storage_t, flow_t +] +asset_types: dict[str, type[TAssets]] = { "energynetwork": hub_t, "consumer": consumer_t, "producer": producer_t, diff --git a/src/esdl4tulipa/parser.py b/src/esdl4tulipa/parser.py index 45d0c8c..5525d48 100644 --- a/src/esdl4tulipa/parser.py +++ b/src/esdl4tulipa/parser.py @@ -5,16 +5,17 @@ from functools import reduce from io import StringIO from itertools import chain -from typing import Callable, Generator +from typing import Callable +from typing import Generator from typing import TypeAlias from typing import TypeVar from esdl import esdl from esdl.esdl_handler import EnergySystemHandler from pyecore.ecore import EOrderedSet from tabulate import tabulate -from .mapping import ESDLAssets +from .mapping import TAssets from .mapping import asset_types - +from .mapping import flow_t _HANDLER = EnergySystemHandler() @@ -104,7 +105,7 @@ def batched( pass -def fill_asset(asset: esdl.EnergyAsset, kind: str = "", **overrides) -> ESDLAssets: +def fill_asset(asset: esdl.EnergyAsset, kind: str = "", **overrides) -> TAssets: """Fill asset dataclasses. Parameters @@ -121,7 +122,7 @@ def fill_asset(asset: esdl.EnergyAsset, kind: str = "", **overrides) -> ESDLAsse Returns ------- - ESDLAssets + TAssets """ if not kind: @@ -139,29 +140,35 @@ def fill_asset(asset: esdl.EnergyAsset, kind: str = "", **overrides) -> ESDLAsse ) ) } - return asset_types[kind](**args, **overrides) + # NOTE: don't merge by unpacking together: + # + # dataclass_t(**args, **overrides) + # + # as it may create duplicate keyword arguments (TypeError) + args.update(overrides) + return asset_types[kind](**args) else: raise ValueError(f"unsupported {kind=}, not one of {list(asset_types)}") -def merge_assets(asset1: ESDLAssets, asset2: ESDLAssets, **overrides) -> ESDLAssets: +def merge_assets(asset1: TAssets, asset2: TAssets, **overrides) -> flow_t: """Merge two assets (from & to) into a flow asset (transport in esdl). - Logic: merge and discard. + Logic: discard invalid, and merge. + + - Discard if not a valid :ref:`flow_t` attribute - Merge - - if attribute is missing asset 1 but present in asset 2, accept the 2nd value + - if attribute is missing in asset 1 but present in asset 2, accept the value - if both assets have set the same value, accept it - if both assets set the value but they aren't equal, raise an error - - Discard if not a valid :ref:`flow_t` attribute (FIXME: should this precede merge?) - Parameters ---------- - asset1: ESDLAssets + asset1: TAssets From asset - asset2: ESDLAssets + asset2: TAssets To asset **overrides @@ -169,7 +176,7 @@ def merge_assets(asset1: ESDLAssets, asset2: ESDLAssets, **overrides) -> ESDLAss Returns ------- - ESDLAssets + TAssets Raises ------ @@ -177,39 +184,42 @@ def merge_assets(asset1: ESDLAssets, asset2: ESDLAssets, **overrides) -> ESDLAss When attributes have mismatched values """ - merged, errs = {}, [] + _fields = [f.name for f in fields(flow_t)] + errs = [] # NOTE: use vars instead of asdict, I think asdict does a # copy. Some pyecore types like EEnumLiteral doesn't like # this. The copied values compare unequal even if they are equal - for (k1, v1), (k2, v2) in zip(vars(asset1).items(), vars(asset2).items()): - assert k1 == k2, "merging assumes both assets are of the same kind" + merged = { + "from_asset": asset1.name, + "to_asset": asset2.name, + **{k: v for k, v in vars(asset1).items() if k in _fields}, + } + for key, val in vars(asset2).items(): + if key in ("name", "id"): + continue - if k1 in ("name", "id"): + if key not in _fields: continue - match v1, v2: - case (None, None): - pass - case (_, None) if v1 is not None: - merged[k1] = v1 - case (None, _) if v2 is not None: - merged[k1] = v2 - case _ if v1 == v2: - merged[k1] = v1 - case _ if v1 != v2: - errs.append((k1, v1, v2)) + # NOTE: empty values are `None` or `""` + if (v1 := merged.get(key)) and v1 != val: + errs.append((key, v1, val)) + else: + merged[key] = val if len(errs) > 0: tbl = tabulate(errs, headers=("column", "from", "to")) raise ValueError(f"mismatching assets: {asset1.name} != {asset2.name}\n {tbl}") - flow_t = asset_types["transport"] - _fields = {f.name for f in fields(flow_t)} - merged = {k: v for k, v in merged.items() if k in _fields} - return flow_t(**merged, **overrides) + merged.pop("name") + merged.pop("id") + # see NOTE: about merging in `fill_asset` + merged.update(overrides) + return flow_t(**merged) -def edge(*assets: esdl.EnergyAsset) -> tuple[ESDLAssets, ESDLAssets, ESDLAssets]: + +def edge(*assets: esdl.EnergyAsset) -> tuple[TAssets, TAssets, TAssets]: """Create a Tulipa flow, and assets from ESDL assets. Parameters @@ -219,7 +229,7 @@ def edge(*assets: esdl.EnergyAsset) -> tuple[ESDLAssets, ESDLAssets, ESDLAssets] Returns ------- - tuple[ESDLAssets, ...] + tuple[TAssets, ...] Tuple of (flow, from_asset, to_asset) Raises @@ -252,9 +262,7 @@ def edge(*assets: esdl.EnergyAsset) -> tuple[ESDLAssets, ESDLAssets, ESDLAssets] | esdl.EnergyNetwork() as a2, ) if len(kinds(a1, a2)) == 2: from_asset, to_asset = map(fill_asset, assets) - flow = merge_assets( - from_asset, to_asset, from_asset=from_asset.name, to_asset=to_asset.name - ) + flow = merge_assets(from_asset, to_asset) case ( esdl.Producer() | esdl.Conversion() @@ -265,10 +273,13 @@ def edge(*assets: esdl.EnergyAsset) -> tuple[ESDLAssets, ESDLAssets, ESDLAssets] | esdl.Conversion() | esdl.Storage() | esdl.EnergyNetwork() as a2, - ) if not isinstance(link, esdl.EnergyNetwork) and len(kinds(a1, a2)) == 2: + ) if not isinstance(link, esdl.EnergyNetwork) and len(kinds(a1, a2)) <= 2: + # NOTE: len(kinds(...)) <= 2 to support EnergyNetwork -> EnergyNetwork from_asset = fill_asset(a1) to_asset = fill_asset(a2) - flow = fill_asset(link, from_asset=a1.name, to_asset=a2.name) + flow = fill_asset( + link, from_asset=a1.name, to_asset=a2.name, name="", id="" + ) # reset 'name' & 'id' for flow case _: # NOTE: unhandled case: asset, transport, ..., asset raise ValueError(f"{assets=}: uncharted territory!") @@ -352,7 +363,10 @@ def hop_nodes( """ match asset: case ( - esdl.Producer() | esdl.Conversion() | esdl.Storage() | esdl.EnergyNetwork() + esdl.Producer() + | esdl.Conversion() + | esdl.Storage() + | esdl.EnergyNetwork() ) if depth == 1: edges.append(asset) itr_nodes(asset, edges, depth) @@ -362,7 +376,10 @@ def hop_nodes( edges.append(asset) itr_nodes(asset, edges, depth) case ( - esdl.Consumer() | esdl.Conversion() | esdl.Storage() | esdl.EnergyNetwork() + esdl.Consumer() + | esdl.Conversion() + | esdl.Storage() + | esdl.EnergyNetwork() ) if depth > 1: edges.append(asset) case _: @@ -370,23 +387,43 @@ def hop_nodes( return edges -def find_edges(asset: esdl.EnergyAsset) -> list[tuple[ESDLAssets, ...]]: +res_t = TypeVar("res_t", tuple, list, set, dict, esdl.EnergyAsset) + + +def find_edges( + asset: esdl.EnergyAsset, + process: Callable[..., res_t] = edge, # type: ignore[assignment] +) -> list[res_t]: """Find all out going flows from the provided asset.""" if (edges := hop_nodes(asset, [])) and len(edges) > 1: from_asset, *rest = edges - return [edge(from_asset, *assets) for assets in batched(rest)] + return [process(from_asset, *assets) for assets in batched(rest)] else: return [] +def _apply( + predicate: Callable[[esdl.EnergyAsset], list[res_t] | res_t], + asset: esdl.EnergyAsset, + res: list[res_t], +) -> list[res_t]: + if hasattr(asset, "name"): + match predicate(asset): + case list() as _interim: + res.extend(_interim) + case _interim: + res.append(_interim) + return res + + ESDLNode: TypeAlias = esdl.EnergySystem | esdl.Area | esdl.EnergyAsset -res_t = TypeVar("res_t", tuple, list, set, dict, esdl.EnergyAsset) def parse_graph( obj: ESDLNode | EOrderedSet, predicate: Callable[[esdl.EnergyAsset], list[res_t] | res_t], res: list[res_t], + visit_all: bool = False, ) -> list[res_t]: """Parse ecore object to extract node attributes and connections. @@ -405,6 +442,10 @@ def parse_graph( dataclass has edge properties determined by combining the properties of the nodes connected by the edge. + visit_all : bool (default: False) + If `True`, visit all supported nodes instead of limiting to + nodes with outgoing edges. + Returns ------- list[res_t] @@ -414,37 +455,40 @@ def parse_graph( match obj: case EOrderedSet(): for el in obj: - parse_graph(el, predicate, res) + parse_graph(el, predicate, res, visit_all=visit_all) case esdl.EnergySystem() if isinstance(obj.instance, esdl.EOrderedSet): - parse_graph(obj.instance, predicate, res) + parse_graph(obj.instance, predicate, res, visit_all=visit_all) case esdl.Instance(): - parse_graph(obj.area, predicate, res) + parse_graph(obj.area, predicate, res, visit_all=visit_all) case esdl.Area(): if isinstance(obj.area, EOrderedSet): # may contain sub-areas - parse_graph(obj.area, predicate, res) + parse_graph(obj.area, predicate, res, visit_all=visit_all) if isinstance(obj.asset, EOrderedSet): # may also contain assets - parse_graph(obj.asset, predicate, res) - case (esdl.Producer() | esdl.Conversion() | esdl.Storage()) as asset: - if hasattr(asset, "name"): - match predicate(asset): - case list() as _interim: - res.extend(_interim) - case _interim: - res.append(_interim) - case esdl.Transport() | esdl.Consumer(): - pass # only following out going flows + parse_graph(obj.asset, predicate, res, visit_all=visit_all) + case ( + ( + esdl.Producer() + | esdl.Conversion() + | esdl.Storage() + | esdl.EnergyNetwork() + ) as asset + ): + _apply(predicate, asset, res) + case (esdl.Transport() | esdl.Consumer()) as asset: + if visit_all: # only follow out going flows by default + _apply(predicate, asset, res) case _: raise ValueError(f"{obj}: unsupported value") return res -def load(path: str) -> tuple[tuple[ESDLAssets, ...], tuple[ESDLAssets, ...]]: +def load(path: str) -> tuple[tuple[flow_t, ...], tuple[TAssets, ...]]: """Load ESDL file and parse nodes.""" with contextlib.redirect_stdout(StringIO()): ensys = _HANDLER.load_file(path) edges = parse_graph(ensys, find_edges, []) - flows = tuple(edge[0] for edge in edges) - assets: tuple[ESDLAssets, ...] = tuple(set(chain(*(edge[1:] for edge in edges)))) + flows: tuple[flow_t] = tuple(edge[0] for edge in edges) + assets: tuple[TAssets, ...] = tuple(set(chain(*(edge[1:] for edge in edges)))) return flows, assets diff --git a/tests/test_parser.py b/tests/test_parser.py index ccc8ce1..e0fe5b2 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,54 +1,168 @@ -from dataclasses import make_dataclass +from itertools import chain import pytest from esdl import esdl -from esdl.esdl_handler import EnergySystemHandler -import esdl4tulipa +from esdl4tulipa.mapping import asset_types +from esdl4tulipa.parser import batched, find_edges, load +from esdl4tulipa.parser import debug +from esdl4tulipa.parser import edge from esdl4tulipa.parser import fill_asset +from esdl4tulipa.parser import kinds from esdl4tulipa.parser import merge_assets from esdl4tulipa.parser import parse_graph +import typing + + +@pytest.fixture +def all_assets(): + ensys = debug("tests/data/esdl/norse-mythology-good.esdl") + return parse_graph(ensys, lambda i: i, [], visit_all=True) + @pytest.fixture -def assets(): - handler = EnergySystemHandler() - ensys = handler.load_file("tests/data/esdl/norse-mythology-good.esdl") - return parse_graph(ensys, lambda i: i, []) +def edges(): + ensys = debug("tests/data/esdl/norse-mythology-good.esdl") + assets = parse_graph(ensys, lambda i: i, [], visit_all=True) + electrolyzer = assets[1] + gasnetwork = electrolyzer.port[1].connectedTo[0].energyasset + gasstorage = gasnetwork.port[1].connectedTo[0].energyasset + fuelcell = gasnetwork.port[1].connectedTo[1].energyasset + gasdemand = gasnetwork.port[1].connectedTo[2].energyasset + _edges = [ + (electrolyzer, gasnetwork), + (gasnetwork, gasstorage), + (gasnetwork, fuelcell), + (gasnetwork, gasdemand), + ] + kinds = [ + ("conversion", "energynetwork"), + ("energynetwork", "storage"), + ("energynetwork", "conversion"), + ("energynetwork", "consumer"), + ] + return _edges, kinds + + +@pytest.fixture +def empty_edges(): + _edges = [ + [esdl.GasNetwork(), esdl.Pipe(), esdl.GasDemand()], + [esdl.GasNetwork(), esdl.Electrolyzer()], + [esdl.GasNetwork(), esdl.FuelCell()], + [esdl.GasNetwork(), esdl.Pipe(), esdl.GasStorage()], + ] + _kinds = [ + ["energynetwork", "transport", "consumer"], + ["energynetwork", "conversion"], + ["energynetwork", "conversion"], + ["energynetwork", "transport", "storage"], + ] + return _edges, _kinds + + +def test_asset_kinds(edges): + _edges = edges[0] + _kinds = list(chain.from_iterable(edges[1])) + inferred = kinds(*chain.from_iterable(_edges), unique=True) + assert inferred == set(_kinds) + + inferred = kinds(*chain.from_iterable(_edges), unique=False) + assert inferred == _kinds + +def test_batched(): + seq = [ + esdl.Pipe(), + esdl.GasDemand(), + esdl.Electrolyzer(), + esdl.FuelCell(), + esdl.Pipe(), + esdl.GasStorage(), + ] + expect = [seq[:2], seq[2:3], seq[3:4], seq[4:]] + assert expect == list(batched(seq)) -def test_fill_asset(assets): - for asset in assets: - if isinstance(asset, esdl.Transport): - with pytest.raises(ValueError): - fill_asset(asset) - else: - assert fill_asset(asset) +def test_fill_asset(all_assets): + for asset in all_assets: + assert fill_asset(asset) -def test_merge_assets(): - fields = [("name", str), ("num", int), ("flag", bool)] - my_type = make_dataclass("my_type", fields) - values = [("foo", 42, None), ("bar", 42, False), ("baz", 42, None)] - expect = [(42, False), (42, False)] +def test_fill_assets_w_override(edges): + flow = fill_asset(edges[0][0][0], name="foo") + assert flow.name == "foo" - npairs = len(values) - 1 - for i in range(npairs): - data1 = my_type(*values[i]) - data2 = my_type(*values[i + 1]) - assert tuple(merge_assets(data1, data2).values()) == expect[i] +def test_fill_assets_bad_kind(edges): + kind = "notthere" + with pytest.raises(ValueError, match=f"unsupported {kind=}"): + fill_asset(edges[0][0][0], kind=kind) -def test_merge_assets_err(): - fields = [("name", str), ("num", int), ("flag", bool)] - my_type = make_dataclass("my_type", fields) - values = [("foo", 42, None), ("bar", 0, None)] +def test_merge_assets(edges): + for a1, a2 in edges[0]: + flow = merge_assets(fill_asset(a1), fill_asset(a2)) + assert hash(flow) + assert flow.from_asset and flow.to_asset - data1 = my_type(*values[0]) - data2 = my_type(*values[1]) - hdr = f"{values[0][0]} != {values[1][0]}" - body = f"{fields[1][0]}.+42.+0" + +def test_merge_assets_w_override(edges): + a1, a2 = map(fill_asset, edges[0][0]) + flow = merge_assets(a1, a2, from_asset="foo", to_asset="bar") + assert flow.from_asset == "foo" + assert flow.to_asset == "bar" + + +@typing.no_type_check +def test_merge_assets_err(edges): + a1, a2 = map(fill_asset, edges[0][0]) + a1.lifetime = 10 + a2.lifetime = 20 + hdr = f"{a1.name} != {a2.name}" + body = f"lifetime.+{a1.lifetime}.+{a2.lifetime}" with pytest.raises(ValueError, match=f"{hdr}[\s\S]+{body}"): - assert merge_assets(data1, data2) + assert merge_assets(a1, a2) + + +def test_edge(empty_edges): + _edges, _kinds = empty_edges + for _edge, _kind in zip(_edges, _kinds): + a = edge(*_edge) + assert all( + [ + isinstance(a[0], asset_types["transport"]), + isinstance(a[1], asset_types[_kind[0]]), + isinstance(a[2], asset_types[_kind[-1]]), + ] + ) + + +def test_find_edges(edges): + _, *gns = edges[0] + _, *expect = edges[1] + gasnetwork = gns[0][0] + for _edge, _kind in zip(find_edges(gasnetwork, edge), expect): + assert all( + [ + isinstance(_edge[0], asset_types["transport"]), + isinstance(_edge[1], asset_types[_kind[0]]), + isinstance(_edge[2], asset_types[_kind[-1]]), + ] + ) + + +def test_parse_graph(): + # TODO: test more meaningful things + ensys = debug("tests/data/esdl/norse-mythology-good.esdl") + assets_all = parse_graph(ensys, lambda i: i, [], visit_all=True) + assets = parse_graph(ensys, lambda i: i, [], visit_all=False) + assert len(assets) < len(assets_all) + + +def test_load(): + # TODO: test more meaningful things + flows, assets = load("tests/data/esdl/norse-mythology-good.esdl") + assert len(flows) == 35 + assert len(assets) == 29