diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 3aa03e7e..7468e086 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -36,7 +36,7 @@ jobs: with: python-version: ${{ matrix.py }} - run: uv sync --extra test --locked - - run: uv run pytest --benchmark-disable -vvv + - run: uv run pytest --benchmark-disable -vvv --durations=10 mypy: runs-on: ubuntu-latest diff --git a/python/egglog/exp/array_api_jit.py b/python/egglog/exp/array_api_jit.py index 74083ead..7344d8eb 100644 --- a/python/egglog/exp/array_api_jit.py +++ b/python/egglog/exp/array_api_jit.py @@ -23,5 +23,6 @@ def jit(fn: X) -> X: fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2)) fn = try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object) + fn.initial_expr = res # type: ignore[attr-defined] fn.expr = res_optimized # type: ignore[attr-defined] return cast(X, fn) diff --git a/python/tests/__snapshots__/test_array_api/test_jit[add][initial_expr].py b/python/tests/__snapshots__/test_array_api/test_jit[add][initial_expr].py new file mode 100644 index 00000000..958f510b --- /dev/null +++ b/python/tests/__snapshots__/test_array_api/test_jit[add][initial_expr].py @@ -0,0 +1 @@ +NDArray.var("x") + NDArray.var("y") \ No newline at end of file diff --git a/python/tests/__snapshots__/test_array_api/TestLDA.test_source_optimized.py b/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py similarity index 100% rename from python/tests/__snapshots__/test_array_api/TestLDA.test_source_optimized.py rename to python/tests/__snapshots__/test_array_api/test_jit[lda][code].py diff --git a/python/tests/__snapshots__/test_array_api/TestLDA.test_optimize.py b/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py similarity index 100% rename from python/tests/__snapshots__/test_array_api/TestLDA.test_optimize.py rename to python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py diff --git a/python/tests/__snapshots__/test_array_api/TestLDA.test_trace.py b/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py similarity index 100% rename from python/tests/__snapshots__/test_array_api/TestLDA.test_trace.py rename to python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py diff --git a/python/tests/__snapshots__/test_array_api/test_jit[tuple][initial_expr].py b/python/tests/__snapshots__/test_array_api/test_jit[tuple][initial_expr].py new file mode 100644 index 00000000..fd442e81 --- /dev/null +++ b/python/tests/__snapshots__/test_array_api/test_jit[tuple][initial_expr].py @@ -0,0 +1 @@ +NDArray.var("x")[IndexKey.int((NDArray.var("x").shape + TupleInt.from_vec(Vec[Int](Int(1), Int(2))))[Int(100)])] \ No newline at end of file diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index f4e37084..8707e169 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -1,11 +1,9 @@ # mypy: disable-error-code="empty-body" -import ast import inspect from collections.abc import Callable from itertools import product from pathlib import Path from types import FunctionType -from typing import Any import numba import pytest @@ -306,42 +304,6 @@ def run_lda(x, y): X_np, y_np = (iris.data, iris.target) -def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any: - """ - Load a python snapshot, evaling the code, and returning the `var` defined in it. - - If no var is provided, then return the last expression. - """ - path = Path(__file__).parent / "__snapshots__" / "test_array_api" / f"TestLDA.{fn.__name__}.py" - contents = path.read_text() - - contents = "import numpy as np\nfrom egglog.exp.array_api import *\n" + contents - globals: dict[str, Any] = {} - if var is None: - # exec once as a full statement - exec(contents, globals) - # Eval the last statement - last_expr = ast.unparse(ast.parse(contents).body[-1]) - return eval(last_expr, globals) - exec(contents, globals) - return globals[var] - - -def lda(X: NDArray, y: NDArray): - assume_dtype(X, X_np.dtype) - assume_shape(X, X_np.shape) - assume_isfinite(X) - - assume_dtype(y, y_np.dtype) - assume_shape(y, y_np.shape) - assume_value_one_of(y, tuple(map(int, np.unique(y_np)))) - return run_lda(X, y) - - -def lda_filled(): - return lda(NDArray.var("X"), NDArray.var("y")) - - @pytest.mark.parametrize( "program", [ @@ -364,80 +326,46 @@ def test_program_compile(program: Program, snapshot_py): assert "\n".join([*statements.split("\n"), expr]) == snapshot_py(name="code") +def lda(X: NDArray, y: NDArray): + assume_dtype(X, X_np.dtype) + assume_shape(X, X_np.shape) + assume_isfinite(X) + + assume_dtype(y, y_np.dtype) + assume_shape(y, y_np.shape) + assume_value_one_of(y, tuple(map(int, np.unique(y_np)))) + return run_lda(X, y) + + @pytest.mark.parametrize( "program", [ pytest.param(lambda x, y: x + y, id="add"), pytest.param(lambda x, y: x[(x.shape + TupleInt.from_vec((1, 2)))[100]], id="tuple"), + pytest.param(lda, id="lda"), ], ) -def test_jit(program, snapshot_py): - jitted = jit(program) +def test_jit(program, snapshot_py, benchmark): + jitted = benchmark(jit, program) + assert str(jitted.initial_expr) == snapshot_py(name="initial_expr") assert str(jitted.expr) == snapshot_py(name="expr") assert inspect.getsource(jitted) == snapshot_py(name="code") -@pytest.mark.benchmark(min_rounds=3) -class TestLDA: - """ - Incrementally benchmark each part of the LDA to see how long it takes to run. - """ - - def test_trace(self, snapshot_py, benchmark): - @benchmark - def X_r2(): - with EGraph().set_current(): - return lda_filled() - - res = str(X_r2) - assert res == snapshot_py - - def test_optimize(self, snapshot_py, benchmark): - egraph = EGraph() - with egraph.set_current(): - expr = lda_filled() - simplified = benchmark(egraph.simplify, expr, array_api_numba_schedule) - - assert str(simplified) == snapshot_py - - # @pytest.mark.xfail(reason="Original source is not working") - # def test_source(self, snapshot_py, benchmark): - # egraph = EGraph() - # expr = trace_lda(egraph) - # assert benchmark(load_source, expr, egraph) == snapshot_py - - def test_source_optimized(self, snapshot_py, benchmark): - egraph = EGraph() - with egraph.set_current(): - expr = lda_filled() - optimized_expr = egraph.simplify(expr, array_api_numba_schedule) - - @benchmark - def py_object(): - fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y")) - return try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object) - - assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np)) - assert inspect.getsource(py_object) == snapshot_py - - @pytest.mark.parametrize( - "fn_thunk", - [ - pytest.param(lambda: LinearDiscriminantAnalysis(n_components=2).fit_transform, id="base"), - pytest.param(lambda: run_lda, id="array_api"), - pytest.param(lambda: _load_py_snapshot(TestLDA.test_source_optimized, "__fn"), id="array_api-optimized"), - pytest.param( - lambda: numba.njit(_load_py_snapshot(TestLDA.test_source_optimized, "__fn")), - id="array_api-optimized-numba", - ), - pytest.param(lambda: jit(lda), id="array_api-jit"), - ], - ) - def test_execution(self, fn_thunk, benchmark): - fn = fn_thunk() - # warmup once for numba - assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np), rtol=1e-03) - benchmark(fn, X_np, y_np) +@pytest.mark.parametrize( + "fn_thunk", + [ + pytest.param(lambda: LinearDiscriminantAnalysis(n_components=2).fit_transform, id="base"), + pytest.param(lambda: run_lda, id="array_api"), + pytest.param(lambda: jit(lda), id="array_api-optimized"), + pytest.param(lambda: numba.njit(jit(lda)), id="array_api-optimized-numba"), + ], +) +def test_run_lda(fn_thunk, benchmark): + fn = fn_thunk() + # warmup once for numba + assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np), rtol=1e-03) + benchmark(fn, X_np, y_np) # if calling as script, print out egglog source for test