Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix outs_path in deps_path #875

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/files/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_deps(proj_path):
_ = NodeC(deps=a.metrics)
_ = NodeC(deps=a.plots)
_ = NodeC(deps=a.outs_path)
# TODO: do we want to allow `x_path` as a `deps` or should it go into `deps_path`?
_ = NodeC(deps=a.metrics_paths)
_ = NodeC(deps=a.plots_path)

Expand Down
18 changes: 18 additions & 0 deletions tests/integration/test_outs_path_deps_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pathlib import Path

import zntrack.examples


def test_outs_path_to_deps_path(proj_path):
with zntrack.Project() as proj:
a = zntrack.examples.WriteDVCOuts(params=10)
# assert a.outs == Path("nodes/WriteDVCOuts/output.txt") # uses znflow.resolve
b = zntrack.examples.ReadFile(path=a.outs)
# b = zntrack.examples.ReadFile(path=znflow.resolve(a.outs))
# b = zntrack.examples.ReadFile(path=Path("nodes/WriteDVCOuts/output.txt")) # works

proj.repro()

assert a.outs == Path("nodes/WriteDVCOuts/output.txt")
assert b.path == Path("nodes/WriteDVCOuts/output.txt")
assert b.content == "10"
37 changes: 27 additions & 10 deletions zntrack/fields/x_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,26 @@
from zntrack.utils.node_wd import NWDReplaceHandler


def _paths_getter(self: Node, name: str):
def _paths_getter_input(self: Node, name: str):
"""Resolve the paths for data the Node consumes."""
if name in self.__dict__ and self.__dict__[name] is not ZNTRACK_LAZY_VALUE:
return self.__dict__[name]
try:
with self.state.fs.open(ZNTRACK_FILE_PATH) as f:
content = json.load(f)[self.name][name]
content = znjson.loads(json.dumps(content))

if self.state.tmp_path is not None:
loader = TempPathLoader()
loader(content, instance=self)

return content
except FileNotFoundError:
return NOT_AVAILABLE


def _paths_getter_output(self: Node, name: str):
"""Resolve the paths for data the Node produces."""
# TODO: if self._external_: try looking into
# external/self.uuid/...
# this works for everything except node-meta.json because that
Expand Down Expand Up @@ -59,15 +78,14 @@ def outs_path(
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.OUTS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_INDEPENDENT_OUTPUT_TYPE] = independent
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_output
return znfields.field(default=default, getter=plugin_getter, **kwargs)


def params_path(default=dataclasses.MISSING, *, cache: bool = True, **kwargs):
def params_path(default=dataclasses.MISSING, **kwargs):
kwargs["metadata"] = kwargs.get("metadata", {})
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.PARAMS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_input
return znfields.field(default=default, getter=plugin_getter, **kwargs)


Expand All @@ -82,7 +100,7 @@ def plots_path(
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.PLOTS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_INDEPENDENT_OUTPUT_TYPE] = independent
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_output
return znfields.field(default=default, getter=plugin_getter, **kwargs)


Expand All @@ -99,13 +117,12 @@ def metrics_path(
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.METRICS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_INDEPENDENT_OUTPUT_TYPE] = independent
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_output
return znfields.field(default=default, getter=plugin_getter, **kwargs)


def deps_path(default=dataclasses.MISSING, *, cache: bool = True, **kwargs):
def deps_path(default=dataclasses.MISSING, **kwargs):
kwargs["metadata"] = kwargs.get("metadata", {})
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.DEPS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter
kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_input
return znfields.field(default=default, getter=plugin_getter, **kwargs)
19 changes: 18 additions & 1 deletion zntrack/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from zntrack.state import NodeStatus
from zntrack.utils.misc import get_plugins_from_env

from .config import NOT_AVAILABLE, ZNTRACK_LAZY_VALUE, NodeStatusEnum
from .config import (
NOT_AVAILABLE,
ZNTRACK_LAZY_VALUE,
ZNTRACK_OPTION,
NodeStatusEnum,
ZnTrackOptionEnum,
)

try:
from typing import dataclass_transform
Expand Down Expand Up @@ -78,6 +84,17 @@ def __post_init__(self):
log.warning(
"Node name should not contain '_'. This character is used for defining groups."
)
for field in dataclasses.fields(self):
# X_Path should be resolved instead of passing
# a connection. They are known at runtime.
if field.metadata.get(ZNTRACK_OPTION, None) in [
ZnTrackOptionEnum.PARAMS_PATH,
ZnTrackOptionEnum.DEPS_PATH,
ZnTrackOptionEnum.OUTS_PATH,
ZnTrackOptionEnum.PLOTS_PATH,
ZnTrackOptionEnum.METRICS_PATH,
]:
self._protected_.append(field.name)

def _post_load_(self):
"""Called after `from_rev` is called."""
Expand Down
Loading