-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
allow dynamic resolving within the graph (#95)
* allow dynamic resolving within the graph * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more testing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * only run relevant nodes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * shorten line * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update description * check connections are not altered * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * default immutable_nodes=True, e.g. assuming Nodes are not updated after they are created * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make `immutable_nodes` are graph property * test loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
14ced99
commit b9bbb3d
Showing
6 changed files
with
251 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import dataclasses | ||
|
||
import znflow | ||
|
||
|
||
@dataclasses.dataclass | ||
class AddOne(znflow.Node): | ||
inputs: int | ||
outputs: int = None | ||
|
||
def run(self): | ||
if self.outputs is not None: | ||
raise ValueError("Node has already been run") | ||
self.outputs = self.inputs + 1 | ||
|
||
|
||
def test_break_loop(): | ||
"""Test loop breaking when output exceeds 5.""" | ||
graph = znflow.DiGraph() | ||
with graph: | ||
node1 = AddOne(inputs=1) | ||
for _ in range(10): | ||
node1 = AddOne(inputs=node1.outputs) | ||
if znflow.resolve(node1.outputs) > 5: | ||
break | ||
|
||
graph.run() | ||
|
||
# Assert the correct number of nodes in the graph | ||
assert len(graph) == 5 | ||
|
||
# Assert the final output value | ||
assert node1.outputs == 6 | ||
|
||
|
||
def test_break_loop_multiple(): | ||
"""Test loop breaking with multiple nodes and different conditions.""" | ||
graph = znflow.DiGraph() | ||
with graph: | ||
node1 = AddOne(inputs=1) | ||
node2 = AddOne(inputs=node1.outputs) # Add another node in the loop | ||
|
||
for _ in range(10): | ||
node1 = AddOne(inputs=node1.outputs) | ||
node2 = AddOne(inputs=node2.outputs) | ||
|
||
# Break if either node's output exceeds 5 or both reach 3 | ||
if ( | ||
znflow.resolve(node1.outputs) > 5 | ||
or znflow.resolve(node2.outputs) > 5 | ||
or znflow.resolve(node1.outputs) == 3 | ||
and znflow.resolve(node2.outputs) == 3 | ||
): | ||
break | ||
|
||
graph.run() | ||
|
||
# Assert the correct number of nodes in the graph | ||
assert len(graph) <= 10 # Maximum number of iterations allowed | ||
|
||
# Assert that at least one node's output exceeds 5 or both reach 3 | ||
assert ( | ||
znflow.resolve(node1.outputs) > 5 | ||
or znflow.resolve(node2.outputs) > 5 | ||
or znflow.resolve(node1.outputs) == 3 | ||
and znflow.resolve(node2.outputs) == 3 | ||
) | ||
|
||
|
||
def test_resolvce_only_run_relevant_nodes(): | ||
"""Test that when using resolve only nodes that are direct predecessors are run.""" | ||
# Check by asserting None to the output of the second node | ||
graph = znflow.DiGraph() | ||
with graph: | ||
node1 = AddOne(inputs=1) | ||
node2 = AddOne(inputs=1234) | ||
for _ in range(10): | ||
node1 = AddOne(inputs=node1.outputs) | ||
if znflow.resolve(node1.outputs) > 5: | ||
break | ||
|
||
# this has to be executed, because of the resolve | ||
assert node1.outputs == 6 | ||
|
||
# this should not be executed, because it is not relevant to the resolve | ||
assert node2.outputs is None | ||
|
||
graph.run() | ||
assert node2.outputs == 1235 | ||
assert node1.outputs == 6 | ||
|
||
|
||
def test_connections_remain(): | ||
graph = znflow.DiGraph() | ||
with graph: | ||
node1 = AddOne(inputs=1) | ||
result = znflow.resolve(node1.outputs) | ||
assert isinstance(result, int) | ||
assert isinstance(node1.outputs, znflow.Connection) | ||
|
||
|
||
def test_loop_over_results(): | ||
graph = znflow.DiGraph() | ||
with graph: | ||
node1 = AddOne(inputs=5) | ||
nodes = [] | ||
for idx in range(znflow.resolve(node1.outputs)): | ||
nodes.append(AddOne(inputs=idx)) | ||
|
||
graph.run() | ||
assert len(nodes) == 6 | ||
assert [node.outputs for node in nodes] == [1, 2, 3, 4, 5, 6] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import dataclasses | ||
|
||
import znflow | ||
|
||
|
||
@dataclasses.dataclass | ||
class AddOne(znflow.Node): | ||
inputs: int | ||
outputs: int = None | ||
|
||
def run(self): | ||
self.outputs = self.inputs + 1 | ||
|
||
|
||
def test_update_after_exit(): | ||
graph = znflow.DiGraph(immutable_nodes=False) | ||
with graph: | ||
node1 = AddOne(inputs=1) | ||
|
||
node1.inputs = 2 | ||
graph.run() | ||
assert node1.outputs == 3 | ||
|
||
node1.inputs = 3 | ||
graph.run() | ||
assert node1.outputs == 4 | ||
|
||
|
||
def test_update_after_exit_immutable(): | ||
graph = znflow.DiGraph() | ||
with graph: | ||
node1 = AddOne(inputs=1) | ||
|
||
node1.inputs = 2 | ||
graph.run() | ||
assert node1.outputs == 3 | ||
|
||
node1.inputs = 3 | ||
graph.run() | ||
assert node1.outputs == 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import typing as t | ||
|
||
from znflow.base import Connection, disable_graph, get_graph | ||
|
||
|
||
def resolve(value: t.Union[Connection, t.Any]) -> t.Any: | ||
"""Resolve a Connection to its actual value. | ||
Allows dynamic resolution of connections to their actual values | ||
within a graph context. This will run all Nodes up to this connection. | ||
Attributes | ||
---------- | ||
value : Connection | ||
The connection to resolve. | ||
Returns | ||
------- | ||
t.Any | ||
The actual value of the connection. | ||
""" | ||
# TODO: support nodify as well | ||
if not isinstance(value, (Connection)): | ||
return value | ||
# get the actual value | ||
with disable_graph(): | ||
result = value.result | ||
if result is not None: | ||
return result | ||
# we assume, that if the result is None, the node has not been run yet | ||
graph = get_graph() | ||
|
||
with disable_graph(): | ||
graph.run(nodes=[value.instance]) | ||
result = value.result | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters