From a0772f7c3948307e0665290304113216bb471768 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 6 Feb 2024 21:04:07 -0500 Subject: [PATCH] Lint and format tests Closes #113 --- docs/conf.py | 2 +- pyproject.toml | 25 ++- python/egglog/exp/array_api.py | 2 +- python/tests/__init__.py | 1 + python/tests/conftest.py | 5 +- python/tests/test_array_api.py | 32 ++-- python/tests/test_bindings.py | 16 +- python/tests/test_convert.py | 34 ++-- python/tests/test_high_level.py | 107 +++++-------- python/tests/test_modules.py | 12 +- python/tests/test_program_gen.py | 9 +- python/tests/test_py_object_sort.py | 12 +- python/tests/test_runtime.py | 164 ++++++++++---------- python/tests/test_type_constraint_solver.py | 27 ++-- 14 files changed, 219 insertions(+), 229 deletions(-) create mode 100644 python/tests/__init__.py diff --git a/docs/conf.py b/docs/conf.py index aef67543..70879cf7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -44,7 +44,7 @@ output_dir = cwd / "presentations" subprocess.run( - [ # noqa: S607, S603 + [ # noqa: S607 "jupyter", "nbconvert", str(presentation_file), diff --git a/pyproject.toml b/pyproject.toml index a1e3a889..f5a0e0a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,12 @@ dependencies = ["typing-extensions", "black", "graphviz"] [project.optional-dependencies] -array = ["scikit-learn", "array_api_compat", "numba==0.59.0rc1", "llvmlite==0.42.0rc1"] +array = [ + "scikit-learn", + "array_api_compat", + "numba==0.59.0rc1", + "llvmlite==0.42.0rc1", +] dev = ["pre-commit", "ruff", "mypy", "anywidget[dev]", "egglog[docs,test]"] test = [ @@ -59,6 +64,17 @@ docs = [ [tool.ruff] ignore = [ + # Allow uppercase vars + "N806", + "N802", + # Allow subprocess run + "S603", + # ALlow any + "ANN401", + # Allow exec + "S102", + "S307", + "PGH001", # allow star imports "F405", "F403", @@ -158,13 +174,18 @@ ignore = [ # Allow private member refs "SLF001", ] + line-length = 120 # Allow lines to be as long as 120. src = ["python"] select = ["ALL"] -extend-exclude = ["python/tests"] +extend-exclude = ["python/tests/__snapshots__"] unsafe-fixes = true +[tool.ruff.lint.per-file-ignores] +# Don't require annotations for tests +"python/tests/**" = ["ANN001", "ANN201"] + [tool.mypy] ignore_missing_imports = true warn_redundant_casts = true diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index cc6be53a..80acc792 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -819,7 +819,7 @@ def to_value(self) -> Value: ... @property - def T(self) -> NDArray: # noqa: N802 + def T(self) -> NDArray: """ https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.T.html#array_api.array.T """ diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 00000000..75c6c6df --- /dev/null +++ b/python/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for egglog.""" diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 1fd8c1dc..cd344b10 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -5,8 +5,9 @@ @pytest.fixture(autouse=True) -def reset_conversions(): +def _reset_conversions(): import egglog.runtime + old_conversions = copy.copy(egglog.runtime.CONVERSIONS) yield egglog.runtime.CONVERSIONS = old_conversions @@ -19,6 +20,6 @@ def serialize(self, data, **kwargs) -> bytes: return str(data).encode() -@pytest.fixture +@pytest.fixture() def snapshot_py(snapshot): return snapshot.with_defaults(extension_class=PythonSnapshotExtension) diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index f2c46b78..1cc3eccf 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -1,15 +1,16 @@ -import pytest +import ast +from collections.abc import Callable +from pathlib import Path +from typing import Any, cast import numba -from pathlib import Path -from typing import Any, Callable, cast -import ast +import pytest +from sklearn import config_context, datasets +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from egglog.exp.array_api import * from egglog.exp.array_api_numba import array_api_numba_schedule from egglog.exp.array_api_program_gen import * -from sklearn import config_context, datasets -from sklearn.discriminant_analysis import LinearDiscriminantAnalysis def test_simplify_any_unique(): @@ -98,9 +99,8 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any: # Eval the last statement last_expr = ast.unparse(ast.parse(contents).body[-1]) return eval(last_expr, globals) - else: - exec(contents, globals) - return globals[var] + exec(contents, globals) + return globals[var] def load_source(expr): @@ -108,7 +108,7 @@ def load_source(expr): fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y"))) egraph.run(array_api_program_gen_schedule) # cast b/c issue with it not recognizing py_object as property - fn = cast(Any, egraph.eval(fn_program.py_object)) + cast(Any, egraph.eval(fn_program.py_object)) assert np.allclose(res, run_lda(X_np, y_np)) return egraph.eval(fn_program.statements) @@ -120,13 +120,13 @@ def test_trace(self, snapshot_py, benchmark): def X_r2(): X_arr = NDArray.var("X") assume_dtype(X_arr, X_np.dtype) - assume_shape(X_arr, X_np.shape) # type: ignore + assume_shape(X_arr, X_np.shape) assume_isfinite(X_arr) y_arr = NDArray.var("y") assume_dtype(y_arr, y_np.dtype) - assume_shape(y_arr, y_np.shape) # type: ignore - assume_value_one_of(y_arr, tuple(map(int, np.unique(y_np)))) # type: ignore + assume_shape(y_arr, y_np.shape) + assume_value_one_of(y_arr, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type] with EGraph(): return run_lda(X_arr, y_arr) @@ -148,14 +148,12 @@ def test_source_optimized(self, snapshot_py, benchmark): assert benchmark(load_source, expr) == snapshot_py @pytest.mark.parametrize( - ("fn",), + "fn", [ pytest.param(LinearDiscriminantAnalysis(n_components=2).fit_transform, id="base"), pytest.param(run_lda, id="array_api"), pytest.param(_load_py_snapshot(test_source_optimized, "__fn"), id="array_api-optimized"), - pytest.param( - numba.njit(_load_py_snapshot(test_source_optimized, "__fn")), id="array_api-optimized-numba" - ), + pytest.param(numba.njit(_load_py_snapshot(test_source_optimized, "__fn")), id="array_api-optimized-numba"), ], ) def test_execution(self, fn, benchmark): diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index 4b317a18..903e8d84 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -1,10 +1,11 @@ +import fractions import json import os import pathlib import subprocess -import fractions import pytest + from egglog.bindings import * @@ -13,12 +14,12 @@ def get_egglog_folder() -> pathlib.Path: Return the egglog source folder """ metadata_process = subprocess.run( - ["cargo", "metadata", "--format-version", "1", "-q"], + ["cargo", "metadata", "--format-version", "1", "-q"], # noqa: S607 capture_output=True, check=True, ) metadata = json.loads(metadata_process.stdout) - (egglog_package,) = [package for package in metadata["packages"] if package["name"] == "egglog"] + (egglog_package,) = (package for package in metadata["packages"] if package["name"] == "egglog") return pathlib.Path(egglog_package["manifest_path"]).parent @@ -183,7 +184,7 @@ def test_cost(self): def test_compare(self): assert Variant("name", []) == Variant("name", []) assert Variant("name", []) != Variant("name", ["a"]) - assert Variant("name", []) != 10 # type: ignore + assert Variant("name", []) != 10 # type: ignore[comparison-overlap] class TestEval: @@ -193,18 +194,15 @@ def test_i64(self): def test_f64(self): assert EGraph().eval_f64(Lit(F64(1.0))) == 1.0 - def test_string(self): assert EGraph().eval_string(Lit(String("hi"))) == "hi" def test_bool(self): - assert EGraph().eval_bool(Lit(Bool(True))) == True + assert EGraph().eval_bool(Lit(Bool(True))) is True @pytest.mark.xfail(reason="Depends on getting actual sort from egraph") def test_rational(self): egraph = EGraph() rational = Call("rational", [Lit(Int(1)), Lit(Int(2))]) - egraph.run_program( - ActionCommand(Expr_(Call("rational", [Lit(Int(1)), Lit(Int(2))]))) - ) + egraph.run_program(ActionCommand(Expr_(Call("rational", [Lit(Int(1)), Lit(Int(2))])))) assert egraph.eval_rational(rational) == fractions.Fraction(1, 2) diff --git a/python/tests/test_convert.py b/python/tests/test_convert.py index 017d327f..e4c53610 100644 --- a/python/tests/test_convert.py +++ b/python/tests/test_convert.py @@ -8,11 +8,10 @@ class MyMeta(type): class MyType(metaclass=MyMeta): pass - egraph = EGraph() - + EGraph() class MyTypeExpr(Expr): - def __init__(self): + def __init__(self) -> None: ... converter(MyMeta, MyTypeExpr, lambda x: MyTypeExpr()) @@ -20,14 +19,13 @@ def __init__(self): def test_conversion(): - egraph = EGraph() + EGraph() class MyType: pass - class MyTypeExpr(Expr): - def __init__(self): + def __init__(self) -> None: ... converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) @@ -36,19 +34,17 @@ def __init__(self): def test_conversion_transitive_forward(): - egraph = EGraph() + EGraph() class MyType: pass - class MyTypeExpr(Expr): - def __init__(self): + def __init__(self) -> None: ... - class MyTypeExpr2(Expr): - def __init__(self): + def __init__(self) -> None: ... converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) @@ -58,19 +54,17 @@ def __init__(self): def test_conversion_transitive_backward(): - egraph = EGraph() + EGraph() class MyType: pass - class MyTypeExpr(Expr): - def __init__(self): + def __init__(self) -> None: ... - class MyTypeExpr2(Expr): - def __init__(self): + def __init__(self) -> None: ... converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2()) @@ -79,19 +73,17 @@ def __init__(self): def test_conversion_transitive_cycle(): - egraph = EGraph() + EGraph() class MyType: pass - class MyTypeExpr(Expr): - def __init__(self): + def __init__(self) -> None: ... - class MyTypeExpr2(Expr): - def __init__(self): + def __init__(self) -> None: ... converter(MyType, MyTypeExpr, lambda x: MyTypeExpr()) diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 39903d8f..1aa13615 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -2,12 +2,12 @@ from __future__ import annotations import importlib -from multiprocessing import Value import pathlib from copy import copy from typing import ClassVar, Union import pytest + from egglog import * from egglog.declarations import ( CallDecl, @@ -34,10 +34,10 @@ def test_unwrap_lit(self): def test_ne(self): assert str(ne(i64(1)).to(i64(2))) == "ne(i64(1)).to(i64(2))" + def test_eqsat_basic(): egraph = EGraph() - class Math(Expr): def __init__(self, value: i64Like) -> None: ... @@ -96,7 +96,6 @@ def fib(x: i64Like) -> i64: def test_fib_demand(): egraph = EGraph() - class Num(Expr): def __init__(self, i: i64Like) -> None: ... @@ -110,23 +109,9 @@ def fib(x: i64Like) -> Num: @egraph.register def _fib(a: i64, b: i64): - yield rewrite( - Num(a) + Num(b) - ).to( - Num(a + b) - ) - yield rewrite( - fib(a) - ).to( - fib(a - 1) + fib(a - 2), - a > 1 - ) - yield rewrite( - fib(a) - ).to( - Num(a), - a <= 1 - ) + yield rewrite(Num(a) + Num(b)).to(Num(a + b)) + yield rewrite(fib(a)).to(fib(a - 1) + fib(a - 2), a > 1) + yield rewrite(fib(a)).to(Num(a), a <= 1) f7 = egraph.let("f7", fib(7)) egraph.run(14) @@ -155,9 +140,9 @@ def foo() -> i64: def test_constants(): egraph = EGraph() - class A(Expr): pass + one = constant("one", A) two = constant("two", A) @@ -168,18 +153,18 @@ class A(Expr): def test_class_vars(): egraph = EGraph() - class A(Expr): ONE: ClassVar[A] + two = constant("two", A) egraph.register(union(A.ONE).with_(two)) egraph.check(eq(A.ONE).to(two)) + def test_simplify_constant(): egraph = EGraph() - class Numeric(Expr): ONE: ClassVar[Numeric] @@ -197,17 +182,18 @@ def test_extract_constant_twice(): # Sometimes extracting a constant twice will give an error egraph = EGraph() - class Numeric(Expr): ONE: ClassVar[Numeric] egraph.extract(Numeric.ONE) egraph.extract(Numeric.ONE) + def test_extract_include_cost(): _, cost = EGraph().extract(i64(0), include_cost=True) assert cost == 1 + def test_relation(): egraph = EGraph() @@ -227,7 +213,7 @@ def test_generic_sort(): def test_keyword_args(): - egraph = EGraph() + EGraph() @function def foo(x: i64Like, y: i64Like) -> i64: @@ -239,18 +225,15 @@ def foo(x: i64Like, y: i64Like) -> i64: def test_keyword_args_init(): - egraph = EGraph() - + EGraph() class Foo(Expr): def __init__(self, x: i64Like) -> None: ... - assert expr_parts(Foo(1)) == expr_parts(Foo(x=1)) - def test_modules() -> None: m = Module() @@ -279,7 +262,6 @@ def from_numeric(n: Numeric) -> OtherNumeric: def test_property(): egraph = EGraph() - class Foo(Expr): def __init__(self) -> None: ... @@ -293,7 +275,7 @@ def bar(self) -> i64: def test_default_args(): - egraph = EGraph() + EGraph() @function def foo(x: i64Like, y: i64Like = i64(1)) -> i64: @@ -323,11 +305,7 @@ def test_eval(self): def test_eval_local(self): x = "hi" - res = py_eval( - "my_add(x, y)", - PyObject(locals()).dict_update("y", "there"), - globals() - ) + res = py_eval("my_add(x, y)", PyObject(locals()).dict_update("y", "there"), globals()) assert EGraph().eval(res) == "hithere" def test_exec(self): @@ -356,7 +334,7 @@ def test_f64_negation() -> None: expr2 = egraph.let("expr2", f64(2.0)) # expr3 = -(-2.0) - expr3 = egraph.let("expr3", -(-f64(2.0))) + expr3 = egraph.let("expr3", -(-f64(2.0))) # noqa: B002 egraph.check(eq(expr1).to(-expr2)) egraph.check(eq(expr3).to(expr2)) @@ -369,15 +347,14 @@ def test_not_equals(): def test_custom_equality(): egraph = EGraph() - class Boolean(Expr): def __init__(self, value: BoolLike) -> None: ... - def __eq__(self, other: Boolean) -> Boolean: # type: ignore[override] + def __eq__(self, other: Boolean) -> Boolean: # type: ignore[override] ... - def __ne__(self, other: Boolean) -> Boolean: # type: ignore[override] + def __ne__(self, other: Boolean) -> Boolean: # type: ignore[override] ... egraph.register(rewrite(Boolean(True) == Boolean(True)).to(Boolean(False))) @@ -390,10 +367,10 @@ def __ne__(self, other: Boolean) -> Boolean: # type: ignore[override] egraph.check(eq(should_be_true).to(Boolean(False))) egraph.check(eq(should_be_false).to(Boolean(True))) + class TestMutate: def test_setitem_defaults(self): - egraph = EGraph() - + EGraph() class Foo(Expr): def __init__(self) -> None: @@ -413,7 +390,6 @@ def __setitem__(self, key: i64Like, value: i64Like) -> None: def test_function(self): egraph = EGraph() - class Math(Expr): def __init__(self, i: i64Like) -> None: ... @@ -457,8 +433,7 @@ def test_builtin_reflected(): def test_reflected_binary_method(): # If we have a reflected binary method, it should be converted into the non-reflected version - egraph = EGraph() - + EGraph() class Math(Expr): def __init__(self, value: i64Like) -> None: @@ -482,8 +457,7 @@ def __radd__(self, other: Math) -> Math: def test_upcast_args(): # -0.1 + Int(x) -> -0.1 + Float(x) - egraph = EGraph() - + EGraph() class Int(Expr): def __init__(self, value: i64Like) -> None: @@ -492,7 +466,6 @@ def __init__(self, value: i64Like) -> None: def __add__(self, other: Int) -> Int: ... - class Float(Expr): def __init__(self, value: f64Like) -> None: ... @@ -508,18 +481,19 @@ def from_int(cls, other: Int) -> Float: converter(f64, Float, Float) converter(Int, Float, Float.from_int) - res: Expr = -0.1 + Int(10) # type: ignore + res: Expr = -0.1 + Int(10) # type: ignore[operator,assignment] assert expr_parts(res) == expr_parts(Float(-0.1) + Float.from_int(Int(10))) - res: Expr = Int(10) + -0.1 # type: ignore + res: Expr = Int(10) + -0.1 # type: ignore[operator,assignment] assert expr_parts(res) == expr_parts(Float.from_int(Int(10)) + Float(-0.1)) + def test_rewrite_upcasts(): - rewrite(i64(1)).to(0) # type: ignore + rewrite(i64(1)).to(0) # type: ignore[arg-type] def test_function_default_upcasts(): - egraph = EGraph() + EGraph() @function def f(x: i64Like) -> i64: @@ -527,11 +501,11 @@ def f(x: i64Like) -> i64: assert expr_parts(f(1)) == expr_parts(f(i64(1))) + def test_upcast_self_lower_cost(): # Verifies that self will be upcasted, if that upcast has a lower cast than converting the other arg # i.e. Int(x) + NDArray(y) -> NDArray(Int(x)) + NDArray(y) instead of Int(x) + NDArray(y).to_int() - egraph = EGraph() - + EGraph() class Int(Expr): def __init__(self, name: StringLike) -> None: @@ -542,7 +516,6 @@ def __add__(self, other: Int) -> Int: NDArrayLike = Union[Int, "NDArray"] - class NDArray(Expr): def __init__(self, name: StringLike) -> None: ... @@ -588,7 +561,6 @@ def test_eval(): # egraph.as_egglog_string - def test_eval_fn(): egraph = EGraph() @@ -598,15 +570,16 @@ def test_eval_fn(): def _global_make_tuple(x): return (x,) + def test_eval_fn_globals(): egraph = EGraph() assert egraph.eval(py_eval_fn(lambda x: _global_make_tuple(x))(PyObject.from_int(1))) == (1,) + def test_eval_fn_locals(): egraph = EGraph() - def _locals_make_tuple(x): return (x,) @@ -634,27 +607,33 @@ def test_functions_seperate_pop(): class T(Expr): def __init__(self, x: i64Like) -> None: ... + with egraph: + @function - def f(x: T) -> T: ... + def f(x: T) -> T: + ... egraph.register(f(T(1))) with egraph: + @function - def f(x: T, y: T) -> T: ... + def f(x: T, y: T) -> T: + ... + + egraph.register(f(T(1), T(2))) # type: ignore[call-arg] - egraph.register(f(T(1), T(2))) # type: ignore[call-arg] # https://github.com/egraphs-good/egglog/issues/113 def test_multiple_generics(): - @function - def f() -> Vec[i64]: ... + def f() -> Vec[i64]: + ... @function - def g() -> Vec[String]: ... - + def g() -> Vec[String]: + ... egraph = EGraph() diff --git a/python/tests/test_modules.py b/python/tests/test_modules.py index 87a3bf75..93c067d1 100644 --- a/python/tests/test_modules.py +++ b/python/tests/test_modules.py @@ -1,16 +1,18 @@ import pytest + # from egglog.declarations import ModuleDeclarations from egglog.egraph import * + # from egglog.egraph import _BUILTIN_DECLS, BUILTINS def test_tree_modules(): """ - BUILTINS - / | \ + BUILTINS + / | \ A B C - | / - D + | / + D """ # assert _BUILTIN_DECLS # assert BUILTINS._mod_decls == ModuleDeclarations(_BUILTIN_DECLS, []) @@ -34,5 +36,5 @@ def test_tree_modules(): egraph = EGraph([D, B]) # assert egraph._flatted_deps == [A, B, D] egraph.check(a(), b(), d()) - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017, PT011 egraph.check(c()) diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index 5d9a323b..7c6e8e78 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -7,8 +7,6 @@ from egglog.exp.program_gen import * - - class Math(Expr): def __init__(self, value: i64Like) -> None: ... @@ -26,7 +24,7 @@ def __mul__(self, other: Math) -> Math: def __neg__(self) -> Math: ... - @method(cost=1000) # type: ignore + @method(cost=1000) # type: ignore[misc] @property def program(self) -> Program: ... @@ -36,6 +34,7 @@ def program(self) -> Program: def assume_pos(x: Math) -> Math: ... + @ruleset def to_program_ruleset( s: String, @@ -78,5 +77,5 @@ def test_py_object(): egraph.register(fn.eval_py_object({"z": 10})) egraph.run(to_program_ruleset * 100 + program_gen_ruleset * 100) res = egraph.eval(fn.py_object) - assert res(1, 2) == 13 # type: ignore - assert inspect.getsource(res) # type: ignore + assert res(1, 2) == 13 # type: ignore[operator] + assert inspect.getsource(res) # type: ignore[arg-type] diff --git a/python/tests/test_py_object_sort.py b/python/tests/test_py_object_sort.py index a7afda73..79af9af6 100644 --- a/python/tests/test_py_object_sort.py +++ b/python/tests/test_py_object_sort.py @@ -3,6 +3,7 @@ import weakref import pytest + from egglog.bindings import * @@ -45,7 +46,7 @@ def test_object_keeps_ref(self): del my_object gc.collect() assert ref() is not None - assert EGraph(sort).eval_py_object(expr)== MyObject() + assert EGraph(sort).eval_py_object(expr) == MyObject() class TestDictUpdate: @@ -98,8 +99,8 @@ def test_eval(self): Lit(String("my_add(x, y)")), globals_, Call("py-dict-update", [locals_, x_expr, one, y_expr, two]), - ] - ) + ], + ), ) ) ) @@ -114,9 +115,7 @@ def test_to_string(self): sort = PyObjectSort() egraph = EGraph(sort) - egraph.run_program( - ActionCommand(Let("res", Call("py-to-string", [sort.store("hi")]))) - ) + egraph.run_program(ActionCommand(Let("res", Call("py-to-string", [sort.store("hi")])))) assert egraph.eval_string(Var("res")) == "hi" def test_from_string(self): @@ -129,4 +128,3 @@ def test_from_string(self): ActionCommand(Let("res", Call("py-from-string", [Lit(String("hi"))]))), ) assert egraph.eval_py_object(Var("res")) == "hi" - diff --git a/python/tests/test_runtime.py b/python/tests/test_runtime.py index 2ce1cf9a..f3b61424 100644 --- a/python/tests/test_runtime.py +++ b/python/tests/test_runtime.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + from egglog.declarations import * from egglog.runtime import * from egglog.type_constraint_solver import * @@ -7,11 +9,11 @@ def test_type_str(): decls = Declarations( - _classes={ - "i64": ClassDecl(), - "Map": ClassDecl(type_vars=("K", "V")), - } - ) + _classes={ + "i64": ClassDecl(), + "Map": ClassDecl(type_vars=("K", "V")), + } + ) i64 = RuntimeClass(decls.update_other, "i64") Map = RuntimeClass(decls.update_other, "Map") assert str(i64) == "i64" @@ -20,99 +22,97 @@ def test_type_str(): def test_function_call(): decls = Declarations( - _classes={ - "i64": ClassDecl(), - }, - _functions={ - "one": FunctionDecl( - (), - (), - (), - TypeRefWithVars("i64"), - False, - ), - }, - ) + _classes={ + "i64": ClassDecl(), + }, + _functions={ + "one": FunctionDecl( + (), + (), + (), + TypeRefWithVars("i64"), + False, + ), + }, + ) one = RuntimeFunction(decls, "one") assert ( - one().__egg_typed_expr__ # type: ignore + one().__egg_typed_expr__ # type: ignore[union-attr] == RuntimeExpr(decls, TypedExprDecl(JustTypeRef("i64"), CallDecl(FunctionRef("one")))).__egg_typed_expr__ ) def test_classmethod_call(): - from pytest import raises - K, V = ClassTypeVarRef("K"), ClassTypeVarRef("V") decls = Declarations( - _classes={ - "i64": ClassDecl(), - "unit": ClassDecl(), - "Map": ClassDecl( - type_vars=("K", "V"), - class_methods={ - "create": FunctionDecl( - (), - (), - (), - TypeRefWithVars("Map", (K, V)), - False, - ) - }, - ), - }, - _type_ref_to_egg_sort={ - JustTypeRef("i64"): "i64", - JustTypeRef("unit"): "unit", - JustTypeRef("Map"): "Map", - } - ) + _classes={ + "i64": ClassDecl(), + "unit": ClassDecl(), + "Map": ClassDecl( + type_vars=("K", "V"), + class_methods={ + "create": FunctionDecl( + (), + (), + (), + TypeRefWithVars("Map", (K, V)), + False, + ) + }, + ), + }, + _type_ref_to_egg_sort={ + JustTypeRef("i64"): "i64", + JustTypeRef("unit"): "unit", + JustTypeRef("Map"): "Map", + }, + ) Map = RuntimeClass(decls.update_other, "Map") - with raises(TypeConstraintError): - Map.create() # type: ignore + with pytest.raises(TypeConstraintError): + Map.create() # type: ignore[operator] i64 = RuntimeClass(decls.update_other, "i64") unit = RuntimeClass(decls.update_other, "unit") assert ( - Map[i64, unit].create().__egg_typed_expr__ # type: ignore + Map[i64, unit].create().__egg_typed_expr__ # type: ignore[union-attr] == TypedExprDecl( - JustTypeRef("Map", (JustTypeRef("i64"), JustTypeRef("unit"))), - CallDecl( - ClassMethodRef("Map", "create"), - (), - (JustTypeRef("i64"), JustTypeRef("unit")), - ), - ) + JustTypeRef("Map", (JustTypeRef("i64"), JustTypeRef("unit"))), + CallDecl( + ClassMethodRef("Map", "create"), + (), + (JustTypeRef("i64"), JustTypeRef("unit")), + ), + ) ) def test_expr_special(): decls = Declarations( - _classes={ - "i64": ClassDecl( - methods={ - "__add__": FunctionDecl( - (TypeRefWithVars("i64"), TypeRefWithVars("i64")), - (), - (None, None), - TypeRefWithVars("i64"), - False, - ) - }, - class_methods={ - "__init__": FunctionDecl( - (TypeRefWithVars("i64"),), - (), - (None,), - TypeRefWithVars("i64"), - False, - ) - }, - ), - }, - ) + _classes={ + "i64": ClassDecl( + methods={ + "__add__": FunctionDecl( + (TypeRefWithVars("i64"), TypeRefWithVars("i64")), + (), + (None, None), + TypeRefWithVars("i64"), + False, + ) + }, + class_methods={ + "__init__": FunctionDecl( + (TypeRefWithVars("i64"),), + (), + (None,), + TypeRefWithVars("i64"), + False, + ) + }, + ), + }, + ) i64 = RuntimeClass(decls.update_other, "i64") - one = i64(1) # type: ignore - res = one + one # type: ignore + one = i64(1) + res = one + one # type: ignore[operator] expected_res = RuntimeExpr( decls, TypedExprDecl( @@ -128,10 +128,10 @@ def test_expr_special(): def test_class_variable(): decls = Declarations( - _classes={ - "i64": ClassDecl(class_variables={"one": JustTypeRef("i64")}), - }, - ) + _classes={ + "i64": ClassDecl(class_variables={"one": JustTypeRef("i64")}), + }, + ) i64 = RuntimeClass(decls.update_other, "i64") one = i64.one assert isinstance(one, RuntimeExpr) diff --git a/python/tests/test_type_constraint_solver.py b/python/tests/test_type_constraint_solver.py index fa928c1f..580caa08 100644 --- a/python/tests/test_type_constraint_solver.py +++ b/python/tests/test_type_constraint_solver.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest + from egglog.declarations import * from egglog.type_constraint_solver import * @@ -9,21 +10,20 @@ K, V = ClassTypeVarRef("K"), ClassTypeVarRef("V") map = TypeRefWithVars("Map", (K, V)) map_i64_unit = JustTypeRef("Map", (i64, unit)) -decls = Declarations( - _classes={ - "Map": ClassDecl( - type_vars = ("K", "V") - ) - } -) +decls = Declarations(_classes={"Map": ClassDecl(type_vars=("K", "V"))}) + def test_simple() -> None: - assert TypeConstraintSolver(Declarations()).infer_return_type([i64.to_var()], i64.to_var(), None, [i64], None) == i64 + assert ( + TypeConstraintSolver(Declarations()).infer_return_type([i64.to_var()], i64.to_var(), None, [i64], None) == i64 + ) + def test_wrong_arg() -> None: with pytest.raises(TypeConstraintError): TypeConstraintSolver(Declarations()).infer_return_type([i64.to_var()], i64.to_var(), None, [unit], None) + def test_wrong_number_args() -> None: with pytest.raises(TypeConstraintError): TypeConstraintSolver(Declarations()).infer_return_type([], i64.to_var(), None, [unit], "Map") @@ -32,28 +32,29 @@ def test_wrong_number_args() -> None: def test_generic() -> None: assert TypeConstraintSolver(decls).infer_return_type([map, K], V, None, [map_i64_unit, i64], "Map") == unit + def test_generic_wrong() -> None: with pytest.raises(TypeConstraintError): TypeConstraintSolver(decls).infer_return_type([map, K], V, None, [map_i64_unit, unit], "Map") + def test_variable() -> None: - assert ( - TypeConstraintSolver(decls).infer_return_type([map, K], V, V, [map_i64_unit, i64, unit, unit], "Map") - == unit - ) + assert TypeConstraintSolver(decls).infer_return_type([map, K], V, V, [map_i64_unit, i64, unit, unit], "Map") == unit + def test_variable_wrong() -> None: with pytest.raises(TypeConstraintError): TypeConstraintSolver(decls).infer_return_type([map, K], V, V, [map_i64_unit, i64, unit, i64], "Map") + def test_bound() -> None: bound_cs = TypeConstraintSolver(decls) bound_cs.bind_class(map_i64_unit) assert bound_cs.infer_return_type([K], V, None, [i64], "Map") == unit + def test_bound_wrong(): bound_cs = TypeConstraintSolver(decls) bound_cs.bind_class(map_i64_unit) with pytest.raises(TypeConstraintError): bound_cs.infer_return_type([K], V, None, [unit], "Map") -