diff --git a/pyproject.toml b/pyproject.toml index a1a41ef9..4e4110af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ZnTrack" -version = "0.7.1" +version = "0.7.2" description = "Create, Run and Benchmark DVC Pipelines in Python" authors = ["zincwarecode "] license = "Apache-2.0" diff --git a/tests/integration/test_dvc_outs.py b/tests/integration/test_dvc_outs.py index b4fbf441..c4d885a6 100644 --- a/tests/integration/test_dvc_outs.py +++ b/tests/integration/test_dvc_outs.py @@ -1,7 +1,7 @@ import pathlib import typing -import zntrack +import zntrack.examples from zntrack import Node, dvc, nwd @@ -102,3 +102,181 @@ def test_SingleNodeDefaultNWD(proj_path): assert SingleNodeDefaultNWD.from_rev(name="SampleNode").path1 == pathlib.Path( "nodes", "SampleNode", "test.json" ) + + +def test_use_tmp_path(proj_path): + with zntrack.Project(automatic_node_names=True) as proj: + node = zntrack.examples.WriteDVCOuts(params="test") + node2 = zntrack.examples.WriteDVCOutsPath(params="test2") + + node3 = zntrack.examples.WriteDVCOuts(params="test", outs="result.txt") + node4 = zntrack.examples.WriteDVCOutsPath( + params="test2", outs=(zntrack.nwd / "data").as_posix() + ) + + proj.run() + + node.get_outs_content() == "test" + node2.get_outs_content() == "test2" + node3.get_outs_content() == "test" + node4.get_outs_content() == "test2" + + assert node.outs == pathlib.Path("nodes", "WriteDVCOuts", "output.txt") + assert node2.outs == pathlib.Path("nodes", "WriteDVCOutsPath", "data") + assert node3.outs == "result.txt" + assert isinstance(node4.outs, str) + assert node4.outs == pathlib.Path("nodes", "WriteDVCOutsPath_1", "data").as_posix() + + with node.state.use_tmp_path(): + assert node.outs == pathlib.Path("nodes", "WriteDVCOuts", "output.txt") + with node2.state.use_tmp_path(): + assert node2.outs == pathlib.Path("nodes", "WriteDVCOutsPath", "data") + with node3.state.use_tmp_path(): + assert node3.outs == "result.txt" + assert isinstance(node3.outs, str) + with node4.state.use_tmp_path(): + assert ( + node4.outs == pathlib.Path("nodes", "WriteDVCOutsPath_1", "data").as_posix() + ) + + # fake remote by passing the current directory + node = node.from_rev(node.name, remote=".") + node2 = node2.from_rev(node2.name, remote=".") + node3 = node3.from_rev(node3.name, remote=".") + node4 = node4.from_rev(node4.name, remote=".") + + node.get_outs_content() == "test" + node2.get_outs_content() == "test2" + node3.get_outs_content() == "test" + node4.get_outs_content() == "test2" + + assert node.outs == pathlib.Path("nodes", "WriteDVCOuts", "output.txt") + assert node2.outs == pathlib.Path("nodes", "WriteDVCOutsPath", "data") + assert node3.outs == "result.txt" + assert isinstance(node4.outs, str) + assert node4.outs == pathlib.Path("nodes", "WriteDVCOutsPath_1", "data").as_posix() + + with node.state.use_tmp_path(): + assert node.outs == node.state.tmp_path / "output.txt" + assert isinstance(node.outs, pathlib.PurePath) + with node2.state.use_tmp_path(): + assert node2.outs == node2.state.tmp_path / "data" + assert isinstance(node2.outs, pathlib.PurePath) + with node3.state.use_tmp_path(): + assert node3.outs == (node3.state.tmp_path / "result.txt").as_posix() + assert isinstance(node3.outs, str) + with node4.state.use_tmp_path(): + assert node4.outs == (node4.state.tmp_path / "data").as_posix() + assert isinstance(node4.outs, str) + + +def test_use_tmp_path_multi(proj_path): + with zntrack.Project(automatic_node_names=True) as proj: + node = zntrack.examples.WriteMultipleDVCOuts(params=["Lorem", "Ipsum", "Dolor"]) + + proj.run() + + assert node.get_outs_content() == ("Lorem", "Ipsum", "Dolor") + + assert node.outs1 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "output.txt") + assert node.outs2 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "output2.txt") + assert node.outs3 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "data") + + with node.state.use_tmp_path(): + assert node.outs1 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "output.txt") + assert node.outs2 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "output2.txt") + assert node.outs3 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "data") + + assert pathlib.Path(node.outs1).read_text() == "Lorem" + assert pathlib.Path(node.outs2).read_text() == "Ipsum" + assert (pathlib.Path(node.outs3) / "file.txt").read_text() == "Dolor" + + node = node.from_rev(remote=".") # fake remote by passing the current directory + + with node.state.use_tmp_path(): + assert node.outs1 == (node.state.tmp_path / "output.txt") + assert node.outs2 == (node.state.tmp_path / "output2.txt") + assert node.outs3 == (node.state.tmp_path / "data") + + assert pathlib.Path(node.outs1).read_text() == "Lorem" + assert pathlib.Path(node.outs2).read_text() == "Ipsum" + assert (pathlib.Path(node.outs3) / "file.txt").read_text() == "Dolor" + + +def test_use_tmp_path_sequence(proj_path): + with zntrack.Project(automatic_node_names=True) as proj: + node = zntrack.examples.WriteDVCOutsSequence( + params=["Lorem", "Ipsum", "Dolor"], + outs=[zntrack.nwd / x for x in ["output.txt", "output2.txt", "output3.txt"]], + ) + + proj.run() + + assert node.outs == [ + pathlib.Path("nodes", "WriteDVCOutsSequence", "output.txt"), + pathlib.Path("nodes", "WriteDVCOutsSequence", "output2.txt"), + pathlib.Path("nodes", "WriteDVCOutsSequence", "output3.txt"), + ] + + for outs in node.outs: + assert pathlib.Path(outs).exists() + + with node.state.use_tmp_path(): + for outs in node.outs: + assert pathlib.Path(outs).exists() + assert pathlib.Path(outs).read_text() in ("Lorem", "Ipsum", "Dolor") + assert pathlib.Path(outs).parent == pathlib.Path( + "nodes", "WriteDVCOutsSequence" + ) + + assert node.get_outs_content() == ["Lorem", "Ipsum", "Dolor"] + + node = node.from_rev(remote=".") # fake remote by passing the current directory + + with node.state.use_tmp_path(): + for outs in node.outs: + assert pathlib.Path(outs).exists() + assert pathlib.Path(outs).read_text() in ("Lorem", "Ipsum", "Dolor") + assert pathlib.Path(outs).parent == node.state.tmp_path + + assert node.get_outs_content() == ["Lorem", "Ipsum", "Dolor"] + + +def test_use_tmp_path_exp(tmp_path_2): + with zntrack.Project(automatic_node_names=True) as proj: + node = zntrack.examples.WriteDVCOuts(params="test") + + proj.run() + + with proj.create_experiment() as exp1: + node.params = "test1" + + with proj.create_experiment() as exp2: + node.params = "test2" + + proj.run_exp() + + exp1.load() + node1 = exp1["WriteDVCOuts"] + assert node1.get_outs_content() == "test1" + + with node1.state.use_tmp_path(): + assert node1.outs == node1.state.tmp_path / "output.txt" + assert isinstance(node1.outs, pathlib.PurePath) + assert pathlib.Path(node1.outs).read_text() == "test1" + + exp2.load() + node2 = exp2["WriteDVCOuts"] + assert node2.get_outs_content() == "test2" + + with node2.state.use_tmp_path(): + assert node2.outs == node2.state.tmp_path / "output.txt" + assert isinstance(node2.outs, pathlib.PurePath) + assert pathlib.Path(node2.outs).read_text() == "test2" + + assert node.get_outs_content() == "test" + assert node.outs == pathlib.Path("nodes", "WriteDVCOuts", "output.txt") + + with node.state.use_tmp_path(): + assert node.outs == pathlib.Path("nodes", "WriteDVCOuts", "output.txt") + assert pathlib.Path(node.outs).read_text() == "test" diff --git a/tests/test_zntrack.py b/tests/test_zntrack.py index 340620f3..9aff1249 100644 --- a/tests/test_zntrack.py +++ b/tests/test_zntrack.py @@ -5,4 +5,4 @@ def test_version(): """Test 'ZnTrack' version.""" - assert __version__ == "0.7.1" + assert __version__ == "0.7.2" diff --git a/zntrack/core/node.py b/zntrack/core/node.py index cabef41f..93f5a782 100644 --- a/zntrack/core/node.py +++ b/zntrack/core/node.py @@ -9,6 +9,7 @@ import logging import os import pathlib +import tempfile import time import typing import unittest.mock @@ -24,6 +25,7 @@ from zntrack import exceptions from zntrack.notebooks.jupyter import jupyter_class_to_file from zntrack.utils import ( + DISABLE_TMP_PATH, NodeName, NodeStatusResults, config, @@ -52,12 +54,19 @@ class NodeStatus: a "remote" location, such as a git repository. rev : str, default = None The revision of the Node. This could be the current "HEAD" or a specific revision. + tmp_path : pathlib.Path, default = DISABLE_TMP_PATH|None + The temporary path used for loading the data. + This is only set within the context manager 'use_tmp_path'. + If neither 'remote' nor 'rev' are set, tmp_path will not be used. """ loaded: bool results: "NodeStatusResults" remote: str = None rev: str = None + tmp_path: pathlib.Path = dataclasses.field( + default=DISABLE_TMP_PATH, init=False, repr=False + ) @functools.cached_property def fs(self) -> dvc.api.DVCFileSystem: @@ -104,6 +113,37 @@ def _listdir(path, *args, **kwargs): # Jupyter Notebooks replace open with io.open yield + @contextlib.contextmanager + def use_tmp_path(self, path: pathlib.Path = None) -> typing.ContextManager: + """Load the data for '*_path' into a temporary directory. + + If you can not use 'node.state.fs.open' you can use + this as an alternative. This will load the data into + a temporary directory and then delete it afterwards. + The respective paths 'node.*_path' will be replaced + automatically inside the context manager. + + This is only set, if either 'remote' or 'rev' are set. + Otherwise, the data will be loaded from the current directory. + """ + if path is not None: + raise NotImplementedError("Custom paths are not implemented yet.") + + if self.tmp_path is DISABLE_TMP_PATH: + yield + else: + with tempfile.TemporaryDirectory() as tmpdir: + self.tmp_path = pathlib.Path(tmpdir) + log.debug(f"Using temporary directory {self.tmp_path}") + try: + yield + finally: + files = list(self.tmp_path.glob("**/*")) + log.debug( + f"Deleting temporary directory {self.tmp_path} containing {files}" + ) + self.tmp_path = None + class _NameDescriptor(zninit.Descriptor): """A descriptor for the name attribute.""" @@ -317,6 +357,11 @@ def from_rev( with config.updated_config(**kwargs): node.load(results=results) + if remote is not None or rev is not None: + # by default, tmp_path is disabled. + # if remote or rev is set, we enable it. + node.state.tmp_path = None + return node diff --git a/zntrack/examples/__init__.py b/zntrack/examples/__init__.py index 38fe359a..5c26562f 100644 --- a/zntrack/examples/__init__.py +++ b/zntrack/examples/__init__.py @@ -137,7 +137,77 @@ class WriteDVCOuts(zntrack.Node): def run(self): """Write an output file.""" - self.outs.write_text(str(self.params)) + pathlib.Path(self.outs).write_text(str(self.params)) + + def get_outs_content(self): + """Get the output file.""" + with self.state.use_tmp_path(): + return pathlib.Path(self.outs).read_text() + + +class WriteDVCOutsSequence(zntrack.Node): + """Write an output file.""" + + params: list = zntrack.params() + outs: list | tuple | set | dict = zntrack.outs_path() + + def run(self): + """Write an output file.""" + for value, path in zip(self.params, self.outs): + pathlib.Path(path).write_text(str(value)) + + def get_outs_content(self): + """Get the output file.""" + data = [] + with self.state.use_tmp_path(): + for path in self.outs: + data.append(pathlib.Path(path).read_text()) + return data + + +class WriteDVCOutsPath(zntrack.Node): + """Write an output file.""" + + params = zntrack.params() + outs = zntrack.outs_path(zntrack.nwd / "data") + + def run(self): + """Write an output file.""" + pathlib.Path(self.outs).mkdir(parents=True, exist_ok=True) + (pathlib.Path(self.outs) / "file.txt").write_text(str(self.params)) + + def get_outs_content(self): + """Get the output file.""" + with self.state.use_tmp_path(): + try: + return (pathlib.Path(self.outs) / "file.txt").read_text() + except FileNotFoundError: + files = list(pathlib.Path(self.outs).iterdir()) + raise ValueError(f"Expected {self.outs } file, found {files}.") + + +class WriteMultipleDVCOuts(zntrack.Node): + """Write an output file.""" + + params = zntrack.params() + outs1 = zntrack.outs_path(zntrack.nwd / "output.txt") + outs2 = zntrack.outs_path(zntrack.nwd / "output2.txt") + outs3 = zntrack.outs_path(zntrack.nwd / "data") + + def run(self): + """Write an output file.""" + pathlib.Path(self.outs1).write_text(str(self.params[0])) + pathlib.Path(self.outs2).write_text(str(self.params[1])) + pathlib.Path(self.outs3).mkdir(parents=True, exist_ok=True) + (pathlib.Path(self.outs3) / "file.txt").write_text(str(self.params[2])) + + def get_outs_content(self) -> t.Tuple[str, str, str]: + """Get the output file.""" + with self.state.use_tmp_path(): + outs1_content = pathlib.Path(self.outs1).read_text() + outs2_content = pathlib.Path(self.outs2).read_text() + outs3_content = (pathlib.Path(self.outs3) / "file.txt").read_text() + return outs1_content, outs2_content, outs3_content class ComputeRandomNumber(zntrack.Node): diff --git a/zntrack/fields/dvc/options.py b/zntrack/fields/dvc/options.py index 8874a59c..3cb93f47 100644 --- a/zntrack/fields/dvc/options.py +++ b/zntrack/fields/dvc/options.py @@ -4,15 +4,39 @@ import pathlib import typing +import znflow.utils import znjson from zntrack.fields.field import Field, FieldGroup, PlotsMixin -from zntrack.utils import get_nwd, node_wd +from zntrack.utils import DISABLE_TMP_PATH, get_nwd, node_wd if typing.TYPE_CHECKING: from zntrack import Node +class _LoadIntoTmpPath(znflow.utils.IterableHandler): + def default(self, value, **kwargs): + instance = kwargs["instance"] + path = value + + if instance.state.fs.isdir(pathlib.Path(path).as_posix()): + instance.state.fs.get( + pathlib.Path(path).as_posix(), + instance.state.tmp_path.as_posix(), + recursive=True, + ) + _path = instance.state.tmp_path / pathlib.Path(path).name + else: + temp_file = instance.state.tmp_path / pathlib.Path(path).name + instance.state.fs.get(pathlib.Path(path).as_posix(), temp_file.as_posix()) + _path = temp_file + + if isinstance(path, pathlib.PurePath): + return _path + else: + return _path.as_posix() + + class DVCOption(Field): """A field that is used as a dvc option. @@ -134,7 +158,12 @@ def __get__(self, instance: "Node", owner=None): if instance is None: return self value = super().__get__(instance, owner) - return node_wd.ReplaceNWD()(value, nwd=get_nwd(instance)) + path = node_wd.ReplaceNWD()(value, nwd=get_nwd(instance)) + if instance.state.tmp_path not in [None, DISABLE_TMP_PATH]: + loader = _LoadIntoTmpPath() + return loader(path, instance=instance) + else: + return path class PlotsOption(PlotsMixin, DVCOption): diff --git a/zntrack/utils/__init__.py b/zntrack/utils/__init__.py index 7c98ce3a..62b77886 100644 --- a/zntrack/utils/__init__.py +++ b/zntrack/utils/__init__.py @@ -16,12 +16,13 @@ import znjson from zntrack.utils import cli -from zntrack.utils.config import config +from zntrack.utils.config import DISABLE_TMP_PATH, config __all__ = [ "cli", "node_wd", "config", + "DISABLE_TMP_PATH", ] if t.TYPE_CHECKING: diff --git a/zntrack/utils/config.py b/zntrack/utils/config.py index 40458bf8..4ceae6cd 100644 --- a/zntrack/utils/config.py +++ b/zntrack/utils/config.py @@ -99,3 +99,19 @@ def updated_config(self, **kwargs) -> None: config = Config() + + +class DISABLE_TMP_PATH: + """Identifier for disabling loading data into a temporary directory.""" + + def __init__(self) -> None: + """Prohibit instantiation.""" + raise NotImplementedError("This class can not be instantiated.") + + def __repr__(self) -> str: + """Provide better representation.""" + return "DISABLE_TMP_PATH" + + def __str__(self) -> str: + """Provide better representation.""" + return "DISABLE_TMP_PATH"