From 44e484c4bbb9365b5a6ea24b2992cde410090446 Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Fri, 24 Jan 2025 10:41:14 +0100 Subject: [PATCH] add `zntrack.config` to update default values (#870) * add `zntrack.config` to update default values * update version * update tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add metrics * fix test * update tests, use `zntrack.config` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 2 +- tests/files/dvc_config/user_config.yaml | 8 ++++ tests/files/params_config/user_config.yaml | 1 + tests/files/test_user_config.py | 49 +++++++++++++++++++++ tests/files/zntrack_config/user_config.json | 12 +++++ zntrack/__init__.py | 2 + zntrack/config.py | 7 +++ zntrack/fields/outs_and_metrics.py | 5 ++- zntrack/fields/x_path.py | 5 ++- zntrack/plugins/dvc_plugin/__init__.py | 8 +++- 10 files changed, 94 insertions(+), 5 deletions(-) create mode 100644 tests/files/dvc_config/user_config.yaml create mode 100644 tests/files/params_config/user_config.yaml create mode 100644 tests/files/test_user_config.py create mode 100644 tests/files/zntrack_config/user_config.json diff --git a/pyproject.toml b/pyproject.toml index f50cbf0a..6bda61c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zntrack" -version = "0.8.1" +version = "0.8.2" description = "Create, Run and Benchmark DVC Pipelines in Python" authors = ["zincwarecode "] license = "Apache-2.0" diff --git a/tests/files/dvc_config/user_config.yaml b/tests/files/dvc_config/user_config.yaml new file mode 100644 index 00000000..76fcc45f --- /dev/null +++ b/tests/files/dvc_config/user_config.yaml @@ -0,0 +1,8 @@ +stages: + MyNode: + cmd: zntrack run test_user_config.MyNode --name MyNode + metrics: + - nodes/MyNode/metric.json + - nodes/MyNode/node-meta.json: + cache: true + - nodes/MyNode/results.json diff --git a/tests/files/params_config/user_config.yaml b/tests/files/params_config/user_config.yaml new file mode 100644 index 00000000..0967ef42 --- /dev/null +++ b/tests/files/params_config/user_config.yaml @@ -0,0 +1 @@ +{} diff --git a/tests/files/test_user_config.py b/tests/files/test_user_config.py new file mode 100644 index 00000000..6b32f1c1 --- /dev/null +++ b/tests/files/test_user_config.py @@ -0,0 +1,49 @@ +"""Test params such as always_changed, ...""" + +import json +import pathlib + +import yaml + +import zntrack + +CWD = pathlib.Path(__file__).parent.resolve() + + +def test_node(proj_path): + assert zntrack.config.ALWAYS_CACHE is False + zntrack.config.ALWAYS_CACHE = True + assert zntrack.config.ALWAYS_CACHE is True + + # We define the node here, because the config has to be set + # bevore calling zntrack.metrics() + class MyNode(zntrack.Node): + """Some Node.""" + + metric: dict = zntrack.metrics() + metrics_path: pathlib.Path = zntrack.metrics_path(zntrack.nwd / "results.json") + + def run(self) -> None: + self.metric = {"a": 1, "b": 2} + + with zntrack.Project() as proj: + node = MyNode() + + proj.build() + + zntrack.config.ALWAYS_CACHE = False # reset to default value + assert zntrack.config.ALWAYS_CACHE is False + + assert json.loads( + (CWD / "zntrack_config" / "user_config.json").read_text() + ) == json.loads((proj_path / "zntrack.json").read_text()) + assert yaml.safe_load( + (CWD / "dvc_config" / "user_config.yaml").read_text() + ) == yaml.safe_load((proj_path / "dvc.yaml").read_text()) + assert (CWD / "params_config" / "user_config.yaml").read_text() == ( + proj_path / "params.yaml" + ).read_text() + + +if __name__ == "__main__": + test_node("") diff --git a/tests/files/zntrack_config/user_config.json b/tests/files/zntrack_config/user_config.json new file mode 100644 index 00000000..f8b52f26 --- /dev/null +++ b/tests/files/zntrack_config/user_config.json @@ -0,0 +1,12 @@ +{ + "MyNode": { + "nwd": { + "_type": "pathlib.Path", + "value": "nodes/MyNode" + }, + "metrics_path": { + "_type": "pathlib.Path", + "value": "$nwd$/results.json" + } + } +} diff --git a/zntrack/__init__.py b/zntrack/__init__.py index e9bbd1b7..21fc9c22 100644 --- a/zntrack/__init__.py +++ b/zntrack/__init__.py @@ -2,6 +2,7 @@ import logging import sys +from zntrack import config from zntrack.add import add from zntrack.apply import apply from zntrack.fields import ( @@ -40,6 +41,7 @@ "apply", "add", "field", + "config", ] logger = logging.getLogger(__name__) diff --git a/zntrack/config.py b/zntrack/config.py index 9b678d46..2e7476c6 100644 --- a/zntrack/config.py +++ b/zntrack/config.py @@ -9,6 +9,13 @@ EXP_INFO_PATH = pathlib.Path(".exp_info.yaml") +# For "node-meta.json" and "dvc stage add ... --metrics-no-cache" the default is using +# git tracked files. Setting this to True will override the default behavior to always +# use the DVC cache. If you have a DVC cache setup, this might be desirable, to avoid +# a mixture between DVC cache and git tracked files. +ALWAYS_CACHE: bool = False + + # Use sentinel object for zntrack specific configurations. Use # a class to give it a better repr. class _ZNTRACK_OPTION_TYPE: diff --git a/zntrack/fields/outs_and_metrics.py b/zntrack/fields/outs_and_metrics.py index df4b4f49..f7158dae 100644 --- a/zntrack/fields/outs_and_metrics.py +++ b/zntrack/fields/outs_and_metrics.py @@ -2,6 +2,7 @@ import znjson +from zntrack import config from zntrack.config import NOT_AVAILABLE, ZnTrackOptionEnum from zntrack.fields.base import field from zntrack.node import Node @@ -45,7 +46,9 @@ def outs(*, cache: bool = True, independent: bool = False, **kwargs): ) -def metrics(*, cache: bool = False, independent: bool = False, **kwargs): +def metrics(*, cache: bool | None = None, independent: bool = False, **kwargs): + if cache is None: + cache = config.ALWAYS_CACHE return field( default=NOT_AVAILABLE, cache=cache, diff --git a/zntrack/fields/x_path.py b/zntrack/fields/x_path.py index e00250ed..97a73cc0 100644 --- a/zntrack/fields/x_path.py +++ b/zntrack/fields/x_path.py @@ -4,6 +4,7 @@ import znfields import znjson +from zntrack import config from zntrack.config import ( NOT_AVAILABLE, ZNTRACK_CACHE, @@ -88,10 +89,12 @@ def plots_path( def metrics_path( default=dataclasses.MISSING, *, - cache: bool = False, + cache: bool | None = None, independent: bool = False, **kwargs, ): + if cache is None: + cache = config.ALWAYS_CACHE kwargs["metadata"] = kwargs.get("metadata", {}) kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.METRICS_PATH kwargs["metadata"][ZNTRACK_CACHE] = cache diff --git a/zntrack/plugins/dvc_plugin/__init__.py b/zntrack/plugins/dvc_plugin/__init__.py index c3e1a74d..ab72673b 100644 --- a/zntrack/plugins/dvc_plugin/__init__.py +++ b/zntrack/plugins/dvc_plugin/__init__.py @@ -12,7 +12,7 @@ import znflow.utils import znjson -from zntrack import converter +from zntrack import config, converter from zntrack.config import ( NOT_AVAILABLE, PARAMS_FILE_PATH, @@ -129,7 +129,11 @@ def convert_to_dvc_yaml(self) -> dict | object: stages = { "cmd": cmd, "metrics": [ - {(self.node.nwd / "node-meta.json").as_posix(): {"cache": False}} + { + (self.node.nwd / "node-meta.json").as_posix(): { + "cache": config.ALWAYS_CACHE + } + } ], } if self.node.always_changed: