Skip to content

Commit

Permalink
update pre-commit lint and format (#876)
Browse files Browse the repository at this point in the history
* add & bump pre-commit, remove isort and black

* unsafe fixes

* rename and shorten line length

* add missing f-string

* typos
  • Loading branch information
PythonFZ authored Feb 1, 2025
1 parent 3433220 commit 9a38ade
Show file tree
Hide file tree
Showing 42 changed files with 254 additions and 164 deletions.
16 changes: 3 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,19 @@ repos:
args: ['--fix=lf']
- id: sort-simple-yaml
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.10.0
hooks:
- id: black
additional_dependencies: [".[jupyter]"]
types_or: [python, pyi, jupyter]
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.1
hooks:
- id: codespell
additional_dependencies: ["tomli"]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: 'v0.8.2'
rev: 'v0.9.4'
hooks:
- id: ruff
args: ['--fix']
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.19
rev: 0.7.22
hooks:
- id: mdformat
args: ["--wrap=80"]
5 changes: 2 additions & 3 deletions examples/docs/04_metrics_and_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,8 @@
" def run(self):\n",
" self.my_metric = {\"alpha\": 1.0, \"beta\": 0.00473}\n",
" self.my_plot = pd.DataFrame({\"val\": np.sin(np.linspace(0, 3.14, 100))})\n",
" self.my_plot.index.name = ( # For DVC it is required that the index has a column name\n",
" \"index\"\n",
" )\n",
" # For DVC it is required that the index has a column name\n",
" self.my_plot.index.name = \"index\"\n",
"\n",
"\n",
"with zntrack.Project() as project:\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/docs/09_lazy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@
"\n",
" add_one = AddOne(deps=random_number, name=\"AddOne_0\")\n",
" for index in range(10):\n",
" add_one = AddOne(deps=add_one, name=f\"AddOne_{index+1}\")\n",
" add_one = AddOne(deps=add_one, name=f\"AddOne_{index + 1}\")\n",
"\n",
"project.run()"
]
Expand Down
2 changes: 1 addition & 1 deletion examples/docs/parameter_optimization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@
"source": [
"# Setup temporary directory and initialize git and dvc\n",
"from zntrack import config\n",
"from zntrack.utils import cwd_temp_dir\n",
"\n",
"config.nb_name = \"parameter_optimization.ipynb\"\n",
"\n",
"from zntrack.utils import cwd_temp_dir\n",
"\n",
"temp_dir = cwd_temp_dir()\n",
"\n",
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dev = [
"mlflow>=2.20.0",
"nbsphinx>=0.9.6",
"nbsphinx-link>=1.3.1",
"pre-commit>=4.1.0",
"pytest>=8.3.4",
"pytest-benchmark>=5.1.0",
"sphinx>=8.1.3",
Expand Down Expand Up @@ -94,7 +95,7 @@ disable = [
line-length = 90

[tool.ruff.lint]
select = ["E", "F", "D", "N", "C", "I"] #, "ANN"]
select = ["E", "F", "N", "C", "I"] #, "ANN"]
extend-ignore = [
"D213", "D203",
"D401",
Expand Down
8 changes: 4 additions & 4 deletions tests/files/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
def test_apply(proj_path) -> None:
project = zntrack.Project()

JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join")
JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") # noqa N806

with project:
a = zntrack.examples.ParamsToOuts(params=["a", "b"])
b = JoinedParamsToOuts(params=["a", "b"])
c = zntrack.apply(zntrack.examples.ParamsToOuts, "join")(params=["a", "b", "c"])
zntrack.examples.ParamsToOuts(params=["a", "b"])
JoinedParamsToOuts(params=["a", "b"])
zntrack.apply(zntrack.examples.ParamsToOuts, "join")(params=["a", "b", "c"])

project.build()

Expand Down
2 changes: 1 addition & 1 deletion tests/files/test_custom_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run(self):

def test_text_node(proj_path):
with zntrack.Project() as project:
node = TextNode()
TextNode()
project.build()

assert json.loads(
Expand Down
2 changes: 1 addition & 1 deletion tests/files/test_meta_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run(self) -> None:

def test_node(proj_path):
with zntrack.Project() as proj:
node = MyNode(name="some-node", always_changed=True)
MyNode(name="some-node", always_changed=True)

proj.repro()

Expand Down
4 changes: 2 additions & 2 deletions tests/files/test_metrics_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def test_metrics_as_deps(proj_path):
with project:
metrics_node = zntrack.examples.ParamsToMetrics(params={"loss": 0.01})

metrics_deps = zntrack.examples.DepsToMetrics(deps=metrics_node.metrics)
params_deps = zntrack.examples.DepsToMetrics(deps=metrics_node.params)
zntrack.examples.DepsToMetrics(deps=metrics_node.metrics)
zntrack.examples.DepsToMetrics(deps=metrics_node.params)

project.build()

Expand Down
4 changes: 2 additions & 2 deletions tests/files/test_user_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_node(proj_path):
assert zntrack.config.ALWAYS_CACHE is True

# We define the node here, because the config has to be set
# bevore calling zntrack.metrics()
# before calling zntrack.metrics()
class MyNode(zntrack.Node):
"""Some Node."""

Expand All @@ -27,7 +27,7 @@ def run(self) -> None:
self.metric = {"a": 1, "b": 2}

with zntrack.Project() as proj:
node = MyNode()
MyNode()

proj.build()

Expand Down
6 changes: 3 additions & 3 deletions tests/integration/skip_test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class NodeWithEnv(zntrack.Node):
def run(self):
import os

assert (
os.environ["OMP_NUM_THREADS"] == self.OMP_NUM_THREADS
), f'{os.environ["OMP_NUM_THREADS"]} != {self.OMP_NUM_THREADS}'
assert os.environ["OMP_NUM_THREADS"] == self.OMP_NUM_THREADS, (
f"{os.environ['OMP_NUM_THREADS']} != {self.OMP_NUM_THREADS}"
)

self.result = os.environ["OMP_NUM_THREADS"]

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_apply_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_apply(proj_path, eager) -> None:
"""Test the "zntrack.apply" function."""
project = zntrack.Project()

JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join")
JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") # noqa N806

with project:
a = zntrack.examples.ParamsToOuts(params=["a", "b"])
Expand All @@ -33,7 +33,7 @@ def test_deps_apply(proj_path, eager, attribute):
"""Test connecting applied nodes to other nodes."""
project = zntrack.Project()

JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join")
JoinedParamsToOuts = zntrack.apply(zntrack.examples.ParamsToOuts, "join") # noqa N806

assert issubclass(JoinedParamsToOuts, zntrack.Node)

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_nodes_not_created(proj_path):
b = zntrack.examples.ParamsToOuts(
params=18,
)
c = zntrack.examples.AddNodeAttributes(
zntrack.examples.AddNodeAttributes(
a=a.outs,
b=b.outs,
)
Expand All @@ -25,7 +25,7 @@ def test_nodes_not_created(proj_path):
e = zntrack.examples.ParamsToOuts(
params=18,
)
f = zntrack.examples.AddNodeAttributes(
zntrack.examples.AddNodeAttributes(
a=d.outs,
b=e.outs,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_dvc_outs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run(self):
def test_run_temp_path(proj_path):
project = zntrack.Project()
with project:
node = AssertTempPath()
AssertTempPath()
project.repro()


Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_node_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_AddNodes(proj_path, eager):
project.repro()

assert add_numbers_a.c == 3
# TODO: Node status is not beind updated when not using from_rev
# TODO: Node status is not being updated when not using from_rev
if eager:
assert add_numbers_a.state.state == NodeStatusEnum.FINISHED
assert add_numbers_b.c == 4
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_node_nwd_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import zntrack

# TODO UNCLEAR ERROR, __file__ atribute is not the same as the test file we want to collect
# TODO UNCLEAR ERROR, __file__ attribute is not the
# same as the test file we want to collect


class WriteToNWD(zntrack.Node):
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_options_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_groups(proj_path):

def test_autosave(proj_path):
with zntrack.Project() as proj:
node = AutoSavePandasPlotNode(n=10)
AutoSavePandasPlotNode(n=10)

proj.build()
subprocess.run(["dvc", "repro"], cwd=proj_path, check=True)
26 changes: 19 additions & 7 deletions tests/integration/test_plugins_aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class RangePlotter(zntrack.Node):
plots: pd.DataFrame = zntrack.plots(y="range")

def run(self):
self.plots = pd.DataFrame({"idx": [idx for idx in range(self.start, self.stop)]})
self.plots = pd.DataFrame({"idx": list(range(self.start, self.stop))})


# fixture to set the os.env before the test and remove if after the test
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_aim_plotting(aim_proj_path):
metrics = {}
for metric in run.metrics():
metrics[metric.name] = list(metric.data.values())[0]
npt.assert_array_equal(metrics["plots.idx"], [[idx for idx in range(10)]])
npt.assert_array_equal(metrics["plots.idx"], [list(range(10))])

proj.finalize(msg="test")
repo = git.Repo()
Expand Down Expand Up @@ -170,17 +170,26 @@ def test_multiple_nodes(aim_proj_path):

aim_repo = aim.Repo(path=os.environ["AIM_TRACKING_URI"])
for run_metrics_col in aim_repo.query_metrics(
f"run.dvc_stage_name == '{a.name}' and run.git_commit_hash == '{repo.head.commit.hexsha}'"
(
f"run.dvc_stage_name == '{a.name}' and "
f"run.git_commit_hash == '{repo.head.commit.hexsha}'"
)
).iter():
assert "original_run_id" not in run_metrics_col.run.dataframe().columns

for run_metrics_col in aim_repo.query_metrics(
f"run.dvc_stage_name == '{c.name}' and run.git_commit_hash == '{repo.head.commit.hexsha}'"
(
f"run.dvc_stage_name == '{c.name}' and "
f"run.git_commit_hash == '{repo.head.commit.hexsha}'"
)
).iter():
assert "original_run_id" not in run_metrics_col.run.dataframe().columns

for run_metrics_col in aim_repo.query_metrics(
f"run.dvc_stage_name == '{b.name}' and run.git_commit_hash == '{repo.head.commit.hexsha}'"
(
f"run.dvc_stage_name == '{b.name}' and "
f"run.git_commit_hash == '{repo.head.commit.hexsha}'"
)
).iter():
assert run_metrics_col.run.dataframe()["original_run_id"].tolist() == [b_run_id]

Expand All @@ -189,7 +198,7 @@ def test_project_tags(aim_proj_path):
with zntrack.Project(tags={"lorem": "ipsum", "hello": "world"}) as proj:
a = zntrack.examples.ParamsToOuts(params=3)
b = zntrack.examples.ParamsToOuts(params=7)
c = zntrack.examples.SumNodeAttributesToMetrics(inputs=[a.outs, b.outs], shift=0)
zntrack.examples.SumNodeAttributesToMetrics(inputs=[a.outs, b.outs], shift=0)

proj.repro()

Expand Down Expand Up @@ -234,5 +243,8 @@ def test_dataclass_deps(aim_proj_path):
with md.from_rev().state.plugins["AIMPlugin"].get_aim_run() as run:
df = run.dataframe()
assert df["t"].tolist() == [
'[{"_cls": "test_plugins_aim.T1", "temperature": 1}, {"_cls": "test_plugins_aim.T2", "temperature": 1}]'
(
'[{"_cls": "test_plugins_aim.T1", "temperature": 1},'
' {"_cls": "test_plugins_aim.T2", "temperature": 1}]'
)
]
30 changes: 20 additions & 10 deletions tests/integration/test_plugins_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class RangePlotter(zntrack.Node):
plots: pd.DataFrame = zntrack.plots(y="range")

def run(self):
self.plots = pd.DataFrame({"idx": [idx for idx in range(self.start, self.stop)]})
self.plots = pd.DataFrame({"idx": list(range(self.start, self.stop))})


@pytest.fixture
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_multiple_nodes(mlflow_proj_path, skip_cached):
assert c.metrics == {"value": 10.0}

with a.state.plugins["MLFlowPlugin"]:
a_run = mlflow.get_run(a.state.plugins["MLFlowPlugin"].child_run_id)
mlflow.get_run(a.state.plugins["MLFlowPlugin"].child_run_id)

with b.state.plugins["MLFlowPlugin"]:
b_run = mlflow.get_run(b.state.plugins["MLFlowPlugin"].child_run_id)
Expand Down Expand Up @@ -210,14 +210,20 @@ def test_multiple_nodes(mlflow_proj_path, skip_cached):
assert len(runs) == 4

a_run_2 = mlflow.search_runs(
filter_string=f"tags.git_commit_hash = '{repo.head.commit.hexsha}' and tags.dvc_stage_name = '{a.name}'",
filter_string=(
f"tags.git_commit_hash = '{repo.head.commit.hexsha}' "
f"and tags.dvc_stage_name = '{a.name}'"
),
output_format="list",
)
assert len(a_run_2) == 1
a_run_2 = a_run_2[0]

b_run_2 = mlflow.search_runs(
filter_string=f"tags.git_commit_hash = '{repo.head.commit.hexsha}' and tags.dvc_stage_name = '{b.name}'",
filter_string=(
f"tags.git_commit_hash = '{repo.head.commit.hexsha}'"
f" and tags.dvc_stage_name = '{b.name}'"
),
output_format="list",
)
if skip_cached:
Expand All @@ -226,11 +232,15 @@ def test_multiple_nodes(mlflow_proj_path, skip_cached):
assert len(b_run_2) == 1
b_run_2 = b_run_2[0]
assert b_run_2.data.tags["original_run_id"] == b_run.info.run_id
# original runs will not be updated with a new name to indicate that they are cached
# original runs will not be updated with a new name to
# indicate that they are cached
assert b_run_2.data.tags[mlflow_tags.MLFLOW_RUN_NAME] == "ParamsToOuts_1"

c_run_2 = mlflow.search_runs(
filter_string=f"tags.git_commit_hash = '{repo.head.commit.hexsha}' and tags.dvc_stage_name = '{c.name}'",
filter_string=(
f"tags.git_commit_hash = '{repo.head.commit.hexsha}'"
f" and tags.dvc_stage_name = '{c.name}'"
),
output_format="list",
)
assert len(c_run_2) == 1
Expand Down Expand Up @@ -302,7 +312,7 @@ def test_dataclass_deps(mlflow_proj_path):
)

proj.finalize(msg="run1 exp.")
repo = git.Repo()
git.Repo()

mdx = MD.from_rev()
assert mdx.__run_note__() != ""
Expand Down Expand Up @@ -338,9 +348,9 @@ def test_dataclass_deps(mlflow_proj_path):
with md.state.plugins["MLFlowPlugin"]:
run = mlflow.get_run(md.state.plugins["MLFlowPlugin"].child_run_id)

assert (
run.data.params["t"]
== "[{'temperature': 1, '_cls': 'test_plugins_mlflow.T1'}, {'temperature': 1, '_cls': 'test_plugins_mlflow.T2'}]"
assert run.data.params["t"] == (
"[{'temperature': 1, '_cls': 'test_plugins_mlflow.T1'},"
" {'temperature': 1, '_cls': 'test_plugins_mlflow.T2'}]"
)

md = zntrack.from_rev(md.name)
Expand Down
Loading

0 comments on commit 9a38ade

Please sign in to comment.