Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 5, 2024
1 parent f3207f1 commit 7923ada
Show file tree
Hide file tree
Showing 15 changed files with 58 additions and 0 deletions.
2 changes: 2 additions & 0 deletions zntrack/core/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _import_from_tempfile(package_and_module: str, remote, rev):
If the module could not be found.
FileNotFoundError
If the file could not be found.
"""
file = pathlib.Path(*package_and_module.split(".")).with_suffix(".py")
fs = dvc.api.DVCFileSystem(url=remote, rev=rev)
Expand Down Expand Up @@ -93,6 +94,7 @@ def from_rev(name, remote=".", rev=None, **kwargs) -> T:
-------
Node
The loaded node.
"""
if isinstance(name, Node):
name = name.name
Expand Down
4 changes: 4 additions & 0 deletions zntrack/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class NodeStatus:
The temporary path used for loading the data.
This is only set within the context manager 'use_tmp_path'.
If neither 'remote' nor 'rev' are set, tmp_path will not be used.
"""

loaded: bool
Expand Down Expand Up @@ -182,6 +183,7 @@ class Node(zninit.ZnInit, znflow.Node):
information about the state of the Node.
nwd : pathlib.Path
the node working directory.
"""

_state: NodeStatus = None
Expand Down Expand Up @@ -215,6 +217,7 @@ def convert_notebook(cls, nb_name: str = None):
----------
nb_name: str
Notebook name when not using config.nb_name (this is not recommended)
"""
# TODO this should not be a class method, but a function.
jupyter_class_to_file(nb_name=nb_name, module_name=cls.__name__)
Expand Down Expand Up @@ -302,6 +305,7 @@ def load(self, lazy: bool = None, results: bool = True) -> None:
Whether to load the node lazily. If None, the value from the config is used.
results : bool, default = True
Whether to load the results. If False, only the parameters are loaded.
"""
from zntrack.fields.field import Field, FieldGroup

