From 701751428e07934384f587b831f4ea6ec559217e Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 26 Feb 2025 17:50:02 -0500 Subject: [PATCH 1/5] Remove eval and rework primitive extraction --- docs/changelog.md | 3 + python/egglog/bindings.pyi | 9 +- python/egglog/builtins.py | 186 +++++++++++++++++- python/egglog/conversion.py | 6 +- python/egglog/declarations.py | 2 +- python/egglog/egraph.py | 186 +++++++++--------- python/egglog/egraph_state.py | 6 +- python/egglog/exp/array_api.py | 165 ++++++++-------- python/egglog/exp/array_api_jit.py | 23 +-- python/egglog/exp/program_gen.py | 11 +- python/egglog/runtime.py | 57 ++++-- .../test_array_api/TestLDA.test_optimize.py | 4 +- .../test_array_api/TestLDA.test_trace.py | 98 +++++---- python/tests/test_array_api.py | 42 ++-- python/tests/test_bindings.py | 14 -- python/tests/test_high_level.py | 109 +++++++--- python/tests/test_program_gen.py | 25 +-- python/tests/test_py_object_sort.py | 29 ++- src/egraph.rs | 51 +---- src/py_object_sort.rs | 43 +++- test-data/unit/check-high-level.test | 8 +- 21 files changed, 668 insertions(+), 409 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 98b903e1..a8a42085 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -24,6 +24,9 @@ _This project uses semantic versioning_ - Updates function constructor to remove `default` and `on_merge`. You also can't set a `cost` when you use a `merge` function or return a primitive. - `eq` now only takes two args, instead of being able to compare any number of values. +- Removes `eval` method from `EGraph` and moves primitive evaluation to methods on each builtin and support `int(...)` type conversions on primitives. +- Change how to set global EGraph context with `with egraph.set_current()` and `EGraph.current` and add support for setting global schedule as well with `with schedule.set_current()` and `Schedule.current`. +- Adds support for using `==` and `!=` directly on values instead of `eq` and `ne` functions. ## 8.0.1 (2024-10-24) diff --git a/python/egglog/bindings.pyi b/python/egglog/bindings.pyi index 1789469c..f090c9b2 100644 --- a/python/egglog/bindings.pyi +++ b/python/egglog/bindings.pyi @@ -92,12 +92,14 @@ class SerializedEGraph: class PyObjectSort: def __init__(self) -> None: ... def store(self, __o: object, /) -> _Expr: ... + def load(self, __e: _Expr, /) -> object: ... @final class EGraph: def __init__( self, - __py_object_sort: PyObjectSort | None = None, + py_object_sort: PyObjectSort | None = None, + /, *, fact_directory: str | Path | None = None, seminaive: bool = True, @@ -116,11 +118,6 @@ class EGraph: max_calls_per_function: int | None = None, include_temporary_functions: bool = False, ) -> SerializedEGraph: ... - def eval_py_object(self, __expr: _Expr) -> object: ... - def eval_i64(self, __expr: _Expr) -> int: ... - def eval_f64(self, __expr: _Expr) -> float: ... - def eval_string(self, __expr: _Expr) -> str: ... - def eval_bool(self, __expr: _Expr) -> bool: ... @final class EggSmolError(Exception): diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index e6fd3c42..ceb2c3a5 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -5,20 +5,24 @@ from __future__ import annotations +from fractions import Fraction from functools import partial, reduce -from types import FunctionType +from types import FunctionType, MethodType from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, cast, overload from typing_extensions import TypeVarTuple, Unpack +from . import bindings from .conversion import convert, converter, get_type_args -from .egraph import BaseExpr, BuiltinExpr, Unit, function, get_current_ruleset, method +from .declarations import * +from .egraph import BaseExpr, BuiltinExpr, EGraph, expr_fact, function, get_current_ruleset, method +from .egraph_state import GLOBAL_PY_OBJECT_SORT from .functionalize import functionalize from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction from .thunk import Thunk if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Iterator __all__ = [ @@ -32,6 +36,7 @@ "SetLike", "String", "StringLike", + "Unit", "UnstableFn", "Vec", "VecLike", @@ -46,7 +51,25 @@ ] +class Unit(BuiltinExpr, egg_sort="Unit"): + """ + The unit type. This is used to reprsent if a value exists in the e-graph or not. + """ + + def __init__(self) -> None: ... + + @method(preserve=True) + def __bool__(self) -> bool: + return bool(expr_fact(self)) + + class String(BuiltinExpr): + @method(preserve=True) + def eval(self) -> str: + value = _extract_lit(self) + assert isinstance(value, bindings.String) + return value.value + def __init__(self, value: str) -> None: ... @method(egg_fn="replace") @@ -62,10 +85,20 @@ def join(*strings: StringLike) -> String: ... converter(str, String, String) -BoolLike = Union["Bool", bool] +BoolLike: TypeAlias = Union["Bool", bool] class Bool(BuiltinExpr, egg_sort="bool"): + @method(preserve=True) + def eval(self) -> bool: + value = _extract_lit(self) + assert isinstance(value, bindings.Bool) + return value.value + + @method(preserve=True) + def __bool__(self) -> bool: + return self.eval() + def __init__(self, value: bool) -> None: ... @method(egg_fn="not") @@ -91,6 +124,20 @@ def implies(self, other: BoolLike) -> Bool: ... class i64(BuiltinExpr): # noqa: N801 + @method(preserve=True) + def eval(self) -> int: + value = _extract_lit(self) + assert isinstance(value, bindings.Int) + return value.value + + @method(preserve=True) + def __index__(self) -> int: + return self.eval() + + @method(preserve=True) + def __int__(self) -> int: + return self.eval() + def __init__(self, value: int) -> None: ... @method(egg_fn="+") @@ -193,6 +240,20 @@ def count_matches(s: StringLike, pattern: StringLike) -> i64: ... class f64(BuiltinExpr): # noqa: N801 + @method(preserve=True) + def eval(self) -> float: + value = _extract_lit(self) + assert isinstance(value, bindings.Float) + return value.value + + @method(preserve=True) + def __float__(self) -> float: + return self.eval() + + @method(preserve=True) + def __int__(self) -> int: + return int(self.eval()) + def __init__(self, value: float) -> None: ... @method(egg_fn="neg") @@ -265,6 +326,33 @@ def to_string(self) -> String: ... class Map(BuiltinExpr, Generic[T, V]): + @method(preserve=True) + def eval(self) -> dict[T, V]: + call = _extract_call(self) + expr = cast(RuntimeExpr, self) + d = {} + while call.callable != ClassMethodRef("Map", "empty"): + assert call.callable == MethodRef("Map", "insert") + call_typed, k_typed, v_typed = call.args + assert isinstance(call_typed.expr, CallDecl) + k = cast(T, expr.__with_expr__(k_typed)) + v = cast(V, expr.__with_expr__(v_typed)) + d[k] = v + call = call_typed.expr + return d + + @method(preserve=True) + def __iter__(self) -> Iterator[T]: + return iter(self.eval()) + + @method(preserve=True) + def __len__(self) -> int: + return len(self.eval()) + + @method(preserve=True) + def __contains__(self, key: T) -> bool: + return key in self.eval() + @method(egg_fn="map-empty") @classmethod def empty(cls) -> Map[T, V]: ... @@ -305,6 +393,24 @@ def rebuild(self) -> Map[T, V]: ... class Set(BuiltinExpr, Generic[T]): + @method(preserve=True) + def eval(self) -> set[T]: + call = _extract_call(self) + assert call.callable == InitRef("Set") + return {cast(T, cast(RuntimeExpr, self).__with_expr__(x)) for x in call.args} + + @method(preserve=True) + def __iter__(self) -> Iterator[T]: + return iter(self.eval()) + + @method(preserve=True) + def __len__(self) -> int: + return len(self.eval()) + + @method(preserve=True) + def __contains__(self, key: T) -> bool: + return key in self.eval() + @method(egg_fn="set-of") def __init__(self, *args: T) -> None: ... @@ -349,6 +455,28 @@ def rebuild(self) -> Set[T]: ... class Rational(BuiltinExpr): + @method(preserve=True) + def eval(self) -> Fraction: + call = _extract_call(self) + assert call.callable == InitRef("Rational") + + def _to_int(e: TypedExprDecl) -> int: + expr = e.expr + assert isinstance(expr, LitDecl) + assert isinstance(expr.value, int) + return expr.value + + num, den = call.args + return Fraction(_to_int(num), _to_int(den)) + + @method(preserve=True) + def __float__(self) -> float: + return float(self.eval()) + + @method(preserve=True) + def __int__(self) -> int: + return int(self.eval()) + @method(egg_fn="rational") def __init__(self, num: i64Like, den: i64Like) -> None: ... @@ -410,6 +538,26 @@ def denom(self) -> i64: ... class Vec(BuiltinExpr, Generic[T]): + @method(preserve=True) + def eval(self) -> tuple[T, ...]: + call = _extract_call(self) + if call.callable == ClassMethodRef("Vec", "empty"): + return () + assert call.callable == InitRef("Vec") + return tuple(cast(T, cast(RuntimeExpr, self).__with_expr__(x)) for x in call.args) + + @method(preserve=True) + def __iter__(self) -> Iterator[T]: + return iter(self.eval()) + + @method(preserve=True) + def __len__(self) -> int: + return len(self.eval()) + + @method(preserve=True) + def __contains__(self, key: T) -> bool: + return key in self.eval() + @method(egg_fn="vec-of") def __init__(self, *args: T) -> None: ... @@ -461,6 +609,13 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ... class PyObject(BuiltinExpr): + @method(preserve=True) + def eval(self) -> object: + report = (EGraph.current or EGraph())._run_extract(cast(RuntimeExpr, self), 0) + assert isinstance(report, bindings.Best) + expr = report.termdag.term_to_expr(report.term, bindings.PanicSpan()) + return GLOBAL_PY_OBJECT_SORT.load(expr) + def __init__(self, value: object) -> None: ... @method(egg_fn="py-from-string") @@ -554,6 +709,8 @@ def __init__(self, f, *partial) -> None: ... def __call__(self, *args: Unpack[TS]) -> T: ... +# Method Type is for builtins like __getitem__ +converter(MethodType, UnstableFn, lambda m: UnstableFn(m.__func__, m.__self__)) converter(RuntimeFunction, UnstableFn, UnstableFn) converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args)) @@ -590,3 +747,24 @@ def value_to_annotation(a: object) -> type | None: converter(FunctionType, UnstableFn, _convert_function) + + +def _extract_lit(e: BaseExpr) -> bindings._Literal: + """ + Special case extracting literals to make this faster by using termdag directly. + """ + report = (EGraph.current or EGraph())._run_extract(cast(RuntimeExpr, e), 0) + assert isinstance(report, bindings.Best) + term = report.term + assert isinstance(term, bindings.TermLit) + return term.value + + +def _extract_call(e: BaseExpr) -> CallDecl: + """ + Extracts the call form of an expression + """ + extracted = cast(RuntimeExpr, (EGraph.current or EGraph()).extract(e)) + expr = extracted.__egg_typed_expr__.expr + assert isinstance(expr, CallDecl) + return expr diff --git a/python/egglog/conversion.py b/python/egglog/conversion.py index a0adaea0..8ac6544e 100644 --- a/python/egglog/conversion.py +++ b/python/egglog/conversion.py @@ -16,7 +16,7 @@ from .egraph import BaseExpr -__all__ = ["convert", "convert_to_same_type", "converter", "resolve_literal", "ConvertError"] +__all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"] # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {} # Global declerations to store all convertable types so we can query if they have certain methods or not @@ -153,9 +153,9 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef: b_converts_to = { to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name) } - if isinstance(a_tp, JustTypeRef): + if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name): a_converts_to[a_tp] = 0 - if isinstance(b_tp, JustTypeRef): + if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name): b_converts_to[b_tp] = 0 common = set(a_converts_to) & set(b_converts_to) if not common: diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index ac989e12..4bc9c08c 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -196,7 +196,7 @@ def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911 return self._classes[class_name].properties[property_name] case InitRef(class_name): init_fn = self._classes[class_name].init - assert init_fn + assert init_fn, f"Class {class_name} does not have an init function." return init_fn case UnnamedFunctionRef(): return ConstructorDecl(ref.signature) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index adeea9d2..79c1bce9 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -4,8 +4,8 @@ import inspect import pathlib import tempfile -from collections.abc import Callable, Generator, Iterable -from contextvars import ContextVar, Token +from collections.abc import Callable, Generator, Iterable, Iterator +from contextvars import ContextVar from dataclasses import InitVar, dataclass, field from functools import partial from inspect import Parameter, currentframe, signature @@ -17,7 +17,7 @@ Generic, Literal, Never, - NoReturn, + Protocol, TypeAlias, TypedDict, TypeVar, @@ -39,16 +39,16 @@ from .thunk import * if TYPE_CHECKING: - from .builtins import Bool, PyObject, String, f64, i64 + from .builtins import String, Unit __all__ = [ "Action", + "BaseExpr", + "BuiltinExpr", "Command", "Command", "EGraph", - "BuiltinExpr", - "BaseExpr", "Expr", "Fact", "Fact", @@ -56,7 +56,6 @@ "RewriteOrRule", "Ruleset", "Schedule", - "Unit", "_BirewriteBuilder", "_EqBuilder", "_NeBuilder", @@ -86,6 +85,7 @@ "set_", "simplify", "subsume", + "try_evaling", "union", "unstable_combine_rulesets", "var", @@ -373,11 +373,9 @@ class BaseExpr(metaclass=_ExprMetaclass): Either a builtin or a user defined expression type. """ - def __ne__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body] - ... + def __ne__(self, other: Self) -> Unit: ... # type: ignore[override, empty-body] - def __eq__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body] - ... + def __eq__(self, other: Self) -> Fact: ... # type: ignore[override, empty-body] class BuiltinExpr(BaseExpr, metaclass=_ExprMetaclass): @@ -435,7 +433,6 @@ def _generate_class_decls( # noqa: C901,PLR0912 ## # Register methods, classmethods, preserved methods, and properties ## - # Get all the methods from the class filtered_namespace: list[tuple[str, Any]] = [ (k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod) @@ -508,7 +505,6 @@ def _generate_class_decls( # noqa: C901,PLR0912 # in the bodies for add_rewrite in add_default_funcs: add_rewrite() - return decls @@ -713,10 +709,12 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[.. Creates a function whose return type is `Unit` and has a default value. """ decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn) - return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, Thunk.value(FunctionRef(name)))) + return cast(Callable[..., "Unit"], RuntimeFunction(decls_thunk, Thunk.value(FunctionRef(name)))) def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations: + from .builtins import Unit + decls = Declarations() decls |= cast(RuntimeClass, Unit) arg_types = tuple(resolve_type_annotation_mutate(decls, tp).to_just() for tp in tps) @@ -848,6 +846,7 @@ class EGraph: Can run actions, check facts, run schedules, or extract minimal cost expressions. """ + current: ClassVar[EGraph | None] = None seminaive: InitVar[bool] = True save_egglog_string: InitVar[bool] = False @@ -855,7 +854,7 @@ class EGraph: # For pushing/popping with egglog _state_stack: list[EGraphState] = field(default_factory=list, repr=False) # For storing the global "current" egraph - _token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False) + _token_stack: list[EGraph] = field(default_factory=list, repr=False) def __post_init__(self, seminaive: bool, save_egglog_string: bool) -> None: egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive, record=save_egglog_string) @@ -970,6 +969,19 @@ def _run_schedule(self, schedule: Schedule) -> bindings.RunReport: raise ValueError(msg) return run_report + def check_bool(self, *facts: FactLike) -> bool: + """ + Returns true if the facts are true in the egraph. + """ + try: + self.check(*facts) + # TODO: Make a separate exception class for this + except Exception as e: + if "Check failed" in str(e): + return False + raise + return True + def check(self, *facts: FactLike) -> None: """ Check if a fact is true in the egraph. @@ -999,14 +1011,14 @@ def extract(self, expr: BASE_EXPR, include_cost: bool = False) -> BASE_EXPR | tu Extract the lowest cost expression from the egraph. """ runtime_expr = to_runtime_expr(expr) - self._add_decls(runtime_expr) - typed_expr = runtime_expr.__egg_typed_expr__ - extract_report = self._run_extract(typed_expr, 0) + extract_report = self._run_extract(runtime_expr, 0) if not isinstance(extract_report, bindings.Best): msg = "No extract report saved" raise ValueError(msg) # noqa: TRY004 - (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp) + (new_typed_expr,) = self._state.exprs_from_egg( + extract_report.termdag, [extract_report.term], runtime_expr.__egg_typed_expr__.tp + ) res = cast(BASE_EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, new_typed_expr)) if include_cost: @@ -1018,21 +1030,25 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]: Extract multiple expressions from the egraph. """ runtime_expr = to_runtime_expr(expr) - self._add_decls(runtime_expr) - typed_expr = runtime_expr.__egg_typed_expr__ - - extract_report = self._run_extract(typed_expr, n) + extract_report = self._run_extract(runtime_expr, n) if not isinstance(extract_report, bindings.Variants): msg = "Wrong extract report type" raise ValueError(msg) # noqa: TRY004 - new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp) + new_exprs = self._state.exprs_from_egg( + extract_report.termdag, extract_report.terms, runtime_expr.__egg_typed_expr__.tp + ) return [cast(BASE_EXPR, RuntimeExpr.__from_values__(self.__egg_decls__, expr)) for expr in new_exprs] - def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport: - expr = self._state.typed_expr_to_egg(typed_expr) - self._egraph.run_program( - bindings.ActionCommand(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n)))) - ) + def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._ExtractReport: + self._add_decls(expr) + expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__) + try: + self._egraph.run_program( + bindings.ActionCommand(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n)))) + ) + except BaseException as e: + e.add_note("Extracting: " + str(expr)) + raise extract_report = self._egraph.extract_report() if not extract_report: msg = "No extract report saved" @@ -1060,50 +1076,12 @@ def __enter__(self) -> Self: Also sets the current egraph to this one. """ - self._token_stack.append(CURRENT_EGRAPH.set(self)) self.push() return self def __exit__(self, exc_type, exc, exc_tb) -> None: - CURRENT_EGRAPH.reset(self._token_stack.pop()) self.pop() - @overload - def eval(self, expr: Bool) -> bool: ... - - @overload - def eval(self, expr: i64) -> int: ... - - @overload - def eval(self, expr: f64) -> float: ... - - @overload - def eval(self, expr: String) -> str: ... - - @overload - def eval(self, expr: PyObject) -> object: ... - - def eval(self, expr: BuiltinExpr) -> object: - """ - Evaluates the given expression (which must be a primitive type), returning the result. - """ - runtime_expr = to_runtime_expr(expr) - self._add_decls(runtime_expr) - typed_expr = runtime_expr.__egg_typed_expr__ - egg_expr = self._state.typed_expr_to_egg(typed_expr) - match typed_expr.tp: - case JustTypeRef("i64"): - return self._egraph.eval_i64(egg_expr) - case JustTypeRef("f64"): - return self._egraph.eval_f64(egg_expr) - case JustTypeRef("Bool"): - return self._egraph.eval_bool(egg_expr) - case JustTypeRef("String"): - return self._egraph.eval_string(egg_expr) - case JustTypeRef("PyObject"): - return self._egraph.eval_py_object(egg_expr) - raise TypeError(f"Eval not implemented for {typed_expr.tp}") - def _serialize( self, **kwargs: Unpack[GraphvizKwargs], @@ -1221,15 +1199,15 @@ def to_json() -> str: if visualize: VisualizerWidget(egraphs=egraphs).display_or_open() - @classmethod - def current(cls) -> EGraph: + @contextlib.contextmanager + def set_current(self) -> Iterator[None]: """ - Returns the current egraph, which is the one in the context. + Context manager that will set the current egraph. It will be set back after. """ - try: - return CURRENT_EGRAPH.get() - except LookupError: - return cls(save_egglog_string=True) + prev_current = EGraph.current + EGraph.current = self + yield + EGraph.current = prev_current @property def _egraph(self) -> bindings.EGraph: @@ -1279,9 +1257,6 @@ def _command_to_egg(self, cmd: Command) -> bindings._Command: return self._state.command_to_egg(cmd_decl, ruleset_name) -CURRENT_EGRAPH = ContextVar[EGraph]("CURRENT_EGRAPH") - - @dataclass(frozen=True) class _WrappedMethod: """ @@ -1302,14 +1277,6 @@ def __call__(self, *args, **kwargs) -> Never: raise NotImplementedError(msg) -class Unit(BuiltinExpr, egg_sort="Unit"): - """ - The unit type. This is also used to reprsent if a value exists, if it is resolved or not. - """ - - def __init__(self) -> None: ... - - def ruleset( rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator | None = None, *rules: RewriteOrRule, @@ -1334,6 +1301,8 @@ class Schedule(DelayedDeclerations): A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met """ + current: ClassVar[Schedule | None] = None + # Defer declerations so that we can have rule generators that used not yet defined yet schedule: ScheduleDecl @@ -1361,6 +1330,16 @@ def __add__(self, other: Schedule) -> Schedule: """ return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule))) + @contextlib.contextmanager + def set_current(self) -> Iterator[None]: + """ + Context manager that will set the current schedule. It will be set back after + """ + prev_current = Schedule.current + Schedule.current = self + yield + Schedule.current = prev_current + @dataclass class Ruleset(Schedule): @@ -1505,11 +1484,17 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) + def __bool__(self) -> bool: + """ + Returns True if the two sides of an equality are equal in the egraph or the expression is in the egraph. + """ + return (EGraph.current or EGraph()).check_bool(self) + @dataclass class Action: """ - A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking. + A change to an EGraph, either unioning multiple expressions, setting the value of a function call, deleting an expression, or panicing. """ __egg_decls__: Declarations @@ -1700,11 +1685,12 @@ class _NeBuilder(Generic[BASE_EXPR]): lhs: BASE_EXPR def to(self, rhs: BASE_EXPR) -> Unit: + from .builtins import Unit + lhs = to_runtime_expr(self.lhs) rhs = convert_to_same_type(rhs, lhs) - assert isinstance(Unit, RuntimeClass) res = RuntimeExpr.__from_values__( - Declarations.create(Unit, lhs, rhs), + Declarations.create(cast(RuntimeClass, Unit), lhs, rhs), TypedExprDecl( JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__)) ), @@ -1888,3 +1874,27 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]: yield finally: _CURRENT_RULESET.reset(token) + + +T_co = TypeVar("T_co", covariant=True) + + +class _EvalsTo(Protocol, Generic[T_co]): + def eval(self) -> T_co: ... + + +def try_evaling(schedule: Schedule, expr: Expr, prim_expr: _EvalsTo[T]) -> T: + """ + Try evaling the expression that will result in a primitive expression being fill. + if it fails, display the egraph and raise an error. + """ + egraph = EGraph.current or EGraph() + egraph.register(expr) + egraph.run(Schedule.current or schedule) + try: + with egraph.set_current(): + return prim_expr.eval() + except BaseException as e: + # egraph.display(n_inline_leaves=1, split_primitive_outputs=True) + e.add_note(f"Cannot evaluate {egraph.extract(expr)}") + raise diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index c95c45b6..e1e22fac 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from collections.abc import Iterable -__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT", "span"] +__all__ = ["GLOBAL_PY_OBJECT_SORT", "EGraphState", "span"] # Create a global sort for python objects, so we can store them without an e-graph instance # Needed when serializing commands to egg commands when creating modules @@ -519,7 +519,7 @@ def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl: elif isinstance(term, bindings.TermApp): if term.name == "py-object": call = self.termdag.term_to_expr(term, span()) - expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call)) + expr_decl = PyObjectDecl(GLOBAL_PY_OBJECT_SORT.load(call)) elif term.name == "unstable-fn": # Get function name fn_term, *arg_terms = term.args @@ -528,7 +528,7 @@ def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl: fn_name = fn_value.expr.value assert isinstance(fn_name, str) - # Resolve what types the partiallied applied args are + # Resolve what types the partially applied args are assert tp.name == "UnstableFn" call_decl = self.from_call(tp.args[0], bindings.TermApp(fn_name, arg_terms)) expr_decl = PartialCallDecl(call_decl) diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 50aa6786..853a0275 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -27,11 +27,23 @@ the `.length` and `.__getitem__` methods, unles you want to to depend on it having known length, in which case you can match directly on the cons list. +To be a list, you must implement two methods: -* Functional --partial--> cons/empty -* cons/empty <--> vec +* `l.length() -> Int` +* `l.__getitem__(i: Int) -> T` -Q: Why are they implented as SNOC lists instead of CONS lists? +There are three main types of constructors for lists which all implement these methods: + +* Functional `List(length, idx_fn)` +* cons (well reversed cons) lists `List.EMPTY` and `l.append(x)` +* Vectors `List.from_vec(vec)` + +Also all lists constructors must be converted to the functional representation, so that we can match on it +and convert lists with known lengths into cons lists and into vectors. + +This is neccessary so that known length lists are properly materialized during extraction. + +Q: Why are they implemented as SNOC lists instead of CONS lists? A: So that when converting from functional to lists we can use the same index function by starting at the end and folding that way recursively. @@ -49,7 +61,7 @@ import sys from copy import copy from types import EllipsisType -from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast, overload +from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast import numpy as np @@ -73,12 +85,18 @@ class Boolean(Expr, ruleset=array_api_ruleset): + def __init__(self, value: BoolLike) -> None: ... + @method(preserve=True) def __bool__(self) -> bool: - return try_evaling(self, self.bool) + return self.eval() + + @method(preserve=True) + def eval(self) -> bool: + return try_evaling(array_api_schedule, self, self.to_bool) @property - def bool(self) -> Bool: ... + def to_bool(self) -> Bool: ... def __or__(self, other: BooleanLike) -> Boolean: ... @@ -89,18 +107,17 @@ def __invert__(self) -> Boolean: ... def __eq__(self, other: BooleanLike) -> Boolean: ... # type: ignore[override] -BooleanLike = Boolean | bool +BooleanLike = Boolean | BoolLike -TRUE = constant("TRUE", Boolean) -FALSE = constant("FALSE", Boolean) -converter(bool, Boolean, lambda x: TRUE if x else FALSE) +TRUE = Boolean(True) +FALSE = Boolean(False) +converter(Bool, Boolean, Boolean) @array_api_ruleset.register -def _bool(x: Boolean, i: Int, j: Int): +def _bool(x: Boolean, i: Int, j: Int, b: Bool): return [ - rule(eq(x).to(TRUE)).then(set_(x.bool).to(Bool(True))), - rule(eq(x).to(FALSE)).then(set_(x.bool).to(Bool(False))), + rule(eq(x).to(Boolean(b))).then(set_(x.to_bool).to(b)), rewrite(TRUE | x).to(TRUE), rewrite(FALSE | x).to(x), rewrite(TRUE & x).to(x), @@ -197,23 +214,27 @@ def __rxor__(self, other: IntLike) -> Int: ... def __ror__(self, other: IntLike) -> Int: ... @property - def i64(self) -> i64: ... + def to_i64(self) -> i64: ... @method(preserve=True) - def __int__(self) -> int: - return try_evaling(self, self.i64) + def eval(self) -> int: + return try_evaling(array_api_schedule, self, self.to_i64) @method(preserve=True) def __index__(self) -> int: - return int(self) + return self.eval() + + @method(preserve=True) + def __int__(self) -> int: + return self.eval() @method(preserve=True) def __float__(self) -> float: - return float(int(self)) + return float(self.eval()) @method(preserve=True) def __bool__(self) -> bool: - return bool(int(self)) + return bool(self.eval()) @classmethod def if_(cls, b: BooleanLike, i: IntLike, j: IntLike) -> Int: ... @@ -236,7 +257,7 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int): yield rule(eq(r).to(Int(i) > Int(j)), i > j).then(union(r).with_(TRUE)) yield rule(eq(r).to(Int(i) > Int(j)), i < j).then(union(r).with_(FALSE)) - yield rule(eq(o).to(Int(j))).then(set_(o.i64).to(j)) + yield rule(eq(o).to(Int(j))).then(set_(o.to_i64).to(j)) yield rule(eq(Int(i)).to(Int(j)), ne(i).to(j)).then(panic("Real ints cannot be equal to different ints")) @@ -304,6 +325,13 @@ class Float(Expr, ruleset=array_api_ruleset): @method(cost=3) def __init__(self, value: f64Like) -> None: ... + @property + def to_f64(self) -> f64: ... + + @method(preserve=True) + def eval(self) -> float: + return try_evaling(array_api_schedule, self, self.to_f64) + def abs(self) -> Float: ... @method(cost=2) @@ -334,8 +362,9 @@ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override] @array_api_ruleset.register -def _float(f: f64, f2: f64, i: i64, r: Rational, r1: Rational): +def _float(fl: Float, f: f64, f2: f64, i: i64, r: Rational, r1: Rational): return [ + rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)), rewrite(Float(f).abs()).to(Float(f), f >= 0.0), rewrite(Float(f).abs()).to(Float(-f), f < 0.0), # Convert from float to rationl, if its a whole number i.e. can be converted to int @@ -398,12 +427,18 @@ def __getitem__(self, i: IntLike) -> Int: ... @method(preserve=True) def __len__(self) -> int: - return int(self.length()) + return self.length().eval() @method(preserve=True) def __iter__(self) -> Iterator[Int]: - # TODO: Change to use as_vec when we have supported execing vecs - return iter(self[i] for i in range(len(self))) + return iter(self.eval()) + + @property + def to_vec(self) -> Vec[Int]: ... + + @method(preserve=True) + def eval(self) -> tuple[Int, ...]: + return try_evaling(array_api_schedule, self, self.to_vec) def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ... def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ... @@ -427,10 +462,6 @@ def map(self, f: Callable[[Int], Int]) -> TupleInt: @classmethod def if_(cls, b: BooleanLike, i: TupleIntLike, j: TupleIntLike) -> TupleInt: ... - @method(preserve=True) - def to_py(self) -> tuple[int, ...]: - return tuple(int(i) for i in self) - def drop(self, n: Int) -> TupleInt: return TupleInt(self.length() - n, lambda i: self[i + n]) @@ -475,6 +506,7 @@ def _tuple_int( k: i64, ): return [ + rule(eq(ti).to(TupleInt.from_vec(vs))).then(set_(ti.to_vec).to(vs)), # Functional access rewrite(TupleInt(i, idx_fn).length()).to(i), rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(check_index(i, i2))), @@ -543,11 +575,18 @@ def __getitem__(self, i: IntLike) -> TupleInt: ... @method(preserve=True) def __len__(self) -> int: - return int(self.length()) + return self.length().eval() @method(preserve=True) def __iter__(self) -> Iterator[TupleInt]: - return iter(self[i] for i in range(len(self))) + return iter(self.eval()) + + @property + def to_vec(self) -> Vec[TupleInt]: ... + + @method(preserve=True) + def eval(self) -> tuple[TupleInt, ...]: + return try_evaling(array_api_schedule, self, self.to_vec) def drop(self, n: Int) -> TupleTupleInt: return TupleTupleInt(self.length() - n, lambda i: self[i + n]) @@ -595,6 +634,7 @@ def _tuple_tuple_int( tti: TupleTupleInt, tti1: TupleTupleInt, ): + yield rule(eq(tti).to(TupleTupleInt.from_vec(vs))).then(set_(tti.to_vec).to(vs)) yield rewrite(TupleTupleInt(length, idx_fn).length()).to(length) yield rewrite(TupleTupleInt(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length))) @@ -1034,14 +1074,14 @@ def shape(self) -> TupleInt: ... @method(preserve=True) def __bool__(self) -> bool: - return bool(self.to_value().to_bool) + return self.to_value().to_bool.eval() @property def size(self) -> Int: ... @method(preserve=True) def __len__(self) -> int: - return int(self.size) + return self.size.eval() @method(preserve=True) def __iter__(self) -> Iterator[NDArray]: @@ -1231,11 +1271,18 @@ def __getitem__(self, i: IntLike) -> NDArray: ... @method(preserve=True) def __len__(self) -> int: - return int(self.length()) + return self.length().eval() @method(preserve=True) def __iter__(self) -> Iterator[NDArray]: - return iter(self[i] for i in range(len(self))) + return iter(self.eval()) + + @property + def to_vec(self) -> Vec[NDArray]: ... + + @method(preserve=True) + def eval(self) -> tuple[NDArray, ...]: + return try_evaling(array_api_schedule, self, self.to_vec) converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x)) @@ -1256,6 +1303,7 @@ def _tuple_ndarray( tv1: TupleNDArray, b: Boolean, ): + yield rule(eq(tv).to(TupleNDArray.from_vec(vs))).then(set_(tv.to_vec).to(vs)) yield rewrite(TupleNDArray(length, idx_fn).length()).to(length) yield rewrite(TupleNDArray(length, idx_fn)[idx]).to(idx_fn(check_index(idx, length))) @@ -1478,7 +1526,8 @@ def unique_counts(x: NDArray) -> TupleNDArray: @array_api_ruleset.register def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value, dtype: DType): return [ - rewrite(unique_counts(x).length()).to(Int(2)), + # rewrite(unique_counts(x).length()).to(Int(2)), + rewrite(unique_counts(x)).to(TupleNDArray(2, unique_counts(x).__getitem__)), # Sum of all unique counts is the size of the array rewrite(sum(unique_counts(x)[Int(1)])).to(NDArray.scalar(Value.int(x.size))), # Same but with astype in the middle @@ -1522,7 +1571,8 @@ def unique_inverse(x: NDArray) -> TupleNDArray: @array_api_ruleset.register def _unique_inverse(x: NDArray, i: Int): return [ - rewrite(unique_inverse(x).length()).to(Int(2)), + # rewrite(unique_inverse(x).length()).to(Int(2)), + rewrite(unique_inverse(x)).to(TupleNDArray(2, unique_inverse(x).__getitem__)), # Shape of unique_inverse first element is same as shape of unique_values rewrite(unique_inverse(x)[Int(0)]).to(unique_values(x)), ] @@ -1572,7 +1622,8 @@ def svd(x: NDArray, full_matrices: Boolean = TRUE) -> TupleNDArray: @array_api_ruleset.register def _linalg(x: NDArray, full_matrices: Boolean): return [ - rewrite(svd(x, full_matrices).length()).to(Int(3)), + # rewrite(svd(x, full_matrices).length()).to(Int(3)), + rewrite(svd(x, full_matrices)).to(TupleNDArray(3, svd(x, full_matrices).__getitem__)), ] @@ -1860,46 +1911,6 @@ def _size(x: NDArray): yield rewrite(x.size).to(x.shape.foldl(Int.__mul__, Int(1))) -@overload -def try_evaling(expr: Expr, prim_expr: i64) -> int: ... - - -@overload -def try_evaling(expr: Expr, prim_expr: Bool) -> bool: ... - - -def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool: - """ - Try evaling the expression, and if it fails, display the egraph and raise an error. - """ - egraph = EGraph.current() - egraph.register(expr) - egraph.run(array_api_schedule) - try: - extracted = egraph.extract(prim_expr) - # Catch base exceptions so that we catch rust panics which happen when trying to extract subsumed nodes - except BaseException as exc: - egraph.display(n_inline_leaves=1, split_primitive_outputs=True) - # Try giving some context, by showing the smallest version of the larger expression - try: - expr_extracted = egraph.extract(expr) - except BaseException as inner_exc: - inner_exc.add_note(f"Cannot simplify {expr}") - raise - exc.add_note(f"Cannot simplify to primitive {expr_extracted}") - raise - return egraph.eval(extracted) - - # string = ( - # egraph.as_egglog_string - # + "\n" - # + str(egraph._state.typed_expr_to_egg(cast(RuntimeExpr, prim_expr).__egg_typed_expr__)) - # ) - # # save to "tmp.egg" - # with open("tmp.egg", "w") as f: - # f.write(string) - - # Seperate rulseset so we can use it in program gen @ruleset def array_api_vec_to_cons_ruleset( diff --git a/python/egglog/exp/array_api_jit.py b/python/egglog/exp/array_api_jit.py index 5681e106..74083ead 100644 --- a/python/egglog/exp/array_api_jit.py +++ b/python/egglog/exp/array_api_jit.py @@ -2,7 +2,7 @@ from collections.abc import Callable from typing import TypeVar, cast -from egglog import EGraph +from egglog import EGraph, try_evaling from egglog.exp.array_api import NDArray from egglog.exp.array_api_numba import array_api_numba_schedule from egglog.exp.array_api_program_gen import array_api_program_gen_schedule, ndarray_function_two @@ -14,27 +14,14 @@ def jit(fn: X) -> X: """ Jit compiles a function """ - # 1. Create variables for each of the two args in the functions sig = inspect.signature(fn) arg1, arg2 = sig.parameters.keys() egraph = EGraph() - with egraph: + with egraph.set_current(): res = fn(NDArray.var(arg1), NDArray.var(arg2)) - egraph.register(res) - egraph.run(array_api_numba_schedule) - res_optimized = egraph.extract(res) - # egraph.display(split_primitive_outputs=True, n_inline_leaves=3) + res_optimized = egraph.simplify(res, array_api_numba_schedule) fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2)) - egraph.register(fn_program) - egraph.run(array_api_program_gen_schedule) - # egraph.display(split_primitive_outputs=True, n_inline_leaves=3) - try: - fn = cast(X, egraph.eval(egraph.extract(fn_program.py_object))) - except Exception as err: - err.add_note(f"Failed to compile the program into a string: \n\n{egraph.extract(fn_program)}") - egraph.display() - raise + fn = try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object) fn.expr = res_optimized # type: ignore[attr-defined] - fn.statements = egraph.eval(fn_program.statements) # type: ignore[attr-defined] - return fn + return cast(X, fn) diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index 0a83b631..e2e151a6 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -109,17 +109,11 @@ def __init__(self, program: Program, globals: object) -> None: # Only allow it to be set once, b/c hash of functions not stable @method(merge=lambda old, _new: old) # type: ignore[misc] @property - def py_object(self) -> PyObject: + def as_py_object(self) -> PyObject: """ Returns the python object of the program, if it's been evaluated. """ - @property - def statements(self) -> String: - """ - Returns the statements of the program, if it's been compiled - """ - @ruleset def eval_program_rulseset(ep: EvalProgram, p: Program, expr: String, statements: String, g: PyObject): @@ -131,13 +125,12 @@ def eval_program_rulseset(ep: EvalProgram, p: Program, expr: String, statements: eq(p.statements).to(statements), eq(p.expr).to(expr), ).then( - set_(ep.py_object).to( + set_(ep.as_py_object).to( py_eval( "l['___res']", PyObject.dict(PyObject.from_string("l"), py_exec(join(statements, "\n", "___res = ", expr), g)), ) ), - set_(ep.statements).to(statements), ) diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 6b8c8571..6a3ff9d3 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -15,7 +15,7 @@ from dataclasses import dataclass, replace from inspect import Parameter, Signature from itertools import zip_longest -from typing import TYPE_CHECKING, NoReturn, TypeVar, Union, cast, get_args, get_origin +from typing import TYPE_CHECKING, TypeVar, Union, cast, get_args, get_origin from .declarations import * from .pretty import * @@ -25,7 +25,8 @@ if TYPE_CHECKING: from collections.abc import Iterable - from .egraph import Expr + from .egraph import Fact + __all__ = [ "LIT_CLASS_NAMES", @@ -59,6 +60,19 @@ "__ror__": "__or__", } +# Methods that need to return real Python values not expressions +PRESERVED_METHODS = [ + "__bool__", + "__len__", + "__complex__", + "__int__", + "__float__", + "__iter__", + "__index__", + "__float__", + "__int__", +] + # Set this globally so we can get access to PyObject when we have a type annotation of just object. # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically. _PY_OBJECT_CLASS: RuntimeClass | None = None @@ -155,12 +169,10 @@ def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None: del args # Assumes we don't have types set for UnstableFn w/ generics, that they have to be inferred - # 1. Create a runtime function for the first arg - assert isinstance(fn_arg, RuntimeFunction) - # 2. Call it with the partial args, and use untyped vars for the rest of the args - res = fn_arg(*partial_args, _egg_partial_function=True) + # 1. Call it with the partial args, and use untyped vars for the rest of the args + res = cast(Callable, fn_arg)(*partial_args, _egg_partial_function=True) assert res is not None, "Mutable partial functions not supported" - # 3. Use the inferred return type and inferred rest arg types as the types of the function, and + # 2. Use the inferred return type and inferred rest arg types as the types of the function, and # the partially applied args as the args. call = (res_typed_expr := res.__egg_typed_expr__).expr return_tp = res_typed_expr.tp @@ -262,8 +274,6 @@ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable: return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.name, name))) msg = f"Class {self.__egg_tp__.name} has no method {name}" - if name == "__ne__": - msg += ". Did you mean to use the ne(...).to(...)?" raise AttributeError(msg) from None def __str__(self) -> str: @@ -451,6 +461,9 @@ class RuntimeExpr(DelayedDeclerations): def __from_values__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr: return cls(Thunk.value(d), Thunk.value(e)) + def __with_expr__(self, e: TypedExprDecl) -> RuntimeExpr: + return RuntimeExpr(self.__egg_decls_thunk__, Thunk.value(e)) + @property def __egg_typed_expr__(self) -> TypedExprDecl: return self.__egg_typed_expr_thunk__() @@ -497,12 +510,12 @@ def __egg_class_name__(self) -> str: def __egg_class_decl__(self) -> ClassDecl: return self.__egg_decls__.get_class_decl(self.__egg_class_name__) - # Have __eq__ take no NoReturn (aka Never https://docs.python.org/3/library/typing.html#typing.Never) because - # we don't wany any type that MyPy thinks is an expr to be used with __eq__. - # That's because we want to reserve __eq__ for domain specific equality checks, overloading this method. - # To check if two exprs are equal, use the expr_eq method. - # At runtime, this will resolve if there is a defined egg function for `__eq__` - def __eq__(self, other: NoReturn) -> Expr: ... # type: ignore[override, empty-body] + # These both will be overriden below in the special methods section, but add these here for type hinting purposes + def __eq__(self, other: object) -> Fact: # type: ignore[override, empty-body] + ... + + def __ne__(self, other: object) -> RuntimeExpr: # type: ignore[override, empty-body] + ... # Implement these so that copy() works on this object # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion @@ -526,7 +539,7 @@ def _special_method( *args: object, __name: str = name, **kwargs: object, - ) -> RuntimeExpr | None: + ) -> RuntimeExpr | Fact | None: from .conversion import ConvertError class_name = self.__egg_class_name__ @@ -552,6 +565,16 @@ def _special_method( if __name in class_decl.methods: fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self) return fn(*args, **kwargs) # type: ignore[arg-type] + # Handle == and != fallbacks to eq and ne helpers if the methods aren't defined on the class explicitly. + if __name == "__eq__": + from .egraph import BaseExpr, eq + + return eq(cast(BaseExpr, self)).to(cast(BaseExpr, args[0])) + if __name == "__ne__": + from .egraph import BaseExpr, ne + + return cast(RuntimeExpr, ne(cast(BaseExpr, self)).to(cast(BaseExpr, args[0]))) + if __name in PARTIAL_METHODS: return NotImplemented raise TypeError(f"{class_name!r} object does not support {__name}") @@ -580,7 +603,7 @@ def call_method_min_conversion(slf: object, other: object, name: str) -> Runtime return method(other) -for name in ["__bool__", "__len__", "__complex__", "__int__", "__float__", "__iter__", "__index__"]: +for name in PRESERVED_METHODS: def _preserved_method(self: RuntimeExpr, __name: str = name): try: diff --git a/python/tests/__snapshots__/test_array_api/TestLDA.test_optimize.py b/python/tests/__snapshots__/test_array_api/TestLDA.test_optimize.py index af730e73..7e83c265 100644 --- a/python/tests/__snapshots__/test_array_api/TestLDA.test_optimize.py +++ b/python/tests/__snapshots__/test_array_api/TestLDA.test_optimize.py @@ -37,7 +37,7 @@ _NDArray_10 = sqrt(sum(_NDArray_9, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_9.shape[Int(0)]))) _NDArray_11 = copy(_NDArray_10) _NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float.rational(Rational(1, 1)))) -_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.float(Float.rational(Rational(1, 147)))), OptionalDType.some(DType.float64))) * (_NDArray_8 / _NDArray_11), FALSE) +_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.float(Float.rational(Rational(1, 147)))), OptionalDType.some(DType.float64))) * (_NDArray_8 / _NDArray_11), Boolean(False)) _Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int)) _NDArray_12 = ( _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] @@ -46,7 +46,7 @@ _TupleNDArray_2 = svd( (sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(Rational(1, 2))))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_12, - FALSE, + Boolean(False), ) ( (_NDArray_1 - (_NDArray_3 @ _NDArray_4)) diff --git a/python/tests/__snapshots__/test_array_api/TestLDA.test_trace.py b/python/tests/__snapshots__/test_array_api/TestLDA.test_trace.py index ada559d9..e7078b6c 100644 --- a/python/tests/__snapshots__/test_array_api/TestLDA.test_trace.py +++ b/python/tests/__snapshots__/test_array_api/TestLDA.test_trace.py @@ -5,63 +5,81 @@ _NDArray_2 = NDArray.var("y") assume_dtype(_NDArray_2, DType.int64) assume_shape(_NDArray_2, TupleInt.from_vec(Vec[Int](Int(150)))) -assume_value_one_of(_NDArray_2, TupleValue.from_vec(Vec[Value](Value.int(Int(0)), Value.int(Int(1)), Value.int(Int(2))))) -_NDArray_3 = asarray(reshape(asarray(_NDArray_2), TupleInt.from_vec(Vec[Int](Int(-1))))) -_NDArray_4 = astype(unique_counts(_NDArray_3)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(150.0))) -_NDArray_5 = zeros( - TupleInt.from_vec(Vec[Int](unique_inverse(_NDArray_3)[Int(0)].shape[Int(0)], asarray(_NDArray_1).shape[Int(1)])), +_TupleValue_1 = TupleValue.from_vec(Vec[Value](Value.int(Int(0)), Value.int(Int(1)), Value.int(Int(2)))) +assume_value_one_of(_NDArray_2, _TupleValue_1) +_NDArray_3 = zeros( + TupleInt.from_vec(Vec[Int](NDArray.vector(_TupleValue_1).shape[Int(0)], asarray(_NDArray_1).shape[Int(1)])), OptionalDType.some(asarray(_NDArray_1).dtype), OptionalDevice.some(asarray(_NDArray_1).device), ) _MultiAxisIndexKeyItem_1 = MultiAxisIndexKeyItem.slice(Slice()) _IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1))) +_IndexKey_2 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(0)))) _OptionalIntOrTuple_1 = OptionalIntOrTuple.some(IntOrTuple.int(Int(0))) -_NDArray_5[_IndexKey_1] = mean(asarray(_NDArray_1)[IndexKey.ndarray(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(0))))], _OptionalIntOrTuple_1) -_IndexKey_2 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1))) -_NDArray_5[_IndexKey_2] = mean(asarray(_NDArray_1)[IndexKey.ndarray(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(1))))], _OptionalIntOrTuple_1) -_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) -_NDArray_5[_IndexKey_3] = mean(asarray(_NDArray_1)[IndexKey.ndarray(unique_inverse(_NDArray_3)[Int(1)] == NDArray.scalar(Value.int(Int(2))))], _OptionalIntOrTuple_1) -_NDArray_6 = unique_values(concat(TupleNDArray.from_vec(Vec[NDArray](unique_values(asarray(_NDArray_3)))))) -_NDArray_7 = concat( +_NDArray_3[_IndexKey_1] = mean(asarray(_NDArray_1)[_IndexKey_2], _OptionalIntOrTuple_1) +_IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1))) +_IndexKey_4 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(1)))) +_NDArray_3[_IndexKey_3] = mean(asarray(_NDArray_1)[_IndexKey_4], _OptionalIntOrTuple_1) +_IndexKey_5 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) +_IndexKey_6 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(2)))) +_NDArray_3[_IndexKey_5] = mean(asarray(_NDArray_1)[_IndexKey_6], _OptionalIntOrTuple_1) +_NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device)) +_IndexKey_7 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1))) +_NDArray_4[_IndexKey_7] = mean(_NDArray_1[_IndexKey_2], _OptionalIntOrTuple_1) +_IndexKey_8 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(1)), _MultiAxisIndexKeyItem_1))) +_NDArray_4[_IndexKey_8] = mean(_NDArray_1[_IndexKey_4], _OptionalIntOrTuple_1) +_IndexKey_9 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) +_NDArray_4[_IndexKey_9] = mean(_NDArray_1[_IndexKey_6], _OptionalIntOrTuple_1) +_NDArray_5 = concat( TupleNDArray.from_vec( Vec[NDArray]( - asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_3 == _NDArray_6[IndexKey.int(Int(0))])] - _NDArray_5[_IndexKey_1], - asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_3 == _NDArray_6[IndexKey.int(Int(1))])] - _NDArray_5[_IndexKey_2], - asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_3 == _NDArray_6[IndexKey.int(Int(2))])] - _NDArray_5[_IndexKey_3], + _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(0))))] - _NDArray_4[_IndexKey_7], + _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(1))))] - _NDArray_4[_IndexKey_8], + _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(2))))] - _NDArray_4[_IndexKey_9], ) ), OptionalInt.some(Int(0)), ) -_NDArray_8 = std(_NDArray_7, _OptionalIntOrTuple_1) -_NDArray_8[IndexKey.ndarray(std(_NDArray_7, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0))) -_TupleNDArray_1 = svd( - sqrt( - asarray( - NDArray.scalar(Value.float(Float(1.0) / Float.from_int(asarray(_NDArray_1).shape[Int(0)] - _NDArray_6.shape[Int(0)]))), OptionalDType.some(asarray(_NDArray_1).dtype) - ) - ) - * (_NDArray_7 / _NDArray_8), - FALSE, -) +_NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1) +_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float.rational(Rational(1, 1)))) +_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.float(Float.rational(Rational(1, 147)))), OptionalDType.some(DType.float64))) * (_NDArray_5 / _NDArray_6), Boolean(False)) _Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int)) -_NDArray_9 = ( - _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] / _NDArray_8 -).T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)] +_NDArray_7 = asarray(reshape(asarray(_NDArray_2), TupleInt.from_vec(Vec[Int](Int(-1))))) +_NDArray_8 = unique_values(concat(TupleNDArray.from_vec(Vec[NDArray](unique_values(asarray(_NDArray_7)))))) +_NDArray_9 = std( + concat( + TupleNDArray.from_vec( + Vec[NDArray]( + asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(0))])] - _NDArray_3[_IndexKey_1], + asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(1))])] - _NDArray_3[_IndexKey_3], + asarray(_NDArray_1)[IndexKey.ndarray(_NDArray_7 == _NDArray_8[IndexKey.int(Int(2))])] - _NDArray_3[_IndexKey_5], + ) + ), + OptionalInt.some(Int(0)), + ), + _OptionalIntOrTuple_1, +) +_NDArray_10 = copy(_NDArray_9) +_NDArray_10[IndexKey.ndarray(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0))) +_NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(Rational(150, 1)))) _TupleNDArray_2 = svd( - ( - sqrt( - (NDArray.scalar(Value.int(asarray(_NDArray_1).shape[Int(0)])) * _NDArray_4) - * NDArray.scalar(Value.float(Float(1.0) / Float.from_int(_NDArray_6.shape[Int(0)] - Int(1)))) - ) - * (_NDArray_5 - (_NDArray_4 @ _NDArray_5)).T - ).T - @ _NDArray_9, - FALSE, + (sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(Rational(1, 2))))) * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T).T + @ ( + ( + _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] + / _NDArray_6 + ).T + / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)] + ), + Boolean(False), ) ( - (asarray(_NDArray_1) - (_NDArray_4 @ _NDArray_5)) + (asarray(_NDArray_1) - ((astype(unique_counts(_NDArray_2)[Int(1)], asarray(_NDArray_1).dtype) / NDArray.scalar(Value.float(Float(150.0)))) @ _NDArray_3)) @ ( - _NDArray_9 + ( + (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] / _NDArray_10).T + / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)] + ) @ _TupleNDArray_2[Int(2)].T[ IndexKey.multi_axis( MultiAxisIndexKey.from_vec( diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index d62cccdc..ec69ebc9 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -1,8 +1,10 @@ # mypy: disable-error-code="empty-body" import ast +import inspect from collections.abc import Callable from itertools import product from pathlib import Path +from types import FunctionType from typing import Any import numba @@ -279,23 +281,15 @@ def test_index_codegen(self, snapshot_py): value_program(simplified_index).function_three(ndarray_program(X), int_program(i), int_program(j)), {"np": np}, ) - egraph = EGraph() - egraph.register(res) - egraph.run(array_api_program_gen_schedule) - print( - egraph.extract( - value_program(simplified_index).function_three(ndarray_program(X), int_program(i), int_program(j)) - ) - ) - # egraph.display(split_primitive_outputs=True, n_inline_leaves=3, split_functions=[TupleInt.EMPTY, TupleInt.append, Int]) - assert egraph.eval(res.statements) == snapshot_py(name="code") + fn = cast(FunctionType, try_evaling(array_api_program_gen_schedule, res, res.as_py_object)) + + assert inspect.getsource(fn) == snapshot_py(name="code") - fn_value = cast(Callable, egraph.eval(res.py_object)) X = np.random.random((3, 2, 3, 4)) expect = np.linalg.norm(X, axis=(0, 1)) for idxs in np.ndindex(*expect.shape): - assert np.allclose(fn_value(X, *idxs), expect[idxs], rtol=1e-03) + assert np.allclose(fn(X, *idxs), expect[idxs], rtol=1e-03) # This test happens in different steps. Each will be benchmarked and saved as a snapshot. @@ -338,7 +332,7 @@ def load_source(fn_program: EvalProgram, egraph: EGraph): egraph.run(array_api_program_gen_schedule) # dp the needed pieces in here for benchmarking try: - return egraph.eval(egraph.extract(fn_program.py_object)) + return egraph.extract(fn_program.as_py_object).eval() except Exception as err: err.add_note(f"Failed to compile the program into a string: \n\n{egraph.extract(fn_program)}") egraph.display(split_primitive_outputs=True, n_inline_leaves=3, split_functions=[Program]) @@ -378,8 +372,9 @@ def test_program_compile(program: Program, snapshot_py): egraph = EGraph() egraph.register(simplified_program.compile()) egraph.run(array_api_program_gen_schedule) - statements = egraph.eval(simplified_program.statements) - expr = egraph.eval(simplified_program.expr) + with egraph.set_current(): + statements = simplified_program.statements.eval() + expr = simplified_program.expr.eval() assert "\n".join([*statements.split("\n"), expr]) == snapshot_py(name="code") @@ -393,7 +388,7 @@ def test_program_compile(program: Program, snapshot_py): def test_jit(program, snapshot_py): jitted = jit(program) assert str(jitted.expr) == snapshot_py(name="expr") - assert jitted.statements == snapshot_py(name="code") + assert inspect.getsource(jitted) == snapshot_py(name="code") @pytest.mark.benchmark(min_rounds=3) @@ -405,15 +400,17 @@ class TestLDA: def test_trace(self, snapshot_py, benchmark): X = NDArray.var("X") y = NDArray.var("y") - with EGraph(): + with EGraph().set_current(): X_r2 = benchmark(lda, X, y) - assert str(X_r2) == snapshot_py + res = str(X_r2) + print(res) + assert res == snapshot_py def test_optimize(self, snapshot_py, benchmark): egraph = EGraph() X = NDArray.var("X") y = NDArray.var("y") - with egraph: + with egraph.set_current(): expr = lda(X, y) simplified = benchmark(simplify_lda, egraph, expr) assert str(simplified) == snapshot_py @@ -428,13 +425,16 @@ def test_source_optimized(self, snapshot_py, benchmark): egraph = EGraph() X = NDArray.var("X") y = NDArray.var("y") - with egraph: + with egraph.set_current(): expr = lda(X, y) optimized_expr = simplify_lda(egraph, expr) + egraph = EGraph() fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y")) py_object = benchmark(load_source, fn_program, egraph) assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np)) - assert egraph.eval(fn_program.statements) == snapshot_py + with egraph.set_current(): + fn_object = cast(FunctionType, fn_program.as_py_object.eval()) + assert inspect.getsource(fn_object) == snapshot_py @pytest.mark.parametrize( "fn_thunk", diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index 6a40c492..9ca7af7b 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -218,20 +218,6 @@ def test_compare(self): assert Variant(DUMMY_SPAN, "name", []) != 10 # type: ignore[comparison-overlap] -class TestEval: - def test_i64(self): - assert EGraph().eval_i64(Lit(DUMMY_SPAN, Int(1))) == 1 - - def test_f64(self): - assert EGraph().eval_f64(Lit(DUMMY_SPAN, Float(1.0))) == 1.0 - - def test_string(self): - assert EGraph().eval_string(Lit(DUMMY_SPAN, String("hi"))) == "hi" - - def test_bool(self): - assert EGraph().eval_bool(Lit(DUMMY_SPAN, Bool(True))) is True - - class TestThreads: """ Verify that objects can be accessed from multiple threads at the same time. diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 30e1206c..77c76be8 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -4,6 +4,7 @@ import importlib import pathlib from copy import copy +from fractions import Fraction from typing import ClassVar, TypeAlias, Union import pytest @@ -242,33 +243,30 @@ def foo(x: i64Like, y: i64Like = i64(1)) -> i64: ... class TestPyObject: def test_from_string(self): - assert EGraph().eval(PyObject.from_string("foo")) == "foo" + assert PyObject.from_string("foo").eval() == "foo" def test_to_string(self): - x: String = PyObject("foo").to_string() - # reveal_type(cast(Bool, x))) - # reveal_type(EGraph().eval(x)) - assert EGraph().eval(x) == "foo" + assert PyObject("foo").to_string().eval() == "foo" def test_dict_update(self): original_d = {"foo": "bar"} - res = EGraph().eval(PyObject(original_d).dict_update("foo", "baz")) + res = PyObject(original_d).dict_update("foo", "baz").eval() assert res == {"foo": "baz"} assert original_d == {"foo": "bar"} def test_eval(self): - assert EGraph().eval(py_eval("x + y", {"x": 10, "y": 20}, {})) == 30 + assert py_eval("x + y", {"x": 10, "y": 20}, {}).eval() == 30 def test_eval_local(self): x = "hi" res = py_eval("my_add(x, y)", PyObject(locals()).dict_update("y", "there"), globals()) - assert EGraph().eval(res) == "hithere" + assert res.eval() == "hithere" # Updated to call eval() directly on res def test_exec(self): - assert EGraph().eval(py_exec("x = 10")) == {"x": 10} + assert py_exec("x = 10").eval() == {"x": 10} # Updated to call eval() directly on res def test_exec_globals(self): - assert EGraph().eval(py_exec("x = y + 1", {"y": 10})) == {"x": 11} + assert py_exec("x = y + 1", {"y": 10}).eval() == {"x": 11} def my_add(a, b): @@ -474,13 +472,64 @@ def from_int(cls, other: Int) -> NDArray: ... assert expr_parts(r) == expr_parts(NDArray.from_int(Int("x")) + NDArray("y")) -def test_eval(): - egraph = EGraph() - assert egraph.eval(String("hi")) == "hi" - assert egraph.eval(i64(10)) == 10 - assert egraph.eval(f64(10.0)) == 10.0 - assert egraph.eval(Bool(True)) is True - assert egraph.eval(PyObject((1, 2))) == (1, 2) +class TestEval: + def test_string(self): + assert String("hi").eval() == "hi" + + def test_bool(self): + assert Bool(True).eval() is True + assert bool(Bool(True)) is True + + def test_i64(self): + assert i64(10).eval() == 10 + assert int(i64(10)) == 10 + assert [10][i64(0)] == 10 + + def test_f64(self): + assert f64(10.0).eval() == 10.0 + assert int(f64(10.0)) == 10 + assert float(f64(10.0)) == 10.0 + + def test_map(self): + assert Map[String, i64].empty().eval() == {} + m = Map[String, i64].empty().insert(String("a"), i64(1)).insert(String("b"), i64(2)) + # TODO: Add __eq__ with eq() that evals to True on boolean comparison? And same with ne? + assert m.eval() == {String("a"): i64(1), String("b"): i64(2)} + + assert set(m) == {String("a"), String("b")} + assert len(m) == 2 + assert String("a") in m + assert String("c") not in m + + def test_set(self): + assert Set[i64].empty().eval() == set() + s = Set(i64(1), i64(2)) + assert s.eval() == {i64(1), i64(2)} + + assert set(s) == {i64(1), i64(2)} + assert len(s) == 2 + assert i64(1) in s + assert i64(3) not in s + + def test_rational(self): + assert Rational(1, 2).eval() == Fraction(1, 2) + assert float(Rational(1, 2)) == 0.5 + assert int(Rational(1, 1)) == 1 + + def test_vec(self): + assert Vec[i64].empty().eval() == () + s = Vec(i64(1), i64(2)) + assert s.eval() == (i64(1), i64(2)) + + assert list(s) == [i64(1), i64(2)] + assert len(s) == 2 + assert i64(1) in s + assert i64(3) not in s + + def test_py_object(self): + assert PyObject(10).eval() == 10 + o = object() + assert PyObject(o).eval() is o # def test_egglog_string(): @@ -496,9 +545,7 @@ def test_eval(): def test_eval_fn(): - egraph = EGraph() - - assert egraph.eval(py_eval_fn(lambda x: (x,))(PyObject.from_int(1))) == (1,) + assert py_eval_fn(lambda x: (x,))(PyObject.from_int(1)).eval() == (1,) def _global_make_tuple(x): @@ -506,18 +553,16 @@ def _global_make_tuple(x): def test_eval_fn_globals(): - egraph = EGraph() - - assert egraph.eval(py_eval_fn(lambda x: _global_make_tuple(x))(PyObject.from_int(1))) == (1,) + assert py_eval_fn(lambda x: _global_make_tuple(x))(PyObject.from_int(1)).eval() == (1,) def test_eval_fn_locals(): - egraph = EGraph() + EGraph() def _locals_make_tuple(x): return (x,) - assert egraph.eval(py_eval_fn(lambda x: _locals_make_tuple(x))(PyObject.from_int(1))) == (1,) + assert py_eval_fn(lambda x: _locals_make_tuple(x))(PyObject.from_int(1)).eval() == (1,) def test_lazy_types(): @@ -774,6 +819,20 @@ def my_fn(xs: MapLike[i64, String, i64Like, StringLike]) -> Unit: ... assert expr_parts(my_fn({})) == expr_parts(my_fn(Map[i64, String].empty())) +class TestEqNE: + def test_eq(self): + assert (i64(1) + 2) == 3 + + def test_ne(self): + assert (i64(1) + 2) != 4 + + def test_eq_false(self): + assert not ((i64(1) + 2) == 4) # noqa: SIM201 + + def test_ne_false(self): + assert not ((i64(1) + 2) != 3) # noqa: SIM202 + + EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py")) diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index 7ac48822..af503e49 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -2,6 +2,8 @@ from __future__ import annotations import inspect +from types import FunctionType +from typing import cast from egglog import * from egglog.exp.program_gen import * @@ -53,12 +55,11 @@ def test_to_string(snapshot_py) -> None: first = assume_pos(-Math.var("x")) + Math.var("y") fn = (first + Math(2) + first).program.function_two(Math.var("x").program, Math.var("y").program, "my_fn") egraph = EGraph() - egraph.register(fn) egraph.register(fn.compile()) egraph.run((to_program_ruleset | program_gen_ruleset).saturate()) - # egraph.display(n_inline_leaves=1) - assert egraph.eval(fn.expr) == "my_fn" - assert egraph.eval(fn.statements) == snapshot_py + with egraph.set_current(): + assert fn.expr.eval() == "my_fn" + assert fn.statements.eval() == snapshot_py def test_to_string_function_three(snapshot_py) -> None: @@ -67,12 +68,11 @@ def test_to_string_function_three(snapshot_py) -> None: Math.var("x").program, Math.var("y").program, Math.var("z").program, "my_fn" ) egraph = EGraph() - egraph.register(fn) egraph.register(fn.compile()) egraph.run((to_program_ruleset | program_gen_ruleset).saturate()) - # egraph.display(n_inline_leaves=1) - assert egraph.eval(fn.expr) == "my_fn" - assert egraph.eval(fn.statements) == snapshot_py + with egraph.set_current(): + assert fn.expr.eval() == "my_fn" + assert fn.statements.eval() == snapshot_py def test_py_object(): @@ -80,10 +80,11 @@ def test_py_object(): y = Math.var("y") z = Math.var("z") fn = (x + y + z).program.function_two(x.program, y.program) - egraph = EGraph() evalled = EvalProgram(fn, {"z": 10}) + egraph = EGraph() egraph.register(evalled) egraph.run((to_program_ruleset | eval_program_rulseset | program_gen_ruleset).saturate()) - res = egraph.eval(evalled.py_object) - assert res(1, 2) == 13 # type: ignore[operator] - assert inspect.getsource(res) # type: ignore[arg-type] + with egraph.set_current(): + res = cast(FunctionType, evalled.as_py_object.eval()) + assert res(1, 2) == 13 + assert inspect.getsource(res) diff --git a/python/tests/test_py_object_sort.py b/python/tests/test_py_object_sort.py index d69e6f4f..c7129c1c 100644 --- a/python/tests/test_py_object_sort.py +++ b/python/tests/test_py_object_sort.py @@ -25,9 +25,8 @@ class TestSaveLoad: ) def test_adding_retrieving_object(self, obj: object): sort = PyObjectSort() - egraph = EGraph(sort) expr = sort.store(obj) - assert egraph.eval_py_object(expr) == obj + assert sort.load(expr) == obj def test_objects_cleaned_up(self): sort = PyObjectSort() @@ -48,7 +47,7 @@ def test_object_keeps_ref(self): del my_object gc.collect() assert ref() is not None - assert EGraph(sort).eval_py_object(expr) == MyObject() + assert sort.load(expr) == MyObject() class TestDictUpdate: @@ -69,9 +68,13 @@ def test_dict_update(self): "new_dict", Call(DUMMY_SPAN, "py-dict-update", [dict_expr, a_expr, new_value_expr, b_expr, new_value_expr]), ) - ) + ), + ActionCommand(Extract(DUMMY_SPAN, Var(DUMMY_SPAN, "new_dict"), Lit(DUMMY_SPAN, Int(0)))), ) - assert egraph.eval_py_object(Var(DUMMY_SPAN, "new_dict")) == {"a": 2, "b": 2} + report = egraph.extract_report() + assert isinstance(report, Best) + expr = report.termdag.term_to_expr(report.term, DUMMY_SPAN) + assert sort.load(expr) == {"a": 2, "b": 2} # Verify that the original dict is unchanged assert initial_dict == {"a": 1} @@ -107,9 +110,13 @@ def test_eval(self): ], ), ) - ) + ), + ActionCommand(Extract(DUMMY_SPAN, Var(DUMMY_SPAN, "res"), Lit(DUMMY_SPAN, Int(0)))), ) - assert egraph.eval_py_object(Var(DUMMY_SPAN, "res")) == 3 + report = egraph.extract_report() + assert isinstance(report, Best) + expr = report.termdag.term_to_expr(report.term, DUMMY_SPAN) + assert sort.load(expr) == 3 class TestConversion: @@ -120,8 +127,10 @@ def test_to_string(self): sort = PyObjectSort() egraph = EGraph(sort) - egraph.run_program(ActionCommand(Let(DUMMY_SPAN, "res", Call(DUMMY_SPAN, "py-to-string", [sort.store("hi")])))) - assert egraph.eval_string(Var(DUMMY_SPAN, "res")) == "hi" + egraph.run_program( + ActionCommand(Let(DUMMY_SPAN, "res", Call(DUMMY_SPAN, "py-to-string", [sort.store("hi")]))), + Check(DUMMY_SPAN, [Eq(DUMMY_SPAN, Var(DUMMY_SPAN, "res"), Lit(DUMMY_SPAN, String("hi")))]), + ) def test_from_string(self): """ @@ -131,5 +140,5 @@ def test_from_string(self): egraph = EGraph(sort) egraph.run_program( ActionCommand(Let(DUMMY_SPAN, "res", Call(DUMMY_SPAN, "py-from-string", [Lit(DUMMY_SPAN, String("hi"))]))), + Check(DUMMY_SPAN, [Eq(DUMMY_SPAN, Var(DUMMY_SPAN, "res"), sort.store("hi"))]), ) - assert egraph.eval_py_object(Var(DUMMY_SPAN, "res")) == "hi" diff --git a/src/egraph.rs b/src/egraph.rs index b5fca9dc..dc5997e1 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -2,15 +2,13 @@ use crate::conversions::*; use crate::error::{EggResult, WrappedError}; -use crate::py_object_sort::{ArcPyObjectSort, MyPyObject, PyObjectSort}; +use crate::py_object_sort::ArcPyObjectSort; use crate::serialize::SerializedEGraph; -use egglog::sort::{BoolSort, F64Sort, I64Sort, StringSort}; use egglog::{span, SerializeConfig}; use log::info; use pyo3::prelude::*; use std::path::PathBuf; -use std::sync::Arc; /// EGraph() /// -- @@ -19,7 +17,6 @@ use std::sync::Arc; #[pyclass(unsendable)] pub struct EGraph { pub(crate) egraph: egglog::EGraph, - py_object_arcsort: Option>, cmds: Option, } @@ -39,17 +36,13 @@ impl EGraph { let mut egraph = egglog_experimental::new_experimental_egraph(); egraph.fact_directory = fact_directory; egraph.seminaive = seminaive; - let py_object_arcsort = if let Some(py_object_sort) = py_object_sort { + if let Some(py_object_sort) = py_object_sort { egraph .add_arcsort(py_object_sort.0.clone(), span!()) .unwrap(); - Some(py_object_sort.0) - } else { - None - }; + } Self { egraph, - py_object_arcsort, cmds: if record { Some(String::new()) } else { None }, } } @@ -137,42 +130,4 @@ impl EGraph { }), } } - - /// Evaluates an expression in the EGraph and returns the result as a Python object. - #[pyo3(signature = (expr, /))] - fn eval_py_object(&mut self, expr: Expr) -> EggResult { - self.eval_sort(expr, self.py_object_arcsort.clone().unwrap()) - } - #[pyo3(signature = (expr, /))] - fn eval_i64(&mut self, expr: Expr) -> EggResult { - self.eval_sort(expr, Arc::new(I64Sort)) - } - - #[pyo3(signature = (expr, /))] - fn eval_f64(&mut self, expr: Expr) -> EggResult { - self.eval_sort(expr, Arc::new(F64Sort)) - } - - #[pyo3(signature = (expr, /))] - fn eval_string(&mut self, expr: Expr) -> EggResult { - let s: egglog::ast::Symbol = self.eval_sort(expr, Arc::new(StringSort))?; - Ok(s.to_string()) - } - - #[pyo3(signature = (expr, /))] - fn eval_bool(&mut self, expr: Expr) -> EggResult { - self.eval_sort(expr, Arc::new(BoolSort)) - } -} - -impl EGraph { - fn eval_sort>( - &mut self, - expr: Expr, - arcsort: Arc, - ) -> EggResult { - let expr: egglog::ast::Expr = expr.into(); - let (_, value) = self.egraph.eval_expr(&expr)?; - Ok(V::load(&arcsort, &value)) - } } diff --git a/src/py_object_sort.rs b/src/py_object_sort.rs index 47960367..f2bf323c 100644 --- a/src/py_object_sort.rs +++ b/src/py_object_sort.rs @@ -52,6 +52,22 @@ impl PyObjectIdent { } }) } + + pub fn from_expr(expr: &Expr) -> Self { + match expr { + Expr::Call(_, head, args) => match head.as_str() { + "py-object" => match args.as_slice() { + [Expr::Lit(_, Literal::Int(type_hash)), Expr::Lit(_, Literal::Int(hash))] => { + PyObjectIdent::Hashable(*type_hash as isize, *hash as isize) + } + [Expr::Lit(_, Literal::Int(id))] => PyObjectIdent::Unhashable(*id as usize), + _ => panic!("Unexpected children when loading PyObjectIdent"), + }, + _ => panic!("Unexpected head when loading PyObjectIdent"), + }, + _ => panic!("Unexpected expr when loading PyObjectIdent"), + } + } pub fn to_expr(self) -> Expr { let children = match self { PyObjectIdent::Unhashable(id) => { @@ -84,10 +100,16 @@ impl PyObjectSort { self.0.lock().unwrap().insert_full(key, value).0 } - // /// Retrives the Python object at the given index. - // pub fn get_index(&self, index: usize) -> PyObject { - // self.0.lock().unwrap().get_index(index).unwrap().1.clone() - // } + /// Retrieves the Python object at the given index. + pub fn get_index(&self, py: Python<'_>, index: usize) -> PyObject { + self.0 + .lock() + .unwrap() + .get_index(index) + .unwrap() + .1 + .clone_ref(py) + } /// Retrieves the index of the given key. pub fn get_index_of(&self, key: &PyObjectIdent) -> usize { @@ -109,10 +131,8 @@ impl PyObjectSort { } pub fn load(&self, py: Python<'_>, value: Value) -> PyObject { - let objects = self.0.lock().unwrap(); let i = value.bits as usize; - let (_ident, obj) = objects.get_index(i).unwrap(); - obj.clone_ref(py) + self.get_index(py, i) } } @@ -155,6 +175,15 @@ impl ArcPyObjectSort { Ok(ident.to_expr().into()) } + // Retrieve the Python object from an expression + #[pyo3(name="load", signature = (expr, /))] + fn load_py(&self, expr: crate::conversions::Expr) -> PyObject { + let expr: Expr = expr.into(); + let ident = PyObjectIdent::from_expr(&expr); + let index = self.0.get_index_of(&ident); + Python::with_gil(|py| self.0.get_index(py, index)) + } + // Integrate with Python garbage collector // https://pyo3.rs/main/class/protocols#garbage-collector-integration diff --git a/test-data/unit/check-high-level.test b/test-data/unit/check-high-level.test index c52e57ad..f1faabb6 100644 --- a/test-data/unit/check-high-level.test +++ b/test-data/unit/check-high-level.test @@ -1,10 +1,10 @@ -[case eqNotAllowed] +[case eqAllowed] from egglog import * -_ = i64(0) == i64(0) # E: Unsupported operand types for == ("i64" and "i64") +_ = i64(0) == i64(0) -[case notEqNotAllowed] +[case notEqAllowed] from egglog import * -_ = i64(0) != i64(0) # E: Unsupported operand types for != ("i64" and "i64") +_ = i64(0) != i64(0) [case eqToAllowed] from egglog import * From 8f948d11f96e7e1e0eb09b6f85878853e036cdf1 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 26 Feb 2025 17:58:32 -0500 Subject: [PATCH 2/5] try removing "fast" hash --- python/egglog/declarations.py | 37 +++++++++++++++++------------------ 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 4bc9c08c..73c6aafc 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -7,7 +7,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from functools import cached_property from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypeVar, Union, runtime_checkable from typing_extensions import Self, assert_never @@ -575,24 +574,24 @@ def __post_init__(self) -> None: msg = "Cannot bind type parameters to a non-class method callable." raise ValueError(msg) - def __hash__(self) -> int: - return self._cached_hash - - @cached_property - def _cached_hash(self) -> int: - return hash((self.callable, self.args, self.bound_tp_params)) - - def __eq__(self, other: object) -> bool: - # Override eq to use cached hash for perf - if not isinstance(other, CallDecl): - return False - if hash(self) != hash(other): - return False - return ( - self.callable == other.callable - and self.args == other.args - and self.bound_tp_params == other.bound_tp_params - ) + # def __hash__(self) -> int: + # return self._cached_hash + + # @cached_property + # def _cached_hash(self) -> int: + # return hash((self.callable, self.args, self.bound_tp_params)) + + # def __eq__(self, other: object) -> bool: + # # Override eq to use cached hash for perf + # if not isinstance(other, CallDecl): + # return False + # if hash(self) != hash(other): + # return False + # return ( + # self.callable == other.callable + # and self.args == other.args + # and self.bound_tp_params == other.bound_tp_params + # ) @dataclass(frozen=True) From eb118bb0a7b3dd9fb36dd963ce9aa70697971061 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 27 Feb 2025 07:25:14 -0500 Subject: [PATCH 3/5] Switch to pooled objects --- python/egglog/declarations.py | 57 +++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 73c6aafc..deeeafde 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -7,7 +7,9 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypeVar, Union, runtime_checkable +from functools import cached_property +from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable +from weakref import WeakValueDictionary from typing_extensions import Self, assert_never @@ -569,29 +571,46 @@ class CallDecl: # Used for pretty printing classmethod calls with type parameters bound_tp_params: tuple[JustTypeRef, ...] | None = None + # pool objects for faster __eq__ + _args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({}) + + def __new__(cls, *args: object, **kwargs: object) -> Self: + """ + Pool CallDecls so that they can be compared by identity more quickly. + + Neccessary bc we search for common parents when serializing CallDecl trees to egglog to + only serialize each sub-tree once. + """ + # normalize the args/kwargs to a tuple so that they can be compared + callable = args[0] if args else kwargs["callable"] + args_ = args[1] if len(args) > 1 else kwargs.get("args", ()) + bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params") + + normalized_args = (callable, args_, bound_tp_params) + try: + return cast(Self, cls._args_to_value[normalized_args]) + except KeyError: + res = super().__new__(cls) + cls._args_to_value[normalized_args] = res + return res + def __post_init__(self) -> None: if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef): msg = "Cannot bind type parameters to a non-class method callable." raise ValueError(msg) - # def __hash__(self) -> int: - # return self._cached_hash - - # @cached_property - # def _cached_hash(self) -> int: - # return hash((self.callable, self.args, self.bound_tp_params)) - - # def __eq__(self, other: object) -> bool: - # # Override eq to use cached hash for perf - # if not isinstance(other, CallDecl): - # return False - # if hash(self) != hash(other): - # return False - # return ( - # self.callable == other.callable - # and self.args == other.args - # and self.bound_tp_params == other.bound_tp_params - # ) + def __hash__(self) -> int: + return self._cached_hash + + @cached_property + def _cached_hash(self) -> int: + return hash((self.callable, self.args, self.bound_tp_params)) + + def __eq__(self, other: object) -> bool: + return self is other + + def __ne__(self, other: object) -> bool: + return self is not other @dataclass(frozen=True) From 49b8cca3f7dbc813a4387b6f240e5a0dc8e9aeca Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 27 Feb 2025 07:57:00 -0500 Subject: [PATCH 4/5] Try eagerly evaluating primitives for speed --- python/egglog/egraph.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 79c1bce9..96d14a5d 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -1889,6 +1889,13 @@ def try_evaling(schedule: Schedule, expr: Expr, prim_expr: _EvalsTo[T]) -> T: if it fails, display the egraph and raise an error. """ egraph = EGraph.current or EGraph() + with egraph.set_current(): + try: + return prim_expr.eval() + except BaseException: # noqa: S110 + pass + # If this primitive doesn't exist in the egraph, we need to try to create it by + # registering the expression and running the schedule egraph.register(expr) egraph.run(Schedule.current or schedule) try: From 87edb7df39714a1064e43ca0ee8dd4be1fb147a6 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 27 Feb 2025 08:06:33 -0500 Subject: [PATCH 5/5] Update changelog --- docs/changelog.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index a8a42085..2d2f5fcb 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -15,7 +15,7 @@ _This project uses semantic versioning_ - Add conversions from generic types to be supported at runtime and typing level (so can go from `(1, 2, 3)` to `TupleInt`) - Open files with webbrowser instead of internal graphviz util for better support - Add support for not visualizing when using `.saturate()` method [#254](https://github.com/egraphs-good/egglog-python/pull/254) -- Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/b0db06832264c9b22694bd3de2bdacd55bbe9e32...saulshanabrook:egg-smol:889ca7635368d7e382e16a93b2883aba82f1078f) +- Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/b0d b06832264c9b22694bd3de2bdacd55bbe9e32...saulshanabrook:egg-smol:889ca7635368d7e382e16a93b2883aba82f1078f) [#258](https://github.com/egraphs-good/egglog-python/pull/258) - This includes a few big changes to the underlying bindings, which I won't go over in full detail here. See the [pyi diff](https://github.com/egraphs-good/egglog-python/pull/258/files#diff-f34a5dd5d6568cd258ed9f786e5abce03df5ee95d356ea9e1b1b39e3505e5d62) for all public changes. - Creates seperate parent classes for `BuiltinExpr` vs `Expr` (aka eqsort aka user defined expressions). This is to allow us statically to differentiate between the two, to be more precise about what behavior is allowed. For example, @@ -24,9 +24,9 @@ _This project uses semantic versioning_ - Updates function constructor to remove `default` and `on_merge`. You also can't set a `cost` when you use a `merge` function or return a primitive. - `eq` now only takes two args, instead of being able to compare any number of values. -- Removes `eval` method from `EGraph` and moves primitive evaluation to methods on each builtin and support `int(...)` type conversions on primitives. -- Change how to set global EGraph context with `with egraph.set_current()` and `EGraph.current` and add support for setting global schedule as well with `with schedule.set_current()` and `Schedule.current`. -- Adds support for using `==` and `!=` directly on values instead of `eq` and `ne` functions. +- Removes `eval` method from `EGraph` and moves primitive evaluation to methods on each builtin and support `int(...)` type conversions on primitives. [#265](https://github.com/egraphs-good/egglog-python/pull/265) +- Change how to set global EGraph context with `with egraph.set_current()` and `EGraph.current` and add support for setting global schedule as well with `with schedule.set_current()` and `Schedule.current`. [#265](https://github.com/egraphs-good/egglog-python/pull/265) +- Adds support for using `==` and `!=` directly on values instead of `eq` and `ne` functions. [#265](https://github.com/egraphs-good/egglog-python/pull/265) ## 8.0.1 (2024-10-24)