Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for magic node names using varnames #776

Merged
merged 14 commits into from
Feb 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ dot4dict = "^0.1"
zninit = "^0.1"
znjson = "^0.2"
znflow = "^0.1"
varname = "^0.13"
# for Python3.12 compatibliity
pyzmq = "^25"

32 changes: 32 additions & 0 deletions tests/integration/test_project.py
Original file line number Diff line number Diff line change
@@ -552,3 +552,35 @@ def test_auto_remove(proj_path):
n1 = zntrack.examples.ParamsToOuts.from_rev(n1.name)
with pytest.raises(zntrack.exceptions.NodeNotAvailableError):
n2 = zntrack.examples.ParamsToOuts.from_rev(n2.name)


def test_magic_names(proj_path):
node = zntrack.examples.ParamsToOuts(params="Lorem Ipsum")
assert node.name == "ParamsToOuts"
with pytest.raises(ValueError):
project = zntrack.Project(magic_names=True, automatic_node_names=True)

project = zntrack.Project(magic_names=True, automatic_node_names=False)
with project:
node01 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum")
node02 = zntrack.examples.ParamsToOuts(params="Dolor Sit")
node03 = zntrack.examples.ParamsToOuts(params="Test01")
assert node01.name == "node01"
assert node02.name == "node02"
assert node03.name == "node03"

with project.group("Grp01"):
node01 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum")
node02 = zntrack.examples.ParamsToOuts(params="Dolor Sit")
grp_node03 = zntrack.examples.ParamsToOuts(params="Test02")

assert node01.name == "Grp01_node01"
assert node02.name == "Grp01_node02"
assert grp_node03.name == "Grp01_grp_node03"

project.run()

zntrack.from_rev(node01.name).outs == "Lorem Ipsum"
zntrack.from_rev(node02.name).outs == "Dolor Sit"
zntrack.from_rev(node03.name).outs == "Test01"
zntrack.from_rev(grp_node03.name).outs == "Test02"
5 changes: 5 additions & 0 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
import znflow
import zninit
import znjson
from varname import VarnameException, varname

from zntrack import exceptions
from zntrack.notebooks.jupyter import jupyter_class_to_file
@@ -161,8 +162,12 @@ def __set__(self, instance, value):
if isinstance(value, NodeName):
if not instance._external_:
value.update_suffix(instance._graph_.project, instance)
with contextlib.suppress(VarnameException):
value.varname = varname(frame=4)
instance._name_ = value
elif isinstance(getattr(instance, "_name_"), NodeName):
with contextlib.suppress(VarnameException):
instance._name_.varname = varname(frame=4)
instance._name_.name = value
instance._name_.suffix = 0
instance._name_.update_suffix(instance._graph_.project, instance)
11 changes: 11 additions & 0 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
@@ -82,6 +82,11 @@ class Project:
This will require a DVC remote to be setup.
force : bool, default = False
overwrite existing nodes.
magic_names : bool, default = False
If True, use magic names for the nodes. This will use the variable name of the
node as the node name. E.g. `node = Node()` will result in a node name of 'node'.
If used within a group, the group name will be added to the node name. E.g.
`group.name = Grp1` and `model = Node()` will result in a name of 'Grp1_model'.
"""

graph: ZnTrackGraph = dataclasses.field(default_factory=ZnTrackGraph, init=False)
@@ -90,6 +95,7 @@ class Project:
automatic_node_names: bool = True
git_only_repo: bool = True
force: bool = False
magic_names: bool = False

_groups: dict[str, NodeGroup] = dataclasses.field(
default_factory=dict, init=False, repr=False
@@ -116,6 +122,11 @@ def __post_init__(self):
config.files.params.unlink(missing_ok=True)
shutil.rmtree("nodes", ignore_errors=True)

if self.automatic_node_names and self.magic_names:
raise ValueError(
"automatic_node_names and magic_names can not be True at the same time"
)

def __enter__(self, *args, **kwargs):
"""Enter the graph context."""
self.graph.__enter__(*args, **kwargs)
12 changes: 10 additions & 2 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -227,28 +227,36 @@ class NodeName:

groups: list[str]
name: str
varname: str = None
suffix: int = 0
use_varname: bool = False

def __str__(self) -> str:
"""Get the node name."""
name = []
if self.groups is not None:
name.extend(self.groups)
name.append(self.name)
if self.use_varname:
name.append(self.varname)
else:
name.append(self.name)
if self.suffix > 0 and self.use_varname:
raise ValueError("Suffixes are not supported for magic names (varnames).")
if self.suffix > 0:
name.append(str(self.suffix))
return "_".join(name)

def get_name_without_groups(self) -> str:
"""Get the node name without the groups."""
name = self.name
name = self.varname if self.use_varname else self.name
if self.suffix > 0:
name += f"_{self.suffix}"
return name

def update_suffix(self, project: "Project", node: "Node") -> None:
"""Update the suffix."""
node_names = [x["value"].name for x in project.graph.nodes.values()]
self.use_varname = project.magic_names

node_names = []
for node_uuid in project.graph.nodes: