Skip to content

Commit

Permalink
Lint and format tests
Browse files Browse the repository at this point in the history
Closes #113
  • Loading branch information
saulshanabrook committed Feb 7, 2024
1 parent 6189a38 commit a0772f7
Show file tree
Hide file tree
Showing 14 changed files with 219 additions and 229 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
output_dir = cwd / "presentations"

subprocess.run(
[ # noqa: S607, S603
[ # noqa: S607
"jupyter",
"nbconvert",
str(presentation_file),
Expand Down
25 changes: 23 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ dependencies = ["typing-extensions", "black", "graphviz"]

[project.optional-dependencies]

array = ["scikit-learn", "array_api_compat", "numba==0.59.0rc1", "llvmlite==0.42.0rc1"]
array = [
"scikit-learn",
"array_api_compat",
"numba==0.59.0rc1",
"llvmlite==0.42.0rc1",
]
dev = ["pre-commit", "ruff", "mypy", "anywidget[dev]", "egglog[docs,test]"]

test = [
Expand Down Expand Up @@ -59,6 +64,17 @@ docs = [

[tool.ruff]
ignore = [
# Allow uppercase vars
"N806",
"N802",
# Allow subprocess run
"S603",
# ALlow any
"ANN401",
# Allow exec
"S102",
"S307",
"PGH001",
# allow star imports
"F405",
"F403",
Expand Down Expand Up @@ -158,13 +174,18 @@ ignore = [
# Allow private member refs
"SLF001",
]

line-length = 120
# Allow lines to be as long as 120.
src = ["python"]
select = ["ALL"]
extend-exclude = ["python/tests"]
extend-exclude = ["python/tests/__snapshots__"]
unsafe-fixes = true

[tool.ruff.lint.per-file-ignores]
# Don't require annotations for tests
"python/tests/**" = ["ANN001", "ANN201"]

[tool.mypy]
ignore_missing_imports = true
warn_redundant_casts = true
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ def to_value(self) -> Value:
...

@property
def T(self) -> NDArray: # noqa: N802
def T(self) -> NDArray:
"""
https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.T.html#array_api.array.T
"""
Expand Down
1 change: 1 addition & 0 deletions python/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for egglog."""
5 changes: 3 additions & 2 deletions python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@


@pytest.fixture(autouse=True)
def reset_conversions():
def _reset_conversions():
import egglog.runtime

old_conversions = copy.copy(egglog.runtime.CONVERSIONS)
yield
egglog.runtime.CONVERSIONS = old_conversions
Expand All @@ -19,6 +20,6 @@ def serialize(self, data, **kwargs) -> bytes:
return str(data).encode()


@pytest.fixture
@pytest.fixture()
def snapshot_py(snapshot):
return snapshot.with_defaults(extension_class=PythonSnapshotExtension)
32 changes: 15 additions & 17 deletions python/tests/test_array_api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import pytest
import ast
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast

import numba
from pathlib import Path
from typing import Any, Callable, cast
import ast
import pytest
from sklearn import config_context, datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from egglog.exp.array_api import *
from egglog.exp.array_api_numba import array_api_numba_schedule
from egglog.exp.array_api_program_gen import *
from sklearn import config_context, datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


def test_simplify_any_unique():
Expand Down Expand Up @@ -98,17 +99,16 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
# Eval the last statement
last_expr = ast.unparse(ast.parse(contents).body[-1])
return eval(last_expr, globals)
else:
exec(contents, globals)
return globals[var]
exec(contents, globals)
return globals[var]


def load_source(expr):
egraph = EGraph()
fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y")))
egraph.run(array_api_program_gen_schedule)
# cast b/c issue with it not recognizing py_object as property
fn = cast(Any, egraph.eval(fn_program.py_object))
cast(Any, egraph.eval(fn_program.py_object))
assert np.allclose(res, run_lda(X_np, y_np))
return egraph.eval(fn_program.statements)

Expand All @@ -120,13 +120,13 @@ def test_trace(self, snapshot_py, benchmark):
def X_r2():
X_arr = NDArray.var("X")
assume_dtype(X_arr, X_np.dtype)
assume_shape(X_arr, X_np.shape) # type: ignore
assume_shape(X_arr, X_np.shape)
assume_isfinite(X_arr)

y_arr = NDArray.var("y")
assume_dtype(y_arr, y_np.dtype)
assume_shape(y_arr, y_np.shape) # type: ignore
assume_value_one_of(y_arr, tuple(map(int, np.unique(y_np)))) # type: ignore
assume_shape(y_arr, y_np.shape)
assume_value_one_of(y_arr, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type]

with EGraph():
return run_lda(X_arr, y_arr)
Expand All @@ -148,14 +148,12 @@ def test_source_optimized(self, snapshot_py, benchmark):
assert benchmark(load_source, expr) == snapshot_py

@pytest.mark.parametrize(
("fn",),
"fn",
[
pytest.param(LinearDiscriminantAnalysis(n_components=2).fit_transform, id="base"),
pytest.param(run_lda, id="array_api"),
pytest.param(_load_py_snapshot(test_source_optimized, "__fn"), id="array_api-optimized"),
pytest.param(
numba.njit(_load_py_snapshot(test_source_optimized, "__fn")), id="array_api-optimized-numba"
),
pytest.param(numba.njit(_load_py_snapshot(test_source_optimized, "__fn")), id="array_api-optimized-numba"),
],
)
def test_execution(self, fn, benchmark):
Expand Down
16 changes: 7 additions & 9 deletions python/tests/test_bindings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import fractions
import json
import os
import pathlib
import subprocess
import fractions

import pytest

from egglog.bindings import *


Expand All @@ -13,12 +14,12 @@ def get_egglog_folder() -> pathlib.Path:
Return the egglog source folder
"""
metadata_process = subprocess.run(
["cargo", "metadata", "--format-version", "1", "-q"],
["cargo", "metadata", "--format-version", "1", "-q"], # noqa: S607
capture_output=True,
check=True,
)
metadata = json.loads(metadata_process.stdout)
(egglog_package,) = [package for package in metadata["packages"] if package["name"] == "egglog"]
(egglog_package,) = (package for package in metadata["packages"] if package["name"] == "egglog")
return pathlib.Path(egglog_package["manifest_path"]).parent


Expand Down Expand Up @@ -183,7 +184,7 @@ def test_cost(self):
def test_compare(self):
assert Variant("name", []) == Variant("name", [])
assert Variant("name", []) != Variant("name", ["a"])
assert Variant("name", []) != 10 # type: ignore
assert Variant("name", []) != 10 # type: ignore[comparison-overlap]


class TestEval:
Expand All @@ -193,18 +194,15 @@ def test_i64(self):
def test_f64(self):
assert EGraph().eval_f64(Lit(F64(1.0))) == 1.0


def test_string(self):
assert EGraph().eval_string(Lit(String("hi"))) == "hi"

def test_bool(self):
assert EGraph().eval_bool(Lit(Bool(True))) == True
assert EGraph().eval_bool(Lit(Bool(True))) is True

@pytest.mark.xfail(reason="Depends on getting actual sort from egraph")
def test_rational(self):
egraph = EGraph()
rational = Call("rational", [Lit(Int(1)), Lit(Int(2))])
egraph.run_program(
ActionCommand(Expr_(Call("rational", [Lit(Int(1)), Lit(Int(2))])))
)
egraph.run_program(ActionCommand(Expr_(Call("rational", [Lit(Int(1)), Lit(Int(2))]))))
assert egraph.eval_rational(rational) == fractions.Fraction(1, 2)
34 changes: 13 additions & 21 deletions python/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,24 @@ class MyMeta(type):
class MyType(metaclass=MyMeta):
pass

egraph = EGraph()

EGraph()

class MyTypeExpr(Expr):
def __init__(self):
def __init__(self) -> None:
...

converter(MyMeta, MyTypeExpr, lambda x: MyTypeExpr())
assert expr_parts(convert(MyType(), MyTypeExpr)) == expr_parts(MyTypeExpr())


def test_conversion():
egraph = EGraph()
EGraph()

class MyType:
pass


class MyTypeExpr(Expr):
def __init__(self):
def __init__(self) -> None:
...

converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())
Expand All @@ -36,19 +34,17 @@ def __init__(self):


def test_conversion_transitive_forward():
egraph = EGraph()
EGraph()

class MyType:
pass


class MyTypeExpr(Expr):
def __init__(self):
def __init__(self) -> None:
...


class MyTypeExpr2(Expr):
def __init__(self):
def __init__(self) -> None:
...

converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())
Expand All @@ -58,19 +54,17 @@ def __init__(self):


def test_conversion_transitive_backward():
egraph = EGraph()
EGraph()

class MyType:
pass


class MyTypeExpr(Expr):
def __init__(self):
def __init__(self) -> None:
...


class MyTypeExpr2(Expr):
def __init__(self):
def __init__(self) -> None:
...

converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2())
Expand All @@ -79,19 +73,17 @@ def __init__(self):


def test_conversion_transitive_cycle():
egraph = EGraph()
EGraph()

class MyType:
pass


class MyTypeExpr(Expr):
def __init__(self):
def __init__(self) -> None:
...


class MyTypeExpr2(Expr):
def __init__(self):
def __init__(self) -> None:
...

converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())
Expand Down
Loading

0 comments on commit a0772f7

Please sign in to comment.