Skip to content

Commit

Permalink
allow dynamic resolving within the graph (#95)
Browse files Browse the repository at this point in the history
* allow dynamic resolving within the graph

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more testing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* only run relevant nodes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shorten line

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update description

* check connections are not altered

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* default immutable_nodes=True, e.g. assuming Nodes are not updated after they are created

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make `immutable_nodes` are graph property

* test loop

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Feb 26, 2024
1 parent 14ced99 commit b9bbb3d
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 9 deletions.
112 changes: 112 additions & 0 deletions tests/test_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import dataclasses

import znflow


@dataclasses.dataclass
class AddOne(znflow.Node):
inputs: int
outputs: int = None

def run(self):
if self.outputs is not None:
raise ValueError("Node has already been run")
self.outputs = self.inputs + 1


def test_break_loop():
"""Test loop breaking when output exceeds 5."""
graph = znflow.DiGraph()
with graph:
node1 = AddOne(inputs=1)
for _ in range(10):
node1 = AddOne(inputs=node1.outputs)
if znflow.resolve(node1.outputs) > 5:
break

graph.run()

# Assert the correct number of nodes in the graph
assert len(graph) == 5

# Assert the final output value
assert node1.outputs == 6


def test_break_loop_multiple():
"""Test loop breaking with multiple nodes and different conditions."""
graph = znflow.DiGraph()
with graph:
node1 = AddOne(inputs=1)
node2 = AddOne(inputs=node1.outputs) # Add another node in the loop

for _ in range(10):
node1 = AddOne(inputs=node1.outputs)
node2 = AddOne(inputs=node2.outputs)

# Break if either node's output exceeds 5 or both reach 3
if (
znflow.resolve(node1.outputs) > 5
or znflow.resolve(node2.outputs) > 5
or znflow.resolve(node1.outputs) == 3
and znflow.resolve(node2.outputs) == 3
):
break

graph.run()

# Assert the correct number of nodes in the graph
assert len(graph) <= 10 # Maximum number of iterations allowed

# Assert that at least one node's output exceeds 5 or both reach 3
assert (
znflow.resolve(node1.outputs) > 5
or znflow.resolve(node2.outputs) > 5
or znflow.resolve(node1.outputs) == 3
and znflow.resolve(node2.outputs) == 3
)


def test_resolvce_only_run_relevant_nodes():
"""Test that when using resolve only nodes that are direct predecessors are run."""
# Check by asserting None to the output of the second node
graph = znflow.DiGraph()
with graph:
node1 = AddOne(inputs=1)
node2 = AddOne(inputs=1234)
for _ in range(10):
node1 = AddOne(inputs=node1.outputs)
if znflow.resolve(node1.outputs) > 5:
break

# this has to be executed, because of the resolve
assert node1.outputs == 6

# this should not be executed, because it is not relevant to the resolve
assert node2.outputs is None

graph.run()
assert node2.outputs == 1235
assert node1.outputs == 6


def test_connections_remain():
graph = znflow.DiGraph()
with graph:
node1 = AddOne(inputs=1)
result = znflow.resolve(node1.outputs)
assert isinstance(result, int)
assert isinstance(node1.outputs, znflow.Connection)


def test_loop_over_results():
graph = znflow.DiGraph()
with graph:
node1 = AddOne(inputs=5)
nodes = []
for idx in range(znflow.resolve(node1.outputs)):
nodes.append(AddOne(inputs=idx))

graph.run()
assert len(nodes) == 6
assert [node.outputs for node in nodes] == [1, 2, 3, 4, 5, 6]
40 changes: 40 additions & 0 deletions tests/test_late_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import dataclasses

import znflow


@dataclasses.dataclass
class AddOne(znflow.Node):
inputs: int
outputs: int = None

def run(self):
self.outputs = self.inputs + 1


def test_update_after_exit():
graph = znflow.DiGraph(immutable_nodes=False)
with graph:
node1 = AddOne(inputs=1)

node1.inputs = 2
graph.run()
assert node1.outputs == 3

node1.inputs = 3
graph.run()
assert node1.outputs == 4


def test_update_after_exit_immutable():
graph = znflow.DiGraph()
with graph:
node1 = AddOne(inputs=1)

node1.inputs = 2
graph.run()
assert node1.outputs == 3

node1.inputs = 3
graph.run()
assert node1.outputs == 3
2 changes: 2 additions & 0 deletions znflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_graph,
)
from znflow.combine import combine
from znflow.dynamic import resolve
from znflow.graph import DiGraph
from znflow.node import Node, nodify
from znflow.visualize import draw
Expand All @@ -37,6 +38,7 @@
"exceptions",
"get_graph",
"empty_graph",
"resolve",
]

with contextlib.suppress(ImportError):
Expand Down
5 changes: 4 additions & 1 deletion znflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from znflow import exceptions

if typing.TYPE_CHECKING:
from znflow.graph import DiGraph


@contextlib.contextmanager
def disable_graph(*args, **kwargs):
Expand Down Expand Up @@ -126,7 +129,7 @@ def run(self):
raise NotImplementedError


def get_graph():
def get_graph() -> DiGraph:
return NodeBaseMixin._graph_


Expand Down
37 changes: 37 additions & 0 deletions znflow/dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import typing as t

from znflow.base import Connection, disable_graph, get_graph


def resolve(value: t.Union[Connection, t.Any]) -> t.Any:
"""Resolve a Connection to its actual value.
Allows dynamic resolution of connections to their actual values
within a graph context. This will run all Nodes up to this connection.
Attributes
----------
value : Connection
The connection to resolve.
Returns
-------
t.Any
The actual value of the connection.
"""
# TODO: support nodify as well
if not isinstance(value, (Connection)):
return value
# get the actual value
with disable_graph():
result = value.result
if result is not None:
return result
# we assume, that if the result is None, the node has not been run yet
graph = get_graph()

with disable_graph():
graph.run(nodes=[value.instance])
result = value.result
return result
64 changes: 56 additions & 8 deletions znflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,20 @@


class DiGraph(nx.MultiDiGraph):
def __init__(self, *args, disable=False, **kwargs):
def __init__(self, *args, disable=False, immutable_nodes=True, **kwargs):
"""
Attributes
----------
immutable_nodes : bool
If True, the nodes are assumed to be immutable and
will not be rerun. If you change the inputs of a node
after it has been run, the outputs will not be updated.
"""
self.disable = disable
self.immutable_nodes = immutable_nodes
self._groups = {}
self.active_group = None

super().__init__(*args, **kwargs)

@property
Expand Down Expand Up @@ -142,13 +152,51 @@ def get_sorted_nodes(self):
all_pipelines += nx.dfs_postorder_nodes(reverse, stage)
return list(dict.fromkeys(all_pipelines)) # remove duplicates but keep order

def run(self):
for node_uuid in self.get_sorted_nodes():
node = self.nodes[node_uuid]["value"]
if not node._external_:
# update connectors
self._update_node_attributes(node, handler.UpdateConnectors())
node.run()
def run(
self,
nodes: typing.Optional[typing.List[NodeBaseMixin]] = None,
):
"""Run the graph.
Attributes
----------
nodes : list[Node]
The nodes to run. If None, all nodes are run.
"""
if nodes is not None:
for node_uuid in self.reverse():
if self.immutable_nodes and self.nodes[node_uuid].get("available", False):
continue
node = self.nodes[node_uuid]["value"]
if node in nodes:
predecessors = list(self.predecessors(node.uuid))
for predecessor in predecessors:
predecessor_node = self.nodes[predecessor]["value"]
if self.immutable_nodes and self.nodes[predecessor].get(
"available", False
):
continue
self._update_node_attributes(
predecessor_node, handler.UpdateConnectors()
)
predecessor_node.run()
if self.immutable_nodes:
self.nodes[predecessor]["available"] = True
self._update_node_attributes(node, handler.UpdateConnectors())
node.run()
if self.immutable_nodes:
self.nodes[node_uuid]["available"] = True
else:
for node_uuid in self.get_sorted_nodes():
if self.immutable_nodes and self.nodes[node_uuid].get("available", False):
continue
node = self.nodes[node_uuid]["value"]
if not node._external_:
# update connectors
self._update_node_attributes(node, handler.UpdateConnectors())
node.run()
if self.immutable_nodes:
self.nodes[node_uuid]["available"] = True

def write_graph(self, *args):
for node in args:
Expand Down

0 comments on commit b9bbb3d

Please sign in to comment.