Skip to content

Commit

Permalink
add zntrack.config to update default values (#870)
Browse files Browse the repository at this point in the history
* add `zntrack.config` to update default values

* update version

* update tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add metrics

* fix test

* update tests, use `zntrack.config`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Jan 24, 2025
1 parent 3cd1929 commit 44e484c
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zntrack"
version = "0.8.1"
version = "0.8.2"
description = "Create, Run and Benchmark DVC Pipelines in Python"
authors = ["zincwarecode <zincwarecode@gmail.com>"]
license = "Apache-2.0"
Expand Down
8 changes: 8 additions & 0 deletions tests/files/dvc_config/user_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
stages:
MyNode:
cmd: zntrack run test_user_config.MyNode --name MyNode
metrics:
- nodes/MyNode/metric.json
- nodes/MyNode/node-meta.json:
cache: true
- nodes/MyNode/results.json
1 change: 1 addition & 0 deletions tests/files/params_config/user_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
49 changes: 49 additions & 0 deletions tests/files/test_user_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Test params such as always_changed, ..."""

import json
import pathlib

import yaml

import zntrack

CWD = pathlib.Path(__file__).parent.resolve()


def test_node(proj_path):
assert zntrack.config.ALWAYS_CACHE is False
zntrack.config.ALWAYS_CACHE = True
assert zntrack.config.ALWAYS_CACHE is True

# We define the node here, because the config has to be set
# bevore calling zntrack.metrics()
class MyNode(zntrack.Node):
"""Some Node."""

metric: dict = zntrack.metrics()
metrics_path: pathlib.Path = zntrack.metrics_path(zntrack.nwd / "results.json")

def run(self) -> None:
self.metric = {"a": 1, "b": 2}

with zntrack.Project() as proj:
node = MyNode()

proj.build()

zntrack.config.ALWAYS_CACHE = False # reset to default value
assert zntrack.config.ALWAYS_CACHE is False

assert json.loads(
(CWD / "zntrack_config" / "user_config.json").read_text()
) == json.loads((proj_path / "zntrack.json").read_text())
assert yaml.safe_load(
(CWD / "dvc_config" / "user_config.yaml").read_text()
) == yaml.safe_load((proj_path / "dvc.yaml").read_text())
assert (CWD / "params_config" / "user_config.yaml").read_text() == (
proj_path / "params.yaml"
).read_text()


if __name__ == "__main__":
test_node("")
12 changes: 12 additions & 0 deletions tests/files/zntrack_config/user_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"MyNode": {
"nwd": {
"_type": "pathlib.Path",
"value": "nodes/MyNode"
},
"metrics_path": {
"_type": "pathlib.Path",
"value": "$nwd$/results.json"
}
}
}
2 changes: 2 additions & 0 deletions zntrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import sys

from zntrack import config
from zntrack.add import add
from zntrack.apply import apply
from zntrack.fields import (
Expand Down Expand Up @@ -40,6 +41,7 @@
"apply",
"add",
"field",
"config",
]

logger = logging.getLogger(__name__)
Expand Down
7 changes: 7 additions & 0 deletions zntrack/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
EXP_INFO_PATH = pathlib.Path(".exp_info.yaml")


# For "node-meta.json" and "dvc stage add ... --metrics-no-cache" the default is using
# git tracked files. Setting this to True will override the default behavior to always
# use the DVC cache. If you have a DVC cache setup, this might be desirable, to avoid
# a mixture between DVC cache and git tracked files.
ALWAYS_CACHE: bool = False


# Use sentinel object for zntrack specific configurations. Use
# a class to give it a better repr.
class _ZNTRACK_OPTION_TYPE:
Expand Down
5 changes: 4 additions & 1 deletion zntrack/fields/outs_and_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import znjson

from zntrack import config
from zntrack.config import NOT_AVAILABLE, ZnTrackOptionEnum
from zntrack.fields.base import field
from zntrack.node import Node
Expand Down Expand Up @@ -45,7 +46,9 @@ def outs(*, cache: bool = True, independent: bool = False, **kwargs):
)


def metrics(*, cache: bool = False, independent: bool = False, **kwargs):
def metrics(*, cache: bool | None = None, independent: bool = False, **kwargs):
if cache is None:
cache = config.ALWAYS_CACHE
return field(
default=NOT_AVAILABLE,
cache=cache,
Expand Down
5 changes: 4 additions & 1 deletion zntrack/fields/x_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import znfields
import znjson

from zntrack import config
from zntrack.config import (
NOT_AVAILABLE,
ZNTRACK_CACHE,
Expand Down Expand Up @@ -88,10 +89,12 @@ def plots_path(
def metrics_path(
default=dataclasses.MISSING,
*,
cache: bool = False,
cache: bool | None = None,
independent: bool = False,
**kwargs,
):
if cache is None:
cache = config.ALWAYS_CACHE
kwargs["metadata"] = kwargs.get("metadata", {})
kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.METRICS_PATH
kwargs["metadata"][ZNTRACK_CACHE] = cache
Expand Down
8 changes: 6 additions & 2 deletions zntrack/plugins/dvc_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import znflow.utils
import znjson

from zntrack import converter
from zntrack import config, converter
from zntrack.config import (
NOT_AVAILABLE,
PARAMS_FILE_PATH,
Expand Down Expand Up @@ -129,7 +129,11 @@ def convert_to_dvc_yaml(self) -> dict | object:
stages = {
"cmd": cmd,
"metrics": [
{(self.node.nwd / "node-meta.json").as_posix(): {"cache": False}}
{
(self.node.nwd / "node-meta.json").as_posix(): {
"cache": config.ALWAYS_CACHE
}
}
],
}
if self.node.always_changed:
Expand Down

0 comments on commit 44e484c

Please sign in to comment.