Skip to content

Commit

Permalink
Merge pull request #267 from egraphs-good/fix-benchmark
Browse files Browse the repository at this point in the history
Simplify LDA benchmarks even more
  • Loading branch information
saulshanabrook authored Feb 28, 2025
2 parents 4fd253f + 6856926 commit d282494
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 102 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
with:
python-version: ${{ matrix.py }}
- run: uv sync --extra test --locked
- run: uv run pytest --benchmark-disable -vvv
- run: uv run pytest --benchmark-disable -vvv --durations=10

mypy:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions python/egglog/exp/array_api_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ def jit(fn: X) -> X:

fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
fn = try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object)
fn.initial_expr = res # type: ignore[attr-defined]
fn.expr = res_optimized # type: ignore[attr-defined]
return cast(X, fn)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
NDArray.var("x") + NDArray.var("y")
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
NDArray.var("x")[IndexKey.int((NDArray.var("x").shape + TupleInt.from_vec(Vec[Int](Int(1), Int(2))))[Int(100)])]
130 changes: 29 additions & 101 deletions python/tests/test_array_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# mypy: disable-error-code="empty-body"
import ast
import inspect
from collections.abc import Callable
from itertools import product
from pathlib import Path
from types import FunctionType
from typing import Any

import numba
import pytest
Expand Down Expand Up @@ -306,42 +304,6 @@ def run_lda(x, y):
X_np, y_np = (iris.data, iris.target)


def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
"""
Load a python snapshot, evaling the code, and returning the `var` defined in it.
If no var is provided, then return the last expression.
"""
path = Path(__file__).parent / "__snapshots__" / "test_array_api" / f"TestLDA.{fn.__name__}.py"
contents = path.read_text()

contents = "import numpy as np\nfrom egglog.exp.array_api import *\n" + contents
globals: dict[str, Any] = {}
if var is None:
# exec once as a full statement
exec(contents, globals)
# Eval the last statement
last_expr = ast.unparse(ast.parse(contents).body[-1])
return eval(last_expr, globals)
exec(contents, globals)
return globals[var]


def lda(X: NDArray, y: NDArray):
assume_dtype(X, X_np.dtype)
assume_shape(X, X_np.shape)
assume_isfinite(X)

assume_dtype(y, y_np.dtype)
assume_shape(y, y_np.shape)
assume_value_one_of(y, tuple(map(int, np.unique(y_np))))
return run_lda(X, y)


def lda_filled():
return lda(NDArray.var("X"), NDArray.var("y"))


@pytest.mark.parametrize(
"program",
[
Expand All @@ -364,80 +326,46 @@ def test_program_compile(program: Program, snapshot_py):
assert "\n".join([*statements.split("\n"), expr]) == snapshot_py(name="code")


def lda(X: NDArray, y: NDArray):
assume_dtype(X, X_np.dtype)
assume_shape(X, X_np.shape)
assume_isfinite(X)

assume_dtype(y, y_np.dtype)
assume_shape(y, y_np.shape)
assume_value_one_of(y, tuple(map(int, np.unique(y_np))))
return run_lda(X, y)


@pytest.mark.parametrize(
"program",
[
pytest.param(lambda x, y: x + y, id="add"),
pytest.param(lambda x, y: x[(x.shape + TupleInt.from_vec((1, 2)))[100]], id="tuple"),
pytest.param(lda, id="lda"),
],
)
def test_jit(program, snapshot_py):
jitted = jit(program)
def test_jit(program, snapshot_py, benchmark):
jitted = benchmark(jit, program)
assert str(jitted.initial_expr) == snapshot_py(name="initial_expr")
assert str(jitted.expr) == snapshot_py(name="expr")
assert inspect.getsource(jitted) == snapshot_py(name="code")


@pytest.mark.benchmark(min_rounds=3)
class TestLDA:
"""
Incrementally benchmark each part of the LDA to see how long it takes to run.
"""

def test_trace(self, snapshot_py, benchmark):
@benchmark
def X_r2():
with EGraph().set_current():
return lda_filled()

res = str(X_r2)
assert res == snapshot_py

def test_optimize(self, snapshot_py, benchmark):
egraph = EGraph()
with egraph.set_current():
expr = lda_filled()
simplified = benchmark(egraph.simplify, expr, array_api_numba_schedule)

assert str(simplified) == snapshot_py

# @pytest.mark.xfail(reason="Original source is not working")
# def test_source(self, snapshot_py, benchmark):
# egraph = EGraph()
# expr = trace_lda(egraph)
# assert benchmark(load_source, expr, egraph) == snapshot_py

def test_source_optimized(self, snapshot_py, benchmark):
egraph = EGraph()
with egraph.set_current():
expr = lda_filled()
optimized_expr = egraph.simplify(expr, array_api_numba_schedule)

@benchmark
def py_object():
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
return try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object)

assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np))
assert inspect.getsource(py_object) == snapshot_py

@pytest.mark.parametrize(
"fn_thunk",
[
pytest.param(lambda: LinearDiscriminantAnalysis(n_components=2).fit_transform, id="base"),
pytest.param(lambda: run_lda, id="array_api"),
pytest.param(lambda: _load_py_snapshot(TestLDA.test_source_optimized, "__fn"), id="array_api-optimized"),
pytest.param(
lambda: numba.njit(_load_py_snapshot(TestLDA.test_source_optimized, "__fn")),
id="array_api-optimized-numba",
),
pytest.param(lambda: jit(lda), id="array_api-jit"),
],
)
def test_execution(self, fn_thunk, benchmark):
fn = fn_thunk()
# warmup once for numba
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np), rtol=1e-03)
benchmark(fn, X_np, y_np)
@pytest.mark.parametrize(
"fn_thunk",
[
pytest.param(lambda: LinearDiscriminantAnalysis(n_components=2).fit_transform, id="base"),
pytest.param(lambda: run_lda, id="array_api"),
pytest.param(lambda: jit(lda), id="array_api-optimized"),
pytest.param(lambda: numba.njit(jit(lda)), id="array_api-optimized-numba"),
],
)
def test_run_lda(fn_thunk, benchmark):
fn = fn_thunk()
# warmup once for numba
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np), rtol=1e-03)
benchmark(fn, X_np, y_np)


# if calling as script, print out egglog source for test
Expand Down

0 comments on commit d282494

Please sign in to comment.