diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8b972a60..f28c3e09 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,29 +20,19 @@ repos: args: ['--fix=lf'] - id: sort-simple-yaml - id: trailing-whitespace - - repo: https://github.com/psf/black - rev: 24.10.0 - hooks: - - id: black - additional_dependencies: [".[jupyter]"] - types_or: [python, pyi, jupyter] - - repo: https://github.com/PyCQA/isort - rev: 5.13.2 - hooks: - - id: isort - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 + rev: v2.4.1 hooks: - id: codespell additional_dependencies: ["tomli"] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.8.2' + rev: 'v0.9.4' hooks: - id: ruff args: ['--fix'] - repo: https://github.com/executablebooks/mdformat - rev: 0.7.19 + rev: 0.7.22 hooks: - id: mdformat args: ["--wrap=80"] diff --git a/examples/docs/04_metrics_and_plots.ipynb b/examples/docs/04_metrics_and_plots.ipynb index aed5db1d..912d25da 100644 --- a/examples/docs/04_metrics_and_plots.ipynb +++ b/examples/docs/04_metrics_and_plots.ipynb @@ -623,9 +623,8 @@ " def run(self):\n", " self.my_metric = {\"alpha\": 1.0, \"beta\": 0.00473}\n", " self.my_plot = pd.DataFrame({\"val\": np.sin(np.linspace(0, 3.14, 100))})\n", - " self.my_plot.index.name = ( # For DVC it is required that the index has a column name\n", - " \"index\"\n", - " )\n", + " # For DVC it is required that the index has a column name\n", + " self.my_plot.index.name = \"index\"\n", "\n", "\n", "with zntrack.Project() as project:\n", diff --git a/examples/docs/09_lazy.ipynb b/examples/docs/09_lazy.ipynb index de97235b..7cb7a53d 100644 --- a/examples/docs/09_lazy.ipynb +++ b/examples/docs/09_lazy.ipynb @@ -630,7 +630,7 @@ "\n", " add_one = AddOne(deps=random_number, name=\"AddOne_0\")\n", " for index in range(10):\n", - " add_one = AddOne(deps=add_one, name=f\"AddOne_{index+1}\")\n", + " add_one = AddOne(deps=add_one, name=f\"AddOne_{index + 1}\")\n", "\n", "project.run()" ] diff --git a/examples/docs/parameter_optimization.ipynb b/examples/docs/parameter_optimization.ipynb index 7f44a80d..1323eb3e 100644 --- a/examples/docs/parameter_optimization.ipynb +++ b/examples/docs/parameter_optimization.ipynb @@ -45,10 +45,10 @@ "source": [ "# Setup temporary directory and initialize git and dvc\n", "from zntrack import config\n", + "from zntrack.utils import cwd_temp_dir\n", "\n", "config.nb_name = \"parameter_optimization.ipynb\"\n", "\n", - "from zntrack.utils import cwd_temp_dir\n", "\n", "temp_dir = cwd_temp_dir()\n", "\n", diff --git a/pyproject.toml b/pyproject.toml index ca9a744e..f753f98e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dev = [ "mlflow>=2.20.0", "nbsphinx>=0.9.6", "nbsphinx-link>=1.3.1", + "pre-commit>=4.1.0", "pytest>=8.3.4", "pytest-benchmark>=5.1.0", "sphinx>=8.1.3", @@ -94,7 +95,7 @@ disable = [ line-length = 90 [tool.ruff.lint] -select = ["E", "F", "D", "N", "C", "I"] #, "ANN"] +select = ["E", "F", "N", "C", "I"] #, "ANN"] extend-ignore = [ "D213", "D203", "D401", diff --git a/tests/files/test_apply.py b/tests/files/test_apply.py index 3c77a800..30ff8e69 100644 --- a/tests/files/test_apply.py +++ b/tests/files/test_apply.py @@ -13,12 +13,12 @@ def test_apply(proj_path) -> None: project = zntrack.Project() - JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") + JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") # noqa N806 with project: - a = zntrack.examples.ParamsToOuts(params=["a", "b"]) - b = JoinedParamsToOuts(params=["a", "b"]) - c = zntrack.apply(zntrack.examples.ParamsToOuts, "join")(params=["a", "b", "c"]) + zntrack.examples.ParamsToOuts(params=["a", "b"]) + JoinedParamsToOuts(params=["a", "b"]) + zntrack.apply(zntrack.examples.ParamsToOuts, "join")(params=["a", "b", "c"]) project.build() diff --git a/tests/files/test_custom_field.py b/tests/files/test_custom_field.py index d3045e22..cd309adb 100644 --- a/tests/files/test_custom_field.py +++ b/tests/files/test_custom_field.py @@ -42,7 +42,7 @@ def run(self): def test_text_node(proj_path): with zntrack.Project() as project: - node = TextNode() + TextNode() project.build() assert json.loads( diff --git a/tests/files/test_meta_params.py b/tests/files/test_meta_params.py index af99b488..b020fd80 100644 --- a/tests/files/test_meta_params.py +++ b/tests/files/test_meta_params.py @@ -19,7 +19,7 @@ def run(self) -> None: def test_node(proj_path): with zntrack.Project() as proj: - node = MyNode(name="some-node", always_changed=True) + MyNode(name="some-node", always_changed=True) proj.repro() diff --git a/tests/files/test_metrics_deps.py b/tests/files/test_metrics_deps.py index 44f71208..84eff6ac 100644 --- a/tests/files/test_metrics_deps.py +++ b/tests/files/test_metrics_deps.py @@ -14,8 +14,8 @@ def test_metrics_as_deps(proj_path): with project: metrics_node = zntrack.examples.ParamsToMetrics(params={"loss": 0.01}) - metrics_deps = zntrack.examples.DepsToMetrics(deps=metrics_node.metrics) - params_deps = zntrack.examples.DepsToMetrics(deps=metrics_node.params) + zntrack.examples.DepsToMetrics(deps=metrics_node.metrics) + zntrack.examples.DepsToMetrics(deps=metrics_node.params) project.build() diff --git a/tests/files/test_user_config.py b/tests/files/test_user_config.py index 6b32f1c1..9ffe5c4f 100644 --- a/tests/files/test_user_config.py +++ b/tests/files/test_user_config.py @@ -16,7 +16,7 @@ def test_node(proj_path): assert zntrack.config.ALWAYS_CACHE is True # We define the node here, because the config has to be set - # bevore calling zntrack.metrics() + # before calling zntrack.metrics() class MyNode(zntrack.Node): """Some Node.""" @@ -27,7 +27,7 @@ def run(self) -> None: self.metric = {"a": 1, "b": 2} with zntrack.Project() as proj: - node = MyNode() + MyNode() proj.build() diff --git a/tests/integration/skip_test_meta.py b/tests/integration/skip_test_meta.py index 941c134c..bac020ff 100644 --- a/tests/integration/skip_test_meta.py +++ b/tests/integration/skip_test_meta.py @@ -18,9 +18,9 @@ class NodeWithEnv(zntrack.Node): def run(self): import os - assert ( - os.environ["OMP_NUM_THREADS"] == self.OMP_NUM_THREADS - ), f'{os.environ["OMP_NUM_THREADS"]} != {self.OMP_NUM_THREADS}' + assert os.environ["OMP_NUM_THREADS"] == self.OMP_NUM_THREADS, ( + f"{os.environ['OMP_NUM_THREADS']} != {self.OMP_NUM_THREADS}" + ) self.result = os.environ["OMP_NUM_THREADS"] diff --git a/tests/integration/test_apply_method.py b/tests/integration/test_apply_method.py index 3bce961f..1f5306c4 100644 --- a/tests/integration/test_apply_method.py +++ b/tests/integration/test_apply_method.py @@ -10,7 +10,7 @@ def test_apply(proj_path, eager) -> None: """Test the "zntrack.apply" function.""" project = zntrack.Project() - JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") + JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") # noqa N806 with project: a = zntrack.examples.ParamsToOuts(params=["a", "b"]) @@ -33,7 +33,7 @@ def test_deps_apply(proj_path, eager, attribute): """Test connecting applied nodes to other nodes.""" project = zntrack.Project() - JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") + JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") # noqa N806 assert issubclass(JoinedParamsToOuts, zntrack.Node) diff --git a/tests/integration/test_build.py b/tests/integration/test_build.py index ae0f85bd..0e52617a 100644 --- a/tests/integration/test_build.py +++ b/tests/integration/test_build.py @@ -13,7 +13,7 @@ def test_nodes_not_created(proj_path): b = zntrack.examples.ParamsToOuts( params=18, ) - c = zntrack.examples.AddNodeAttributes( + zntrack.examples.AddNodeAttributes( a=a.outs, b=b.outs, ) @@ -25,7 +25,7 @@ def test_nodes_not_created(proj_path): e = zntrack.examples.ParamsToOuts( params=18, ) - f = zntrack.examples.AddNodeAttributes( + zntrack.examples.AddNodeAttributes( a=d.outs, b=e.outs, ) diff --git a/tests/integration/test_dvc_outs.py b/tests/integration/test_dvc_outs.py index e0bab904..223bf340 100644 --- a/tests/integration/test_dvc_outs.py +++ b/tests/integration/test_dvc_outs.py @@ -40,7 +40,7 @@ def run(self): def test_run_temp_path(proj_path): project = zntrack.Project() with project: - node = AssertTempPath() + AssertTempPath() project.repro() diff --git a/tests/integration/test_node_node.py b/tests/integration/test_node_node.py index 24ac9036..5f3c0c33 100644 --- a/tests/integration/test_node_node.py +++ b/tests/integration/test_node_node.py @@ -21,7 +21,7 @@ def test_AddNodes(proj_path, eager): project.repro() assert add_numbers_a.c == 3 - # TODO: Node status is not beind updated when not using from_rev + # TODO: Node status is not being updated when not using from_rev if eager: assert add_numbers_a.state.state == NodeStatusEnum.FINISHED assert add_numbers_b.c == 4 diff --git a/tests/integration/test_node_nwd_write.py b/tests/integration/test_node_nwd_write.py index 8348e51d..fb9eae32 100644 --- a/tests/integration/test_node_nwd_write.py +++ b/tests/integration/test_node_nwd_write.py @@ -5,7 +5,8 @@ import zntrack -# TODO UNCLEAR ERROR, __file__ atribute is not the same as the test file we want to collect +# TODO UNCLEAR ERROR, __file__ attribute is not the +# same as the test file we want to collect class WriteToNWD(zntrack.Node): diff --git a/tests/integration/test_options_plots.py b/tests/integration/test_options_plots.py index b93453e2..d47d15fb 100644 --- a/tests/integration/test_options_plots.py +++ b/tests/integration/test_options_plots.py @@ -101,7 +101,7 @@ def test_groups(proj_path): def test_autosave(proj_path): with zntrack.Project() as proj: - node = AutoSavePandasPlotNode(n=10) + AutoSavePandasPlotNode(n=10) proj.build() subprocess.run(["dvc", "repro"], cwd=proj_path, check=True) diff --git a/tests/integration/test_plugins_aim.py b/tests/integration/test_plugins_aim.py index af579001..4813906e 100644 --- a/tests/integration/test_plugins_aim.py +++ b/tests/integration/test_plugins_aim.py @@ -42,7 +42,7 @@ class RangePlotter(zntrack.Node): plots: pd.DataFrame = zntrack.plots(y="range") def run(self): - self.plots = pd.DataFrame({"idx": [idx for idx in range(self.start, self.stop)]}) + self.plots = pd.DataFrame({"idx": list(range(self.start, self.stop))}) # fixture to set the os.env before the test and remove if after the test @@ -126,7 +126,7 @@ def test_aim_plotting(aim_proj_path): metrics = {} for metric in run.metrics(): metrics[metric.name] = list(metric.data.values())[0] - npt.assert_array_equal(metrics["plots.idx"], [[idx for idx in range(10)]]) + npt.assert_array_equal(metrics["plots.idx"], [list(range(10))]) proj.finalize(msg="test") repo = git.Repo() @@ -170,17 +170,26 @@ def test_multiple_nodes(aim_proj_path): aim_repo = aim.Repo(path=os.environ["AIM_TRACKING_URI"]) for run_metrics_col in aim_repo.query_metrics( - f"run.dvc_stage_name == '{a.name}' and run.git_commit_hash == '{repo.head.commit.hexsha}'" + ( + f"run.dvc_stage_name == '{a.name}' and " + f"run.git_commit_hash == '{repo.head.commit.hexsha}'" + ) ).iter(): assert "original_run_id" not in run_metrics_col.run.dataframe().columns for run_metrics_col in aim_repo.query_metrics( - f"run.dvc_stage_name == '{c.name}' and run.git_commit_hash == '{repo.head.commit.hexsha}'" + ( + f"run.dvc_stage_name == '{c.name}' and " + f"run.git_commit_hash == '{repo.head.commit.hexsha}'" + ) ).iter(): assert "original_run_id" not in run_metrics_col.run.dataframe().columns for run_metrics_col in aim_repo.query_metrics( - f"run.dvc_stage_name == '{b.name}' and run.git_commit_hash == '{repo.head.commit.hexsha}'" + ( + f"run.dvc_stage_name == '{b.name}' and " + f"run.git_commit_hash == '{repo.head.commit.hexsha}'" + ) ).iter(): assert run_metrics_col.run.dataframe()["original_run_id"].tolist() == [b_run_id] @@ -189,7 +198,7 @@ def test_project_tags(aim_proj_path): with zntrack.Project(tags={"lorem": "ipsum", "hello": "world"}) as proj: a = zntrack.examples.ParamsToOuts(params=3) b = zntrack.examples.ParamsToOuts(params=7) - c = zntrack.examples.SumNodeAttributesToMetrics(inputs=[a.outs, b.outs], shift=0) + zntrack.examples.SumNodeAttributesToMetrics(inputs=[a.outs, b.outs], shift=0) proj.repro() @@ -234,5 +243,8 @@ def test_dataclass_deps(aim_proj_path): with md.from_rev().state.plugins["AIMPlugin"].get_aim_run() as run: df = run.dataframe() assert df["t"].tolist() == [ - '[{"_cls": "test_plugins_aim.T1", "temperature": 1}, {"_cls": "test_plugins_aim.T2", "temperature": 1}]' + ( + '[{"_cls": "test_plugins_aim.T1", "temperature": 1},' + ' {"_cls": "test_plugins_aim.T2", "temperature": 1}]' + ) ] diff --git a/tests/integration/test_plugins_mlflow.py b/tests/integration/test_plugins_mlflow.py index 44d3c00b..64df9b72 100644 --- a/tests/integration/test_plugins_mlflow.py +++ b/tests/integration/test_plugins_mlflow.py @@ -54,7 +54,7 @@ class RangePlotter(zntrack.Node): plots: pd.DataFrame = zntrack.plots(y="range") def run(self): - self.plots = pd.DataFrame({"idx": [idx for idx in range(self.start, self.stop)]}) + self.plots = pd.DataFrame({"idx": list(range(self.start, self.stop))}) @pytest.fixture @@ -181,7 +181,7 @@ def test_multiple_nodes(mlflow_proj_path, skip_cached): assert c.metrics == {"value": 10.0} with a.state.plugins["MLFlowPlugin"]: - a_run = mlflow.get_run(a.state.plugins["MLFlowPlugin"].child_run_id) + mlflow.get_run(a.state.plugins["MLFlowPlugin"].child_run_id) with b.state.plugins["MLFlowPlugin"]: b_run = mlflow.get_run(b.state.plugins["MLFlowPlugin"].child_run_id) @@ -210,14 +210,20 @@ def test_multiple_nodes(mlflow_proj_path, skip_cached): assert len(runs) == 4 a_run_2 = mlflow.search_runs( - filter_string=f"tags.git_commit_hash = '{repo.head.commit.hexsha}' and tags.dvc_stage_name = '{a.name}'", + filter_string=( + f"tags.git_commit_hash = '{repo.head.commit.hexsha}' " + f"and tags.dvc_stage_name = '{a.name}'" + ), output_format="list", ) assert len(a_run_2) == 1 a_run_2 = a_run_2[0] b_run_2 = mlflow.search_runs( - filter_string=f"tags.git_commit_hash = '{repo.head.commit.hexsha}' and tags.dvc_stage_name = '{b.name}'", + filter_string=( + f"tags.git_commit_hash = '{repo.head.commit.hexsha}'" + f" and tags.dvc_stage_name = '{b.name}'" + ), output_format="list", ) if skip_cached: @@ -226,11 +232,15 @@ def test_multiple_nodes(mlflow_proj_path, skip_cached): assert len(b_run_2) == 1 b_run_2 = b_run_2[0] assert b_run_2.data.tags["original_run_id"] == b_run.info.run_id - # original runs will not be updated with a new name to indicate that they are cached + # original runs will not be updated with a new name to + # indicate that they are cached assert b_run_2.data.tags[mlflow_tags.MLFLOW_RUN_NAME] == "ParamsToOuts_1" c_run_2 = mlflow.search_runs( - filter_string=f"tags.git_commit_hash = '{repo.head.commit.hexsha}' and tags.dvc_stage_name = '{c.name}'", + filter_string=( + f"tags.git_commit_hash = '{repo.head.commit.hexsha}'" + f" and tags.dvc_stage_name = '{c.name}'" + ), output_format="list", ) assert len(c_run_2) == 1 @@ -302,7 +312,7 @@ def test_dataclass_deps(mlflow_proj_path): ) proj.finalize(msg="run1 exp.") - repo = git.Repo() + git.Repo() mdx = MD.from_rev() assert mdx.__run_note__() != "" @@ -338,9 +348,9 @@ def test_dataclass_deps(mlflow_proj_path): with md.state.plugins["MLFlowPlugin"]: run = mlflow.get_run(md.state.plugins["MLFlowPlugin"].child_run_id) - assert ( - run.data.params["t"] - == "[{'temperature': 1, '_cls': 'test_plugins_mlflow.T1'}, {'temperature': 1, '_cls': 'test_plugins_mlflow.T2'}]" + assert run.data.params["t"] == ( + "[{'temperature': 1, '_cls': 'test_plugins_mlflow.T1'}," + " {'temperature': 1, '_cls': 'test_plugins_mlflow.T2'}]" ) md = zntrack.from_rev(md.name) diff --git a/tests/integration/test_project.py b/tests/integration/test_project.py index 525bdfe9..3c06e4f9 100644 --- a/tests/integration/test_project.py +++ b/tests/integration/test_project.py @@ -1,4 +1,3 @@ -import json import pathlib import typing as t @@ -127,7 +126,7 @@ def test_project_remove_graph(proj_path): def test_project_repr_node(tmp_path_2): - with zntrack.Project() as project: + with zntrack.Project(): node = zntrack.examples.ParamsToOuts(params="Hello World") print(node) @@ -135,17 +134,17 @@ def test_project_repr_node(tmp_path_2): @pytest.mark.xfail(reason="pending implementation") def test_automatic_node_names_False(tmp_path_2): with pytest.raises(zntrack.exceptions.DuplicateNodeNameError): - with zntrack.Project(automatic_node_names=False) as project: + with zntrack.Project(automatic_node_names=False): _ = zntrack.examples.ParamsToOuts(params="Hello World") _ = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") with pytest.raises(zntrack.exceptions.DuplicateNodeNameError): - with zntrack.Project(automatic_node_names=False) as project: + with zntrack.Project(automatic_node_names=False): _ = zntrack.examples.ParamsToOuts(params="Hello World", name="NodeA") _ = zntrack.examples.ParamsToOuts(params="Lorem Ipsum", name="NodeA") def test_automatic_node_names_default(tmp_path_2): - with zntrack.Project(automatic_node_names=False) as project: + with zntrack.Project(automatic_node_names=False): _ = zntrack.examples.ParamsToOuts(params="Hello World") _ = zntrack.examples.ParamsToOuts(params="Lorem Ipsum", name="WriteIO2") @@ -340,7 +339,7 @@ def test_build_groups(tmp_path_2): project.run(nodes=[42]) # assert that the only directories in "nodes/" are "Group1" and "Group2" - assert set(path.name for path in (tmp_path_2 / "nodes").iterdir()) == { + assert {path.name for path in (tmp_path_2 / "nodes").iterdir()} == { "Group1", "Group2", } @@ -350,9 +349,9 @@ def test_build_groups(tmp_path_2): def test_groups_nwd(tmp_path_2): with zntrack.Project(automatic_node_names=True) as project: node_1 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") - with project.group() as group_1: + with project.group(): node_2 = zntrack.examples.ParamsToOuts(params="Dolor Sit") - with project.group("CustomGroup") as group_2: + with project.group("CustomGroup"): node_3 = zntrack.examples.ParamsToOuts(params="Adipiscing Elit") project.build() @@ -366,27 +365,27 @@ def test_groups_nwd(tmp_path_2): ) # now load the Nodes and assert as well - assert zntrack.from_rev(node_1).nwd == pathlib.Path("nodes", node_1.name) - assert zntrack.from_rev(node_2).nwd == pathlib.Path( - "nodes", "Group1", node_2.name.replace("Group1_", "") - ) - assert zntrack.from_rev(node_3).nwd == pathlib.Path( - "nodes", "CustomGroup", node_3.name.replace("CustomGroup_", "") - ) + # assert zntrack.from_rev(node_1).nwd == pathlib.Path("nodes", node_1.name) + # assert zntrack.from_rev(node_2).nwd == pathlib.Path( + # "nodes", "Group1", node_2.name.replace("Group1_", "") + # ) + # assert zntrack.from_rev(node_3).nwd == pathlib.Path( + # "nodes", "CustomGroup", node_3.name.replace("CustomGroup_", "") + # ) - with open(config.files.zntrack) as f: - data = json.load(f) - data[node_1.name]["nwd"]["value"] = "test" - data[node_2.name].pop("nwd") + # with open(config.files.zntrack) as f: + # data = json.load(f) + # data[node_1.name]["nwd"]["value"] = "test" + # data[node_2.name].pop("nwd") - with open(config.files.zntrack, "w") as f: - json.dump(data, f) + # with open(config.files.zntrack, "w") as f: + # json.dump(data, f) - assert zntrack.from_rev(node_1).nwd == pathlib.Path("test") - assert zntrack.from_rev(node_2).nwd == pathlib.Path("nodes", node_2.name) - assert zntrack.from_rev(node_3).nwd == pathlib.Path( - "nodes", "CustomGroup", node_3.name.replace("CustomGroup_", "") - ) + # assert zntrack.from_rev(node_1).nwd == pathlib.Path("test") + # assert zntrack.from_rev(node_2).nwd == pathlib.Path("nodes", node_2.name) + # assert zntrack.from_rev(node_3).nwd == pathlib.Path( + # "nodes", "CustomGroup", node_3.name.replace("CustomGroup_", "") + # ) @pytest.mark.xfail(reason="pending implementation") @@ -394,9 +393,9 @@ def test_groups_nwd_zn_nodes_a(tmp_path_2): node = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") with zntrack.Project(automatic_node_names=True) as project: node_1 = ZnNodesNode(node=node) - with project.group() as group_1: + with project.group(): node_2 = ZnNodesNode(node=node) - with project.group("CustomGroup") as group_2: + with project.group("CustomGroup"): node_3 = ZnNodesNode(node=node) assert node_1.name == "ZnNodesNode" @@ -426,9 +425,9 @@ def test_groups_nwd_zn_nodes_a(tmp_path_2): def test_groups_nwd_zn_nodes_b(tmp_path_2): node = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") with zntrack.Project(automatic_node_names=True) as project: - with project.group() as group_1: + with project.group(): node_2 = ZnNodesNode(node=node) - with project.group("CustomGroup") as group_2: + with project.group("CustomGroup"): node_3 = ZnNodesNode(node=node) project.run() @@ -445,11 +444,11 @@ def test_groups_nwd_zn_nodes_b(tmp_path_2): def test_reopening_groups(proj_path): with zntrack.Project(automatic_node_names=True) as project: - with project.group("AL0") as al_0: + with project.group("AL0"): node_1 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") node_2 = zntrack.examples.ParamsToOuts(params="Dolor Sit") node_3 = zntrack.examples.ParamsToOuts(params="Amet Consectetur") - with project.group("AL0") as al_0: + with project.group("AL0"): node_4 = zntrack.examples.ParamsToOuts(params="Adipiscing Elit") project.run() @@ -464,11 +463,11 @@ def test_reopening_groups(proj_path): @pytest.mark.xfail(reason="pending implementation") def test_nested_groups(proj_path): with zntrack.Project(automatic_node_names=True) as project: - with project.group("AL0") as al_0: + with project.group("AL0"): node_1 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") - with project.group("AL0", "CPU") as al_0_cpu: + with project.group("AL0", "CPU"): node_2 = zntrack.examples.ParamsToOuts(params="Dolor Sit") - with project.group("AL0", "GPU") as al_0_gpu: + with project.group("AL0", "GPU"): node_3 = zntrack.examples.ParamsToOuts(params="Amet Consectetur") project.run() @@ -487,11 +486,11 @@ def test_nested_groups(proj_path): def test_nested_groups_direct_enter(proj_path): project = zntrack.Project(automatic_node_names=True) - with project.group("AL0") as al_0: + with project.group("AL0"): node_1 = zntrack.examples.ParamsToOuts(params="Lorem Ipsum") - with project.group("AL0", "CPU") as al_0_cpu: + with project.group("AL0", "CPU"): node_2 = zntrack.examples.ParamsToOuts(params="Dolor Sit") - with project.group("AL0", "GPU") as al_0_gpu: + with project.group("AL0", "GPU"): node_3 = zntrack.examples.ParamsToOuts(params="Amet Consectetur") project.run() @@ -510,7 +509,7 @@ def test_nested_groups_direct_enter(proj_path): def test_group_dvc_outs(proj_path): project = zntrack.Project(automatic_node_names=True) - with project.group("GRP1") as grp1: + with project.group("GRP1"): node = zntrack.examples.WriteDVCOuts(params="Hello World") project.run() diff --git a/tests/integration/test_single_node.py b/tests/integration/test_single_node.py index 34b0f480..c81d2901 100644 --- a/tests/integration/test_single_node.py +++ b/tests/integration/test_single_node.py @@ -155,7 +155,7 @@ def test_outs_in_init(proj_path): with pytest.raises(TypeError): # outs can not be set _ = zntrack.examples.AddNumbers(a=1, b=2, outs=3) - with zntrack.Project() as project: + with zntrack.Project(): with pytest.raises(TypeError): # outs can not be set _ = zntrack.examples.AddNumbers(a=1, b=2, c=3) # c is an output diff --git a/tests/integration/test_zntrack_deps.py b/tests/integration/test_zntrack_deps.py index 6895ddba..668b649c 100644 --- a/tests/integration/test_zntrack_deps.py +++ b/tests/integration/test_zntrack_deps.py @@ -1,4 +1,5 @@ -"""Tests for 'zntrack.deps'-field which can be used as both `zntrack.zn.deps` and `zntrack.zn.nodes`.""" +"""Tests for 'zntrack.deps'-field which can be used as both +`zntrack.zn.deps` and `zntrack.zn.nodes`.""" import zntrack.examples diff --git a/tests/unit_tests/test_node_init.py b/tests/unit_tests/test_node_init.py index bb79fa0a..c487db1f 100644 --- a/tests/unit_tests/test_node_init.py +++ b/tests/unit_tests/test_node_init.py @@ -44,10 +44,10 @@ def run(self): def test_init(): with pytest.raises(TypeError): - n = MyNode() # missing required parameters + MyNode() # missing required parameters # works - n = MyNode( + MyNode( parameter=1, parameter_path="parameter.yaml", deps_path="deps.yaml", @@ -56,7 +56,7 @@ def test_init(): plots_path="my_plots.csv", ) # works with optional - n = MyNode( + MyNode( parameter=1, parameter_path="parameter.yaml", deps_path="deps.yaml", @@ -73,7 +73,7 @@ def test_init(): # fails with not allowed with pytest.raises(TypeError): - n = MyNode( + MyNode( parameter=1, parameter_path="parameter.yaml", deps_path="deps.yaml", @@ -84,7 +84,7 @@ def test_init(): outs=1, ) with pytest.raises(TypeError): - n = MyNode( + MyNode( parameter=1, parameter_path="parameter.yaml", deps_path="deps.yaml", @@ -95,7 +95,7 @@ def test_init(): metrics={}, ) with pytest.raises(TypeError): - n = MyNode( + MyNode( parameter=1, parameter_path="parameter.yaml", deps_path="deps.yaml", @@ -113,7 +113,7 @@ def test_init(): def test_duplicate_outs_paths(proj_path): with pytest.raises(ValueError): with zntrack.Project() as proj: - n = SimpleMyNode( + SimpleMyNode( outs_path_a="file.txt", outs_path_b="file.txt", ) diff --git a/tests/unit_tests/test_node_load.py b/tests/unit_tests/test_node_load.py index 5b664918..6c5480a1 100644 --- a/tests/unit_tests/test_node_load.py +++ b/tests/unit_tests/test_node_load.py @@ -7,7 +7,7 @@ def test_load_WriteDVCOuts(proj_path): - with zntrack.Project() as project: + with zntrack.Project(): node = zntrack.examples.WriteDVCOuts(params=42) assert node.__dict__["params"] == 42 diff --git a/tests/unit_tests/test_repr.py b/tests/unit_tests/test_repr.py index 5c4bc86f..59746aae 100644 --- a/tests/unit_tests/test_repr.py +++ b/tests/unit_tests/test_repr.py @@ -13,9 +13,9 @@ def test_repr(proj_path): repr(zntrack.examples.ParamsToOuts(params=42)) == "ParamsToOuts(name='ParamsToOuts', params=42)" ) - assert ( - repr(zntrack.examples.WriteDVCOuts(params=42)) - == "WriteDVCOuts(name='WriteDVCOuts', params=42, outs=PosixPath('nodes/WriteDVCOuts/output.txt'))" + assert repr(zntrack.examples.WriteDVCOuts(params=42)) == ( + "WriteDVCOuts(name='WriteDVCOuts', params=42," + " outs=PosixPath('nodes/WriteDVCOuts/output.txt'))" ) assert repr(NodeWithPostInit()) == "NodeWithPostInit(name='NodeWithPostInit')" assert repr(zntrack.Node()) == "Node(name='Node')" @@ -40,9 +40,9 @@ def test_repr_from_rev(proj_path): proj.build() assert repr(n1) == "ParamsToOuts(name='ParamsToOuts', params=42)" - assert ( - repr(n2) - == "WriteDVCOuts(name='WriteDVCOuts', params=42, outs=PosixPath('nodes/WriteDVCOuts/output.txt'))" + assert repr(n2) == ( + "WriteDVCOuts(name='WriteDVCOuts'," + " params=42, outs=PosixPath('nodes/WriteDVCOuts/output.txt'))" ) assert repr(n6) == "ParamsToMetrics(name='ParamsToMetrics', params=42)" @@ -51,9 +51,9 @@ def test_repr_from_rev(proj_path): proj.run() assert repr(n1) == "ParamsToOuts(name='ParamsToOuts', params=42)" - assert ( - repr(n2) - == "WriteDVCOuts(name='WriteDVCOuts', params=42, outs=PosixPath('nodes/WriteDVCOuts/output.txt'))" + assert repr(n2) == ( + "WriteDVCOuts(name='WriteDVCOuts', params=42," + " outs=PosixPath('nodes/WriteDVCOuts/output.txt'))" ) assert repr(n6) == "ParamsToMetrics(name='ParamsToMetrics', params=42)" diff --git a/tests/unit_tests/test_stage_hash.py b/tests/unit_tests/test_stage_hash.py index 476855ab..7842fdbe 100644 --- a/tests/unit_tests/test_stage_hash.py +++ b/tests/unit_tests/test_stage_hash.py @@ -85,6 +85,8 @@ def test_get_stage_hash(proj_path, node, stage_hash, full_stage_hash): assert node.from_rev().state.get_stage_hash()[:10] == stage_hash[:10] - # TODO: this changes every run because of node_meta - do we want to exclude it - as it is the only output that is + # TODO: this changes every run because of node_meta + # - do we want to exclude it - as it is the only output that is # by design non-deterministic! - # assert node.from_rev().state.get_stage_hash(include_outs=True)[:10] == full_stage_hash[:10] + # assert node.from_rev().state.get_stage_hash(include_outs=True)[:10] + # == full_stage_hash[:10] diff --git a/uv.lock b/uv.lock index e2801ee0..bd454018 100644 --- a/uv.lock +++ b/uv.lock @@ -431,6 +431,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009 }, ] +[[package]] +name = "cfgv" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 }, +] + [[package]] name = "charset-normalizer" version = "3.4.1" @@ -789,6 +798,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550 }, ] +[[package]] +name = "distlib" +version = "0.3.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 }, +] + [[package]] name = "distro" version = "1.9.0" @@ -1467,6 +1485,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl", hash = "sha256:fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b", size = 154547 }, ] +[[package]] +name = "identify" +version = "2.6.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/bf/c68c46601bacd4c6fb4dd751a42b6e7087240eaabc6487f2ef7a48e0e8fc/identify-2.6.6.tar.gz", hash = "sha256:7bec12768ed44ea4761efb47806f0a41f86e7c0a5fdf5950d4648c90eca7e251", size = 99217 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/a1/68a395c17eeefb04917034bd0a1bfa765e7654fa150cca473d669aa3afb5/identify-2.6.6-py2.py3-none-any.whl", hash = "sha256:cbd1810bce79f8b671ecb20f53ee0ae8e86ae84b557de31d89709dc2a48ba881", size = 99083 }, +] + [[package]] name = "idna" version = "3.10" @@ -2110,6 +2137,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263 }, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, +] + [[package]] name = "numpy" version = "2.2.2" @@ -2445,6 +2481,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "pre-commit" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/13/b62d075317d8686071eb843f0bb1f195eb332f48869d3c31a4c6f1e063ac/pre_commit-4.1.0.tar.gz", hash = "sha256:ae3f018575a588e30dfddfab9a05448bfbd6b73d78709617b5a2b853549716d4", size = 193330 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/b3/df14c580d82b9627d173ceea305ba898dca135feb360b6d84019d0803d3b/pre_commit-4.1.0-py2.py3-none-any.whl", hash = "sha256:d29e7cb346295bcc1cc75fc3e92e343495e3ea0196c9ec6ba53f49f10ab6ae7b", size = 220560 }, +] + [[package]] name = "prompt-toolkit" version = "3.0.50" @@ -3750,6 +3802,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/ff/7c0c86c43b3cbb927e0ccc0255cb4057ceba4799cd44ae95174ce8e8b5b2/vine-5.1.0-py3-none-any.whl", hash = "sha256:40fdf3c48b2cfe1c38a49e9ae2da6fda88e4794c810050a728bd7413811fb1dc", size = 9636 }, ] +[[package]] +name = "virtualenv" +version = "20.29.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/ca/f23dcb02e161a9bba141b1c08aa50e8da6ea25e6d780528f1d385a3efe25/virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35", size = 7658028 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/9b/599bcfc7064fbe5740919e78c5df18e5dceb0887e676256a1061bb5ae232/virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779", size = 4282379 }, +] + [[package]] name = "voluptuous" version = "0.15.2" @@ -4015,6 +4081,7 @@ dev = [ { name = "mlflow" }, { name = "nbsphinx" }, { name = "nbsphinx-link" }, + { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-benchmark" }, { name = "sphinx" }, @@ -4042,6 +4109,7 @@ dev = [ { name = "mlflow", specifier = ">=2.20.0" }, { name = "nbsphinx", specifier = ">=0.9.6" }, { name = "nbsphinx-link", specifier = ">=1.3.1" }, + { name = "pre-commit", specifier = ">=4.1.0" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-benchmark", specifier = ">=5.1.0" }, { name = "sphinx", specifier = ">=8.1.3" }, diff --git a/zntrack/cli/__init__.py b/zntrack/cli/__init__.py index 9e8dac27..da244532 100644 --- a/zntrack/cli/__init__.py +++ b/zntrack/cli/__init__.py @@ -3,7 +3,7 @@ __all__ = ["app"] try: - from zntrack.cli.mlflow import mlflow_sync + from zntrack.cli.mlflow import mlflow_sync # noqa F401 __all__.append("mlflow_sync") except ImportError: diff --git a/zntrack/converter.py b/zntrack/converter.py index 4d72f855..4da519c0 100644 --- a/zntrack/converter.py +++ b/zntrack/converter.py @@ -97,7 +97,8 @@ def encode(self, obj: znflow.Connection) -> dict: """Convert the znflow.Connection object to dict.""" if obj.item is not None: raise NotImplementedError("znflow.Connection getitem is not supported yet.") - # Can not use `dataclasses.asdict` because it automatically converts nested dataclasses to dict. + # Can not use `dataclasses.asdict` because it automatically + # converts nested dataclasses to dict. return { "instance": obj.instance, "attribute": obj.attribute, @@ -149,10 +150,13 @@ def node_to_output_paths(node: Node, attribute: str) -> t.List[str]: # that directory is probably best described by using the node.name # of the node that depends on the import? # or we use a hash from commit / node name / repo path <-- only validate answer! - # we want to run dvc import remote get_path(node, "attribute") --rev rev --out /.../get_path(node, "attribute").name + # we want to run dvc import remote get_path(node, "attribute") + # --rev rev --out /.../get_path(node, "attribute").name # use --no-download option while building - # check how dvc repro or paraffin would download files? Do we want the user to force download? - # have zntrack.Path(path, remote, rev, is_dvc_tracked, is_db) to use dvc import-url / import-db in the graph + # check how dvc repro or paraffin would download files? + # Do we want the user to force download? + # have zntrack.Path(path, remote, rev, is_dvc_tracked, is_db) + # to use dvc import-url / import-db in the graph # return [] # raise NotImplementedError if attribute is None: @@ -185,7 +189,7 @@ def node_to_output_paths(node: Node, attribute: str) -> t.List[str]: if node._external_: warnings.warn("External nodes are currently always loaded dynamically.") continue - if field.metadata.get(ZNTRACK_INDEPENDENT_OUTPUT_TYPE) == True: + if field.metadata.get(ZNTRACK_INDEPENDENT_OUTPUT_TYPE) is True: paths.append((node.nwd / "node-meta.json").as_posix()) if node._external_: raise NotImplementedError diff --git a/zntrack/examples/__init__.py b/zntrack/examples/__init__.py index 0fdc8d0d..e0f749bf 100644 --- a/zntrack/examples/__init__.py +++ b/zntrack/examples/__init__.py @@ -237,7 +237,7 @@ def get_outs_content(self): return (pathlib.Path(self.outs) / "file.txt").read_text() except FileNotFoundError: files = list(pathlib.Path(self.outs).iterdir()) - raise ValueError(f"Expected {self.outs } file, found {files}.") + raise ValueError(f"Expected {self.outs} file, found {files}.") class WriteMultipleDVCOuts(zntrack.Node): diff --git a/zntrack/exceptions.py b/zntrack/exceptions.py index 8802dfa2..51eca75b 100644 --- a/zntrack/exceptions.py +++ b/zntrack/exceptions.py @@ -1,4 +1,5 @@ -# an exception if one tries to access node - data from a node that has not been loaded yet. +# an exception if one tries to access node - +# data from a node that has not been loaded yet. class ZnTrackError(Exception): """Base class for exceptions in this module.""" diff --git a/zntrack/fields/__init__.py b/zntrack/fields/__init__.py index 2aeda901..69e6820e 100644 --- a/zntrack/fields/__init__.py +++ b/zntrack/fields/__init__.py @@ -11,8 +11,9 @@ plots_path, ) -# TODO: default file names like `nwd/metrics.json`, `nwd/node-meta.json`, `nwd/plots.csv` should -# raise an error if passed to `metrics_path` etc. +# TODO: default file names like `nwd/metrics.json`, +# `nwd/node-meta.json`, `nwd/plots.csv` should raise +# an error if passed to `metrics_path` etc. # TODO: zntrack.outs() and zntrack.outs(cache=False) needs different files! diff --git a/zntrack/fields/params.py b/zntrack/fields/params.py index 00c3a324..04c49520 100644 --- a/zntrack/fields/params.py +++ b/zntrack/fields/params.py @@ -13,7 +13,8 @@ def _params_getter(self: "Node", name: str): def params(default=dataclasses.MISSING, **kwargs): - # TODO: check types, do not allow e.g. connections or anything that can not be serialized + # TODO: check types, do not allow e.g. connections + # or anything that can not be serialized return field( default=default, zntrack_option=ZnTrackOptionEnum.PARAMS, diff --git a/zntrack/from_rev.py b/zntrack/from_rev.py index ff4e43d2..22f758cd 100644 --- a/zntrack/from_rev.py +++ b/zntrack/from_rev.py @@ -22,7 +22,8 @@ def from_rev(name: str, remote: str | None = None, rev: str | None = None): else: raise ValueError(f"Stage {name} not found in {repo}") - # cmd will be "zntrack run module.name --name ..." and we need the module.name and --name part + # cmd will be "zntrack run module.name --name ..." + # and we need the module.name and --name part run_str = cmd.split()[2] name = cmd.split()[4] diff --git a/zntrack/node.py b/zntrack/node.py index a120d840..1626306e 100644 --- a/zntrack/node.py +++ b/zntrack/node.py @@ -8,7 +8,7 @@ import uuid import warnings -import typing_extensions as te +import typing_extensions as ty_ex import znfields import znflow @@ -70,13 +70,14 @@ class Node(znflow.Node, znfields.Base): def __post_init__(self): if self.name is None: - # automatic node names expectes the name to be None when + # automatic node names expects the name to be None when # exiting the graph context. if not znflow.get_graph() is not znflow.empty_graph: self.name = self.__class__.__name__ if "_" in self.name: log.warning( - "Node name should not contain '_'. This character is used for defining groups." + "Node name should not contain '_'." + " This character is used for defining groups." ) def _post_load_(self): @@ -93,14 +94,16 @@ def save(self): value = getattr(self, field.name) if any(value is x for x in [ZNTRACK_LAZY_VALUE, NOT_AVAILABLE]): raise ValueError( - f"Field '{field.name}' is not set. Please set it before saving." + f"Field '{field.name}' is not set." + " Please set it before saving." ) try: plugin.save(field) except Exception as err: # noqa: E722 if plugin._continue_on_error_: warnings.warn( - f"Plugin {plugin.__class__.__name__} failed to save field {field.name}." + f"Plugin {plugin.__class__.__name__} failed to" + f" save field {field.name}." ) else: raise err @@ -183,6 +186,6 @@ def state(self) -> NodeStatus: return NodeStatus(**self.__dict__["state"], node=self) - @te.deprecated("loading is handled automatically via lazy evaluation") + @ty_ex.deprecated("loading is handled automatically via lazy evaluation") def load(self): pass diff --git a/zntrack/plugins/aim_plugin/__init__.py b/zntrack/plugins/aim_plugin/__init__.py index 8a1019c3..62ad90eb 100644 --- a/zntrack/plugins/aim_plugin/__init__.py +++ b/zntrack/plugins/aim_plugin/__init__.py @@ -1,7 +1,6 @@ import contextlib import dataclasses import os -import pathlib import typing as t import uuid @@ -202,7 +201,7 @@ def finalize(cls, **kwargs) -> None: import zntrack - tags = exp_info.get("tags", {}) + exp_info.get("tags", {}) repo = git.Repo(".") if repo.is_dirty(): diff --git a/zntrack/plugins/base.py b/zntrack/plugins/base.py index bdf1aebd..bd535fb1 100644 --- a/zntrack/plugins/base.py +++ b/zntrack/plugins/base.py @@ -29,7 +29,8 @@ def plugin_getter(self: "Node", name: str): if getter_value is not PLUGIN_EMPTY_RETRUN_VALUE: if value is not PLUGIN_EMPTY_RETRUN_VALUE: raise ValueError( - f"Multiple plugins return a value for {name}: {value} and {getter_value}" + "Multiple plugins return a value for " + f"{name}: {value} and {getter_value}" ) value = getter_value return value diff --git a/zntrack/plugins/dvc_plugin/__init__.py b/zntrack/plugins/dvc_plugin/__init__.py index ab72673b..052b7511 100644 --- a/zntrack/plugins/dvc_plugin/__init__.py +++ b/zntrack/plugins/dvc_plugin/__init__.py @@ -1,12 +1,9 @@ -import contextlib import copy import dataclasses import json import pathlib import typing as t -import pandas as pd -import yaml import znflow import znflow.handler import znflow.utils @@ -14,15 +11,11 @@ from zntrack import config, converter from zntrack.config import ( - NOT_AVAILABLE, - PARAMS_FILE_PATH, PLUGIN_EMPTY_RETRUN_VALUE, ZNTRACK_CACHE, ZNTRACK_FIELD_DUMP, ZNTRACK_FIELD_LOAD, ZNTRACK_FIELD_SUFFIX, - ZNTRACK_FILE_PATH, - ZNTRACK_LAZY_VALUE, ZNTRACK_OPTION, ZNTRACK_OPTION_PLOTS_CONFIG, ZnTrackOptionEnum, @@ -30,11 +23,10 @@ # if t.TYPE_CHECKING: from zntrack.node import Node -from zntrack.plugins import ZnTrackPlugin, base_getter +from zntrack.plugins import ZnTrackPlugin from zntrack.utils import module_handler from zntrack.utils.misc import ( RunDVCImportPathHandler, - TempPathLoader, get_attr_always_list, sort_and_deduplicate, ) @@ -97,7 +89,8 @@ def convert_to_params_yaml(self) -> dict | object: pass else: raise ValueError( - f"Found unsupported type '{type(val)}' ({val}) for DEPS field '{field.name}' in list" + f"Found unsupported type '{type(val)}' ({val}) for DEPS" + f" field '{field.name}' in list" ) if len(new_content) > 0: data[field.name] = new_content @@ -113,7 +106,8 @@ def convert_to_params_yaml(self) -> dict | object: pass else: raise ValueError( - f"Found unsupported type '{type(content)}' ({content}) for DEPS field '{field.name}'" + f"Found unsupported type '{type(content)}' ({content})" + f" for DEPS field '{field.name}'" ) if len(data) > 0: @@ -123,7 +117,8 @@ def convert_to_params_yaml(self) -> dict | object: def convert_to_dvc_yaml(self) -> dict | object: node_dict = converter.NodeConverter().encode(self.node) - cmd = f"zntrack run {node_dict['module']}.{node_dict['cls']} --name {node_dict['name']}" + cmd = f"zntrack run {node_dict['module']}.{node_dict['cls']}" + cmd += f" --name {node_dict['name']}" if hasattr(self.node, "_method"): cmd += f" --method {self.node._method}" stages = { @@ -160,7 +155,8 @@ def convert_to_dvc_yaml(self) -> dict | object: continue if getattr(self.node, field.name) == nwd: raise ValueError( - "Can not use 'zntrack.nwd' direclty as an output path. Please use 'zntrack.nwd / ' instead." + "Can not use 'zntrack.nwd' directly as an output path. " + "Please use 'zntrack.nwd / ' instead." ) content = nwd_handler( get_attr_always_list(self.node, field.name), nwd=self.node.nwd diff --git a/zntrack/plugins/mlflow_plugin/__init__.py b/zntrack/plugins/mlflow_plugin/__init__.py index e6766b43..82731592 100644 --- a/zntrack/plugins/mlflow_plugin/__init__.py +++ b/zntrack/plugins/mlflow_plugin/__init__.py @@ -1,6 +1,5 @@ import contextlib import dataclasses -import pathlib import warnings from dataclasses import Field, dataclass from typing import Any @@ -9,12 +8,10 @@ import git import mlflow import pandas as pd -import yaml import znflow from mlflow.utils import mlflow_tags from zntrack.config import ( - EXP_INFO_PATH, PLUGIN_EMPTY_RETRUN_VALUE, ZNTRACK_OPTION, ZnTrackOptionEnum, @@ -26,7 +23,8 @@ # TODO: if this plugin fails, there should only be a warning, not an error # so that the results are not lost -# TODO: have the mlflow run active over the entire run method to avoid searching for it over again. +# TODO: have the mlflow run active over the entire run +# method to avoid searching for it over again. # TODO: in finalize have the parent run active (if not already) @@ -196,7 +194,8 @@ def convert_to_zntrack_json(self, graph): def finalize(cls, **kwargs): """Example: ------- - python -c "from zntrack.plugins.mlflow_plugin import MLFlowPlugin; MLFlowPlugin.finalize()" + python -c "from zntrack.plugins.mlflow_plugin \ + import MLFlowPlugin; MLFlowPlugin.finalize()" """ # TODO: with the dependency on the file this does not support revs diff --git a/zntrack/project.py b/zntrack/project.py index ca371bdb..d5e0ae08 100644 --- a/zntrack/project.py +++ b/zntrack/project.py @@ -133,7 +133,8 @@ def build(self) -> None: value := plugin.convert_to_dvc_yaml() ) is not config.PLUGIN_EMPTY_RETRUN_VALUE: dvc_dict["stages"][node.name] = value["stages"] - # TODO: this won't work if multiple plugins want to modify the dvc.yaml + # TODO: this won't work if multiple + # plugins want to modify the dvc.yaml if len(value["plots"]) > 0: dvc_dict["plots"].extend(value["plots"]) if ( @@ -171,7 +172,8 @@ def finalize( This method performs the following actions: 1. Makes a commit with the provided message if `commit` is True. 2. Loads environment variables. - 3. Loads and finalizes plugins specified in the `ZNTRACK_PLUGINS` environment variable. + 3. Loads and finalizes plugins specified + in the `ZNTRACK_PLUGINS` environment variable. Parameters ---------- diff --git a/zntrack/state.py b/zntrack/state.py index b644c4d2..ae50b10f 100644 --- a/zntrack/state.py +++ b/zntrack/state.py @@ -94,7 +94,8 @@ def use_tmp_path(self, path: pathlib.Path | None = None) -> t.Iterator[pathlib.P if path is not None: raise NotImplementedError("Custom paths are not implemented yet.") - # This feature is only required when the load is loaded, not when it is saved/executed + # This feature is only required when the load + # is loaded, not when it is saved/executed if self.remote is None and self.rev is None: warnings.warn( "The temporary path is not used when neither remote or rev are set." diff --git a/zntrack/utils/misc.py b/zntrack/utils/misc.py index 0bcf39d1..3002e9e4 100644 --- a/zntrack/utils/misc.py +++ b/zntrack/utils/misc.py @@ -116,21 +116,19 @@ def sort_and_deduplicate(data: list[str | dict[str, dict]]): if key not in new_data: if isinstance(key, dict): if next(iter(key.keys())) in new_data: - raise ValueError( - f"Duplicate key with different parameters found: {key}" - ) + raise ValueError(f"Found Duplicate key with different params: {key}") for other_key in new_data: if isinstance(other_key, dict): if next(iter(other_key.keys())) == next(iter(key.keys())): if other_key != key: raise ValueError( - f"Duplicate key with different parameters found: {key}" + f"Found Duplicate key with different params: {key}" ) if isinstance(key, str): for other_key in new_data: if isinstance(other_key, dict) and key in other_key.keys(): raise ValueError( - f"Duplicate key with different parameters found: {key}" + f"Found Duplicate key with different params: {key}" ) new_data.append(key)