Skip to content

Commit

Permalink
parser.py: fix merge_assets, bug fixes, refactor, add tests
Browse files Browse the repository at this point in the history
- silence SyntaxWarning for match string that's actually a regex in
  test suite
  • Loading branch information
suvayu committed Jul 10, 2024
1 parent 69136f1 commit 565904a
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 98 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ version_scheme = "release-branch-semver"
[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = ["--import-mode=importlib", "--cov=src", "-q"]
filterwarnings = ['ignore:invalid escape sequence:SyntaxWarning']

[tool.black]
include = '\.pyi?$'
Expand Down
8 changes: 5 additions & 3 deletions src/esdl4tulipa/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import fields
from dataclasses import is_dataclass
from typing import Type
from typing import TypeAlias
from typing import TypeVar
from typing import Union

Expand Down Expand Up @@ -41,7 +42,6 @@ def fields(cls) -> list[str]:
def __post_init__(self): # noqa: D105
for key, field_t in self.__annotations__.items():
value = getattr(self, key)
print(key, field_t, unguarded_is_dataclass(field_t), value)
if unguarded_is_dataclass(field_t) and isinstance(value, dict):
setattr(self, key, field_t(**value))

Expand Down Expand Up @@ -151,8 +151,10 @@ def esdl_key(cls, key: str) -> str: # noqa: D102
return _esdl_key.get(key, "")


ESDLAssets = Union[hub_t, consumer_t, producer_t, conversion_t, storage_t, flow_t]
asset_types: dict[str, type[ESDLAssets]] = {
TAssets: TypeAlias = Union[
hub_t, consumer_t, producer_t, conversion_t, storage_t, flow_t
]
asset_types: dict[str, type[TAssets]] = {
"energynetwork": hub_t,
"consumer": consumer_t,
"producer": producer_t,
Expand Down
168 changes: 106 additions & 62 deletions src/esdl4tulipa/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from functools import reduce
from io import StringIO
from itertools import chain
from typing import Callable, Generator
from typing import Callable
from typing import Generator
from typing import TypeAlias
from typing import TypeVar
from esdl import esdl
from esdl.esdl_handler import EnergySystemHandler
from pyecore.ecore import EOrderedSet
from tabulate import tabulate
from .mapping import ESDLAssets
from .mapping import TAssets
from .mapping import asset_types

from .mapping import flow_t

_HANDLER = EnergySystemHandler()

Expand Down Expand Up @@ -104,7 +105,7 @@ def batched(
pass


def fill_asset(asset: esdl.EnergyAsset, kind: str = "", **overrides) -> ESDLAssets:
def fill_asset(asset: esdl.EnergyAsset, kind: str = "", **overrides) -> TAssets:
"""Fill asset dataclasses.
Parameters
Expand All @@ -121,7 +122,7 @@ def fill_asset(asset: esdl.EnergyAsset, kind: str = "", **overrides) -> ESDLAsse
Returns
-------
ESDLAssets
TAssets
"""
if not kind:
Expand All @@ -139,77 +140,86 @@ def fill_asset(asset: esdl.EnergyAsset, kind: str = "", **overrides) -> ESDLAsse
)
)
}
return asset_types[kind](**args, **overrides)
# NOTE: don't merge by unpacking together:
#
# dataclass_t(**args, **overrides)
#
# as it may create duplicate keyword arguments (TypeError)
args.update(overrides)
return asset_types[kind](**args)
else:
raise ValueError(f"unsupported {kind=}, not one of {list(asset_types)}")


def merge_assets(asset1: ESDLAssets, asset2: ESDLAssets, **overrides) -> ESDLAssets:
def merge_assets(asset1: TAssets, asset2: TAssets, **overrides) -> flow_t:
"""Merge two assets (from & to) into a flow asset (transport in esdl).
Logic: merge and discard.
Logic: discard invalid, and merge.
- Discard if not a valid :ref:`flow_t` attribute
- Merge
- if attribute is missing asset 1 but present in asset 2, accept the 2nd value
- if attribute is missing in asset 1 but present in asset 2, accept the value
- if both assets have set the same value, accept it
- if both assets set the value but they aren't equal, raise an error
- Discard if not a valid :ref:`flow_t` attribute (FIXME: should this precede merge?)
Parameters
----------
asset1: ESDLAssets
asset1: TAssets
From asset
asset2: ESDLAssets
asset2: TAssets
To asset
**overrides
Any attribute overrides
Returns
-------
ESDLAssets
TAssets
Raises
------
ValueError
When attributes have mismatched values
"""
merged, errs = {}, []
_fields = [f.name for f in fields(flow_t)]
errs = []
# NOTE: use vars instead of asdict, I think asdict does a
# copy. Some pyecore types like EEnumLiteral doesn't like
# this. The copied values compare unequal even if they are equal
for (k1, v1), (k2, v2) in zip(vars(asset1).items(), vars(asset2).items()):
assert k1 == k2, "merging assumes both assets are of the same kind"
merged = {
"from_asset": asset1.name,
"to_asset": asset2.name,
**{k: v for k, v in vars(asset1).items() if k in _fields},
}
for key, val in vars(asset2).items():
if key in ("name", "id"):
continue

if k1 in ("name", "id"):
if key not in _fields:
continue

match v1, v2:
case (None, None):
pass
case (_, None) if v1 is not None:
merged[k1] = v1
case (None, _) if v2 is not None:
merged[k1] = v2
case _ if v1 == v2:
merged[k1] = v1
case _ if v1 != v2:
errs.append((k1, v1, v2))
# NOTE: empty values are `None` or `""`
if (v1 := merged.get(key)) and v1 != val:
errs.append((key, v1, val))
else:
merged[key] = val

if len(errs) > 0:
tbl = tabulate(errs, headers=("column", "from", "to"))
raise ValueError(f"mismatching assets: {asset1.name} != {asset2.name}\n {tbl}")

flow_t = asset_types["transport"]
_fields = {f.name for f in fields(flow_t)}
merged = {k: v for k, v in merged.items() if k in _fields}
return flow_t(**merged, **overrides)
merged.pop("name")
merged.pop("id")

# see NOTE: about merging in `fill_asset`
merged.update(overrides)
return flow_t(**merged)

def edge(*assets: esdl.EnergyAsset) -> tuple[ESDLAssets, ESDLAssets, ESDLAssets]:

def edge(*assets: esdl.EnergyAsset) -> tuple[TAssets, TAssets, TAssets]:
"""Create a Tulipa flow, and assets from ESDL assets.
Parameters
Expand All @@ -219,7 +229,7 @@ def edge(*assets: esdl.EnergyAsset) -> tuple[ESDLAssets, ESDLAssets, ESDLAssets]
Returns
-------
tuple[ESDLAssets, ...]
tuple[TAssets, ...]
Tuple of (flow, from_asset, to_asset)
Raises
Expand Down Expand Up @@ -252,9 +262,7 @@ def edge(*assets: esdl.EnergyAsset) -> tuple[ESDLAssets, ESDLAssets, ESDLAssets]
| esdl.EnergyNetwork() as a2,
) if len(kinds(a1, a2)) == 2:
from_asset, to_asset = map(fill_asset, assets)
flow = merge_assets(
from_asset, to_asset, from_asset=from_asset.name, to_asset=to_asset.name
)
flow = merge_assets(from_asset, to_asset)
case (
esdl.Producer()
| esdl.Conversion()
Expand All @@ -265,10 +273,13 @@ def edge(*assets: esdl.EnergyAsset) -> tuple[ESDLAssets, ESDLAssets, ESDLAssets]
| esdl.Conversion()
| esdl.Storage()
| esdl.EnergyNetwork() as a2,
) if not isinstance(link, esdl.EnergyNetwork) and len(kinds(a1, a2)) == 2:
) if not isinstance(link, esdl.EnergyNetwork) and len(kinds(a1, a2)) <= 2:
# NOTE: len(kinds(...)) <= 2 to support EnergyNetwork -> EnergyNetwork
from_asset = fill_asset(a1)
to_asset = fill_asset(a2)
flow = fill_asset(link, from_asset=a1.name, to_asset=a2.name)
flow = fill_asset(
link, from_asset=a1.name, to_asset=a2.name, name="", id=""
) # reset 'name' & 'id' for flow
case _:
# NOTE: unhandled case: asset, transport, ..., asset
raise ValueError(f"{assets=}: uncharted territory!")
Expand Down Expand Up @@ -352,7 +363,10 @@ def hop_nodes(
"""
match asset:
case (
esdl.Producer() | esdl.Conversion() | esdl.Storage() | esdl.EnergyNetwork()
esdl.Producer()
| esdl.Conversion()
| esdl.Storage()
| esdl.EnergyNetwork()
) if depth == 1:
edges.append(asset)
itr_nodes(asset, edges, depth)
Expand All @@ -362,31 +376,54 @@ def hop_nodes(
edges.append(asset)
itr_nodes(asset, edges, depth)
case (
esdl.Consumer() | esdl.Conversion() | esdl.Storage() | esdl.EnergyNetwork()
esdl.Consumer()
| esdl.Conversion()
| esdl.Storage()
| esdl.EnergyNetwork()
) if depth > 1:
edges.append(asset)
case _:
raise ValueError(f"{asset}: why am I here? {depth=}")
return edges


def find_edges(asset: esdl.EnergyAsset) -> list[tuple[ESDLAssets, ...]]:
res_t = TypeVar("res_t", tuple, list, set, dict, esdl.EnergyAsset)


def find_edges(
asset: esdl.EnergyAsset,
process: Callable[..., res_t] = edge, # type: ignore[assignment]
) -> list[res_t]:
"""Find all out going flows from the provided asset."""
if (edges := hop_nodes(asset, [])) and len(edges) > 1:
from_asset, *rest = edges
return [edge(from_asset, *assets) for assets in batched(rest)]
return [process(from_asset, *assets) for assets in batched(rest)]
else:
return []


def _apply(
predicate: Callable[[esdl.EnergyAsset], list[res_t] | res_t],
asset: esdl.EnergyAsset,
res: list[res_t],
) -> list[res_t]:
if hasattr(asset, "name"):
match predicate(asset):
case list() as _interim:
res.extend(_interim)
case _interim:
res.append(_interim)
return res


ESDLNode: TypeAlias = esdl.EnergySystem | esdl.Area | esdl.EnergyAsset
res_t = TypeVar("res_t", tuple, list, set, dict, esdl.EnergyAsset)


def parse_graph(
obj: ESDLNode | EOrderedSet,
predicate: Callable[[esdl.EnergyAsset], list[res_t] | res_t],
res: list[res_t],
visit_all: bool = False,
) -> list[res_t]:
"""Parse ecore object to extract node attributes and connections.
Expand All @@ -405,6 +442,10 @@ def parse_graph(
dataclass has edge properties determined by combining the
properties of the nodes connected by the edge.
visit_all : bool (default: False)
If `True`, visit all supported nodes instead of limiting to
nodes with outgoing edges.
Returns
-------
list[res_t]
Expand All @@ -414,37 +455,40 @@ def parse_graph(
match obj:
case EOrderedSet():
for el in obj:
parse_graph(el, predicate, res)
parse_graph(el, predicate, res, visit_all=visit_all)
case esdl.EnergySystem() if isinstance(obj.instance, esdl.EOrderedSet):
parse_graph(obj.instance, predicate, res)
parse_graph(obj.instance, predicate, res, visit_all=visit_all)
case esdl.Instance():
parse_graph(obj.area, predicate, res)
parse_graph(obj.area, predicate, res, visit_all=visit_all)
case esdl.Area():
if isinstance(obj.area, EOrderedSet): # may contain sub-areas
parse_graph(obj.area, predicate, res)
parse_graph(obj.area, predicate, res, visit_all=visit_all)
if isinstance(obj.asset, EOrderedSet): # may also contain assets
parse_graph(obj.asset, predicate, res)
case (esdl.Producer() | esdl.Conversion() | esdl.Storage()) as asset:
if hasattr(asset, "name"):
match predicate(asset):
case list() as _interim:
res.extend(_interim)
case _interim:
res.append(_interim)
case esdl.Transport() | esdl.Consumer():
pass # only following out going flows
parse_graph(obj.asset, predicate, res, visit_all=visit_all)
case (
(
esdl.Producer()
| esdl.Conversion()
| esdl.Storage()
| esdl.EnergyNetwork()
) as asset
):
_apply(predicate, asset, res)
case (esdl.Transport() | esdl.Consumer()) as asset:
if visit_all: # only follow out going flows by default
_apply(predicate, asset, res)
case _:
raise ValueError(f"{obj}: unsupported value")
return res


def load(path: str) -> tuple[tuple[ESDLAssets, ...], tuple[ESDLAssets, ...]]:
def load(path: str) -> tuple[tuple[flow_t, ...], tuple[TAssets, ...]]:
"""Load ESDL file and parse nodes."""
with contextlib.redirect_stdout(StringIO()):
ensys = _HANDLER.load_file(path)
edges = parse_graph(ensys, find_edges, [])
flows = tuple(edge[0] for edge in edges)
assets: tuple[ESDLAssets, ...] = tuple(set(chain(*(edge[1:] for edge in edges))))
flows: tuple[flow_t] = tuple(edge[0] for edge in edges)
assets: tuple[TAssets, ...] = tuple(set(chain(*(edge[1:] for edge in edges))))
return flows, assets


Expand Down
Loading

0 comments on commit 565904a

Please sign in to comment.