Skip to content

Commit

Permalink
Simplify benchmarks and make them more accurate
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Feb 27, 2025
1 parent c005804 commit dd6b6f6
Showing 1 changed file with 22 additions and 38 deletions.
60 changes: 22 additions & 38 deletions python/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from egglog.exp.array_api_loopnest import *
from egglog.exp.array_api_numba import array_api_numba_schedule
from egglog.exp.array_api_program_gen import *
from egglog.exp.program_gen import Program
from egglog.exp.program_gen import EvalProgram, Program

some_shape = constant("some_shape", TupleInt)
some_dtype = constant("some_dtype", DType)
Expand Down Expand Up @@ -327,33 +327,19 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
return globals[var]


def load_source(fn_program: EvalProgram, egraph: EGraph):
egraph.register(fn_program)
egraph.run(array_api_program_gen_schedule)
# dp the needed pieces in here for benchmarking
try:
return egraph.extract(fn_program.as_py_object).eval()
except Exception as err:
err.add_note(f"Failed to compile the program into a string: \n\n{egraph.extract(fn_program)}")
egraph.display(split_primitive_outputs=True, n_inline_leaves=3, split_functions=[Program])
raise


def lda(X, y):
def lda(X: NDArray, y: NDArray):
assume_dtype(X, X_np.dtype)
assume_shape(X, X_np.shape)
assume_isfinite(X)

assume_dtype(y, y_np.dtype)
assume_shape(y, y_np.shape)
assume_value_one_of(y, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type]
assume_value_one_of(y, tuple(map(int, np.unique(y_np))))
return run_lda(X, y)


def simplify_lda(egraph: EGraph, expr: NDArray) -> NDArray:
egraph.register(expr)
egraph.run(array_api_numba_schedule)
return egraph.extract(expr)
def lda_filled():
return lda(NDArray.var("X"), NDArray.var("y"))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -398,21 +384,20 @@ class TestLDA:
"""

def test_trace(self, snapshot_py, benchmark):
X = NDArray.var("X")
y = NDArray.var("y")
with EGraph().set_current():
X_r2 = benchmark(lda, X, y)
@benchmark
def X_r2():
with EGraph().set_current():
return lda_filled()

res = str(X_r2)
print(res)
assert res == snapshot_py

def test_optimize(self, snapshot_py, benchmark):
egraph = EGraph()
X = NDArray.var("X")
y = NDArray.var("y")
with egraph.set_current():
expr = lda(X, y)
simplified = benchmark.pedantic(simplify_lda, args=(egraph, expr))
expr = lda_filled()
simplified = benchmark(egraph.simplify, expr, array_api_numba_schedule)

assert str(simplified) == snapshot_py

# @pytest.mark.xfail(reason="Original source is not working")
Expand All @@ -423,18 +408,17 @@ def test_optimize(self, snapshot_py, benchmark):

def test_source_optimized(self, snapshot_py, benchmark):
egraph = EGraph()
X = NDArray.var("X")
y = NDArray.var("y")
with egraph.set_current():
expr = lda(X, y)
optimized_expr = simplify_lda(egraph, expr)
egraph = EGraph()
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
py_object = benchmark(load_source, fn_program, egraph)
expr = lda_filled()
optimized_expr = egraph.simplify(expr, array_api_numba_schedule)

@benchmark
def py_object():
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
return try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object)

assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np))
with egraph.set_current():
fn_object = cast(FunctionType, fn_program.as_py_object.eval())
assert inspect.getsource(fn_object) == snapshot_py
assert inspect.getsource(py_object) == snapshot_py

@pytest.mark.parametrize(
"fn_thunk",
Expand Down

0 comments on commit dd6b6f6

Please sign in to comment.