From dd6b6f6880b0d23f57ba02123040b9013aceac49 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 27 Feb 2025 09:17:41 -0500 Subject: [PATCH] Simplify benchmarks and make them more accurate --- python/tests/test_array_api.py | 60 +++++++++++++--------------------- 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 9adb710..f4e3708 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -19,7 +19,7 @@ from egglog.exp.array_api_loopnest import * from egglog.exp.array_api_numba import array_api_numba_schedule from egglog.exp.array_api_program_gen import * -from egglog.exp.program_gen import Program +from egglog.exp.program_gen import EvalProgram, Program some_shape = constant("some_shape", TupleInt) some_dtype = constant("some_dtype", DType) @@ -327,33 +327,19 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any: return globals[var] -def load_source(fn_program: EvalProgram, egraph: EGraph): - egraph.register(fn_program) - egraph.run(array_api_program_gen_schedule) - # dp the needed pieces in here for benchmarking - try: - return egraph.extract(fn_program.as_py_object).eval() - except Exception as err: - err.add_note(f"Failed to compile the program into a string: \n\n{egraph.extract(fn_program)}") - egraph.display(split_primitive_outputs=True, n_inline_leaves=3, split_functions=[Program]) - raise - - -def lda(X, y): +def lda(X: NDArray, y: NDArray): assume_dtype(X, X_np.dtype) assume_shape(X, X_np.shape) assume_isfinite(X) assume_dtype(y, y_np.dtype) assume_shape(y, y_np.shape) - assume_value_one_of(y, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type] + assume_value_one_of(y, tuple(map(int, np.unique(y_np)))) return run_lda(X, y) -def simplify_lda(egraph: EGraph, expr: NDArray) -> NDArray: - egraph.register(expr) - egraph.run(array_api_numba_schedule) - return egraph.extract(expr) +def lda_filled(): + return lda(NDArray.var("X"), NDArray.var("y")) @pytest.mark.parametrize( @@ -398,21 +384,20 @@ class TestLDA: """ def test_trace(self, snapshot_py, benchmark): - X = NDArray.var("X") - y = NDArray.var("y") - with EGraph().set_current(): - X_r2 = benchmark(lda, X, y) + @benchmark + def X_r2(): + with EGraph().set_current(): + return lda_filled() + res = str(X_r2) - print(res) assert res == snapshot_py def test_optimize(self, snapshot_py, benchmark): egraph = EGraph() - X = NDArray.var("X") - y = NDArray.var("y") with egraph.set_current(): - expr = lda(X, y) - simplified = benchmark.pedantic(simplify_lda, args=(egraph, expr)) + expr = lda_filled() + simplified = benchmark(egraph.simplify, expr, array_api_numba_schedule) + assert str(simplified) == snapshot_py # @pytest.mark.xfail(reason="Original source is not working") @@ -423,18 +408,17 @@ def test_optimize(self, snapshot_py, benchmark): def test_source_optimized(self, snapshot_py, benchmark): egraph = EGraph() - X = NDArray.var("X") - y = NDArray.var("y") with egraph.set_current(): - expr = lda(X, y) - optimized_expr = simplify_lda(egraph, expr) - egraph = EGraph() - fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y")) - py_object = benchmark(load_source, fn_program, egraph) + expr = lda_filled() + optimized_expr = egraph.simplify(expr, array_api_numba_schedule) + + @benchmark + def py_object(): + fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y")) + return try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object) + assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np)) - with egraph.set_current(): - fn_object = cast(FunctionType, fn_program.as_py_object.eval()) - assert inspect.getsource(fn_object) == snapshot_py + assert inspect.getsource(py_object) == snapshot_py @pytest.mark.parametrize( "fn_thunk",