Skip to content

Commit

Permalink
Access *_path with node.use_tmp_paths (#749)
Browse files Browse the repository at this point in the history
* add use_tmp_paths

* bump version

* fix for str

* fix path/str

* test multiple files in one tmp_path

* support lists/nested for outs_path in tmp_path

* add further testing

* debug messages

* disable tmp_path automatically, if neither remote nor rev are used.

* rename `use_tmp_paths` to `use_tmp_path`
  • Loading branch information
PythonFZ authored Dec 19, 2023
1 parent 9def0fe commit 1a50fc6
Show file tree
Hide file tree
Showing 8 changed files with 346 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ZnTrack"
version = "0.7.1"
version = "0.7.2"
description = "Create, Run and Benchmark DVC Pipelines in Python"
authors = ["zincwarecode <zincwarecode@gmail.com>"]
license = "Apache-2.0"
Expand Down
180 changes: 179 additions & 1 deletion tests/integration/test_dvc_outs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib
import typing

import zntrack
import zntrack.examples
from zntrack import Node, dvc, nwd


Expand Down Expand Up @@ -102,3 +102,181 @@ def test_SingleNodeDefaultNWD(proj_path):
assert SingleNodeDefaultNWD.from_rev(name="SampleNode").path1 == pathlib.Path(
"nodes", "SampleNode", "test.json"
)


def test_use_tmp_path(proj_path):
with zntrack.Project(automatic_node_names=True) as proj:
node = zntrack.examples.WriteDVCOuts(params="test")
node2 = zntrack.examples.WriteDVCOutsPath(params="test2")

node3 = zntrack.examples.WriteDVCOuts(params="test", outs="result.txt")
node4 = zntrack.examples.WriteDVCOutsPath(
params="test2", outs=(zntrack.nwd / "data").as_posix()
)

proj.run()

node.get_outs_content() == "test"
node2.get_outs_content() == "test2"
node3.get_outs_content() == "test"
node4.get_outs_content() == "test2"

assert node.outs == pathlib.Path("nodes", "WriteDVCOuts", "output.txt")
assert node2.outs == pathlib.Path("nodes", "WriteDVCOutsPath", "data")
assert node3.outs == "result.txt"
assert isinstance(node4.outs, str)
assert node4.outs == pathlib.Path("nodes", "WriteDVCOutsPath_1", "data").as_posix()

with node.state.use_tmp_path():
assert node.outs == pathlib.Path("nodes", "WriteDVCOuts", "output.txt")
with node2.state.use_tmp_path():
assert node2.outs == pathlib.Path("nodes", "WriteDVCOutsPath", "data")
with node3.state.use_tmp_path():
assert node3.outs == "result.txt"
assert isinstance(node3.outs, str)
with node4.state.use_tmp_path():
assert (
node4.outs == pathlib.Path("nodes", "WriteDVCOutsPath_1", "data").as_posix()
)

# fake remote by passing the current directory
node = node.from_rev(node.name, remote=".")
node2 = node2.from_rev(node2.name, remote=".")
node3 = node3.from_rev(node3.name, remote=".")
node4 = node4.from_rev(node4.name, remote=".")

node.get_outs_content() == "test"
node2.get_outs_content() == "test2"
node3.get_outs_content() == "test"
node4.get_outs_content() == "test2"

assert node.outs == pathlib.Path("nodes", "WriteDVCOuts", "output.txt")
assert node2.outs == pathlib.Path("nodes", "WriteDVCOutsPath", "data")
assert node3.outs == "result.txt"
assert isinstance(node4.outs, str)
assert node4.outs == pathlib.Path("nodes", "WriteDVCOutsPath_1", "data").as_posix()

with node.state.use_tmp_path():
assert node.outs == node.state.tmp_path / "output.txt"
assert isinstance(node.outs, pathlib.PurePath)
with node2.state.use_tmp_path():
assert node2.outs == node2.state.tmp_path / "data"
assert isinstance(node2.outs, pathlib.PurePath)
with node3.state.use_tmp_path():
assert node3.outs == (node3.state.tmp_path / "result.txt").as_posix()
assert isinstance(node3.outs, str)
with node4.state.use_tmp_path():
assert node4.outs == (node4.state.tmp_path / "data").as_posix()
assert isinstance(node4.outs, str)


def test_use_tmp_path_multi(proj_path):
with zntrack.Project(automatic_node_names=True) as proj:
node = zntrack.examples.WriteMultipleDVCOuts(params=["Lorem", "Ipsum", "Dolor"])

proj.run()

assert node.get_outs_content() == ("Lorem", "Ipsum", "Dolor")

assert node.outs1 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "output.txt")
assert node.outs2 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "output2.txt")
assert node.outs3 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "data")

