diff --git a/pyproject.toml b/pyproject.toml index f50cbf0a..6bda61c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zntrack" -version = "0.8.1" +version = "0.8.2" description = "Create, Run and Benchmark DVC Pipelines in Python" authors = ["zincwarecode "] license = "Apache-2.0" diff --git a/tests/files/dvc_config/user_config.yaml b/tests/files/dvc_config/user_config.yaml new file mode 100644 index 00000000..76fcc45f --- /dev/null +++ b/tests/files/dvc_config/user_config.yaml @@ -0,0 +1,8 @@ +stages: + MyNode: + cmd: zntrack run test_user_config.MyNode --name MyNode + metrics: + - nodes/MyNode/metric.json + - nodes/MyNode/node-meta.json: + cache: true + - nodes/MyNode/results.json diff --git a/tests/files/params_config/user_config.yaml b/tests/files/params_config/user_config.yaml new file mode 100644 index 00000000..0967ef42 --- /dev/null +++ b/tests/files/params_config/user_config.yaml @@ -0,0 +1 @@ +{} diff --git a/tests/files/test_user_config.py b/tests/files/test_user_config.py new file mode 100644 index 00000000..6b32f1c1 --- /dev/null +++ b/tests/files/test_user_config.py @@ -0,0 +1,49 @@ +"""Test params such as always_changed, ...""" + +import json +import pathlib + +import yaml + +import zntrack + +CWD = pathlib.Path(__file__).parent.resolve() + + +def test_node(proj_path): + assert zntrack.config.ALWAYS_CACHE is False + zntrack.config.ALWAYS_CACHE = True + assert zntrack.config.ALWAYS_CACHE is True + + # We define the node here, because the config has to be set + # bevore calling zntrack.metrics() + class MyNode(zntrack.Node): + """Some Node.""" + + metric: dict = zntrack.metrics() + metrics_path: pathlib.Path = zntrack.metrics_path(zntrack.nwd / "results.json") + + def run(self) -> None: + self.metric = {"a": 1, "b": 2} + + with zntrack.Project() as proj: + node = MyNode() + + proj.build() + + zntrack.config.ALWAYS_CACHE = False # reset to default value + assert zntrack.config.ALWAYS_CACHE is False + + assert json.loads( + (CWD / "zntrack_config" / "user_config.json").read_text() + ) == json.loads((proj_path / "zntrack.json").read_text()) + assert yaml.safe_load( + (CWD / "dvc_config" / "user_config.yaml").read_text() + ) == yaml.safe_load((proj_path / "dvc.yaml").read_text()) + assert (CWD / "params_config" / "user_config.yaml").read_text() == ( + proj_path / "params.yaml" + ).read_text() + + +if __name__ == "__main__": + test_node("") diff --git a/tests/files/zntrack_config/user_config.json b/tests/files/zntrack_config/user_config.json new file mode 100644 index 00000000..f8b52f26 --- /dev/null +++ b/tests/files/zntrack_config/user_config.json @@ -0,0 +1,12 @@ +{ + "MyNode": { + "nwd": { + "_type": "pathlib.Path", + "value": "nodes/MyNode" + }, + "metrics_path": { + "_type": "pathlib.Path", + "value": "$nwd$/results.json" + } + } +} diff --git a/zntrack/__init__.py b/zntrack/__init__.py index e9bbd1b7..21fc9c22 100644 --- a/zntrack/__init__.py +++ b/zntrack/__init__.py @@ -2,6 +2,7 @@ import logging import sys +from zntrack import config from zntrack.add import add from zntrack.apply import apply from zntrack.fields import ( @@ -40,6 +41,7 @@ "apply", "add", "field", + "config", ] logger = logging.getLogger(__name__) diff --git a/zntrack/config.py b/zntrack/config.py index 9b678d46..2e7476c6 100644 --- a/zntrack/config.py +++ b/zntrack/config.py @@ -9,6 +9,13 @@ EXP_INFO_PATH = pathlib.Path(".exp_info.yaml") +# For "node-meta.json" and "dvc stage add ... --metrics-no-cache" the default is using +# git tracked files. Setting this to True will override the default behavior to always +# use the DVC cache. If you have a DVC cache setup, this might be desirable, to avoid +# a mixture between DVC cache and git tracked files. +ALWAYS_CACHE: bool = False + + # Use sentinel object for zntrack specific configurations. Use # a class to give it a better repr. class _ZNTRACK_OPTION_TYPE: diff --git a/zntrack/fields/outs_and_metrics.py b/zntrack/fields/outs_and_metrics.py index df4b4f49..f7158dae 100644 --- a/zntrack/fields/outs_and_metrics.py +++ b/zntrack/fields/outs_and_metrics.py @@ -2,6 +2,7 @@ import znjson +from zntrack import config from zntrack.config import NOT_AVAILABLE, ZnTrackOptionEnum from zntrack.fields.base import field from zntrack.node import Node @@ -45,7 +46,9 @@ def outs(*, cache: bool = True, independent: bool = False, **kwargs): ) -def metrics(*, cache: bool = False, independent: bool = False, **kwargs): +def metrics(*, cache: bool | None = None, independent: bool = False, **kwargs): + if cache is None: + cache = config.ALWAYS_CACHE return field( default=NOT_AVAILABLE, cache=cache, diff --git a/zntrack/fields/x_path.py b/zntrack/fields/x_path.py index e00250ed..97a73cc0 100644 --- a/zntrack/fields/x_path.py +++ b/zntrack/fields/x_path.py @@ -4,6 +4,7 @@ import znfields import znjson +from zntrack import config from zntrack.config import ( NOT_AVAILABLE, ZNTRACK_CACHE, @@ -88,10 +89,12 @@ def plots_path( def metrics_path( default=dataclasses.MISSING, *, - cache: bool = False, + cache: bool | None = None, independent: bool = False, **kwargs, ): + if cache is None: + cache = config.ALWAYS_CACHE kwargs["metadata"] = kwargs.get("metadata", {}) kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.METRICS_PATH kwargs["metadata"][ZNTRACK_CACHE] = cache diff --git a/zntrack/plugins/dvc_plugin/__init__.py b/zntrack/plugins/dvc_plugin/__init__.py index c3e1a74d..ab72673b 100644 --- a/zntrack/plugins/dvc_plugin/__init__.py +++ b/zntrack/plugins/dvc_plugin/__init__.py @@ -12,7 +12,7 @@ import znflow.utils import znjson -from zntrack import converter +from zntrack import config, converter from zntrack.config import ( NOT_AVAILABLE, PARAMS_FILE_PATH, @@ -129,7 +129,11 @@ def convert_to_dvc_yaml(self) -> dict | object: stages = { "cmd": cmd, "metrics": [ - {(self.node.nwd / "node-meta.json").as_posix(): {"cache": False}} + { + (self.node.nwd / "node-meta.json").as_posix(): { + "cache": config.ALWAYS_CACHE + } + } ], } if self.node.always_changed: