From e21d92a6d8a445178b29870f8732b9453341871a Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Mon, 26 Feb 2024 09:02:50 +0100 Subject: [PATCH 01/16] allow dynamic resolving within the graph --- tests/test_dynamic.py | 24 ++++++++++++++++++++++++ znflow/__init__.py | 2 ++ znflow/dynamic.py | 20 ++++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 tests/test_dynamic.py create mode 100644 znflow/dynamic.py diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py new file mode 100644 index 0000000..6f72628 --- /dev/null +++ b/tests/test_dynamic.py @@ -0,0 +1,24 @@ +import znflow +import dataclasses + + +@dataclasses.dataclass +class AddOne(znflow.Node): + inputs: float + outputs: float = None + + def run(self): + self.outputs = self.inputs + 1 + +def test_break_loop(): + graph = znflow.DiGraph() + with graph: + node1 = AddOne(inputs=1) + for _ in range(10): + node1 = AddOne(inputs=node1.outputs) + if znflow.resolve(node1.outputs) > 5: + break + + graph.run() + assert len(graph) == 5 + assert node1.outputs == 6 diff --git a/znflow/__init__.py b/znflow/__init__.py index 2d00af5..ecc59d8 100644 --- a/znflow/__init__.py +++ b/znflow/__init__.py @@ -19,6 +19,7 @@ from znflow.graph import DiGraph from znflow.node import Node, nodify from znflow.visualize import draw +from znflow.dynamic import resolve __version__ = importlib.metadata.version(__name__) @@ -37,6 +38,7 @@ "exceptions", "get_graph", "empty_graph", + "resolve", ] with contextlib.suppress(ImportError): diff --git a/znflow/dynamic.py b/znflow/dynamic.py new file mode 100644 index 0000000..135c555 --- /dev/null +++ b/znflow/dynamic.py @@ -0,0 +1,20 @@ +import dis +from znflow.node import Node +from znflow.base import Connection, disable_graph, get_graph +import typing as t + +def resolve(value: Connection| t.Any): + # TODO: support nodify as well + if not isinstance(value, (Connection)): + raise ValueError(f"Expected a Node, got {value}") + # get the actual value + with disable_graph(): + # if the node has not been run yet, run it + result = value.result + if result is None: + graph = get_graph() + + with disable_graph(): + graph.run() + result = value.result + return result From 00bdd41a9a5ddd3e493248c4d3f608c7a2bc5646 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 08:03:05 +0000 Subject: [PATCH 02/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_dynamic.py | 4 +++- znflow/__init__.py | 2 +- znflow/dynamic.py | 10 +++++----- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index 6f72628..6819df4 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -1,6 +1,7 @@ -import znflow import dataclasses +import znflow + @dataclasses.dataclass class AddOne(znflow.Node): @@ -10,6 +11,7 @@ class AddOne(znflow.Node): def run(self): self.outputs = self.inputs + 1 + def test_break_loop(): graph = znflow.DiGraph() with graph: diff --git a/znflow/__init__.py b/znflow/__init__.py index ecc59d8..85c41e4 100644 --- a/znflow/__init__.py +++ b/znflow/__init__.py @@ -16,10 +16,10 @@ get_graph, ) from znflow.combine import combine +from znflow.dynamic import resolve from znflow.graph import DiGraph from znflow.node import Node, nodify from znflow.visualize import draw -from znflow.dynamic import resolve __version__ = importlib.metadata.version(__name__) diff --git a/znflow/dynamic.py b/znflow/dynamic.py index 135c555..17986c1 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -1,9 +1,9 @@ -import dis -from znflow.node import Node -from znflow.base import Connection, disable_graph, get_graph import typing as t -def resolve(value: Connection| t.Any): +from znflow.base import Connection, disable_graph, get_graph + + +def resolve(value: Connection | t.Any): # TODO: support nodify as well if not isinstance(value, (Connection)): raise ValueError(f"Expected a Node, got {value}") @@ -13,7 +13,7 @@ def resolve(value: Connection| t.Any): result = value.result if result is None: graph = get_graph() - + with disable_graph(): graph.run() result = value.result From b3a21cf6b14037549d19b36f6934c81d22511d7c Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Mon, 26 Feb 2024 09:11:05 +0100 Subject: [PATCH 03/16] more testing --- tests/test_dynamic.py | 34 ++++++++++++++++++++++++++++++++++ znflow/dynamic.py | 4 +++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index 6f72628..98a50cb 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -10,7 +10,9 @@ class AddOne(znflow.Node): def run(self): self.outputs = self.inputs + 1 + def test_break_loop(): + """Test loop breaking when output exceeds 5.""" graph = znflow.DiGraph() with graph: node1 = AddOne(inputs=1) @@ -20,5 +22,37 @@ def test_break_loop(): break graph.run() + + # Assert the correct number of nodes in the graph assert len(graph) == 5 + + # Assert the final output value assert node1.outputs == 6 + + +def test_break_loop_multiple(): + """Test loop breaking with multiple nodes and different conditions.""" + graph = znflow.DiGraph() + with graph: + node1 = AddOne(inputs=1) + node2 = AddOne(inputs=node1.outputs) # Add another node in the loop + + for _ in range(10): + node1 = AddOne(inputs=node1.outputs) + node2 = AddOne(inputs=node2.outputs) + + # Break if either node's output exceeds 5 or both reach 3 + if (znflow.resolve(node1.outputs) > 5 or + znflow.resolve(node2.outputs) > 5 or + znflow.resolve(node1.outputs) == 3 and znflow.resolve(node2.outputs) == 3): + break + + graph.run() + + # Assert the correct number of nodes in the graph + assert len(graph) <= 10 # Maximum number of iterations allowed + + # Assert that at least one node's output exceeds 5 or both reach 3 + assert (znflow.resolve(node1.outputs) > 5 or + znflow.resolve(node2.outputs) > 5 or + znflow.resolve(node1.outputs) == 3 and znflow.resolve(node2.outputs) == 3) diff --git a/znflow/dynamic.py b/znflow/dynamic.py index 135c555..8cc3b3e 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -6,13 +6,15 @@ def resolve(value: Connection| t.Any): # TODO: support nodify as well if not isinstance(value, (Connection)): - raise ValueError(f"Expected a Node, got {value}") + return value # get the actual value with disable_graph(): # if the node has not been run yet, run it result = value.result if result is None: graph = get_graph() + else: + return result with disable_graph(): graph.run() From aa9b6543a098d8d3907edf8e155fbec508d232e9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 08:12:02 +0000 Subject: [PATCH 04/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_dynamic.py | 18 ++++++++++++------ znflow/dynamic.py | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index d186047..87caa05 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -43,9 +43,12 @@ def test_break_loop_multiple(): node2 = AddOne(inputs=node2.outputs) # Break if either node's output exceeds 5 or both reach 3 - if (znflow.resolve(node1.outputs) > 5 or - znflow.resolve(node2.outputs) > 5 or - znflow.resolve(node1.outputs) == 3 and znflow.resolve(node2.outputs) == 3): + if ( + znflow.resolve(node1.outputs) > 5 + or znflow.resolve(node2.outputs) > 5 + or znflow.resolve(node1.outputs) == 3 + and znflow.resolve(node2.outputs) == 3 + ): break graph.run() @@ -54,6 +57,9 @@ def test_break_loop_multiple(): assert len(graph) <= 10 # Maximum number of iterations allowed # Assert that at least one node's output exceeds 5 or both reach 3 - assert (znflow.resolve(node1.outputs) > 5 or - znflow.resolve(node2.outputs) > 5 or - znflow.resolve(node1.outputs) == 3 and znflow.resolve(node2.outputs) == 3) + assert ( + znflow.resolve(node1.outputs) > 5 + or znflow.resolve(node2.outputs) > 5 + or znflow.resolve(node1.outputs) == 3 + and znflow.resolve(node2.outputs) == 3 + ) diff --git a/znflow/dynamic.py b/znflow/dynamic.py index 9016c58..49fb7d2 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -15,7 +15,7 @@ def resolve(value: Connection | t.Any): graph = get_graph() else: return result - + with disable_graph(): graph.run() result = value.result From fcd4c98e8b2e0f2a2df60e1b41232ad87cc83a6e Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Mon, 26 Feb 2024 09:29:12 +0100 Subject: [PATCH 05/16] only run relevant nodes --- tests/test_dynamic.py | 25 +++++++++++++++++++++++++ znflow/base.py | 5 ++++- znflow/dynamic.py | 20 ++++++++++++++++++-- znflow/graph.py | 30 +++++++++++++++++++++++------- 4 files changed, 70 insertions(+), 10 deletions(-) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index d186047..b8cfa1d 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -9,6 +9,8 @@ class AddOne(znflow.Node): outputs: float = None def run(self): + # if self.outputs is not None: + # raise ValueError("Node has already been run") self.outputs = self.inputs + 1 @@ -57,3 +59,26 @@ def test_break_loop_multiple(): assert (znflow.resolve(node1.outputs) > 5 or znflow.resolve(node2.outputs) > 5 or znflow.resolve(node1.outputs) == 3 and znflow.resolve(node2.outputs) == 3) + + +def test_resolvce_only_run_relevant_nodes(): + """Test that when using resolve only nodes that are direct predecessors are run.""" + # Check by asserting None to the output of the second node + graph = znflow.DiGraph() + with graph: + node1 = AddOne(inputs=1) + node2 = AddOne(inputs=1234) + for _ in range(10): + node1 = AddOne(inputs=node1.outputs) + if znflow.resolve(node1.outputs) > 5: + break + + # this has to be executed, because of the resolve + assert node1.outputs == 6 + + # this should not be executed, because it is not relevant to the resolve + assert node2.outputs is None + + graph.run() + assert node2.outputs == 1235 + assert node1.outputs == 6 diff --git a/znflow/base.py b/znflow/base.py index 7c46a2a..c30fff3 100644 --- a/znflow/base.py +++ b/znflow/base.py @@ -8,6 +8,9 @@ from znflow import exceptions +if typing.TYPE_CHECKING: + from znflow.graph import DiGraph + @contextlib.contextmanager def disable_graph(*args, **kwargs): @@ -126,7 +129,7 @@ def run(self): raise NotImplementedError -def get_graph(): +def get_graph() -> DiGraph: return NodeBaseMixin._graph_ diff --git a/znflow/dynamic.py b/znflow/dynamic.py index 9016c58..890033a 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -3,7 +3,23 @@ from znflow.base import Connection, disable_graph, get_graph -def resolve(value: Connection | t.Any): +def resolve(value: t.Union[Connection, t.Any]): + """Resolve a Connection to its actual value. + + Allows dynamic resolution of connections to their actual values within a graph context. + This will run all Nodes up to this connection. + + Attributes + ---------- + value : Connection + The connection to resolve. + + Returns + ------- + t.Any + The actual value of the connection. + + """ # TODO: support nodify as well if not isinstance(value, (Connection)): return value @@ -17,6 +33,6 @@ def resolve(value: Connection | t.Any): return result with disable_graph(): - graph.run() + graph.run(nodes=[value.instance]) result = value.result return result diff --git a/znflow/graph.py b/znflow/graph.py index 5519078..81678d1 100644 --- a/znflow/graph.py +++ b/znflow/graph.py @@ -142,13 +142,29 @@ def get_sorted_nodes(self): all_pipelines += nx.dfs_postorder_nodes(reverse, stage) return list(dict.fromkeys(all_pipelines)) # remove duplicates but keep order - def run(self): - for node_uuid in self.get_sorted_nodes(): - node = self.nodes[node_uuid]["value"] - if not node._external_: - # update connectors - self._update_node_attributes(node, handler.UpdateConnectors()) - node.run() + def run(self, nodes: typing.Optional[typing.List[NodeBaseMixin]] = None): + if nodes is not None: + for node_uuid in self.reverse(): + node = self.nodes[node_uuid]["value"] + if node in nodes: + predecessors = list(self.predecessors(node.uuid)) + for predecessor in predecessors: + predecessor_node = self.nodes[predecessor]["value"] + print(f"Predecessor: {predecessor_node}") + self._update_node_attributes( + predecessor_node, handler.UpdateConnectors() + ) + predecessor_node.run() + self._update_node_attributes(node, handler.UpdateConnectors()) + print(f"Node: {node}") + node.run() + else: + for node_uuid in self.get_sorted_nodes(): + node = self.nodes[node_uuid]["value"] + if not node._external_: + # update connectors + self._update_node_attributes(node, handler.UpdateConnectors()) + node.run() def write_graph(self, *args): for node in args: From 114600c3eccfd7ba6ef75e284e74774101dcc1cb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 08:30:06 +0000 Subject: [PATCH 06/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_dynamic.py | 17 ++++++++++------- znflow/dynamic.py | 6 +++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index d0e270d..ca9dd44 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -59,9 +59,12 @@ def test_break_loop_multiple(): assert len(graph) <= 10 # Maximum number of iterations allowed # Assert that at least one node's output exceeds 5 or both reach 3 - assert (znflow.resolve(node1.outputs) > 5 or - znflow.resolve(node2.outputs) > 5 or - znflow.resolve(node1.outputs) == 3 and znflow.resolve(node2.outputs) == 3) + assert ( + znflow.resolve(node1.outputs) > 5 + or znflow.resolve(node2.outputs) > 5 + or znflow.resolve(node1.outputs) == 3 + and znflow.resolve(node2.outputs) == 3 + ) def test_resolvce_only_run_relevant_nodes(): @@ -75,12 +78,12 @@ def test_resolvce_only_run_relevant_nodes(): node1 = AddOne(inputs=node1.outputs) if znflow.resolve(node1.outputs) > 5: break - + # this has to be executed, because of the resolve - assert node1.outputs == 6 - + assert node1.outputs == 6 + # this should not be executed, because it is not relevant to the resolve - assert node2.outputs is None + assert node2.outputs is None graph.run() assert node2.outputs == 1235 diff --git a/znflow/dynamic.py b/znflow/dynamic.py index 73bac85..b61841e 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -8,17 +8,17 @@ def resolve(value: t.Union[Connection, t.Any]): Allows dynamic resolution of connections to their actual values within a graph context. This will run all Nodes up to this connection. - + Attributes ---------- value : Connection The connection to resolve. - + Returns ------- t.Any The actual value of the connection. - + """ # TODO: support nodify as well if not isinstance(value, (Connection)): From c69708541297258276c0173d9cdb90f568a6a7bb Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Mon, 26 Feb 2024 09:31:32 +0100 Subject: [PATCH 07/16] shorten line --- znflow/dynamic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/znflow/dynamic.py b/znflow/dynamic.py index b61841e..3ccece0 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -6,8 +6,8 @@ def resolve(value: t.Union[Connection, t.Any]): """Resolve a Connection to its actual value. - Allows dynamic resolution of connections to their actual values within a graph context. - This will run all Nodes up to this connection. + Allows dynamic resolution of connections to their actual values + within a graph context. This will run all Nodes up to this connection. Attributes ---------- From 569901dfe36d0e61a096fcee1021e68c607b1250 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 08:31:43 +0000 Subject: [PATCH 08/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- znflow/dynamic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/znflow/dynamic.py b/znflow/dynamic.py index 3ccece0..688152c 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -6,7 +6,7 @@ def resolve(value: t.Union[Connection, t.Any]): """Resolve a Connection to its actual value. - Allows dynamic resolution of connections to their actual values + Allows dynamic resolution of connections to their actual values within a graph context. This will run all Nodes up to this connection. Attributes From e81b923b86c41d91f161116a5de9d8e17dc938ed Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Mon, 26 Feb 2024 09:35:12 +0100 Subject: [PATCH 09/16] update description --- znflow/dynamic.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/znflow/dynamic.py b/znflow/dynamic.py index 3ccece0..5bf1f61 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -3,7 +3,7 @@ from znflow.base import Connection, disable_graph, get_graph -def resolve(value: t.Union[Connection, t.Any]): +def resolve(value: t.Union[Connection, t.Any]) -> t.Any: """Resolve a Connection to its actual value. Allows dynamic resolution of connections to their actual values @@ -25,12 +25,11 @@ def resolve(value: t.Union[Connection, t.Any]): return value # get the actual value with disable_graph(): - # if the node has not been run yet, run it result = value.result - if result is None: - graph = get_graph() - else: + if result is not None: return result + # we assume, that if the result is None, the node has not been run yet + graph = get_graph() with disable_graph(): graph.run(nodes=[value.instance]) From 60e2ffd3df3683431ec738ca73b47e99ddfc2712 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Mon, 26 Feb 2024 09:37:59 +0100 Subject: [PATCH 10/16] check connections are not altered --- tests/test_dynamic.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index ca9dd44..4e9c608 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -5,8 +5,8 @@ @dataclasses.dataclass class AddOne(znflow.Node): - inputs: float - outputs: float = None + inputs: int + outputs: int = None def run(self): # if self.outputs is not None: @@ -88,3 +88,12 @@ def test_resolvce_only_run_relevant_nodes(): graph.run() assert node2.outputs == 1235 assert node1.outputs == 6 + +def test_connections_remain(): + graph = znflow.DiGraph() + with graph: + node1 = AddOne(inputs=1) + result = znflow.resolve(node1.outputs) + assert isinstance(result, int) + assert isinstance(node1.outputs, znflow.Connection) + \ No newline at end of file From 0f06e58ffd390b2ef608a9557f102ac777070954 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 08:38:08 +0000 Subject: [PATCH 11/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_dynamic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index 4e9c608..d67ffa8 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -89,6 +89,7 @@ def test_resolvce_only_run_relevant_nodes(): assert node2.outputs == 1235 assert node1.outputs == 6 + def test_connections_remain(): graph = znflow.DiGraph() with graph: @@ -96,4 +97,3 @@ def test_connections_remain(): result = znflow.resolve(node1.outputs) assert isinstance(result, int) assert isinstance(node1.outputs, znflow.Connection) - \ No newline at end of file From d357805bdae64d4710fd85267e06d74eade7d68d Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Mon, 26 Feb 2024 09:57:35 +0100 Subject: [PATCH 12/16] default immutable_nodes=True, e.g. assuming Nodes are not updated after they are created --- tests/test_dynamic.py | 4 ++-- tests/test_late_updates.py | 37 +++++++++++++++++++++++++++++++++++++ znflow/dynamic.py | 8 ++++++-- znflow/graph.py | 27 ++++++++++++++++++++++++--- 4 files changed, 69 insertions(+), 7 deletions(-) create mode 100644 tests/test_late_updates.py diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index d67ffa8..8983005 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -9,8 +9,8 @@ class AddOne(znflow.Node): outputs: int = None def run(self): - # if self.outputs is not None: - # raise ValueError("Node has already been run") + if self.outputs is not None: + raise ValueError("Node has already been run") self.outputs = self.inputs + 1 diff --git a/tests/test_late_updates.py b/tests/test_late_updates.py new file mode 100644 index 0000000..236bae3 --- /dev/null +++ b/tests/test_late_updates.py @@ -0,0 +1,37 @@ +import dataclasses +import znflow +from znflow import node + +@dataclasses.dataclass +class AddOne(znflow.Node): + inputs: int + outputs: int = None + + def run(self): + self.outputs = self.inputs + 1 + +def test_update_after_exit(): + graph = znflow.DiGraph() + with graph: + node1 = AddOne(inputs=1) + + node1.inputs = 2 + graph.run(immutable_nodes=False) + assert node1.outputs == 3 + + node1.inputs = 3 + graph.run(immutable_nodes=False) + assert node1.outputs == 4 + +def test_update_after_exit_immutable(): + graph = znflow.DiGraph() + with graph: + node1 = AddOne(inputs=1) + + node1.inputs = 2 + graph.run() + assert node1.outputs == 3 + + node1.inputs = 3 + graph.run() + assert node1.outputs == 3 diff --git a/znflow/dynamic.py b/znflow/dynamic.py index 34bae2b..21d571c 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -3,7 +3,7 @@ from znflow.base import Connection, disable_graph, get_graph -def resolve(value: t.Union[Connection, t.Any]) -> t.Any: +def resolve(value: t.Union[Connection, t.Any], immutable_nodes: bool = True) -> t.Any: """Resolve a Connection to its actual value. Allows dynamic resolution of connections to their actual values @@ -13,6 +13,10 @@ def resolve(value: t.Union[Connection, t.Any]) -> t.Any: ---------- value : Connection The connection to resolve. + immutable_nodes : bool + If True, the nodes are assumed to be immutable and + will not be rerun. If you change the inputs of a node + after it has been run, the outputs will not be updated. Returns ------- @@ -32,6 +36,6 @@ def resolve(value: t.Union[Connection, t.Any]) -> t.Any: graph = get_graph() with disable_graph(): - graph.run(nodes=[value.instance]) + graph.run(nodes=[value.instance], immutable_nodes=immutable_nodes) result = value.result return result diff --git a/znflow/graph.py b/znflow/graph.py index 81678d1..b82379b 100644 --- a/znflow/graph.py +++ b/znflow/graph.py @@ -142,29 +142,50 @@ def get_sorted_nodes(self): all_pipelines += nx.dfs_postorder_nodes(reverse, stage) return list(dict.fromkeys(all_pipelines)) # remove duplicates but keep order - def run(self, nodes: typing.Optional[typing.List[NodeBaseMixin]] = None): + def run(self, nodes: typing.Optional[typing.List[NodeBaseMixin]] = None, immutable_nodes: bool = True): + """Run the graph. + + Attributes + ---------- + nodes : list[Node] + The nodes to run. If None, all nodes are run. + immutable_nodes : bool + If True, the nodes are assumed to be immutable and + will not be rerun. If you change the inputs of a node + after it has been run, the outputs will not be updated. + """ if nodes is not None: for node_uuid in self.reverse(): + if immutable_nodes and self.nodes[node_uuid].get("available", False): + continue node = self.nodes[node_uuid]["value"] if node in nodes: predecessors = list(self.predecessors(node.uuid)) for predecessor in predecessors: predecessor_node = self.nodes[predecessor]["value"] - print(f"Predecessor: {predecessor_node}") + if immutable_nodes and self.nodes[predecessor].get("available", False): + continue self._update_node_attributes( predecessor_node, handler.UpdateConnectors() ) predecessor_node.run() + if immutable_nodes: + self.nodes[predecessor]["available"] = True self._update_node_attributes(node, handler.UpdateConnectors()) - print(f"Node: {node}") node.run() + if immutable_nodes: + self.nodes[node_uuid]["available"] = True else: for node_uuid in self.get_sorted_nodes(): + if immutable_nodes and self.nodes[node_uuid].get("available", False): + continue node = self.nodes[node_uuid]["value"] if not node._external_: # update connectors self._update_node_attributes(node, handler.UpdateConnectors()) node.run() + if immutable_nodes: + self.nodes[node_uuid]["available"] = True def write_graph(self, *args): for node in args: From 563c656a55883943441b01accc43ff460e9ada86 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 08:57:46 +0000 Subject: [PATCH 13/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_late_updates.py | 5 ++++- znflow/dynamic.py | 2 +- znflow/graph.py | 14 ++++++++++---- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/test_late_updates.py b/tests/test_late_updates.py index 236bae3..29908dc 100644 --- a/tests/test_late_updates.py +++ b/tests/test_late_updates.py @@ -1,6 +1,7 @@ import dataclasses + import znflow -from znflow import node + @dataclasses.dataclass class AddOne(znflow.Node): @@ -10,6 +11,7 @@ class AddOne(znflow.Node): def run(self): self.outputs = self.inputs + 1 + def test_update_after_exit(): graph = znflow.DiGraph() with graph: @@ -23,6 +25,7 @@ def test_update_after_exit(): graph.run(immutable_nodes=False) assert node1.outputs == 4 + def test_update_after_exit_immutable(): graph = znflow.DiGraph() with graph: diff --git a/znflow/dynamic.py b/znflow/dynamic.py index 21d571c..b9c39ed 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -14,7 +14,7 @@ def resolve(value: t.Union[Connection, t.Any], immutable_nodes: bool = True) -> value : Connection The connection to resolve. immutable_nodes : bool - If True, the nodes are assumed to be immutable and + If True, the nodes are assumed to be immutable and will not be rerun. If you change the inputs of a node after it has been run, the outputs will not be updated. diff --git a/znflow/graph.py b/znflow/graph.py index b82379b..14cfa8e 100644 --- a/znflow/graph.py +++ b/znflow/graph.py @@ -142,15 +142,19 @@ def get_sorted_nodes(self): all_pipelines += nx.dfs_postorder_nodes(reverse, stage) return list(dict.fromkeys(all_pipelines)) # remove duplicates but keep order - def run(self, nodes: typing.Optional[typing.List[NodeBaseMixin]] = None, immutable_nodes: bool = True): + def run( + self, + nodes: typing.Optional[typing.List[NodeBaseMixin]] = None, + immutable_nodes: bool = True, + ): """Run the graph. - + Attributes ---------- nodes : list[Node] The nodes to run. If None, all nodes are run. immutable_nodes : bool - If True, the nodes are assumed to be immutable and + If True, the nodes are assumed to be immutable and will not be rerun. If you change the inputs of a node after it has been run, the outputs will not be updated. """ @@ -163,7 +167,9 @@ def run(self, nodes: typing.Optional[typing.List[NodeBaseMixin]] = None, immutab predecessors = list(self.predecessors(node.uuid)) for predecessor in predecessors: predecessor_node = self.nodes[predecessor]["value"] - if immutable_nodes and self.nodes[predecessor].get("available", False): + if immutable_nodes and self.nodes[predecessor].get( + "available", False + ): continue self._update_node_attributes( predecessor_node, handler.UpdateConnectors() From 0ae08a1b373d365054ea658b6d6f4d9b1d89c15d Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Mon, 26 Feb 2024 10:03:54 +0100 Subject: [PATCH 14/16] make `immutable_nodes` are graph property --- tests/test_late_updates.py | 6 +++--- znflow/dynamic.py | 8 ++------ znflow/graph.py | 29 +++++++++++++++++------------ 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/test_late_updates.py b/tests/test_late_updates.py index 29908dc..f105b22 100644 --- a/tests/test_late_updates.py +++ b/tests/test_late_updates.py @@ -13,16 +13,16 @@ def run(self): def test_update_after_exit(): - graph = znflow.DiGraph() + graph = znflow.DiGraph(immutable_nodes=False) with graph: node1 = AddOne(inputs=1) node1.inputs = 2 - graph.run(immutable_nodes=False) + graph.run() assert node1.outputs == 3 node1.inputs = 3 - graph.run(immutable_nodes=False) + graph.run() assert node1.outputs == 4 diff --git a/znflow/dynamic.py b/znflow/dynamic.py index b9c39ed..34bae2b 100644 --- a/znflow/dynamic.py +++ b/znflow/dynamic.py @@ -3,7 +3,7 @@ from znflow.base import Connection, disable_graph, get_graph -def resolve(value: t.Union[Connection, t.Any], immutable_nodes: bool = True) -> t.Any: +def resolve(value: t.Union[Connection, t.Any]) -> t.Any: """Resolve a Connection to its actual value. Allows dynamic resolution of connections to their actual values @@ -13,10 +13,6 @@ def resolve(value: t.Union[Connection, t.Any], immutable_nodes: bool = True) -> ---------- value : Connection The connection to resolve. - immutable_nodes : bool - If True, the nodes are assumed to be immutable and - will not be rerun. If you change the inputs of a node - after it has been run, the outputs will not be updated. Returns ------- @@ -36,6 +32,6 @@ def resolve(value: t.Union[Connection, t.Any], immutable_nodes: bool = True) -> graph = get_graph() with disable_graph(): - graph.run(nodes=[value.instance], immutable_nodes=immutable_nodes) + graph.run(nodes=[value.instance]) result = value.result return result diff --git a/znflow/graph.py b/znflow/graph.py index 14cfa8e..e24a24e 100644 --- a/znflow/graph.py +++ b/znflow/graph.py @@ -20,10 +20,20 @@ class DiGraph(nx.MultiDiGraph): - def __init__(self, *args, disable=False, **kwargs): + def __init__(self, *args, disable=False, immutable_nodes=True, **kwargs): + """ + Attributes + ---------- + immutable_nodes : bool + If True, the nodes are assumed to be immutable and + will not be rerun. If you change the inputs of a node + after it has been run, the outputs will not be updated. + """ self.disable = disable + self.immutable_nodes = immutable_nodes self._groups = {} self.active_group = None + super().__init__(*args, **kwargs) @property @@ -145,7 +155,6 @@ def get_sorted_nodes(self): def run( self, nodes: typing.Optional[typing.List[NodeBaseMixin]] = None, - immutable_nodes: bool = True, ): """Run the graph. @@ -153,21 +162,17 @@ def run( ---------- nodes : list[Node] The nodes to run. If None, all nodes are run. - immutable_nodes : bool - If True, the nodes are assumed to be immutable and - will not be rerun. If you change the inputs of a node - after it has been run, the outputs will not be updated. """ if nodes is not None: for node_uuid in self.reverse(): - if immutable_nodes and self.nodes[node_uuid].get("available", False): + if self.immutable_nodes and self.nodes[node_uuid].get("available", False): continue node = self.nodes[node_uuid]["value"] if node in nodes: predecessors = list(self.predecessors(node.uuid)) for predecessor in predecessors: predecessor_node = self.nodes[predecessor]["value"] - if immutable_nodes and self.nodes[predecessor].get( + if self.immutable_nodes and self.nodes[predecessor].get( "available", False ): continue @@ -175,22 +180,22 @@ def run( predecessor_node, handler.UpdateConnectors() ) predecessor_node.run() - if immutable_nodes: + if self.immutable_nodes: self.nodes[predecessor]["available"] = True self._update_node_attributes(node, handler.UpdateConnectors()) node.run() - if immutable_nodes: + if self.immutable_nodes: self.nodes[node_uuid]["available"] = True else: for node_uuid in self.get_sorted_nodes(): - if immutable_nodes and self.nodes[node_uuid].get("available", False): + if self.immutable_nodes and self.nodes[node_uuid].get("available", False): continue node = self.nodes[node_uuid]["value"] if not node._external_: # update connectors self._update_node_attributes(node, handler.UpdateConnectors()) node.run() - if immutable_nodes: + if self.immutable_nodes: self.nodes[node_uuid]["available"] = True def write_graph(self, *args): From f6d9f41ad7e2a81b3485ef6e4a91c437bca0829f Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Mon, 26 Feb 2024 10:12:59 +0100 Subject: [PATCH 15/16] test loop --- tests/test_dynamic.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index 8983005..dd3907d 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -97,3 +97,16 @@ def test_connections_remain(): result = znflow.resolve(node1.outputs) assert isinstance(result, int) assert isinstance(node1.outputs, znflow.Connection) + + +def test_loop_over_results(): + graph = znflow.DiGraph() + with graph: + node1 = AddOne(inputs=5) + nodes = [] + for idx in range(znflow.resolve(node1.outputs)): + nodes.append(AddOne(inputs=idx)) + + graph.run() + assert len(nodes) == 6 + assert [node.outputs for node in nodes] == [1, 2, 3, 4, 5, 6] From 9624a94a2f4515d2f8902fb2f2a8a6e9c22828e6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:13:10 +0000 Subject: [PATCH 16/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_dynamic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index dd3907d..db15991 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -106,7 +106,7 @@ def test_loop_over_results(): nodes = [] for idx in range(znflow.resolve(node1.outputs)): nodes.append(AddOne(inputs=idx)) - + graph.run() assert len(nodes) == 6 assert [node.outputs for node in nodes] == [1, 2, 3, 4, 5, 6]