with node.state.use_tmp_path():
assert node.outs1 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "output.txt")
assert node.outs2 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "output2.txt")
assert node.outs3 == pathlib.Path("nodes", "WriteMultipleDVCOuts", "data")

assert pathlib.Path(node.outs1).read_text() == "Lorem"
assert pathlib.Path(node.outs2).read_text() == "Ipsum"
assert (pathlib.Path(node.outs3) / "file.txt").read_text() == "Dolor"

node = node.from_rev(remote=".") # fake remote by passing the current directory

with node.state.use_tmp_path():
assert node.outs1 == (node.state.tmp_path / "output.txt")
assert node.outs2 == (node.state.tmp_path / "output2.txt")
assert node.outs3 == (node.state.tmp_path / "data")

assert pathlib.Path(node.outs1).read_text() == "Lorem"
assert pathlib.Path(node.outs2).read_text() == "Ipsum"
assert (pathlib.Path(node.outs3) / "file.txt").read_text() == "Dolor"


def test_use_tmp_path_sequence(proj_path):
with zntrack.Project(automatic_node_names=True) as proj:
node = zntrack.examples.WriteDVCOutsSequence(
params=["Lorem", "Ipsum", "Dolor"],
outs=[zntrack.nwd / x for x in ["output.txt", "output2.txt", "output3.txt"]],
)

proj.run()

assert node.outs == [
pathlib.Path("nodes", "WriteDVCOutsSequence", "output.txt"),
pathlib.Path("nodes", "WriteDVCOutsSequence", "output2.txt"),
pathlib.Path("nodes", "WriteDVCOutsSequence", "output3.txt"),
]

for outs in node.outs:
assert pathlib.Path(outs).exists()

with node.state.use_tmp_path():
for outs in node.outs:
assert pathlib.Path(outs).exists()
assert pathlib.Path(outs).read_text() in ("Lorem", "Ipsum", "Dolor")
assert pathlib.Path(outs).parent == pathlib.Path(
"nodes", "WriteDVCOutsSequence"
)

assert node.get_outs_content() == ["Lorem", "Ipsum", "Dolor"]

node = node.from_rev(remote=".") # fake remote by passing the current directory

with node.state.use_tmp_path():
for outs in node.outs:
assert pathlib.Path(outs).exists()
assert pathlib.Path(outs).read_text() in ("Lorem", "Ipsum", "Dolor")
assert pathlib.Path(outs).parent == node.state.tmp_path

assert node.get_outs_content() == ["Lorem", "Ipsum", "Dolor"]


def test_use_tmp_path_exp(tmp_path_2):
with zntrack.Project(automatic_node_names=True) as proj:
node = zntrack.examples.WriteDVCOuts(params="test")

proj.run()

with proj.create_experiment() as exp1:
node.params = "test1"

with proj.create_experiment() as exp2:
node.params = "test2"

proj.run_exp()

exp1.load()
node1 = exp1["WriteDVCOuts"]
assert node1.get_outs_content() == "test1"

with node1.state.use_tmp_path():
assert node1.outs == node1.state.tmp_path / "output.txt"
assert isinstance(node1.outs, pathlib.PurePath)
assert pathlib.Path(node1.outs).read_text() == "test1"

exp2.load()
node2 = exp2["WriteDVCOuts"]
assert node2.get_outs_content() == "test2"

with node2.state.use_tmp_path():
assert node2.outs == node2.state.tmp_path / "output.txt"
assert isinstance(node2.outs, pathlib.PurePath)
assert pathlib.Path(node2.outs).read_text() == "test2"

assert node.get_outs_content() == "test"
assert node.outs == pathlib.Path("nodes", "WriteDVCOuts", "output.txt")

with node.state.use_tmp_path():
assert node.outs == pathlib.Path("nodes", "WriteDVCOuts", "output.txt")
assert pathlib.Path(node.outs).read_text() == "test"
2 changes: 1 addition & 1 deletion tests/test_zntrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

