Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Jan 24, 2025
1 parent 4bfb32c commit 17c9f24
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 24 deletions.
3 changes: 1 addition & 2 deletions tests/files/dvc_config/user_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ stages:
MyNode:
cmd: zntrack run test_user_config.MyNode --name MyNode
metrics:
- nodes/MyNode/metric.json
- nodes/MyNode/node-meta.json:
cache: true
- nodes/MyNode/metrics.json:
cache: true
33 changes: 22 additions & 11 deletions tests/files/test_user_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,38 @@
CWD = pathlib.Path(__file__).parent.resolve()


class MyNode(zntrack.Node):
"""Some Node."""
# class MyNode(zntrack.Node):
# """Some Node."""

metric: dict = zntrack.metrics()
# metric: dict = zntrack.metrics()

def run(self) -> None:
self.metric = {"a": 1, "b": 2}
# def run(self) -> None:
# self.metric = {"a": 1, "b": 2}


def test_node(proj_path):
assert zntrack.config.ALWAYS_CACHE is False
zntrack.config.ALWAYS_CACHE = True
assert zntrack.config.ALWAYS_CACHE is True
assert zntrack.Config.ALWAYS_CACHE is False
zntrack.Config.ALWAYS_CACHE = True
assert zntrack.Config.ALWAYS_CACHE is True

# We define the node here, because the config has to be set
# bevore calling zntrack.metrics()
class MyNode(zntrack.Node):
"""Some Node."""

metric: dict = zntrack.metrics()

def run(self) -> None:
self.metric = {"a": 1, "b": 2}


with zntrack.Project() as proj:
node = MyNode()

proj.repro()
proj.build()

zntrack.config.ALWAYS_CACHE = False # reset to default value
assert zntrack.config.ALWAYS_CACHE is False
zntrack.Config.ALWAYS_CACHE = False # reset to default value
assert zntrack.Config.ALWAYS_CACHE is False

assert json.loads(
(CWD / "zntrack_config" / "user_config.json").read_text()
Expand Down
4 changes: 2 additions & 2 deletions zntrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from zntrack.add import add
from zntrack.apply import apply
from zntrack.config import config
from zntrack.config import Config
from zntrack.fields import (
deps,
deps_path,
Expand Down Expand Up @@ -41,7 +41,7 @@
"apply",
"add",
"field",
"config",
"Config",
]

logger = logging.getLogger(__name__)
Expand Down
3 changes: 0 additions & 3 deletions zntrack/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ class Config:
ALWAYS_CACHE: bool = False


config = Config()


# Use sentinel object for zntrack specific configurations. Use
# a class to give it a better repr.
class _ZNTRACK_OPTION_TYPE:
Expand Down
6 changes: 4 additions & 2 deletions zntrack/fields/outs_and_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import znjson

from zntrack.config import NOT_AVAILABLE, ZnTrackOptionEnum, config
from zntrack.config import NOT_AVAILABLE, ZnTrackOptionEnum, Config
from zntrack.fields.base import field
from zntrack.node import Node

Expand Down Expand Up @@ -45,7 +45,9 @@ def outs(*, cache: bool = True, independent: bool = False, **kwargs):
)


def metrics(*, cache: bool = config.ALWAYS_CACHE, independent: bool = False, **kwargs):
def metrics(*, cache: bool|None = None, independent: bool = False, **kwargs):
if cache is None:
cache = Config.ALWAYS_CACHE
return field(
default=NOT_AVAILABLE,
cache=cache,
Expand Down
4 changes: 2 additions & 2 deletions zntrack/fields/x_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ZNTRACK_LAZY_VALUE,
ZNTRACK_OPTION,
ZnTrackOptionEnum,
config,
Config,
)

# if t.TYPE_CHECKING:
Expand Down Expand Up @@ -89,7 +89,7 @@ def plots_path(
def metrics_path(
default=dataclasses.MISSING,
*,
cache: bool = config.ALWAYS_CACHE,
cache: bool = Config.ALWAYS_CACHE,
independent: bool = False,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions zntrack/plugins/dvc_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ZNTRACK_OPTION,
ZNTRACK_OPTION_PLOTS_CONFIG,
ZnTrackOptionEnum,
config,
Config,
)

# if t.TYPE_CHECKING:
Expand Down Expand Up @@ -132,7 +132,7 @@ def convert_to_dvc_yaml(self) -> dict | object:
"metrics": [
{
(self.node.nwd / "node-meta.json").as_posix(): {
"cache": config.ALWAYS_CACHE
"cache": Config.ALWAYS_CACHE
}
}
],
Expand Down

0 comments on commit 17c9f24

Please sign in to comment.