Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove eval and rework primitive extraction #265

Merged
merged 5 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ _This project uses semantic versioning_
- Add conversions from generic types to be supported at runtime and typing level (so can go from `(1, 2, 3)` to `TupleInt`)
- Open files with webbrowser instead of internal graphviz util for better support
- Add support for not visualizing when using `.saturate()` method [#254](https://github.com/egraphs-good/egglog-python/pull/254)
- Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/b0db06832264c9b22694bd3de2bdacd55bbe9e32...saulshanabrook:egg-smol:889ca7635368d7e382e16a93b2883aba82f1078f)
- Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/b0d b06832264c9b22694bd3de2bdacd55bbe9e32...saulshanabrook:egg-smol:889ca7635368d7e382e16a93b2883aba82f1078f) [#258](https://github.com/egraphs-good/egglog-python/pull/258)
- This includes a few big changes to the underlying bindings, which I won't go over in full detail here. See the [pyi diff](https://github.com/egraphs-good/egglog-python/pull/258/files#diff-f34a5dd5d6568cd258ed9f786e5abce03df5ee95d356ea9e1b1b39e3505e5d62) for all public changes.
- Creates seperate parent classes for `BuiltinExpr` vs `Expr` (aka eqsort aka user defined expressions). This is to
allow us statically to differentiate between the two, to be more precise about what behavior is allowed. For example,
Expand All @@ -24,6 +24,9 @@ _This project uses semantic versioning_
- Updates function constructor to remove `default` and `on_merge`. You also can't set a `cost` when you use a `merge`
function or return a primitive.
- `eq` now only takes two args, instead of being able to compare any number of values.
- Removes `eval` method from `EGraph` and moves primitive evaluation to methods on each builtin and support `int(...)` type conversions on primitives. [#265](https://github.com/egraphs-good/egglog-python/pull/265)
- Change how to set global EGraph context with `with egraph.set_current()` and `EGraph.current` and add support for setting global schedule as well with `with schedule.set_current()` and `Schedule.current`. [#265](https://github.com/egraphs-good/egglog-python/pull/265)
- Adds support for using `==` and `!=` directly on values instead of `eq` and `ne` functions. [#265](https://github.com/egraphs-good/egglog-python/pull/265)

## 8.0.1 (2024-10-24)

Expand Down
9 changes: 3 additions & 6 deletions python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,14 @@ class SerializedEGraph:
class PyObjectSort:
def __init__(self) -> None: ...
def store(self, __o: object, /) -> _Expr: ...
def load(self, __e: _Expr, /) -> object: ...

@final
class EGraph:
def __init__(
self,
__py_object_sort: PyObjectSort | None = None,
py_object_sort: PyObjectSort | None = None,
/,
*,
fact_directory: str | Path | None = None,
seminaive: bool = True,
Expand All @@ -116,11 +118,6 @@ class EGraph:
max_calls_per_function: int | None = None,
include_temporary_functions: bool = False,
) -> SerializedEGraph: ...
def eval_py_object(self, __expr: _Expr) -> object: ...
def eval_i64(self, __expr: _Expr) -> int: ...
def eval_f64(self, __expr: _Expr) -> float: ...
def eval_string(self, __expr: _Expr) -> str: ...
def eval_bool(self, __expr: _Expr) -> bool: ...

@final
class EggSmolError(Exception):
Expand Down
186 changes: 182 additions & 4 deletions python/egglog/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@

from __future__ import annotations

from fractions import Fraction
from functools import partial, reduce
from types import FunctionType
from types import FunctionType, MethodType
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, cast, overload

from typing_extensions import TypeVarTuple, Unpack

from . import bindings
from .conversion import convert, converter, get_type_args
from .egraph import BaseExpr, BuiltinExpr, Unit, function, get_current_ruleset, method
from .declarations import *
from .egraph import BaseExpr, BuiltinExpr, EGraph, expr_fact, function, get_current_ruleset, method
from .egraph_state import GLOBAL_PY_OBJECT_SORT
from .functionalize import functionalize
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
from .thunk import Thunk

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Callable, Iterator


__all__ = [
Expand All @@ -32,6 +36,7 @@
"SetLike",
"String",
"StringLike",
"Unit",
"UnstableFn",
"Vec",
"VecLike",
Expand All @@ -46,7 +51,25 @@
]


class Unit(BuiltinExpr, egg_sort="Unit"):
"""
The unit type. This is used to reprsent if a value exists in the e-graph or not.
"""

def __init__(self) -> None: ...

@method(preserve=True)
def __bool__(self) -> bool:
return bool(expr_fact(self))


class String(BuiltinExpr):
@method(preserve=True)
def eval(self) -> str:
value = _extract_lit(self)
assert isinstance(value, bindings.String)
return value.value

def __init__(self, value: str) -> None: ...

@method(egg_fn="replace")
Expand All @@ -62,10 +85,20 @@ def join(*strings: StringLike) -> String: ...

converter(str, String, String)

BoolLike = Union["Bool", bool]
BoolLike: TypeAlias = Union["Bool", bool]


class Bool(BuiltinExpr, egg_sort="bool"):
@method(preserve=True)
def eval(self) -> bool:
value = _extract_lit(self)
assert isinstance(value, bindings.Bool)
return value.value

@method(preserve=True)
def __bool__(self) -> bool:
return self.eval()

def __init__(self, value: bool) -> None: ...

@method(egg_fn="not")
Expand All @@ -91,6 +124,20 @@ def implies(self, other: BoolLike) -> Bool: ...


class i64(BuiltinExpr): # noqa: N801
@method(preserve=True)
def eval(self) -> int:
value = _extract_lit(self)
assert isinstance(value, bindings.Int)
return value.value

@method(preserve=True)
def __index__(self) -> int:
return self.eval()

@method(preserve=True)
def __int__(self) -> int:
return self.eval()

def __init__(self, value: int) -> None: ...

@method(egg_fn="+")
Expand Down Expand Up @@ -193,6 +240,20 @@ def count_matches(s: StringLike, pattern: StringLike) -> i64: ...


class f64(BuiltinExpr): # noqa: N801
@method(preserve=True)
def eval(self) -> float:
value = _extract_lit(self)
assert isinstance(value, bindings.Float)
return value.value

@method(preserve=True)
def __float__(self) -> float:
return self.eval()

@method(preserve=True)
def __int__(self) -> int:
return int(self.eval())

def __init__(self, value: float) -> None: ...

@method(egg_fn="neg")
Expand Down Expand Up @@ -265,6 +326,33 @@ def to_string(self) -> String: ...


class Map(BuiltinExpr, Generic[T, V]):
@method(preserve=True)
def eval(self) -> dict[T, V]:
call = _extract_call(self)
expr = cast(RuntimeExpr, self)
d = {}
while call.callable != ClassMethodRef("Map", "empty"):
assert call.callable == MethodRef("Map", "insert")
call_typed, k_typed, v_typed = call.args
assert isinstance(call_typed.expr, CallDecl)
k = cast(T, expr.__with_expr__(k_typed))
v = cast(V, expr.__with_expr__(v_typed))
d[k] = v
call = call_typed.expr
return d

@method(preserve=True)
def __iter__(self) -> Iterator[T]:
return iter(self.eval())

@method(preserve=True)
def __len__(self) -> int:
return len(self.eval())

@method(preserve=True)
def __contains__(self, key: T) -> bool:
return key in self.eval()

@method(egg_fn="map-empty")
@classmethod
def empty(cls) -> Map[T, V]: ...
Expand Down Expand Up @@ -305,6 +393,24 @@ def rebuild(self) -> Map[T, V]: ...


class Set(BuiltinExpr, Generic[T]):
@method(preserve=True)
def eval(self) -> set[T]:
call = _extract_call(self)
assert call.callable == InitRef("Set")
return {cast(T, cast(RuntimeExpr, self).__with_expr__(x)) for x in call.args}

@method(preserve=True)
def __iter__(self) -> Iterator[T]:
return iter(self.eval())

@method(preserve=True)
def __len__(self) -> int:
return len(self.eval())

@method(preserve=True)
def __contains__(self, key: T) -> bool:
return key in self.eval()

@method(egg_fn="set-of")
def __init__(self, *args: T) -> None: ...

Expand Down Expand Up @@ -349,6 +455,28 @@ def rebuild(self) -> Set[T]: ...


class Rational(BuiltinExpr):
@method(preserve=True)
def eval(self) -> Fraction:
call = _extract_call(self)
assert call.callable == InitRef("Rational")

def _to_int(e: TypedExprDecl) -> int:
expr = e.expr
assert isinstance(expr, LitDecl)
assert isinstance(expr.value, int)
return expr.value

num, den = call.args
return Fraction(_to_int(num), _to_int(den))

@method(preserve=True)
def __float__(self) -> float:
return float(self.eval())

@method(preserve=True)
def __int__(self) -> int:
return int(self.eval())

@method(egg_fn="rational")
def __init__(self, num: i64Like, den: i64Like) -> None: ...

Expand Down Expand Up @@ -410,6 +538,26 @@ def denom(self) -> i64: ...


class Vec(BuiltinExpr, Generic[T]):
@method(preserve=True)
def eval(self) -> tuple[T, ...]:
call = _extract_call(self)
if call.callable == ClassMethodRef("Vec", "empty"):
return ()
assert call.callable == InitRef("Vec")
return tuple(cast(T, cast(RuntimeExpr, self).__with_expr__(x)) for x in call.args)

@method(preserve=True)
def __iter__(self) -> Iterator[T]:
return iter(self.eval())

@method(preserve=True)
def __len__(self) -> int:
return len(self.eval())

@method(preserve=True)
def __contains__(self, key: T) -> bool:
return key in self.eval()

@method(egg_fn="vec-of")
def __init__(self, *args: T) -> None: ...

Expand Down Expand Up @@ -461,6 +609,13 @@ def set(self, index: i64Like, value: T) -> Vec[T]: ...


class PyObject(BuiltinExpr):
@method(preserve=True)
def eval(self) -> object:
report = (EGraph.current or EGraph())._run_extract(cast(RuntimeExpr, self), 0)
assert isinstance(report, bindings.Best)
expr = report.termdag.term_to_expr(report.term, bindings.PanicSpan())
return GLOBAL_PY_OBJECT_SORT.load(expr)

def __init__(self, value: object) -> None: ...

@method(egg_fn="py-from-string")
Expand Down Expand Up @@ -554,6 +709,8 @@ def __init__(self, f, *partial) -> None: ...
def __call__(self, *args: Unpack[TS]) -> T: ...


# Method Type is for builtins like __getitem__
converter(MethodType, UnstableFn, lambda m: UnstableFn(m.__func__, m.__self__))
converter(RuntimeFunction, UnstableFn, UnstableFn)
converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))

Expand Down Expand Up @@ -590,3 +747,24 @@ def value_to_annotation(a: object) -> type | None:


converter(FunctionType, UnstableFn, _convert_function)


def _extract_lit(e: BaseExpr) -> bindings._Literal:
"""
Special case extracting literals to make this faster by using termdag directly.
"""
report = (EGraph.current or EGraph())._run_extract(cast(RuntimeExpr, e), 0)
assert isinstance(report, bindings.Best)
term = report.term
assert isinstance(term, bindings.TermLit)
return term.value


def _extract_call(e: BaseExpr) -> CallDecl:
"""
Extracts the call form of an expression
"""
extracted = cast(RuntimeExpr, (EGraph.current or EGraph()).extract(e))
expr = extracted.__egg_typed_expr__.expr
assert isinstance(expr, CallDecl)
return expr
6 changes: 3 additions & 3 deletions python/egglog/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from .egraph import BaseExpr

__all__ = ["convert", "convert_to_same_type", "converter", "resolve_literal", "ConvertError"]
__all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"]
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
# Global declerations to store all convertable types so we can query if they have certain methods or not
Expand Down Expand Up @@ -153,9 +153,9 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
b_converts_to = {
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
}
if isinstance(a_tp, JustTypeRef):
if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
a_converts_to[a_tp] = 0
if isinstance(b_tp, JustTypeRef):
if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
b_converts_to[b_tp] = 0
common = set(a_converts_to) & set(b_converts_to)
if not common:
Expand Down
Loading