def test_version():
"""Test 'ZnTrack' version."""
assert __version__ == "0.7.1"
assert __version__ == "0.7.2"
45 changes: 45 additions & 0 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import pathlib
import tempfile
import time
import typing
import unittest.mock
Expand All @@ -24,6 +25,7 @@
from zntrack import exceptions
from zntrack.notebooks.jupyter import jupyter_class_to_file
from zntrack.utils import (
DISABLE_TMP_PATH,
NodeName,
NodeStatusResults,
config,
Expand Down Expand Up @@ -52,12 +54,19 @@ class NodeStatus:
a "remote" location, such as a git repository.
rev : str, default = None
The revision of the Node. This could be the current "HEAD" or a specific revision.
tmp_path : pathlib.Path, default = DISABLE_TMP_PATH|None
The temporary path used for loading the data.
This is only set within the context manager 'use_tmp_path'.
If neither 'remote' nor 'rev' are set, tmp_path will not be used.
"""

loaded: bool
results: "NodeStatusResults"
remote: str = None
rev: str = None
tmp_path: pathlib.Path = dataclasses.field(
default=DISABLE_TMP_PATH, init=False, repr=False
)

@functools.cached_property
def fs(self) -> dvc.api.DVCFileSystem:
Expand Down Expand Up @@ -104,6 +113,37 @@ def _listdir(path, *args, **kwargs):
# Jupyter Notebooks replace open with io.open
yield

@contextlib.contextmanager
def use_tmp_path(self, path: pathlib.Path = None) -> typing.ContextManager:
"""Load the data for '*_path' into a temporary directory.
If you can not use 'node.state.fs.open' you can use
this as an alternative. This will load the data into
a temporary directory and then delete it afterwards.
The respective paths 'node.*_path' will be replaced
automatically inside the context manager.
This is only set, if either 'remote' or 'rev' are set.
Otherwise, the data will be loaded from the current directory.
"""
if path is not None:
raise NotImplementedError("Custom paths are not implemented yet.")

if self.tmp_path is DISABLE_TMP_PATH:
yield
else:
with tempfile.TemporaryDirectory() as tmpdir:
self.tmp_path = pathlib.Path(tmpdir)
log.debug(f"Using temporary directory {self.tmp_path}")
try:
yield
finally:
files = list(self.tmp_path.glob("**/*"))
log.debug(
f"Deleting temporary directory {self.tmp_path} containing {files}"
)
self.tmp_path = None


class _NameDescriptor(zninit.Descriptor):
"""A descriptor for the name attribute."""
Expand Down Expand Up @@ -317,6 +357,11 @@ def from_rev(
with config.updated_config(**kwargs):
node.load(results=results)

if remote is not None or rev is not None:
# by default, tmp_path is disabled.
# if remote or rev is set, we enable it.
node.state.tmp_path = None

return node


Expand Down
72 changes: 71 additions & 1 deletion zntrack/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,77 @@ class WriteDVCOuts(zntrack.Node):

def run(self):
"""Write an output file."""
self.outs.write_text(str(self.params))
pathlib.Path(self.outs).write_text(str(self.params))

def get_outs_content(self):
"""Get the output file."""
with self.state.use_tmp_path():
return pathlib.Path(self.outs).read_text()


class WriteDVCOutsSequence(zntrack.Node):
"""Write an output file."""

params: list = zntrack.params()
outs: list | tuple | set | dict = zntrack.outs_path()

def run(self):
"""Write an output file."""
for value, path in zip(self.params, self.outs):
pathlib.Path(path).write_text(str(value))

def get_outs_content(self):
"""Get the output file."""
data = []
with self.state.use_tmp_path():
for path in self.outs:
data.append(pathlib.Path(path).read_text())
return data


class WriteDVCOutsPath(zntrack.Node):
"""Write an output file."""

params = zntrack.params()
outs = zntrack.outs_path(zntrack.nwd / "data")

def run(self):
"""Write an output file."""
pathlib.Path(self.outs).mkdir(parents=True, exist_ok=True)
(pathlib.Path(self.outs) / "file.txt").write_text(str(self.params))

def get_outs_content(self):
"""Get the output file."""
with self.state.use_tmp_path():
try:
return (pathlib.Path(self.outs) / "file.txt").read_text()
except FileNotFoundError:
files = list(pathlib.Path(self.outs).iterdir())
raise ValueError(f"Expected {self.outs } file, found {files}.")


class WriteMultipleDVCOuts(zntrack.Node):
"""Write an output file."""

params = zntrack.params()
outs1 = zntrack.outs_path(zntrack.nwd / "output.txt")
outs2 = zntrack.outs_path(zntrack.nwd / "output2.txt")
outs3 = zntrack.outs_path(zntrack.nwd / "data")

def run(self):
"""Write an output file."""
pathlib.Path(self.outs1).write_text(str(self.params[0]))
pathlib.Path(self.outs2).write_text(str(self.params[1]))
pathlib.Path(self.outs3).mkdir(parents=True, exist_ok=True)
(pathlib.Path(self.outs3) / "file.txt").write_text(str(self.params[2]))

def get_outs_content(self) -> t.Tuple[str, str, str]:
"""Get the output file."""
with self.state.use_tmp_path():
outs1_content = pathlib.Path(self.outs1).read_text()
outs2_content = pathlib.Path(self.outs2).read_text()
outs3_content = (pathlib.Path(self.outs3) / "file.txt").read_text()
return outs1_content, outs2_content, outs3_content


class ComputeRandomNumber(zntrack.Node):
Expand Down
Loading

0 comments on commit 1a50fc6

Please sign in to comment.