Expand Down
6 changes: 6 additions & 0 deletions zntrack/core/nodify.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DVCRunOptions:
References
----------
https://dvc.org/doc/command-reference/run#options.
"""

no_commit: bool
Expand All @@ -51,6 +52,7 @@ def dvc_args(self) -> list:
-------
list: A list of strings for the subprocess call, e.g.:
["--no-commit", "--external"].
"""
out = []
for datacls_field in dataclasses.fields(self):
Expand Down Expand Up @@ -97,6 +99,7 @@ def prepare_dvc_script(
-------
list[str]
The list to be passed to the subprocess call.
"""
script = ["stage", "add", "-n", node_name]
script += dvc_run_option.dvc_args
Expand Down Expand Up @@ -134,6 +137,7 @@ def check_type(
accept None even if not in types.
allow_dict:
allow for {key: types}
"""
if isinstance(obj, (list, tuple, set)) and allow_iterable:
for value in obj:
Expand Down Expand Up @@ -254,6 +258,7 @@ def save_node_config_to_files(cfg: NodeConfig, node_name: str):
The NodeConfig object which should be serialized to zntrack.json / params.yaml
node_name: str
The name of the node, usually func.__name__.
"""
for value_name, value in dataclasses.asdict(cfg).items():
if value_name == "params":
Expand Down Expand Up @@ -339,6 +344,7 @@ def nodify(
References
----------
https://dvc.org/doc/command-reference/run#options
"""
cfg_ = NodeConfig(
outs=outs,
Expand Down
3 changes: 3 additions & 0 deletions zntrack/exceptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self, arg):
----------
arg : str|Node
Custom Error message or Node that is not available.
"""
if isinstance(arg, str):
super().__init__(arg)
Expand All @@ -33,6 +34,7 @@ def __init__(self, node, field, instance):
The 'zn.nodes' field
instance : Node
The node that contains the 'zn.nodes' field
"""
msg = (
f"Can not set '{field.name}' of Node<'{instance.name}'> to"
Expand All @@ -59,6 +61,7 @@ def __init__(self, node):
----------
node: Node
The node that is already on the graph.
"""
msg = (
f"Node name '{node.name}' is already used in the graph. Please use"
Expand Down
1 change: 1 addition & 0 deletions zntrack/fields/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _get_nodes_on_off_graph(self, instance) -> t.Tuple[list, list]:
The nodes that are on the graph.
off_graph : list
The nodes that are off the graph.
"""
values = getattr(instance, self.name)
# TODO use IterableHandler?
Expand Down
1 change: 1 addition & 0 deletions zntrack/fields/dvc/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def get_data(self, instance: "Node") -> any:
-------
any
The value of the field from the configuration file.
"""
zntrack_dict = json.loads(
instance.state.fs.read_text("zntrack.json"),
Expand Down
9 changes: 9 additions & 0 deletions zntrack/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Field(zninit.Descriptor, abc.ABC):
----------
dvc_option : str
The dvc command option for this field.
"""

dvc_option: str = None
Expand All @@ -49,6 +50,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The Node instance to save the field for.
"""
raise NotImplementedError

Expand All @@ -70,6 +72,7 @@ def get_files(self, instance: "Node") -> list:
-------
list
The affected files.
"""
raise NotImplementedError

Expand All @@ -83,6 +86,7 @@ def load(self, instance: "Node", lazy: bool = None):
lazy : bool, optional
Whether to load the field lazily.
This only applies to 'LazyField' classes.
"""
try:
instance.__dict__[self.name] = self.get_data(instance)
Expand All @@ -103,6 +107,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
-------
typing.List[tuple]
The stage add argument for this field.
"""
return [
(f"--{self.dvc_option}", pathlib.Path(x).as_posix())
Expand All @@ -127,6 +132,7 @@ def get_optional_dvc_cmd(
-------
typing.List[str]
The optional dvc commands.
"""
return []

Expand Down Expand Up @@ -173,6 +179,7 @@ def get_value_except_lazy(self, instance):
------
DataIsLazyError
If the value is lazy.
"""
with contextlib.suppress(KeyError):
if instance.__dict__[self.name] is LazyOption:
Expand All @@ -198,6 +205,7 @@ def load(self, instance: "Node", lazy: bool = None):
The Node instance to load the field for.
lazy : bool, optional
Whether to load the field lazily, by default 'zntrack.config.lazy'.
"""
if lazy in {None, True} and config.lazy:
instance.__dict__[self.name] = LazyOption
Expand Down Expand Up @@ -226,6 +234,7 @@ def __init__(
----------
use_global_plots : bool
Save the plots config not in 'stages' but in 'plots' in the dvc.yaml file.
"""
super().__init__(*args, **kwargs)
self.plots_options = {}
Expand Down
5 changes: 5 additions & 0 deletions zntrack/fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def outs():
The object is serialized and deserialized by ZnTrack
and stored in the node working directory.
see https://dvc.org/doc/command-reference/stage/add#-o
"""
return Output(dvc_option="outs", use_repr=False)

Expand Down Expand Up @@ -49,6 +50,7 @@ def params(*args, **kwargs):
see https://dvc.org/doc/command-reference/stage/add#-p
kwargs: dict
Additional keyword arguments.
"""
return Params(*args, **kwargs)

Expand All @@ -63,6 +65,7 @@ def deps(*data):
This can either be a Node or an attribute of a Node.
It can not be an object that is not part of the Node graph.
see https://dvc.org/doc/command-reference/stage/add#-d
"""
return Dependency(*data)

Expand Down Expand Up @@ -132,6 +135,7 @@ def params_path(*args, **kwargs):
see https://dvc.org/doc/command-reference/stage/add#-p
kwargs: dict
Additional keyword arguments.
"""
return DVCOption(*args, dvc_option="params", **kwargs)

Expand Down Expand Up @@ -163,5 +167,6 @@ def plots_path(*args, dvc_option="plots", **kwargs):
The DVC option to use for this field.
kwargs: dict
Additional keyword arguments that are used for plotting.
"""
return PlotsOption(*args, dvc_option=dvc_option, **kwargs)
8 changes: 8 additions & 0 deletions zntrack/fields/zn/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class Params(Field):
----------
dvc_option: str
The DVC option to use. Default is "params".
"""

dvc_option: str = "params"
Expand All @@ -115,6 +116,7 @@ def get_files(self, instance: "Node") -> list:
-------
list
A list of file paths.
"""
return [config.files.params]

Expand All @@ -125,6 +127,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The node instance associated with this field.
"""
file = self.get_files(instance)[0]

Expand Down Expand Up @@ -161,6 +164,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
-------
list
A list of tuples containing the DVC option and the file path.
"""
file = self.get_files(instance)[0]
return [(f"--{self.dvc_option}", f"{file}:{instance.name}")]
Expand All @@ -180,6 +184,7 @@ def __init__(self, dvc_option: str, **kwargs):
The DVC option used to specify the output file.
**kwargs
Additional arguments to pass to the parent constructor.
"""
self.dvc_option = dvc_option
super().__init__(**kwargs)
Expand All @@ -196,6 +201,7 @@ def get_files(self, instance) -> list:
-------
list
A list containing the path of the file.
"""
return [get_nwd(instance) / f"{self.name}.json"]

Expand All @@ -206,6 +212,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The node instance.
"""
try:
value = self.get_value_except_lazy(instance)
Expand Down Expand Up @@ -236,6 +243,7 @@ def get_stage_add_argument(self, instance) -> typing.List[tuple]:
-------
list
A list containing the DVC command for this field.
"""
file = self.get_files(instance)[0]
return [(f"--{self.dvc_option}", file.as_posix())]
Expand Down
4 changes: 4 additions & 0 deletions zntrack/project/zntrack_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class Project:
This will require a DVC remote to be setup.
force : bool, default = False
overwrite existing nodes.
"""

graph: ZnTrackGraph = dataclasses.field(default_factory=ZnTrackGraph, init=False)
Expand All @@ -105,6 +106,7 @@ def __post_init__(self):
remove_existing_graph : bool, default = False
If True, remove 'dvc.yaml', 'zntrack.json' and 'params.yaml'
before writing new nodes.
"""
self.graph.project = self
if self.initialize:
Expand Down Expand Up @@ -146,6 +148,7 @@ def group(self, *names: typing.List[str]):
The name of the group. If None, the group will be named 'GroupX' where X is
the number of groups + 1. If more than one name is given, the groups will
be nested to 'nwd = name[0]/name[1]/.../name[-1]'
"""
if not names:
name = "Group1"
Expand Down Expand Up @@ -237,6 +240,7 @@ def run(
auto_remove : bool, default = False
If True, remove all nodes from 'dvc.yaml' that are not in the graph.
This is the same as calling 'project.auto_remove()'
"""
if not save and not eager:
raise ValueError("Save can only be false if eager is True")
Expand Down
1 change: 1 addition & 0 deletions zntrack/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def timeit(field: str):
field : str
The field to store the time in.
The value is stored as {func_name: time} or {func_name: [time1, time2, ...]}
"""

def decorator(func):
Expand Down
5 changes: 5 additions & 0 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self) -> None:
------
NotImplementedError:
This class is not meant to be instantiated.
"""
raise NotImplementedError("This class is not meant to be instantiated.")

Expand All @@ -60,6 +61,7 @@ def module_handler(obj) -> str:
----------
obj:
Any object that implements __module__
"""
if config.nb_name:
try:
Expand Down Expand Up @@ -111,6 +113,7 @@ def run_dvc_cmd(script, stdout=None):
------
DVCProcessError:
if the dvc cli command fails.
"""
dvc_short_string = " ".join(script[:5])
if len(script) > 5:
Expand Down Expand Up @@ -177,6 +180,7 @@ class NodeStatusResults(enum.Enum):
the Node instance has failed to run.
AVAILABLE : int
the Node instance was loaded and results are available.
"""

UNKNOWN = 0
Expand All @@ -202,6 +206,7 @@ def cwd_temp_dir(required_files=None) -> tempfile.TemporaryDirectory:
-------
temp_dir:
The temporary directory file. Close with temp_dir.cleanup() at the end.
"""
temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
# add ignore_cleanup_errors=True in Py3.10?
Expand Down
Loading

0 comments on commit 7923ada

Please sign in to comment.