diff --git a/tests/test_graph_group.py b/tests/test_graph_group.py index b5d2396..ba036e6 100644 --- a/tests/test_graph_group.py +++ b/tests/test_graph_group.py @@ -14,7 +14,7 @@ def run(self): def test_empty_grp_name(): graph = znflow.DiGraph() - with pytest.raises(TypeError): + with pytest.raises(ValueError): with graph.group(): # name required pass @@ -24,19 +24,19 @@ def test_grp(): assert graph.active_group is None - with graph.group("my_grp") as grp_name: - assert graph.active_group == grp_name + with graph.group("my_grp") as grp: + assert graph.active_group == grp node = PlainNode(1) assert graph.active_group is None graph.run() - assert grp_name == "my_grp" + assert grp.names == ("my_grp",) assert node.value == 2 assert node.uuid in graph.nodes - assert grp_name in graph._groups - assert graph.get_group(grp_name) == [node.uuid] + assert grp.names in graph._groups + assert graph.get_group("my_grp").uuids == [node.uuid] assert len(graph._groups) == 1 assert len(graph) == 1 @@ -47,15 +47,15 @@ def test_muliple_grps(): assert graph.active_group is None - with graph.group("my_grp") as grp_name: - assert graph.active_group == grp_name + with graph.group("my_grp") as grp: + assert graph.active_group == grp node = PlainNode(1) assert graph.active_group is None - with graph.group("my_grp2") as grp_name2: - assert graph.active_group == grp_name2 + with graph.group("my_grp2") as grp2: + assert graph.active_group == grp2 node2 = PlainNode(2) @@ -63,8 +63,8 @@ def test_muliple_grps(): graph.run() - assert grp_name == "my_grp" - assert grp_name2 == "my_grp2" + assert grp.names == ("my_grp",) + assert grp2.names == ("my_grp2",) assert node.value == 2 assert node2.value == 3 @@ -72,11 +72,11 @@ def test_muliple_grps(): assert node.uuid in graph.nodes assert node2.uuid in graph.nodes - assert grp_name in graph._groups - assert grp_name2 in graph._groups + assert grp.names in graph._groups + assert grp2.names in graph._groups - assert graph.get_group(grp_name) == [node.uuid] - assert graph.get_group(grp_name2) == [node2.uuid] + assert graph.get_group(*grp.names).uuids == [node.uuid] + assert graph.get_group(*grp2.names).uuids == [node2.uuid] assert len(graph._groups) == 2 assert len(graph) == 2 @@ -85,8 +85,8 @@ def test_muliple_grps(): def test_nested_grps(): graph = znflow.DiGraph() - with graph.group("my_grp") as grp_name: - assert graph.active_group == grp_name + with graph.group("my_grp") as grp: + assert graph.active_group == grp with pytest.raises(TypeError): with graph.group("my_grp2"): pass @@ -96,8 +96,8 @@ def test_grp_with_existing_nodes(): with znflow.DiGraph() as graph: node = PlainNode(1) - with graph.group("my_grp") as grp_name: - assert graph.active_group == grp_name + with graph.group("my_grp") as grp: + assert graph.active_group == grp node2 = PlainNode(2) @@ -105,7 +105,7 @@ def test_grp_with_existing_nodes(): graph.run() - assert grp_name == "my_grp" + assert grp.names == ("my_grp",) assert node.value == 2 assert node2.value == 3 @@ -113,9 +113,9 @@ def test_grp_with_existing_nodes(): assert node.uuid in graph.nodes assert node2.uuid in graph.nodes - assert grp_name in graph._groups + assert grp.names in graph._groups - assert graph.get_group(grp_name) == [node2.uuid] + assert graph.get_group(*grp.names).uuids == [node2.uuid] assert len(graph._groups) == 1 assert len(graph) == 2 @@ -126,8 +126,8 @@ def test_grp_with_multiple_nodes(): node = PlainNode(1) node2 = PlainNode(2) - with graph.group("my_grp") as grp_name: - assert graph.active_group == grp_name + with graph.group("my_grp") as grp: + assert graph.active_group == grp node3 = PlainNode(3) node4 = PlainNode(4) @@ -136,7 +136,7 @@ def test_grp_with_multiple_nodes(): graph.run() - assert grp_name == "my_grp" + assert grp.names == ("my_grp",) assert node.value == 2 assert node2.value == 3 @@ -148,9 +148,9 @@ def test_grp_with_multiple_nodes(): assert node3.uuid in graph.nodes assert node4.uuid in graph.nodes - assert grp_name in graph._groups + assert grp.names in graph._groups - assert graph.get_group(grp_name) == [node3.uuid, node4.uuid] + assert graph.get_group(*grp.names).uuids == [node3.uuid, node4.uuid] assert len(graph._groups) == 1 assert len(graph) == 4 @@ -158,13 +158,13 @@ def test_grp_with_multiple_nodes(): def test_reopen_grps(): with znflow.DiGraph() as graph: - with graph.group("my_grp") as grp_name: - assert graph.active_group == grp_name + with graph.group("my_grp") as grp: + assert graph.active_group == grp node = PlainNode(1) - with graph.group("my_grp") as grp_name2: - assert graph.active_group == grp_name2 + with graph.group("my_grp") as grp2: + assert graph.active_group == grp2 node2 = PlainNode(2) @@ -172,8 +172,8 @@ def test_reopen_grps(): graph.run() - assert grp_name == "my_grp" - assert grp_name2 == grp_name + assert grp.names == ("my_grp",) + assert grp.names == grp2.names assert node.value == 2 assert node2.value == 3 @@ -181,9 +181,9 @@ def test_reopen_grps(): assert node.uuid in graph.nodes assert node2.uuid in graph.nodes - assert grp_name in graph._groups + assert grp.names in graph._groups - assert graph.get_group(grp_name) == [node.uuid, node2.uuid] + assert graph.get_group(*grp.names).uuids == [node.uuid, node2.uuid] assert len(graph._groups) == 1 assert len(graph) == 2 @@ -193,19 +193,19 @@ def test_tuple_grp_names(): graph = znflow.DiGraph() assert graph.active_group is None - with graph.group(("grp", "1")) as grp_name: - assert graph.active_group == grp_name + with graph.group("grp", "1") as grp: + assert graph.active_group == grp node = PlainNode(1) assert graph.active_group is None graph.run() - assert grp_name == ("grp", "1") + assert grp.names == ("grp", "1") assert node.value == 2 assert node.uuid in graph.nodes - assert grp_name in graph._groups - assert graph.get_group(grp_name) == [node.uuid] + assert grp.names in graph._groups + assert graph.get_group(*grp.names).uuids == [node.uuid] def test_grp_nodify(): @@ -218,4 +218,61 @@ def compute_mean(x, y): with graph.group("grp1"): n1 = compute_mean(2, 4) - assert n1.uuid in graph.get_group("grp1") + assert n1.uuid in graph.get_group("grp1").uuids + + +def test_grp_iter(): + graph = znflow.DiGraph() + + with graph.group("grp1") as grp: + n1 = PlainNode(1) + n2 = PlainNode(2) + + assert list(grp) == [n1.uuid, n2.uuid] + + +def test_grp_contains(): + graph = znflow.DiGraph() + + with graph.group("grp1") as grp: + n1 = PlainNode(1) + n2 = PlainNode(2) + + assert n1.uuid in grp + assert n2.uuid in grp + assert "foo" not in grp + + +def test_grp_len(): + graph = znflow.DiGraph() + + with graph.group("grp1") as grp: + PlainNode(1) + PlainNode(2) + + assert len(grp) == 2 + + +def test_grp_getitem(): + graph = znflow.DiGraph() + + with graph.group("grp1") as grp: + n1 = PlainNode(1) + n2 = PlainNode(2) + + assert grp[n1.uuid] == n1 + assert grp[n2.uuid] == n2 + with pytest.raises(KeyError): + grp["foo"] + + +def test_grp_nodes(): + graph = znflow.DiGraph() + + with graph.group("grp1") as grp: + n1 = PlainNode(1) + n2 = PlainNode(2) + + assert grp.nodes == [n1, n2] + assert grp.uuids == [n1.uuid, n2.uuid] + assert grp.names == ("grp1",) diff --git a/znflow/__init__.py b/znflow/__init__.py index 85c41e4..b60b866 100644 --- a/znflow/__init__.py +++ b/znflow/__init__.py @@ -17,7 +17,7 @@ ) from znflow.combine import combine from znflow.dynamic import resolve -from znflow.graph import DiGraph +from znflow.graph import DiGraph, Group from znflow.node import Node, nodify from znflow.visualize import draw @@ -39,6 +39,7 @@ "get_graph", "empty_graph", "resolve", + "Group", ] with contextlib.suppress(ImportError): diff --git a/znflow/graph.py b/znflow/graph.py index 2d89ade..e432fb4 100644 --- a/znflow/graph.py +++ b/znflow/graph.py @@ -1,7 +1,9 @@ import contextlib +import dataclasses import functools import logging import typing +import uuid import networkx as nx @@ -19,6 +21,29 @@ log = logging.getLogger(__name__) +@dataclasses.dataclass +class Group: + names: tuple[str, ...] + uuids: list[uuid.UUID] + graph: "DiGraph" + + def __iter__(self) -> typing.Iterator[uuid.UUID]: + return iter(self.uuids) + + def __len__(self) -> int: + return len(self.uuids) + + def __contains__(self, item) -> bool: + return item in self.uuids + + def __getitem__(self, item) -> NodeBaseMixin: + return self.graph.nodes[item]["value"] + + @property + def nodes(self) -> typing.List[NodeBaseMixin]: + return [self.graph.nodes[uuid_]["value"] for uuid_ in self.uuids] + + class DiGraph(nx.MultiDiGraph): def __init__(self, *args, disable=False, immutable_nodes=True, **kwargs): """ @@ -32,7 +57,7 @@ def __init__(self, *args, disable=False, immutable_nodes=True, **kwargs): self.disable = disable self.immutable_nodes = immutable_nodes self._groups = {} - self.active_group = None + self.active_group: typing.Union[Group, None] = None super().__init__(*args, **kwargs) @@ -210,9 +235,7 @@ def write_graph(self, *args): pass @contextlib.contextmanager - def group( - self, name: typing.Union[str, typing.Tuple[str]] - ) -> typing.Generator[str, None, None]: + def group(self, *names: str) -> typing.Generator[Group, None, None]: """Create a group of nodes. Allows to group nodes together, independent of their order in the graph. @@ -223,9 +246,10 @@ def group( Attributes ---------- - name : str|tuple[str] + *names : str Name of the group. If the name is already used, the nodes will be added - to the existing group. + to the existing group. Multiple names can be provided to create nested + groups. Raises ------ @@ -235,9 +259,11 @@ def group( Yields ------ - str - Name of the group. + Group: + A group of containing the nodes that are added within the context manager. """ + if len(names) == 0: + raise ValueError("At least one name must be provided.") if self.active_group is not None: raise TypeError( f"Nested groups are not supported. Group with name '{self.active_group}'" @@ -246,18 +272,21 @@ def group( existing_nodes = self.get_sorted_nodes() + group = self._groups.get(names, Group(names=names, uuids=[], graph=self)) + try: - self.active_group = name + self.active_group = group if get_graph() is empty_graph: with self: - yield name + yield group else: - yield name + yield group finally: self.active_group = None for node_uuid in self.nodes: if node_uuid not in existing_nodes: - self._groups.setdefault(name, []).append(node_uuid) + self._groups[group.names] = group + group.uuids.append(node_uuid) - def get_group(self, name: str) -> typing.List[str]: - return self._groups[name] + def get_group(self, *names: str) -> Group: + return self._groups[names]