Skip to content

Commit

Permalink
first idea for a bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Feb 1, 2025
1 parent 3433220 commit 87be424
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 12 deletions.
3 changes: 2 additions & 1 deletion tests/files/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def test_deps(proj_path):
_ = NodeC(deps=a.results)
_ = NodeC(deps=a.metrics)
_ = NodeC(deps=a.plots)
_ = NodeC(deps=a.outs_path)
_ = NodeC(deps=a.outs_path)
# TODO: do we want to allow `x_path` as a `deps` or should it go into `deps_path`?
_ = NodeC(deps=a.metrics_paths)
_ = NodeC(deps=a.plots_path)

Expand Down
16 changes: 16 additions & 0 deletions tests/integration/test_outs_path_deps_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import zntrack.examples
from pathlib import Path

def test_outs_path_to_deps_path(proj_path):
with zntrack.Project() as proj:
a = zntrack.examples.WriteDVCOuts(params=10)
# assert a.outs == Path("nodes/WriteDVCOuts/output.txt") # uses znflow.resolve
b = zntrack.examples.ReadFile(path=a.outs)
# b = zntrack.examples.ReadFile(path=znflow.resolve(a.outs))
# b = zntrack.examples.ReadFile(path=Path("nodes/WriteDVCOuts/output.txt")) # works

proj.repro()

assert a.outs == Path("nodes/WriteDVCOuts/output.txt")
assert b.path == Path("nodes/WriteDVCOuts/output.txt")
assert b.content == "10"
37 changes: 27 additions & 10 deletions zntrack/fields/x_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,26 @@
from zntrack.utils.node_wd import NWDReplaceHandler


def _paths_getter(self: Node, name: str):
def _paths_getter_input(self: Node, name: str):
"""Resolve the paths for data the Node consumes."""
if name in self.__dict__ and self.__dict__[name] is not ZNTRACK_LAZY_VALUE:
return self.__dict__[name]
try:
with self.state.fs.open(ZNTRACK_FILE_PATH) as f:
content = json.load(f)[self.name][name]
content = znjson.loads(json.dumps(content))

if self.state.tmp_path is not None:
loader = TempPathLoader()
loader(content, instance=self)

return content
except FileNotFoundError:
return NOT_AVAILABLE


def _paths_getter_output(self: Node, name: str):
"""Resolve the paths for data the Node produces."""
# TODO: if self._external_: try looking into
# external/self.uuid/...
# this works for everything except node-meta.json because that
Expand Down Expand Up @@ -59,15 +78,14 @@ def outs_path(
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.OUTS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_INDEPENDENT_OUTPUT_TYPE] = independent
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_output
return znfields.field(default=default, getter=plugin_getter, **kwargs)


def params_path(default=dataclasses.MISSING, *, cache: bool = True, **kwargs):
def params_path(default=dataclasses.MISSING, **kwargs):
kwargs["metadata"] = kwargs.get("metadata", {})
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.PARAMS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_input
return znfields.field(default=default, getter=plugin_getter, **kwargs)


Expand All @@ -82,7 +100,7 @@ def plots_path(
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.PLOTS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_INDEPENDENT_OUTPUT_TYPE] = independent
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_output
return znfields.field(default=default, getter=plugin_getter, **kwargs)


Expand All @@ -99,13 +117,12 @@ def metrics_path(
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.METRICS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_INDEPENDENT_OUTPUT_TYPE] = independent
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_output
return znfields.field(default=default, getter=plugin_getter, **kwargs)


def deps_path(default=dataclasses.MISSING, *, cache: bool = True, **kwargs):
def deps_path(default=dataclasses.MISSING, **kwargs):
kwargs["metadata"] = kwargs.get("metadata", {})
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.DEPS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_input
return znfields.field(default=default, getter=plugin_getter, **kwargs)
19 changes: 18 additions & 1 deletion zntrack/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from zntrack.state import NodeStatus
from zntrack.utils.misc import get_plugins_from_env

from .config import NOT_AVAILABLE, ZNTRACK_LAZY_VALUE, NodeStatusEnum
from .config import (
NOT_AVAILABLE,
ZNTRACK_LAZY_VALUE,
NodeStatusEnum,
ZNTRACK_OPTION,
ZnTrackOptionEnum,
)

try:
from typing import dataclass_transform
Expand Down Expand Up @@ -78,6 +84,17 @@ def __post_init__(self):
log.warning(
"Node name should not contain '_'. This character is used for defining groups."
)
for field in dataclasses.fields(self):
# X_Path should be resolved instead of passing
# a connection. They are known at runtime.
if field.metadata.get(ZNTRACK_OPTION, None) in [
ZnTrackOptionEnum.PARAMS_PATH,
ZnTrackOptionEnum.DEPS_PATH,
ZnTrackOptionEnum.OUTS_PATH,
ZnTrackOptionEnum.PLOTS_PATH,
ZnTrackOptionEnum.METRICS_PATH,
]:
self._protected_.append(field.name)

def _post_load_(self):
"""Called after `from_rev` is called."""
Expand Down

0 comments on commit 87be424

Please sign in to comment.