diff --git a/docs/changelog.md b/docs/changelog.md index 48956045..78f39fee 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,11 @@ _This project uses semantic versioning_ ## UNRELEASED +- Defers adding rules in functions until they are used, so that you can use types that are not present yet. +- Removes ability to set custom default ruleset for egraph. Either just use the empty default ruleset or explicitly set it for every run +- Automatically mark Python builtin operators as preserved if they must return a real Python value +- Properly pretty print all items (rewrites, actions, exprs, etc) so that expressions are de-duplicated and state is handled correctly. + ## 6.1.0 (2024-03-06) - Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/4cc011f6b48029dd72104a38a2ca0c7657846e0b...0113af1d6476b75d4319591cc3d675f96a71cdc5) diff --git a/pyproject.toml b/pyproject.toml index ed7f5445..281f86e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,3 +214,13 @@ filterwarnings = [ "ignore::numba.core.errors.NumbaPerformanceWarning", "ignore::pytest_benchmark.logger.PytestBenchmarkWarning", ] + +[tool.coverage.report] +exclude_also = [ + "def __repr__", + "raise NotImplementedError", + "if TYPE_CHECKING:", + "class .*\\bProtocol\\):", + "@(abc\\.)?abstractmethod", + "assert_never\\(", +] diff --git a/python/egglog/__init__.py b/python/egglog/__init__.py index a66589fa..946a642d 100644 --- a/python/egglog/__init__.py +++ b/python/egglog/__init__.py @@ -4,7 +4,7 @@ from . import config, ipython_magic # noqa: F401 from .builtins import * # noqa: UP029 +from .conversion import convert, converter # noqa: F401 from .egraph import * -from .runtime import convert, converter # noqa: F401 del ipython_magic diff --git a/python/egglog/bindings.pyi b/python/egglog/bindings.pyi index 3553ac0c..47c26a59 100644 --- a/python/egglog/bindings.pyi +++ b/python/egglog/bindings.pyi @@ -29,7 +29,9 @@ class EGraph: fact_directory: str | Path | None = None, seminaive: bool = True, terms_encoding: bool = False, + record: bool = False, ) -> None: ... + def commands(self) -> str | None: ... def parse_program(self, __input: str, /) -> list[_Command]: ... def run_program(self, *commands: _Command) -> list[str]: ... def extract_report(self) -> _ExtractReport | None: ... diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index c2b3a5f9..7bed35cd 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union +from .conversion import converter from .egraph import Expr, Unit, function, method -from .runtime import converter if TYPE_CHECKING: from collections.abc import Callable diff --git a/python/egglog/conversion.py b/python/egglog/conversion.py new file mode 100644 index 00000000..20e0d23f --- /dev/null +++ b/python/egglog/conversion.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypeVar, cast + +from .declarations import * +from .pretty import * +from .runtime import * +from .thunk import * + +if TYPE_CHECKING: + from collections.abc import Callable + + from .declarations import HasDeclerations + from .egraph import Expr + +__all__ = ["convert", "converter", "resolve_literal", "convert_to_same_type"] +# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target +CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {} +# Global declerations to store all convertable types so we can query if they have certain methods or not +# Defer it as a thunk so we can register conversions without triggering type signature loading +CONVERSIONS_DECLS: Callable[[], Declarations] = Thunk.value(Declarations()) + +T = TypeVar("T") +V = TypeVar("V", bound="Expr") + + +class ConvertError(Exception): + pass + + +def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: int = 1) -> None: + """ + Register a converter from some type to an egglog type. + """ + to_type_name = process_tp(to_type) + if not isinstance(to_type_name, JustTypeRef): + raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}") + _register_converter(process_tp(from_type), to_type_name, fn, cost) + + +def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None: + """ + Registers a converter from some type to an egglog type, if not already registered. + + Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered. + Also, if registering A->B and there is already D->A, then D->B will be registered. + """ + if a == b: + return + if (a, b) in CONVERSIONS and CONVERSIONS[(a, b)][0] <= cost: + return + CONVERSIONS[(a, b)] = (cost, a_b) + for (c, d), (other_cost, c_d) in list(CONVERSIONS.items()): + if b == c: + _register_converter(a, d, _ComposedConverter(a_b, c_d), cost + other_cost) + if a == d: + _register_converter(c, b, _ComposedConverter(c_d, a_b), cost + other_cost) + + +@dataclass +class _ComposedConverter: + """ + A converter which is composed of multiple converters. + + _ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x)) + + We use the dataclass instead of the lambda to make it easier to debug. + """ + + a_b: Callable + b_c: Callable + + def __call__(self, x: object) -> object: + return self.b_c(self.a_b(x)) + + def __str__(self) -> str: + return f"{self.b_c} ∘ {self.a_b}" + + +def convert(source: object, target: type[V]) -> V: + """ + Convert a source object to a target type. + """ + assert isinstance(target, RuntimeClass) + return cast(V, resolve_literal(target.__egg_tp__, source)) + + +def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr: + """ + Convert a source object to the same type as the target. + """ + tp = target.__egg_typed_expr__.tp + return resolve_literal(tp.to_var(), source) + + +def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type: + """ + Process a type before converting it, to add it to the global declerations and resolve to a ref. + """ + global CONVERSIONS_DECLS + if isinstance(tp, RuntimeClass): + CONVERSIONS_DECLS = Thunk.fn(_combine_decls, CONVERSIONS_DECLS, tp) + return tp.__egg_tp__.to_just() + return tp + + +def _combine_decls(d: Callable[[], Declarations], x: HasDeclerations) -> Declarations: + return Declarations.create(d(), x) + + +def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef: + """ + Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists. + """ + decls = CONVERSIONS_DECLS() + a_tp = _get_tp(a) + b_tp = _get_tp(b) + a_converts_to = { + to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name) + } + b_converts_to = { + to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name) + } + if isinstance(a_tp, JustTypeRef): + a_converts_to[a_tp] = 0 + if isinstance(b_tp, JustTypeRef): + b_converts_to[b_tp] = 0 + common = set(a_converts_to) & set(b_converts_to) + if not common: + raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type") + return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp]) + + +def identity(x: object) -> object: + return x + + +def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr: + arg_type = _get_tp(arg) + + # If we have any type variables, dont bother trying to resolve the literal, just return the arg + try: + tp_just = tp.to_just() + except NotImplementedError: + # If this is a var, it has to be a runtime exprssions + assert isinstance(arg, RuntimeExpr) + return arg + if arg_type == tp_just: + # If the type is an egg type, it has to be a runtime expr + assert isinstance(arg, RuntimeExpr) + return arg + # Try all parent types as well, if we are converting from a Python type + for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]: + try: + fn = CONVERSIONS[(cast(JustTypeRef | type, arg_type_instance), tp_just)][1] + except KeyError: + continue + break + else: + raise ConvertError(f"Cannot convert {arg_type} to {tp_just}") + return fn(arg) + + +def _get_tp(x: object) -> JustTypeRef | type: + if isinstance(x, RuntimeExpr): + return x.__egg_typed_expr__.tp + tp = type(x) + # If this value has a custom metaclass, let's use that as our index instead of the type + if type(tp) != type: + return type(tp) + return tp diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index fc396b44..f738315b 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -1,18 +1,17 @@ """ Data only descriptions of the components of an egraph and the expressions. + +We seperate it it into two pieces, the references the declerations, so that we can report mutually recursive types. """ from __future__ import annotations -from collections import defaultdict from dataclasses import dataclass, field -from inspect import Parameter, Signature -from typing import TYPE_CHECKING, Protocol, TypeAlias, Union, runtime_checkable +from functools import cached_property +from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_checkable from typing_extensions import Self, assert_never -from . import bindings - if TYPE_CHECKING: from collections.abc import Callable, Iterable @@ -20,84 +19,64 @@ __all__ = [ "Declarations", "DeclerationsLike", - "upcast_decleratioons", + "DelayedDeclerations", + "upcast_declerations", + "Declarations", "JustTypeRef", "ClassTypeVarRef", "TypeRefWithVars", "TypeOrVarRef", - "FunctionRef", "MethodRef", "ClassMethodRef", + "FunctionRef", + "ConstantRef", "ClassVariableRef", - "FunctionCallableRef", "PropertyRef", "CallableRef", - "ConstantRef", "FunctionDecl", + "RelationDecl", + "ConstantDecl", + "CallableDecl", "VarDecl", - "LitType", "PyObjectDecl", + "LitType", "LitDecl", "CallDecl", "ExprDecl", "TypedExprDecl", "ClassDecl", - "PrettyContext", - "GLOBAL_PY_OBJECT_SORT", + "RulesetDecl", + "SaturateDecl", + "RepeatDecl", + "SequenceDecl", + "RunDecl", + "ScheduleDecl", + "EqDecl", + "ExprFactDecl", + "FactDecl", + "LetDecl", + "SetDecl", + "ExprActionDecl", + "ChangeDecl", + "UnionDecl", + "PanicDecl", + "ActionDecl", + "RewriteDecl", + "BiRewriteDecl", + "RuleDecl", + "RewriteOrRuleDecl", + "ActionCommandDecl", + "CommandDecl", ] -# Create a global sort for python objects, so we can store them without an e-graph instance -# Needed when serializing commands to egg commands when creating modules -GLOBAL_PY_OBJECT_SORT = bindings.PyObjectSort() - -# Special methods which we might want to use as functions -# Mapping to the operator they represent for pretty printing them -# https://docs.python.org/3/reference/datamodel.html -BINARY_METHODS = { - "__lt__": "<", - "__le__": "<=", - "__eq__": "==", - "__ne__": "!=", - "__gt__": ">", - "__ge__": ">=", - # Numeric - "__add__": "+", - "__sub__": "-", - "__mul__": "*", - "__matmul__": "@", - "__truediv__": "/", - "__floordiv__": "//", - "__mod__": "%", - # TODO: Support divmod, with tuple return value - # "__divmod__": "divmod", - # TODO: Three arg power - "__pow__": "**", - "__lshift__": "<<", - "__rshift__": ">>", - "__and__": "&", - "__xor__": "^", - "__or__": "|", -} -REFLECTED_BINARY_METHODS = { - "__radd__": "__add__", - "__rsub__": "__sub__", - "__rmul__": "__mul__", - "__rmatmul__": "__matmul__", - "__rtruediv__": "__truediv__", - "__rfloordiv__": "__floordiv__", - "__rmod__": "__mod__", - "__rpow__": "__pow__", - "__rlshift__": "__lshift__", - "__rrshift__": "__rshift__", - "__rand__": "__and__", - "__rxor__": "__xor__", - "__ror__": "__or__", -} -UNARY_METHODS = { - "__pos__": "+", - "__neg__": "-", - "__invert__": "~", -} + +@dataclass +class DelayedDeclerations: + __egg_decls_thunk__: Callable[[], Declarations] + + @property + def __egg_decls__(self) -> Declarations: + return self.__egg_decls_thunk__() @runtime_checkable @@ -109,7 +88,10 @@ def __egg_decls__(self) -> Declarations: ... DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"] -def upcast_decleratioons(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]: +# TODO: Make all ClassDecls take deferred type refs, which return new decls when resolving. + + +def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]: d = [] for l in declerations_like: if l is None: @@ -125,30 +107,14 @@ def upcast_decleratioons(declerations_like: Iterable[DeclerationsLike]) -> list[ @dataclass class Declarations: - _functions: dict[str, FunctionDecl] = field(default_factory=dict) + _functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict) + _constants: dict[str, ConstantDecl] = field(default_factory=dict) _classes: dict[str, ClassDecl] = field(default_factory=dict) - _constants: dict[str, JustTypeRef] = field(default_factory=dict) - - # Bidirectional mapping between egg function names and python callable references. - # Note that there are possibly mutliple callable references for a single egg function name, like `+` - # for both int and rational classes. - _egg_fn_to_callable_refs: defaultdict[str, set[CallableRef]] = field(default_factory=lambda: defaultdict(set)) - _callable_ref_to_egg_fn: dict[CallableRef, str] = field(default_factory=dict) - - # Bidirectional mapping between egg sort names and python type references. - _egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict) - _type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict) - - # Mapping from egg name (of sort or function) to command to create it. - _cmds: dict[str, bindings._Command] = field(default_factory=dict) - - def __post_init__(self) -> None: - if "!=" not in self._egg_fn_to_callable_refs: - self.register_callable_ref(FunctionRef("!="), "!=") + _rulesets: dict[str, RulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])}) @classmethod def create(cls, *others: DeclerationsLike) -> Declarations: - others = upcast_decleratioons(others) + others = upcast_declerations(others) if not others: return Declarations() first, *rest = others @@ -159,25 +125,9 @@ def create(cls, *others: DeclerationsLike) -> Declarations: return new def copy(self) -> Declarations: - return Declarations( - _functions=self._functions.copy(), - _classes=self._classes.copy(), - _constants=self._constants.copy(), - _egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self._egg_fn_to_callable_refs.items()}), - _callable_ref_to_egg_fn=self._callable_ref_to_egg_fn.copy(), - _egg_sort_to_type_ref=self._egg_sort_to_type_ref.copy(), - _type_ref_to_egg_sort=self._type_ref_to_egg_sort.copy(), - _cmds=self._cmds.copy(), - ) - - def __deepcopy__(self, memo: dict) -> Declarations: - return self.copy() - - def add_cmd(self, name: str, cmd: bindings._Command) -> None: - self._cmds[name] = cmd - - def list_cmds(self) -> list[bindings._Command]: - return list(self._cmds.values()) + new = Declarations() + new |= self + return new def update(self, *others: DeclerationsLike) -> None: for other in others: @@ -200,82 +150,26 @@ def update_other(self, other: Declarations) -> None: """ Updates the other decl with these values in palce. """ - # If cmds are == skip unioning for time savings - # if set(self._cmds) == set(other._cmds) and self.record_cmds and other.record_cmds: - # return self other._functions |= self._functions other._classes |= self._classes other._constants |= self._constants - other._egg_sort_to_type_ref |= self._egg_sort_to_type_ref - other._type_ref_to_egg_sort |= self._type_ref_to_egg_sort - other._cmds |= self._cmds - other._callable_ref_to_egg_fn |= self._callable_ref_to_egg_fn - for egg_fn, callable_refs in self._egg_fn_to_callable_refs.items(): - other._egg_fn_to_callable_refs[egg_fn] |= callable_refs - - def set_function_decl(self, ref: FunctionCallableRef, decl: FunctionDecl) -> None: - """ - Sets a function declaration for the given callable reference. - """ + other._rulesets |= self._rulesets + + def get_callable_decl(self, ref: CallableRef) -> CallableDecl: match ref: case FunctionRef(name): - if name in self._functions: - raise ValueError(f"Function {name} already registered") - self._functions[name] = decl + return self._functions[name] + case ConstantRef(name): + return self._constants[name] case MethodRef(class_name, method_name): - if method_name in self._classes[class_name].methods: - raise ValueError(f"Method {class_name}.{method_name} already registered") - self._classes[class_name].methods[method_name] = decl - case ClassMethodRef(class_name, method_name): - if method_name in self._classes[class_name].class_methods: - raise ValueError(f"Class method {class_name}.{method_name} already registered") - self._classes[class_name].class_methods[method_name] = decl + return self._classes[class_name].methods[method_name] + case ClassVariableRef(class_name, name): + return self._classes[class_name].class_variables[name] + case ClassMethodRef(class_name, name): + return self._classes[class_name].class_methods[name] case PropertyRef(class_name, property_name): - if property_name in self._classes[class_name].properties: - raise ValueError(f"Property {class_name}.{property_name} already registered") - self._classes[class_name].properties[property_name] = decl - case _: - assert_never(ref) - - def set_constant_type(self, ref: ConstantCallableRef, tp: JustTypeRef) -> None: - match ref: - case ConstantRef(name): - if name in self._constants: - raise ValueError(f"Constant {name} already registered") - self._constants[name] = tp - case ClassVariableRef(class_name, variable_name): - if variable_name in self._classes[class_name].class_variables: - raise ValueError(f"Class variable {class_name}.{variable_name} already registered") - self._classes[class_name].class_variables[variable_name] = tp - case _: - assert_never(ref) - - def register_callable_ref(self, ref: CallableRef, egg_name: str) -> None: - """ - Registers a callable reference with the given egg name. - - The callable's function needs to be registered first. - """ - if ref in self._callable_ref_to_egg_fn: - raise ValueError(f"Callable ref {ref} already registered") - self._callable_ref_to_egg_fn[ref] = egg_name - self._egg_fn_to_callable_refs[egg_name].add(ref) - - def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]: - return self._egg_fn_to_callable_refs[egg_name] - - def get_egg_fn(self, ref: CallableRef) -> str: - return self._callable_ref_to_egg_fn[ref] - - def get_egg_sort(self, ref: JustTypeRef) -> str: - return self._type_ref_to_egg_sort[ref] - - def op_mapping(self) -> dict[str, str]: - """ - Create a mapping of egglog function name to Python function name, for use in the serialized format - for better visualization. - """ - return {k: str(next(iter(v))) for k, v in self._egg_fn_to_callable_refs.items() if len(v) == 1} + return self._classes[class_name].properties[property_name] + assert_never(ref) def has_method(self, class_name: str, method_name: str) -> bool | None: """ @@ -285,138 +179,31 @@ def has_method(self, class_name: str, method_name: str) -> bool | None: return method_name in self._classes[class_name].methods return None - def get_function_decl(self, ref: CallableRef) -> FunctionDecl: - match ref: - case ConstantRef(name): - return self._constants[name].to_constant_function_decl() - case ClassVariableRef(class_name, variable_name): - return self._classes[class_name].class_variables[variable_name].to_constant_function_decl() - case FunctionRef(name): - return self._functions[name] - case MethodRef(class_name, method_name): - return self._classes[class_name].methods[method_name] - case ClassMethodRef(class_name, method_name): - return self._classes[class_name].class_methods[method_name] - case PropertyRef(class_name, property_name): - return self._classes[class_name].properties[property_name] - assert_never(ref) - def get_class_decl(self, name: str) -> ClassDecl: return self._classes[name] - def get_possible_types(self, cls_name: str) -> frozenset[JustTypeRef]: - """ - Given a class name, returns all possible registered types that it can be. - """ - return frozenset(tp for tp in self._type_ref_to_egg_sort if tp.name == cls_name) - def register_class(self, name: str, type_vars: tuple[str, ...], builtin: bool, egg_sort: str | None) -> None: - # Register class first - if name in self._classes: - raise ValueError(f"Class {name} already registered") - decl = ClassDecl(type_vars=type_vars) - self._classes[name] = decl - self.register_sort(JustTypeRef(name), builtin, egg_sort) +@dataclass +class ClassDecl: + egg_name: str | None = None + type_vars: tuple[str, ...] = () + builtin: bool = False + class_methods: dict[str, FunctionDecl] = field(default_factory=dict) + # These have to be seperate from class_methods so that printing them can be done easily + class_variables: dict[str, ConstantDecl] = field(default_factory=dict) + methods: dict[str, FunctionDecl] = field(default_factory=dict) + properties: dict[str, FunctionDecl] = field(default_factory=dict) + preserved_methods: dict[str, Callable] = field(default_factory=dict) - def register_sort(self, ref: JustTypeRef, builtin: bool, egg_name: str | None = None) -> str: - """ - Register a sort with the given name. If no name is given, one is generated. - If this is a type called with generic args, register the generic args as well. - """ - # If the sort is already registered, do nothing - try: - egg_sort = self.get_egg_sort(ref) - except KeyError: - pass - else: - return egg_sort - egg_name = egg_name or ref.generate_egg_name() - if egg_name in self._egg_sort_to_type_ref: - raise ValueError(f"Sort {egg_name} is already registered.") - self._egg_sort_to_type_ref[egg_name] = ref - self._type_ref_to_egg_sort[ref] = egg_name - if not builtin: - self.add_cmd( - egg_name, - bindings.Sort( - egg_name, - ( - self.get_egg_sort(JustTypeRef(ref.name)), - [bindings.Var(self.register_sort(arg, False)) for arg in ref.args], - ) - if ref.args - else None, - ), - ) - - return egg_name - - def register_function_callable( - self, - ref: FunctionCallableRef, - fn_decl: FunctionDecl, - egg_name: str | None, - cost: int | None, - default: ExprDecl | None, - merge: ExprDecl | None, - merge_action: list[bindings._Action], - unextractable: bool, - builtin: bool, - is_relation: bool = False, - ) -> None: - """ - Registers a callable with the given egg name. +@dataclass +class RulesetDecl: + rules: list[RewriteOrRuleDecl] - The callable's function needs to be registered first. - """ - egg_name = egg_name or ref.generate_egg_name() - self.register_callable_ref(ref, egg_name) - self.set_function_decl(ref, fn_decl) - - # Skip generating the cmds if we don't want to record them, like for the builtins - if builtin: - return - - if fn_decl.var_arg_type is not None: - msg = "egglog does not support variable arguments yet." - raise NotImplementedError(msg) - # Remove all vars from the type refs, raising an errory if we find one, - # since we cannot create egg functions with vars - arg_sorts = [self.register_sort(a.to_just(), False) for a in fn_decl.arg_types] - cmd: bindings._Command - if is_relation: - assert not default - assert not merge - assert not merge_action - assert not cost - cmd = bindings.Relation(egg_name, arg_sorts) - else: - egg_fn_decl = bindings.FunctionDecl( - egg_name, - bindings.Schema(arg_sorts, self.register_sort(fn_decl.return_type.to_just(), False)), - default.to_egg(self) if default else None, - merge.to_egg(self) if merge else None, - merge_action, - cost, - unextractable, - ) - cmd = bindings.Function(egg_fn_decl) - self.add_cmd(egg_name, cmd) - - def register_constant_callable(self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: str | None) -> None: - egg_name = egg_name or ref.generate_egg_name() - self.register_callable_ref(ref, egg_name) - self.set_constant_type(ref, type_ref) - egg_sort = self.register_sort(type_ref, False) - # self.add_cmd(egg_name, bindings.Declare(egg_name, self.get_egg_sort(type_ref))) - # Use function decleration instead of constant b/c constants cannot be extracted - # https://github.com/egraphs-good/egglog/issues/334 - fn_decl = bindings.FunctionDecl(egg_name, bindings.Schema([], egg_sort)) - self.add_cmd(egg_name, bindings.Function(fn_decl)) - - def register_preserved_method(self, class_: str, method: str, fn: Callable) -> None: - self._classes[class_].preserved_methods[method] = fn + # Make hashable so when traversing for pretty-fying we can know which rulesets we have already + # made into strings + def __hash__(self) -> int: + return hash((type(self), tuple(self.rules))) # Have two different types of type refs, one that can include vars recursively and one that cannot. @@ -427,38 +214,18 @@ class JustTypeRef: name: str args: tuple[JustTypeRef, ...] = () - def generate_egg_name(self) -> str: - """ - Generates an egg sort name for this type reference by linearizing the type. - """ - if not self.args: - return self.name - args = "_".join(a.generate_egg_name() for a in self.args) - return f"{self.name}_{args}" - def to_var(self) -> TypeRefWithVars: return TypeRefWithVars(self.name, tuple(a.to_var() for a in self.args)) - def pretty(self) -> str: - if not self.args: - return self.name - args = ", ".join(a.pretty() for a in self.args) - return f"{self.name}[{args}]" + def __str__(self) -> str: + if self.args: + return f"{self.name}[{', '.join(str(a) for a in self.args)}]" + return self.name - def to_constant_function_decl(self) -> FunctionDecl: - """ - Create a function declaration for a constant function. - This is similar to how egglog compiles the `constant` command. - """ - return FunctionDecl( - arg_types=(), - arg_names=(), - arg_defaults=(), - return_type=self.to_var(), - mutates_first_arg=False, - var_arg_type=None, - ) +## +# Type references with vars +## @dataclass(frozen=True) @@ -473,7 +240,7 @@ def to_just(self) -> JustTypeRef: msg = "egglog does not support generic classes yet." raise NotImplementedError(msg) - def pretty(self) -> str: + def __str__(self) -> str: return self.name @@ -485,30 +252,27 @@ class TypeRefWithVars: def to_just(self) -> JustTypeRef: return JustTypeRef(self.name, tuple(a.to_just() for a in self.args)) - def pretty(self) -> str: - if not self.args: - return self.name - args = ", ".join(a.pretty() for a in self.args) - return f"{self.name}[{args}]" + def __str__(self) -> str: + if self.args: + return f"{self.name}[{', '.join(str(a) for a in self.args)}]" + return self.name TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars +## +# Callables References +## + @dataclass(frozen=True) class FunctionRef: name: str - def generate_egg_name(self) -> str: - return self.name - - def __str__(self) -> str: - return self.name - -# Use this special character in place of the args, so that if the args are inlined -# in the viz, they will replace it -ARG = "·" +@dataclass(frozen=True) +class ConstantRef: + name: str @dataclass(frozen=True) @@ -516,123 +280,115 @@ class MethodRef: class_name: str method_name: str - def generate_egg_name(self) -> str: - return f"{self.class_name}_{self.method_name}" - - def __str__(self) -> str: # noqa: PLR0911 - match self.method_name: - case _ if self.method_name in UNARY_METHODS: - return f"{UNARY_METHODS[self.method_name]}{ARG}" - case _ if self.method_name in BINARY_METHODS: - return f"({ARG} {BINARY_METHODS[self.method_name]} {ARG})" - case "__getitem__": - return f"{ARG}[{ARG}]" - case "__call__": - return f"{ARG}({ARG})" - case "__delitem__": - return f"del {ARG}[{ARG}]" - case "__setitem__": - return f"{ARG}[{ARG}] = {ARG}" - return f"{ARG}.{self.method_name}" - @dataclass(frozen=True) class ClassMethodRef: class_name: str method_name: str - def generate_egg_name(self) -> str: - return f"{self.class_name}_{self.method_name}" - - def __str__(self) -> str: - if self.method_name == "__init__": - return self.class_name - return f"{self.class_name}.{self.method_name}" - @dataclass(frozen=True) -class ConstantRef: - name: str - - def generate_egg_name(self) -> str: - return self.name - - def __str__(self) -> str: - return self.name +class ClassVariableRef: + class_name: str + var_name: str @dataclass(frozen=True) -class ClassVariableRef: +class PropertyRef: class_name: str - variable_name: str + property_name: str - def generate_egg_name(self) -> str: - return f"{self.class_name}_{self.variable_name}" - def __str__(self) -> str: - return f"{self.class_name}.{self.variable_name}" +CallableRef: TypeAlias = FunctionRef | ConstantRef | MethodRef | ClassMethodRef | ClassVariableRef | PropertyRef + + +## +# Callables +## @dataclass(frozen=True) -class PropertyRef: - class_name: str - property_name: str +class RelationDecl: + arg_types: tuple[JustTypeRef, ...] + # List of defaults. None for any arg which doesn't have one. + arg_defaults: tuple[ExprDecl | None, ...] + egg_name: str | None - def generate_egg_name(self) -> str: - return f"{self.class_name}_{self.property_name}" + def to_function_decl(self) -> FunctionDecl: + return FunctionDecl( + arg_types=tuple(a.to_var() for a in self.arg_types), + arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))), + arg_defaults=self.arg_defaults, + return_type=TypeRefWithVars("Unit"), + egg_name=self.egg_name, + default=LitDecl(None), + ) - def __str__(self) -> str: - return f"{ARG}.{self.property_name}" +@dataclass(frozen=True) +class ConstantDecl: + """ + Same as `(declare)` in egglog + """ -ConstantCallableRef: TypeAlias = ConstantRef | ClassVariableRef -FunctionCallableRef: TypeAlias = FunctionRef | MethodRef | ClassMethodRef | PropertyRef -CallableRef: TypeAlias = ConstantCallableRef | FunctionCallableRef + type_ref: JustTypeRef + egg_name: str | None = None + + def to_function_decl(self) -> FunctionDecl: + return FunctionDecl( + arg_types=(), + arg_names=(), + arg_defaults=(), + return_type=self.type_ref.to_var(), + egg_name=self.egg_name, + ) @dataclass(frozen=True) class FunctionDecl: + # All args are delayed except for relations converted to function decls arg_types: tuple[TypeOrVarRef, ...] - # Is None for relation which doesn't have named args - arg_names: tuple[str, ...] | None + arg_names: tuple[str, ...] + # List of defaults. None for any arg which doesn't have one. arg_defaults: tuple[ExprDecl | None, ...] - return_type: TypeOrVarRef - mutates_first_arg: bool + # If None, then the first arg is mutated and returned + return_type: TypeOrVarRef | None var_arg_type: TypeOrVarRef | None = None - def __post_init__(self) -> None: - # If we mutate the first arg, then the first arg should be the same type as the return - if self.mutates_first_arg: - assert self.arg_types[0] == self.return_type - - def to_signature(self, transform_default: Callable[[TypedExprDecl], object]) -> Signature: - arg_names = self.arg_names or tuple(f"__{i}" for i in range(len(self.arg_types))) - parameters = [ - Parameter( - n, - Parameter.POSITIONAL_OR_KEYWORD, - default=transform_default(TypedExprDecl(t.to_just(), d)) if d else Parameter.empty, - ) - for n, d, t in zip(arg_names, self.arg_defaults, self.arg_types, strict=True) - ] - if self.var_arg_type is not None: - parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL)) - return Signature(parameters) + # Egg params + builtin: bool = False + egg_name: str | None = None + cost: int | None = None + default: ExprDecl | None = None + on_merge: tuple[ActionDecl, ...] = () + merge: ExprDecl | None = None + unextractable: bool = False + def to_function_decl(self) -> FunctionDecl: + return self -@dataclass(frozen=True) -class VarDecl: - name: str + @property + def semantic_return_type(self) -> TypeOrVarRef: + """ + The type that is returned by the function, which wil be in the first arg if it mutates it. + """ + return self.return_type or self.arg_types[0] - @classmethod - def from_egg(cls, var: bindings.TermVar) -> ExprDecl: - return cls(var.name) + @property + def mutates(self) -> bool: + return self.return_type is None - def to_egg(self, _decls: Declarations) -> bindings.Var: - return bindings.Var(self.name) - def pretty(self, context: PrettyContext, **kwargs) -> str: - return self.name +CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl + +## +# Expressions +## + + +@dataclass(frozen=True) +class VarDecl: + name: str @dataclass(frozen=True) @@ -646,16 +402,14 @@ def __hash__(self) -> int: except TypeError: return id(self.value) - @classmethod - def from_egg(cls, egraph: bindings.EGraph, termdag: bindings.TermDag, term: bindings.TermApp) -> ExprDecl: - call = bindings.termdag_term_to_expr(termdag, term) - return cls(egraph.eval_py_object(call)) - - def to_egg(self, _decls: Declarations) -> bindings._Expr: - return GLOBAL_PY_OBJECT_SORT.store(self.value) + def __eq__(self, other: object) -> bool: + if not isinstance(other, PyObjectDecl): + return False + return self.parts == other.parts - def pretty(self, context: PrettyContext, **kwargs) -> str: - return repr(self.value) + @property + def parts(self) -> tuple[type, object]: + return (type(self.value), self.value) LitType: TypeAlias = int | str | float | bool | None @@ -665,53 +419,30 @@ def pretty(self, context: PrettyContext, **kwargs) -> str: class LitDecl: value: LitType - @classmethod - def from_egg(cls, lit: bindings.TermLit) -> ExprDecl: - value = lit.value - if isinstance(value, bindings.Unit): - return cls(None) - return cls(value.value) - - def to_egg(self, _decls: Declarations) -> bindings.Lit: - if self.value is None: - return bindings.Lit(bindings.Unit()) - if isinstance(self.value, bool): - return bindings.Lit(bindings.Bool(self.value)) - if isinstance(self.value, int): - return bindings.Lit(bindings.Int(self.value)) - if isinstance(self.value, float): - return bindings.Lit(bindings.F64(self.value)) - if isinstance(self.value, str): - return bindings.Lit(bindings.String(self.value)) - assert_never(self.value) - - def pretty(self, context: PrettyContext, unwrap_lit: bool = True, **kwargs) -> str: + def __hash__(self) -> int: """ - Returns a string representation of the literal. - - :param wrap_lit: If True, wraps the literal in a call to the literal constructor. + Include type in has so that 1.0 != 1 """ - if self.value is None: - return "Unit()" - if isinstance(self.value, bool): - return f"Bool({self.value})" if not unwrap_lit else str(self.value) - if isinstance(self.value, int): - return f"i64({self.value})" if not unwrap_lit else str(self.value) - if isinstance(self.value, float): - return f"f64({self.value})" if not unwrap_lit else str(self.value) - if isinstance(self.value, str): - return f"String({self.value!r})" if not unwrap_lit else repr(self.value) - assert_never(self.value) + return hash(self.parts) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LitDecl): + return False + return self.parts == other.parts + + @property + def parts(self) -> tuple[type, LitType]: + return (type(self.value), self.value) @dataclass(frozen=True) class CallDecl: callable: CallableRef + # TODO: Can I make these not typed expressions? args: tuple[TypedExprDecl, ...] = () # type parameters that were bound to the callable, if it is a classmethod # Used for pretty printing classmethod calls with type parameters bound_tp_params: tuple[JustTypeRef, ...] | None = None - _cached_hash: int | None = None def __post_init__(self) -> None: if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef): @@ -719,302 +450,165 @@ def __post_init__(self) -> None: raise ValueError(msg) def __hash__(self) -> int: - # Modified hash which will cache result for performance - if self._cached_hash is None: - res = hash((self.callable, self.args, self.bound_tp_params)) - object.__setattr__(self, "_cached_hash", res) - return res return self._cached_hash + @cached_property + def _cached_hash(self) -> int: + return hash((self.callable, self.args, self.bound_tp_params)) + def __eq__(self, other: object) -> bool: # Override eq to use cached hash for perf if not isinstance(other, CallDecl): return False return hash(self) == hash(other) - @classmethod - def from_egg( - cls, - egraph: bindings.EGraph, - decls: Declarations, - return_tp: JustTypeRef, - termdag: bindings.TermDag, - term: bindings.TermApp, - cache: dict[int, TypedExprDecl], - ) -> ExprDecl: - """ - Convert an egg expression into a typed expression by using the declerations. - Also pass in the desired type to do type checking top down. Needed to disambiguate calls like (map-create) - during expression extraction, where we always know the types. +ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl + + +@dataclass(frozen=True) +class TypedExprDecl: + tp: JustTypeRef + expr: ExprDecl + + def descendants(self) -> list[TypedExprDecl]: """ - from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver - - # Find the first callable ref that matches the call - for callable_ref in decls.get_callable_refs(term.name): - # If this is a classmethod, we might need the type params that were bound for this type - # This could be multiple types if the classmethod is ambiguous, like map create. - possible_types: Iterable[JustTypeRef | None] - fn_decl = decls.get_function_decl(callable_ref) - if isinstance(callable_ref, ClassMethodRef): - possible_types = decls.get_possible_types(callable_ref.class_name) - cls_name = callable_ref.class_name - else: - possible_types = [None] - cls_name = None - for possible_type in possible_types: - tcs = TypeConstraintSolver(decls) - if possible_type and possible_type.args: - tcs.bind_class(possible_type) - - try: - arg_types, bound_tp_params = tcs.infer_arg_types( - fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, return_tp, cls_name - ) - except TypeConstraintError: - continue - args: list[TypedExprDecl] = [] - for a, tp in zip(term.args, arg_types, strict=False): - if a in cache: - res = cache[a] - else: - res = TypedExprDecl.from_egg(egraph, decls, tp, termdag, termdag.nodes[a], cache) - cache[a] = res - args.append(res) - return cls(callable_ref, tuple(args), bound_tp_params) - raise ValueError(f"Could not find callable ref for call {term}") - - def to_egg(self, decls: Declarations) -> bindings._Expr: - """Convert a Call to an egg Call.""" - # This was removed when we replaced declerations constants with our b/c of unextractable constants - # # If this is a constant, then emit it just as a var, not as a call - # if isinstance(self.callable, ConstantRef | ClassVariableRef): - # decls.get_egg_fn - # return bindings.Var(egg_fn) - if hasattr(self, "_cached_egg"): - return self._cached_egg - egg_fn = decls.get_egg_fn(self.callable) - res = bindings.Call(egg_fn, [a.to_egg(decls) for a in self.args]) - object.__setattr__(self, "_cached_egg", res) - return res - - def pretty(self, context: PrettyContext, parens: bool = True, **kwargs) -> str: # noqa: C901 + Returns a list of all the descendants of this expression. """ - Pretty print the call. + l = [self] + if isinstance(self.expr, CallDecl): + for a in self.expr.args: + l.extend(a.descendants()) + return l - :param parens: If true, wrap the call in parens if it is a binary method call. - """ - if self in context.names: - return context.names[self] - ref, args = self.callable, [a.expr for a in self.args] - # Special case != - if ref == FunctionRef("!="): - return f"ne({args[0].pretty(context, parens=False, unwrap_lit=False)}).to({args[1].pretty(context, parens=False, unwrap_lit=False)})" - function_decl = context.decls.get_function_decl(ref) - # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default - n_defaults = 0 - for arg, default in zip( - reversed(args), reversed(function_decl.arg_defaults), strict=not function_decl.var_arg_type - ): - if arg != default: - break - n_defaults += 1 - if n_defaults: - args = args[:-n_defaults] - if function_decl.mutates_first_arg: - first_arg = args[0] - expr_str = first_arg.pretty(context, parens=False) - # copy an identifer expression iff it has multiple parents (b/c then we can't mutate it directly) - has_multiple_parents = context.parents[first_arg] > 1 - expr_name = context.name_expr(function_decl.arg_types[0], expr_str, copy_identifier=has_multiple_parents) - # Set the first arg to be the name of the mutated arg and return the name - args[0] = VarDecl(expr_name) - else: - expr_name = None - match ref: - case FunctionRef(name): - expr = _pretty_call(context, name, args) - case ClassMethodRef(class_name, method_name): - tp_ref = JustTypeRef(class_name, self.bound_tp_params or ()) - fn_str = tp_ref.pretty() if method_name == "__init__" else f"{tp_ref.pretty()}.{method_name}" - expr = _pretty_call(context, fn_str, args) - case MethodRef(_class_name, method_name): - slf, *args = args - slf = slf.pretty(context, unwrap_lit=False) - match method_name: - case _ if method_name in UNARY_METHODS: - expr = f"{UNARY_METHODS[method_name]}{slf}" - case _ if method_name in BINARY_METHODS: - assert len(args) == 1 - expr = f"{slf} {BINARY_METHODS[method_name]} {args[0].pretty(context)}" - if parens: - expr = f"({expr})" - case "__getitem__": - assert len(args) == 1 - expr = f"{slf}[{args[0].pretty(context, parens=False)}]" - case "__call__": - expr = _pretty_call(context, slf, args) - case "__delitem__": - assert len(args) == 1 - expr = f"del {slf}[{args[0].pretty(context, parens=False)}]" - case "__setitem__": - assert len(args) == 2 - expr = ( - f"{slf}[{args[0].pretty(context, parens=False)}] = {args[1].pretty(context, parens=False)}" - ) - case _: - expr = _pretty_call(context, f"{slf}.{method_name}", args) - case ConstantRef(name): - expr = name - case ClassVariableRef(class_name, variable_name): - expr = f"{class_name}.{variable_name}" - case PropertyRef(_class_name, property_name): - expr = f"{args[0].pretty(context)}.{property_name}" - case _: - assert_never(ref) - # If we have a name, then we mutated - if expr_name: - context.statements.append(expr) - context.names[self] = expr_name - return expr_name - - # We use a heuristic to decide whether to name this sub-expression as a variable - # The rough goal is to reduce the number of newlines, given our line length of ~180 - # We determine it's worth making a new line for this expression if the total characters - # it would take up is > than some constant (~ line length). - n_parents = context.parents[self] - line_diff: int = len(expr) - LINE_DIFFERENCE - if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH: - expr_name = context.name_expr(function_decl.return_type, expr, copy_identifier=False) - context.names[self] = expr_name - return expr_name - return expr - - -MAX_LINE_LENGTH = 110 -LINE_DIFFERENCE = 10 - - -def _plot_line_length(expr: object): - """ - Plots the number of line lengths based on different max lengths - """ - global MAX_LINE_LENGTH, LINE_DIFFERENCE - import altair as alt - import pandas as pd - sizes = [] - for line_length in range(40, 180, 10): - MAX_LINE_LENGTH = line_length - for diff in range(0, 40, 5): - LINE_DIFFERENCE = diff - new_l = len(str(expr).split()) - sizes.append((line_length, diff, new_l)) +## +# Schedules +## - df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"]) # noqa: PD901 - return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q") +@dataclass(frozen=True) +class SaturateDecl: + schedule: ScheduleDecl -def _pretty_call(context: PrettyContext, fn: str, args: Iterable[ExprDecl]) -> str: - return f"{fn}({', '.join(a.pretty(context, parens=False) for a in args)})" +@dataclass(frozen=True) +class RepeatDecl: + schedule: ScheduleDecl + times: int -@dataclass -class PrettyContext: - decls: Declarations - # List of statements of "context" setting variable for the expr - statements: list[str] = field(default_factory=list) - - names: dict[ExprDecl, str] = field(default_factory=dict) - parents: dict[ExprDecl, int] = field(default_factory=lambda: defaultdict(lambda: 0)) - _traversed_exprs: set[ExprDecl] = field(default_factory=set) - - # Mapping of type to the number of times we have generated a name for that type, used to generate unique names - _gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0)) - - def generate_name(self, typ: str) -> str: - self._gen_name_types[typ] += 1 - return f"_{typ}_{self._gen_name_types[typ]}" - - def name_expr(self, expr_type: TypeOrVarRef, expr_str: str, copy_identifier: bool) -> str: - tp_name = expr_type.to_just().name - # If the thing we are naming is already a variable, we don't need to name it - if expr_str.isidentifier(): - if copy_identifier: - name = self.generate_name(tp_name) - self.statements.append(f"{name} = copy({expr_str})") - else: - name = expr_str - else: - name = self.generate_name(tp_name) - self.statements.append(f"{name} = {expr_str}") - return name +@dataclass(frozen=True) +class SequenceDecl: + schedules: tuple[ScheduleDecl, ...] + - def render(self, expr: str) -> str: - return "\n".join([*self.statements, expr]) +@dataclass(frozen=True) +class RunDecl: + ruleset: str + until: tuple[FactDecl, ...] | None - def traverse_for_parents(self, expr: ExprDecl) -> None: - if expr in self._traversed_exprs: - return - self._traversed_exprs.add(expr) - if isinstance(expr, CallDecl): - for arg in set(expr.args): - self.parents[arg.expr] += 1 - self.traverse_for_parents(arg.expr) +ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl -ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl +## +# Facts +## @dataclass(frozen=True) -class TypedExprDecl: +class EqDecl: tp: JustTypeRef - expr: ExprDecl + exprs: tuple[ExprDecl, ...] - @classmethod - def from_egg( - cls, - egraph: bindings.EGraph, - decls: Declarations, - tp: JustTypeRef, - termdag: bindings.TermDag, - term: bindings._Term, - cache: dict[int, TypedExprDecl], - ) -> TypedExprDecl: - expr_decl: ExprDecl - if isinstance(term, bindings.TermVar): - expr_decl = VarDecl.from_egg(term) - elif isinstance(term, bindings.TermLit): - expr_decl = LitDecl.from_egg(term) - elif isinstance(term, bindings.TermApp): - if term.name == "py-object": - expr_decl = PyObjectDecl.from_egg(egraph, termdag, term) - else: - expr_decl = CallDecl.from_egg(egraph, decls, tp, termdag, term, cache) - else: - assert_never(term) - return cls(tp, expr_decl) - def to_egg(self, decls: Declarations) -> bindings._Expr: - return self.expr.to_egg(decls) +@dataclass(frozen=True) +class ExprFactDecl: + typed_expr: TypedExprDecl - def descendants(self) -> list[TypedExprDecl]: - """ - Returns a list of all the descendants of this expression. - """ - l = [self] - if isinstance(self.expr, CallDecl): - for a in self.expr.args: - l.extend(a.descendants()) - return l +FactDecl: TypeAlias = EqDecl | ExprFactDecl -@dataclass -class ClassDecl: - methods: dict[str, FunctionDecl] = field(default_factory=dict) - class_methods: dict[str, FunctionDecl] = field(default_factory=dict) - class_variables: dict[str, JustTypeRef] = field(default_factory=dict) - properties: dict[str, FunctionDecl] = field(default_factory=dict) - preserved_methods: dict[str, Callable] = field(default_factory=dict) - type_vars: tuple[str, ...] = field(default=()) +## +# Actions +## + + +@dataclass(frozen=True) +class LetDecl: + name: str + typed_expr: TypedExprDecl + + +@dataclass(frozen=True) +class SetDecl: + tp: JustTypeRef + call: CallDecl + rhs: ExprDecl + + +@dataclass(frozen=True) +class ExprActionDecl: + typed_expr: TypedExprDecl + + +@dataclass(frozen=True) +class ChangeDecl: + tp: JustTypeRef + call: CallDecl + change: Literal["delete", "subsume"] + + +@dataclass(frozen=True) +class UnionDecl: + tp: JustTypeRef + lhs: ExprDecl + rhs: ExprDecl + + +@dataclass(frozen=True) +class PanicDecl: + msg: str + + +ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl + + +## +# Commands +## + + +@dataclass(frozen=True) +class RewriteDecl: + tp: JustTypeRef + lhs: ExprDecl + rhs: ExprDecl + conditions: tuple[FactDecl, ...] + subsume: bool + + +@dataclass(frozen=True) +class BiRewriteDecl: + tp: JustTypeRef + lhs: ExprDecl + rhs: ExprDecl + conditions: tuple[FactDecl, ...] + + +@dataclass(frozen=True) +class RuleDecl: + head: tuple[ActionDecl, ...] + body: tuple[FactDecl, ...] + name: str | None + + +RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl + + +@dataclass(frozen=True) +class ActionCommandDecl: + action: ActionDecl + + +CommandDecl: TypeAlias = RewriteOrRuleDecl | ActionCommandDecl diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index d3b059bc..4dc27fee 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -3,12 +3,10 @@ import inspect import pathlib import tempfile -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Callable, Iterable from contextvars import ContextVar, Token -from copy import deepcopy from dataclasses import InitVar, dataclass, field -from functools import cached_property from inspect import Parameter, currentframe, signature from types import FrameType, FunctionType from typing import ( @@ -18,6 +16,7 @@ Generic, Literal, NoReturn, + TypeAlias, TypedDict, TypeVar, cast, @@ -26,14 +25,16 @@ ) import graphviz -from typing_extensions import ParamSpec, Self, Unpack, deprecated - -from egglog.declarations import REFLECTED_BINARY_METHODS, Declarations +from typing_extensions import ParamSpec, Self, Unpack, assert_never, deprecated from . import bindings +from .conversion import * from .declarations import * +from .egraph_state import * from .ipython_magic import IN_IPYTHON +from .pretty import pretty_decl from .runtime import * +from .thunk import * if TYPE_CHECKING: import ipywidgets @@ -58,6 +59,7 @@ "let", "constant", "delete", + "subsume", "union", "set_", "rule", @@ -65,6 +67,9 @@ "vars_", "Fact", "expr_parts", + "expr_action", + "expr_fact", + "action_command", "Schedule", "run", "seq", @@ -79,11 +84,10 @@ "_NeBuilder", "_SetBuilder", "_UnionBuilder", - "Rule", - "Rewrite", - "BiRewrite", - "Union_", + "RewriteOrRule", + "Fact", "Action", + "Command", ] T = TypeVar("T") @@ -112,7 +116,24 @@ } +# special methods that return none and mutate self ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"} +# special methods which must return real python values instead of lazy expressions +ALWAYS_PRESERVED = { + "__repr__", + "__str__", + "__bytes__", + "__format__", + "__hash__", + "__bool__", + "__len__", + "__length_hint__", + "__iter__", + "__reversed__", + "__contains__", + "__index__", + "__bufer__", +} def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR: @@ -124,7 +145,7 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR: return EGraph().extract(x) -def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr | Set) -> None: +def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None: """ Verifies that the fact is true given some assumptions and after running the schedule. """ @@ -136,9 +157,6 @@ def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr | egraph.check(x) -# def extract(res: ) - - @dataclass class _BaseModule: """ @@ -181,15 +199,8 @@ def class_(self, *args, **kwargs) -> Any: Registers a class. """ if kwargs: - assert set(kwargs.keys()) == {"egg_sort"} - - def _inner(cls: object, egg_sort: str = kwargs["egg_sort"]): - assert isinstance(cls, RuntimeClass) - assert isinstance(cls.lazy_decls, _ClassDeclerationsConstructor) - cls.lazy_decls.egg_sort = egg_sort - return cls - - return _inner + msg = "Switch to subclassing from Expr and passing egg_sort as a keyword arg to the class constructor" + raise NotImplementedError(msg) assert len(args) == 1 return args[0] @@ -280,9 +291,9 @@ def function(self, *args, **kwargs) -> Any: # If we have any positional args, then we are calling it directly on a function if args: assert len(args) == 1 - return _function(args[0], fn_locals, False) + return _FunctionConstructor(fn_locals)(args[0]) # otherwise, we are passing some keyword args, so save those, and then return a partial - return lambda fn: _function(fn, fn_locals, False, **kwargs) + return _FunctionConstructor(fn_locals, **kwargs) @deprecated("Use top level `ruleset` function instead") def ruleset(self, name: str) -> Ruleset: @@ -324,17 +335,26 @@ def constant(self, name: str, tp: type[EXPR], egg_name: str | None = None) -> EX """ return constant(name, tp, egg_name) - def register(self, /, command_or_generator: CommandLike | CommandGenerator, *command_likes: CommandLike) -> None: + def register( + self, + /, + command_or_generator: ActionLike | RewriteOrRule | RewriteOrRuleGenerator, + *command_likes: ActionLike | RewriteOrRule, + ) -> None: """ Registers any number of rewrites or rules. """ if isinstance(command_or_generator, FunctionType): assert not command_likes - command_likes = tuple(_command_generator(command_or_generator)) + current_frame = inspect.currentframe() + assert current_frame + original_frame = current_frame.f_back + assert original_frame + command_likes = tuple(_rewrite_or_rule_generator(command_or_generator, original_frame)) else: command_likes = (cast(CommandLike, command_or_generator), *command_likes) - - self._register_commands(list(map(_command_like, command_likes))) + commands = [_command_like(c) for c in command_likes] + self._register_commands(commands) @abstractmethod def _register_commands(self, cmds: list[Command]) -> None: @@ -417,136 +437,116 @@ def __new__( # type: ignore[misc] # If this is the Expr subclass, just return the class if not bases: return super().__new__(cls, name, bases, namespace) + # TODO: Raise error on subclassing or multiple inheritence frame = currentframe() assert frame prev_frame = frame.f_back assert prev_frame - return _ClassDeclerationsConstructor( - namespace=namespace, - # Store frame so that we can get live access to updated locals/globals - # Otherwise, f_locals returns a copy - # https://peps.python.org/pep-0667/ - frame=prev_frame, - builtin=builtin, - egg_sort=egg_sort, - cls_name=name, - ).current_cls + + # Store frame so that we can get live access to updated locals/globals + # Otherwise, f_locals returns a copy + # https://peps.python.org/pep-0667/ + decls_thunk = Thunk.fn( + _generate_class_decls, namespace, prev_frame, builtin, egg_sort, name, fallback=Declarations + ) + return RuntimeClass(decls_thunk, TypeRefWithVars(name)) def __instancecheck__(cls, instance: object) -> bool: return isinstance(instance, RuntimeExpr) -@dataclass -class _ClassDeclerationsConstructor: +def _generate_class_decls( + namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str +) -> Declarations: """ Lazy constructor for class declerations to support classes with methods whose types are not yet defined. """ + parameters: list[TypeVar] = ( + # Get the generic params from the orig bases generic class + namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in namespace else [] + ) + type_vars = tuple(p.__name__ for p in parameters) + del parameters + cls_decl = ClassDecl(egg_sort, type_vars, builtin) + decls = Declarations(_classes={cls_name: cls_decl}) + + ## + # Register class variables + ## + # Create a dummy type to pass to get_type_hints to resolve the annotations we have + _Dummytype = type("_DummyType", (), {"__annotations__": namespace.get("__annotations__", {})}) + for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items(): + if getattr(v, "__origin__", None) == ClassVar: + (inner_tp,) = v.__args__ + type_ref = resolve_type_annotation(decls, inner_tp).to_just() + cls_decl.class_variables[k] = ConstantDecl(type_ref) - namespace: dict[str, Any] - frame: FrameType - builtin: bool - egg_sort: str | None - cls_name: str - current_cls: RuntimeClass = field(init=False) + else: + msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}" + raise NotImplementedError(msg) - def __post_init__(self) -> None: - self.current_cls = RuntimeClass(self, self.cls_name) - - def __call__(self, decls: Declarations) -> None: # noqa: PLR0912 - # Get all the methods from the class - cls_dict: dict[str, Any] = { - k: v for k, v in self.namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod) - } - parameters: list[TypeVar] = ( - # Get the generic params from the orig bases generic class - self.namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in self.namespace else [] - ) - type_vars = tuple(p.__name__ for p in parameters) - del parameters - - decls.register_class(self.cls_name, type_vars, self.builtin, self.egg_sort) - # The type ref of self is paramterized by the type vars - slf_type_ref = TypeRefWithVars(self.cls_name, tuple(map(ClassTypeVarRef, type_vars))) - - # Create a dummy type to pass to get_type_hints to resolve the annotations we have - class _Dummytype: - pass - - _Dummytype.__annotations__ = self.namespace.get("__annotations__", {}) - # Make lazy update to locals, so we keep a live handle on them after class creation - locals = self.frame.f_locals.copy() - locals[self.cls_name] = self.current_cls - for k, v in get_type_hints(_Dummytype, globalns=self.frame.f_globals, localns=locals).items(): - if getattr(v, "__origin__", None) == ClassVar: - (inner_tp,) = v.__args__ - _register_constant(decls, ClassVariableRef(self.cls_name, k), inner_tp, None) - else: - msg = f"On class {self.cls_name}, for attribute '{k}', expected a ClassVar, but got {v}" - raise NotImplementedError(msg) - - # Then register each of its methods - for method_name, method in cls_dict.items(): - is_init = method_name == "__init__" - # Don't register the init methods for literals, since those don't use the type checking mechanisms - if is_init and self.cls_name in LIT_CLASS_NAMES: - continue - if isinstance(method, _WrappedMethod): - fn = method.fn - egg_fn = method.egg_fn - cost = method.cost - default = method.default - merge = method.merge - on_merge = method.on_merge - mutates_first_arg = method.mutates_self - unextractable = method.unextractable - if method.preserve: - decls.register_preserved_method(self.cls_name, method_name, fn) - continue - else: - fn = method + ## + # Register methods, classmethods, preserved methods, and properties + ## + + # The type ref of self is paramterized by the type vars + slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars))) + + # Get all the methods from the class + filtered_namespace: list[tuple[str, Any]] = [ + (k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod) + ] + + # Then register each of its methods + for method_name, method in filtered_namespace: + is_init = method_name == "__init__" + # Don't register the init methods for literals, since those don't use the type checking mechanisms + if is_init and cls_name in LIT_CLASS_NAMES: + continue + match method: + case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable): + pass + case _: egg_fn, cost, default, merge, on_merge = None, None, None, None, None - unextractable = False - mutates_first_arg = False - if isinstance(fn, classmethod): - fn = fn.__func__ - is_classmethod = True - else: - # We count __init__ as a classmethod since it is called on the class - is_classmethod = is_init - - if isinstance(fn, property): - fn = fn.fget - is_property = True - if is_classmethod: - msg = "Can't have a classmethod property" - raise NotImplementedError(msg) - else: - is_property = False - ref: FunctionCallableRef = ( - ClassMethodRef(self.cls_name, method_name) - if is_classmethod - else PropertyRef(self.cls_name, method_name) - if is_property - else MethodRef(self.cls_name, method_name) - ) - _register_function( + fn = method + unextractable, preserve = False, False + mutates = method_name in ALWAYS_MUTATES_SELF + if preserve: + cls_decl.preserved_methods[method_name] = fn + continue + locals = frame.f_locals + + def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl: + return _fn_decl( decls, - ref, - egg_fn, + egg_fn, # noqa: B023 fn, - locals, - default, - cost, - merge, - on_merge, - mutates_first_arg or method_name in ALWAYS_MUTATES_SELF, - self.builtin, - "cls" if is_classmethod and not is_init else slf_type_ref, - is_init, - unextractable=unextractable, + locals, # noqa: B023 + default, # noqa: B023 + cost, # noqa: B023 + merge, # noqa: B023 + on_merge, # noqa: B023 + mutates, # noqa: B023 + builtin, + first, + is_init, # noqa: B023 + unextractable, # noqa: B023 ) + match fn: + case classmethod(): + cls_decl.class_methods[method_name] = create_decl(fn.__func__, "cls") + case property(): + cls_decl.properties[method_name] = create_decl(fn.fget, slf_type_ref) + case _: + if is_init: + cls_decl.class_methods[method_name] = create_decl(fn, slf_type_ref) + else: + cls_decl.methods[method_name] = create_decl(fn, slf_type_ref) + + return decls + @overload def function(fn: CALLABLE, /) -> CALLABLE: ... @@ -589,48 +589,46 @@ def function(*args, **kwargs) -> Any: # If we have any positional args, then we are calling it directly on a function if args: assert len(args) == 1 - return _function(args[0], fn_locals, False) + return _FunctionConstructor(fn_locals)(args[0]) # otherwise, we are passing some keyword args, so save those, and then return a partial - return lambda fn: _function(fn, fn_locals, **kwargs) + return _FunctionConstructor(fn_locals, **kwargs) -def _function( - fn: Callable[..., RuntimeExpr], - hint_locals: dict[str, Any], - builtin: bool = False, - mutates_first_arg: bool = False, - egg_fn: str | None = None, - cost: int | None = None, - default: RuntimeExpr | None = None, - merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None, - on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None, - unextractable: bool = False, -) -> RuntimeFunction: - """ - Uncurried version of function decorator - """ - name = fn.__name__ - decls = Declarations() - _register_function( - decls, - FunctionRef(name), - egg_fn, - fn, - hint_locals, - default, - cost, - merge, - on_merge, - mutates_first_arg, - builtin, - unextractable=unextractable, - ) - return RuntimeFunction(decls, name) +@dataclass +class _FunctionConstructor: + hint_locals: dict[str, Any] + builtin: bool = False + mutates_first_arg: bool = False + egg_fn: str | None = None + cost: int | None = None + default: RuntimeExpr | None = None + merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None + on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None + unextractable: bool = False + + def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction: + return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__)) + + def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations: + decls = Declarations() + decls._functions[fn.__name__] = _fn_decl( + decls, + self.egg_fn, + fn, + self.hint_locals, + self.default, + self.cost, + self.merge, + self.on_merge, + self.mutates_first_arg, + self.builtin, + unextractable=self.unextractable, + ) + return decls -def _register_function( +def _fn_decl( decls: Declarations, - ref: FunctionCallableRef, egg_name: str | None, fn: object, # Pass in the locals, retrieved from the frame when wrapping, @@ -646,7 +644,7 @@ def _register_function( first_arg: Literal["cls"] | TypeOrVarRef | None = None, is_init: bool = False, unextractable: bool = False, -) -> None: +) -> FunctionDecl: if not isinstance(fn, FunctionType): raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}") @@ -699,8 +697,8 @@ def _register_function( None if merge is None else merge( - RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))), - RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))), + RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))), + RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))), ) ) decls |= merged @@ -710,30 +708,25 @@ def _register_function( if on_merge is None else _action_likes( on_merge( - RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))), - RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))), + RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))), + RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))), ) ) ) decls.update(*merge_action) - fn_decl = FunctionDecl( - return_type=return_type, + return FunctionDecl( + return_type=None if mutates_first_arg else return_type, var_arg_type=var_arg_type, arg_types=arg_types, arg_names=tuple(t.name for t in params), arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults), - mutates_first_arg=mutates_first_arg, - ) - decls.register_function_callable( - ref, - fn_decl, - egg_name, - cost, - None if default is None else default.__egg_typed_expr__.expr, - merged.__egg_typed_expr__.expr if merged is not None else None, - [a._to_egg_action() for a in merge_action], - unextractable, - is_builtin, + cost=cost, + egg_name=egg_name, + merge=merged.__egg_typed_expr__.expr if merged is not None else None, + unextractable=unextractable, + builtin=is_builtin, + default=None if default is None else default.__egg_typed_expr__.expr, + on_merge=tuple(a.action for a in merge_action), ) @@ -764,49 +757,31 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[.. """ Creates a function whose return type is `Unit` and has a default value. """ + decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn) + return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name))) + + +def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations: decls = Declarations() decls |= cast(RuntimeClass, Unit) - arg_types = tuple(resolve_type_annotation(decls, tp) for tp in tps) - fn_decl = FunctionDecl(arg_types, None, tuple(None for _ in tps), TypeRefWithVars("Unit"), mutates_first_arg=False) - decls.register_function_callable( - FunctionRef(name), - fn_decl, - egg_fn, - cost=None, - default=None, - merge=None, - merge_action=[], - unextractable=False, - builtin=False, - is_relation=True, - ) - return cast(Callable[..., Unit], RuntimeFunction(decls, name)) + arg_types = tuple(resolve_type_annotation(decls, tp).to_just() for tp in tps) + decls._functions[name] = RelationDecl(arg_types, tuple(None for _ in tps), egg_fn) + return decls def constant(name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR: """ - A "constant" is implemented as the instantiation of a value that takes no args. This creates a function with `name` and return type `tp` and returns a value of it being called. """ - ref = ConstantRef(name) - decls = Declarations() - type_ref = _register_constant(decls, ref, tp, egg_name) - return cast(EXPR, RuntimeExpr(decls, TypedExprDecl(type_ref, CallDecl(ref)))) + return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name))) -def _register_constant( - decls: Declarations, - ref: ConstantRef | ClassVariableRef, - tp: object, - egg_name: str | None, -) -> JustTypeRef: - """ - Register a constant, returning its typeref(). - """ +def _constant_thunk(name: str, tp: type, egg_name: str | None) -> tuple[Declarations, TypedExprDecl]: + decls = Declarations() type_ref = resolve_type_annotation(decls, tp).to_just() - decls.register_constant_callable(ref, type_ref, egg_name) - return type_ref + decls._constants[name] = ConstantDecl(type_ref, egg_name) + return decls, TypedExprDecl(type_ref, CallDecl(ConstantRef(name))) def _last_param_variable(params: list[Parameter]) -> bool: @@ -858,29 +833,6 @@ class GraphvizKwargs(TypedDict, total=False): split_primitive_outputs: bool -@dataclass -class _EGraphState: - """ - State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined. - """ - - # The decleratons we have added. The _cmds represent all the symbols we have added - decls: Declarations = field(default_factory=Declarations) - # List of rulesets already added, so we don't re-add them if they are passed again - added_rulesets: set[str] = field(default_factory=set) - - def add_decls(self, new_decls: Declarations) -> Iterable[bindings._Command]: - new_cmds = [v for k, v in new_decls._cmds.items() if k not in self.decls._cmds] - self.decls |= new_decls - return new_cmds - - def add_rulesets(self, rulesets: Iterable[Ruleset]) -> Iterable[bindings._Command]: - for ruleset in rulesets: - if ruleset.egg_name not in self.added_rulesets: - self.added_rulesets.add(ruleset.egg_name) - yield from ruleset._cmds - - @dataclass class EGraph(_BaseModule): """ @@ -892,56 +844,34 @@ class EGraph(_BaseModule): seminaive: InitVar[bool] = True save_egglog_string: InitVar[bool] = False - default_ruleset: Ruleset | None = None - _egraph: bindings.EGraph = field(repr=False, init=False) - _egglog_string: str | None = field(default=None, repr=False, init=False) - _state: _EGraphState = field(default_factory=_EGraphState, repr=False) + _state: EGraphState = field(init=False) # For pushing/popping with egglog - _state_stack: list[_EGraphState] = field(default_factory=list, repr=False) + _state_stack: list[EGraphState] = field(default_factory=list, repr=False) # For storing the global "current" egraph _token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False) def __post_init__(self, modules: list[Module], seminaive: bool, save_egglog_string: bool) -> None: - self._egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive) + egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive, record=save_egglog_string) + self._state = EGraphState(egraph) super().__post_init__(modules) for m in self._flatted_deps: - self._add_decls(*m.cmds) self._register_commands(m.cmds) - if save_egglog_string: - self._egglog_string = "" - - def _register_commands(self, commands: list[Command]) -> None: - for c in commands: - if c.ruleset: - self._add_schedule(c.ruleset) - - self._add_decls(*commands) - self._process_commands(command._to_egg_command(self._default_ruleset_name) for command in commands) - - def _process_commands(self, commands: Iterable[bindings._Command]) -> None: - commands = list(commands) - self._egraph.run_program(*commands) - if isinstance(self._egglog_string, str): - self._egglog_string += "\n".join(str(c) for c in commands) + "\n" def _add_decls(self, *decls: DeclerationsLike) -> None: - for d in upcast_decleratioons(decls): - self._process_commands(self._state.add_decls(d)) - - def _add_schedule(self, schedule: Schedule) -> None: - self._add_decls(schedule) - self._process_commands(self._state.add_rulesets(schedule._rulesets())) + for d in decls: + self._state.__egg_decls__ |= d @property def as_egglog_string(self) -> str: """ Returns the egglog string for this module. """ - if self._egglog_string is None: + cmds = self._egraph.commands() + if cmds is None: msg = "Can't get egglog string unless EGraph created with save_egglog_string=True" raise ValueError(msg) - return self._egglog_string + return cmds def _repr_mimebundle_(self, *args, **kwargs): """ @@ -954,7 +884,7 @@ def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source: kwargs.setdefault("split_primitive_outputs", True) n_inline = kwargs.pop("n_inline_leaves", 0) serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc] - serialized.map_ops(self._state.decls.op_mapping()) + serialized.map_ops(self._state.op_mapping()) for _ in range(n_inline): serialized.inline_leaves() original = serialized.to_dot() @@ -1016,17 +946,23 @@ def input(self, fn: Callable[..., String], path: str) -> None: Loads a CSV file and sets it as *input, output of the function. """ ref, decls = resolve_callable(fn) - fn_name = decls.get_egg_fn(ref) - self._process_commands(decls.list_cmds()) - self._process_commands([bindings.Input(fn_name, path)]) + self._add_decls(decls) + fn_name = self._state.callable_ref_to_egg(ref) + self._egraph.run_program(bindings.Input(fn_name, path)) def let(self, name: str, expr: EXPR) -> EXPR: """ Define a new expression in the egraph and return a reference to it. """ - self._register_commands([let(name, expr)]) - expr = to_runtime_expr(expr) - return cast(EXPR, RuntimeExpr(expr.__egg_decls__, TypedExprDecl(expr.__egg_typed_expr__.tp, VarDecl(name)))) + action = let(name, expr) + self.register(action) + runtime_expr = to_runtime_expr(expr) + return cast( + EXPR, + RuntimeExpr.__from_value__( + self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name)) + ), + ) @overload def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ... @@ -1041,28 +977,19 @@ def simplify( Simplifies the given expression. """ schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule - del limit_or_schedule - expr = to_runtime_expr(expr) - self._add_decls(expr) - self._add_schedule(schedule) - - # decls = Declarations.create(expr, schedule) - self._process_commands([bindings.Simplify(expr.__egg__, schedule._to_egg_schedule(self._default_ruleset_name))]) + del limit_or_schedule, until, ruleset + runtime_expr = to_runtime_expr(expr) + self._add_decls(runtime_expr, schedule) + egg_schedule = self._state.schedule_to_egg(schedule.schedule) + typed_expr = runtime_expr.__egg_typed_expr__ + egg_expr = self._state.expr_to_egg(typed_expr.expr) + self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule)) extract_report = self._egraph.extract_report() if not isinstance(extract_report, bindings.Best): msg = "No extract report saved" raise ValueError(msg) # noqa: TRY004 - new_typed_expr = TypedExprDecl.from_egg( - self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, extract_report.term, {} - ) - return cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr)) - - @property - def _default_ruleset_name(self) -> str: - if self.default_ruleset: - self._add_schedule(self.default_ruleset) - return self.default_ruleset.egg_name - return "" + (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp) + return cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr)) def include(self, path: str) -> None: """ @@ -1092,8 +1019,9 @@ def run( return self._run_schedule(limit_or_schedule) def _run_schedule(self, schedule: Schedule) -> bindings.RunReport: - self._add_schedule(schedule) - self._process_commands([bindings.RunSchedule(schedule._to_egg_schedule(self._default_ruleset_name))]) + self._add_decls(schedule) + egg_schedule = self._state.schedule_to_egg(schedule.schedule) + self._egraph.run_program(bindings.RunSchedule(egg_schedule)) run_report = self._egraph.run_report() if not run_report: msg = "No run report saved" @@ -1104,18 +1032,18 @@ def check(self, *facts: FactLike) -> None: """ Check if a fact is true in the egraph. """ - self._process_commands([self._facts_to_check(facts)]) + self._egraph.run_program(self._facts_to_check(facts)) def check_fail(self, *facts: FactLike) -> None: """ Checks that one of the facts is not true """ - self._process_commands([bindings.Fail(self._facts_to_check(facts))]) + self._egraph.run_program(bindings.Fail(self._facts_to_check(facts))) - def _facts_to_check(self, facts: Iterable[FactLike]) -> bindings.Check: - facts = _fact_likes(facts) + def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check: + facts = _fact_likes(fact_likes) self._add_decls(*facts) - egg_facts = [f._to_egg_fact() for f in _fact_likes(facts)] + egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)] return bindings.Check(egg_facts) @overload @@ -1128,16 +1056,17 @@ def extract(self, expr: EXPR, include_cost: bool = False) -> EXPR | tuple[EXPR, """ Extract the lowest cost expression from the egraph. """ - assert isinstance(expr, RuntimeExpr) - self._add_decls(expr) - extract_report = self._run_extract(expr.__egg__, 0) + runtime_expr = to_runtime_expr(expr) + self._add_decls(runtime_expr) + typed_expr = runtime_expr.__egg_typed_expr__ + extract_report = self._run_extract(typed_expr.expr, 0) + if not isinstance(extract_report, bindings.Best): msg = "No extract report saved" raise ValueError(msg) # noqa: TRY004 - new_typed_expr = TypedExprDecl.from_egg( - self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, extract_report.term, {} - ) - res = cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr)) + (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp) + + res = cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr)) if include_cost: return res, extract_report.cost return res @@ -1146,23 +1075,20 @@ def extract_multiple(self, expr: EXPR, n: int) -> list[EXPR]: """ Extract multiple expressions from the egraph. """ - assert isinstance(expr, RuntimeExpr) - self._add_decls(expr) + runtime_expr = to_runtime_expr(expr) + self._add_decls(runtime_expr) + typed_expr = runtime_expr.__egg_typed_expr__ - extract_report = self._run_extract(expr.__egg__, n) + extract_report = self._run_extract(typed_expr.expr, n) if not isinstance(extract_report, bindings.Variants): msg = "Wrong extract report type" raise ValueError(msg) # noqa: TRY004 - new_exprs = [ - TypedExprDecl.from_egg( - self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, term, {} - ) - for term in extract_report.terms - ] - return [cast(EXPR, RuntimeExpr(self._state.decls, expr)) for expr in new_exprs] + new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp) + return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs] - def _run_extract(self, expr: bindings._Expr, n: int) -> bindings._ExtractReport: - self._process_commands([bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n))))]) + def _run_extract(self, expr: ExprDecl, n: int) -> bindings._ExtractReport: + expr = self._state.expr_to_egg(expr) + self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n))))) extract_report = self._egraph.extract_report() if not extract_report: msg = "No extract report saved" @@ -1173,15 +1099,15 @@ def push(self) -> None: """ Push the current state of the egraph, so that it can be popped later and reverted back. """ - self._process_commands([bindings.Push(1)]) + self._egraph.run_program(bindings.Push(1)) self._state_stack.append(self._state) - self._state = deepcopy(self._state) + self._state = self._state.copy() def pop(self) -> None: """ Pop the current state of the egraph, reverting back to the previous state. """ - self._process_commands([bindings.Pop(1)]) + self._egraph.run_program(bindings.Pop(1)) self._state = self._state_stack.pop() def __enter__(self) -> Self: @@ -1217,9 +1143,10 @@ def eval(self, expr: Expr) -> object: """ Evaluates the given expression (which must be a primitive type), returning the result. """ - assert isinstance(expr, RuntimeExpr) - typed_expr = expr.__egg_typed_expr__ - egg_expr = expr.__egg__ + runtime_expr = to_runtime_expr(expr) + self._add_decls(runtime_expr) + typed_expr = runtime_expr.__egg_typed_expr__ + egg_expr = self._state.expr_to_egg(typed_expr.expr) match typed_expr.tp: case JustTypeRef("i64"): return self._egraph.eval_i64(egg_expr) @@ -1231,7 +1158,7 @@ def eval(self, expr: Expr) -> object: return self._egraph.eval_string(egg_expr) case JustTypeRef("PyObject"): return self._egraph.eval_py_object(egg_expr) - raise NotImplementedError(f"Eval not implemented for {typed_expr.tp.name}") + raise TypeError(f"Eval not implemented for {typed_expr.tp}") def saturate( self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs] @@ -1270,6 +1197,32 @@ def current(cls) -> EGraph: """ return CURRENT_EGRAPH.get() + @property + def _egraph(self) -> bindings.EGraph: + return self._state.egraph + + @property + def __egg_decls__(self) -> Declarations: + return self._state.__egg_decls__ + + def _register_commands(self, cmds: list[Command]) -> None: + self._add_decls(*cmds) + egg_cmds = list(map(self._command_to_egg, cmds)) + self._egraph.run_program(*egg_cmds) + + def _command_to_egg(self, cmd: Command) -> bindings._Command: + ruleset_name = "" + cmd_decl: CommandDecl + match cmd: + case RewriteOrRule(_, cmd_decl, ruleset): + if ruleset: + ruleset_name = ruleset.__egg_name__ + case Action(_, action): + cmd_decl = ActionCommandDecl(action) + case _: + assert_never(cmd) + return self._state.command_to_egg(cmd_decl, ruleset_name) + CURRENT_EGRAPH = ContextVar[EGraph]("CURRENT_EGRAPH") @@ -1316,61 +1269,53 @@ def __init__(self) -> None: ... def ruleset( - rule_or_generator: CommandLike | CommandGenerator | None = None, *rules: Rule | Rewrite, name: None | str = None + rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator | None = None, + *rules: RewriteOrRule, + name: None | str = None, ) -> Ruleset: """ Creates a ruleset with the following rules. If no name is provided, one is generated based on the current module """ - r = Ruleset(name=name) + r = Ruleset(name) if rule_or_generator is not None: - r.register(rule_or_generator, *rules) + r.register(rule_or_generator, *rules, _increase_frame=True) return r -class Schedule(ABC): +@dataclass +class Schedule(DelayedDeclerations): """ A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met """ + # Defer declerations so that we can have rule generators that used not yet defined yet + schedule: ScheduleDecl + + def __str__(self) -> str: + return pretty_decl(self.__egg_decls__, self.schedule) + + def __repr__(self) -> str: + return str(self) + def __mul__(self, length: int) -> Schedule: """ Repeat the schedule a number of times. """ - return Repeat(length, self) + return Schedule(self.__egg_decls_thunk__, RepeatDecl(self.schedule, length)) def saturate(self) -> Schedule: """ Run the schedule until the e-graph is saturated. """ - return Saturate(self) + return Schedule(self.__egg_decls_thunk__, SaturateDecl(self.schedule)) def __add__(self, other: Schedule) -> Schedule: """ Run two schedules in sequence. """ - return Sequence((self, other)) - - @abstractmethod - def __str__(self) -> str: - raise NotImplementedError - - @abstractmethod - def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule: - raise NotImplementedError - - @abstractmethod - def _rulesets(self) -> Iterable[Ruleset]: - """ - Mapping of all the rulesets used to commands. - """ - raise NotImplementedError - - @property - @abstractmethod - def __egg_decls__(self) -> Declarations: - raise NotImplementedError + return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule))) @dataclass @@ -1379,422 +1324,119 @@ class Ruleset(Schedule): A collection of rules, which can be run as a schedule. """ + __egg_decls_thunk__: Callable[[], Declarations] = field(init=False) + schedule: RunDecl = field(init=False) name: str | None - rules: list[Rule | Rewrite] = field(default_factory=list) - def append(self, rule: Rule | Rewrite) -> None: + # Current declerations we have accumulated + _current_egg_decls: Declarations = field(default_factory=Declarations) + # Current rulesets we have accumulated + __egg_ruleset__: RulesetDecl = field(init=False) + # Rule generator functions that have been deferred, to allow for late type binding + deferred_rule_gens: list[Callable[[], Iterable[RewriteOrRule]]] = field(default_factory=list) + + def __post_init__(self) -> None: + self.schedule = RunDecl(self.__egg_name__, ()) + self.__egg_ruleset__ = self._current_egg_decls._rulesets[self.__egg_name__] = RulesetDecl([]) + self.__egg_decls_thunk__ = self._update_egg_decls + + def _update_egg_decls(self) -> Declarations: + """ + To return the egg decls, we go through our deferred rules and add any we haven't yet + """ + while self.deferred_rule_gens: + rules = self.deferred_rule_gens.pop()() + self._current_egg_decls.update(*rules) + self.__egg_ruleset__.rules.extend(r.decl for r in rules) + return self._current_egg_decls + + def append(self, rule: RewriteOrRule) -> None: """ Register a rule with the ruleset. """ - self.rules.append(rule) + self._current_egg_decls |= rule + self.__egg_ruleset__.rules.append(rule.decl) - def register(self, /, rule_or_generator: CommandLike | CommandGenerator, *rules: Rule | Rewrite) -> None: + def register( + self, + /, + rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator, + *rules: RewriteOrRule, + _increase_frame: bool = False, + ) -> None: """ Register rewrites or rules, either as a function or as values. """ - if isinstance(rule_or_generator, FunctionType): - assert not rules - rules = tuple(_command_generator(rule_or_generator)) + if isinstance(rule_or_generator, RewriteOrRule): + self.append(rule_or_generator) + for r in rules: + self.append(r) else: - rules = (cast(Rule | Rewrite, rule_or_generator), *rules) - for r in rules: - self.append(r) - - @cached_property - def __egg_decls__(self) -> Declarations: - return Declarations.create(*self.rules) - - @property - def _cmds(self) -> list[bindings._Command]: - cmds = [r._to_egg_command(self.egg_name) for r in self.rules] - if self.egg_name: - cmds.insert(0, bindings.AddRuleset(self.egg_name)) - return cmds + assert not rules + current_frame = inspect.currentframe() + assert current_frame + original_frame = current_frame.f_back + assert original_frame + if _increase_frame: + original_frame = original_frame.f_back + assert original_frame + self.deferred_rule_gens.append(Thunk.fn(_rewrite_or_rule_generator, rule_or_generator, original_frame)) def __str__(self) -> str: - return f"ruleset(name={self.egg_name!r})" + return pretty_decl(self._current_egg_decls, self.__egg_ruleset__, ruleset_name=self.name) def __repr__(self) -> str: - if not self.rules: - return str(self) - rules = ", ".join(map(repr, self.rules)) - return f"ruleset({rules}, name={self.egg_name!r})" - - def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule: - return bindings.Run(self._to_egg_config()) - - def _to_egg_config(self) -> bindings.RunConfig: - return bindings.RunConfig(self.egg_name, None) - - def _rulesets(self) -> Iterable[Ruleset]: - yield self - - @property - def egg_name(self) -> str: - return self.name or f"_ruleset_{id(self)}" - - -class Command(ABC): - """ - A command that can be executed in the egg interpreter. - - We only use this for commands which return no result and don't create new Python objects. - - Anything that can be passed to the `register` function in a Module is a Command. - """ - - ruleset: Ruleset | None + return str(self) + # Create a unique name if we didn't pass one from the user @property - @abstractmethod - def __egg_decls__(self) -> Declarations: - raise NotImplementedError - - @abstractmethod - def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command: - raise NotImplementedError - - @abstractmethod - def __str__(self) -> str: - raise NotImplementedError + def __egg_name__(self) -> str: + return self.name or f"ruleset_{id(self)}" @dataclass -class Rewrite(Command): - ruleset: Ruleset | None - _lhs: RuntimeExpr - _rhs: RuntimeExpr - _conditions: tuple[Fact, ...] - _subsume: bool - _fn_name: ClassVar[str] = "rewrite" +class RewriteOrRule: + __egg_decls__: Declarations + decl: RewriteOrRuleDecl + ruleset: Ruleset | None = None def __str__(self) -> str: - args_str = ", ".join(map(str, [self._rhs, *self._conditions])) - return f"{self._fn_name}({self._lhs}).to({args_str})" - - def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command: - return bindings.RewriteCommand( - self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite(), self._subsume - ) - - def _to_egg_rewrite(self) -> bindings.Rewrite: - return bindings.Rewrite( - self._lhs.__egg_typed_expr__.expr.to_egg(self._lhs.__egg_decls__), - self._rhs.__egg_typed_expr__.expr.to_egg(self._rhs.__egg_decls__), - [c._to_egg_fact() for c in self._conditions], - ) - - @cached_property - def __egg_decls__(self) -> Declarations: - return Declarations.create(self._lhs, self._rhs, *self._conditions) - - def with_ruleset(self, ruleset: Ruleset) -> Rewrite: - return Rewrite(ruleset, self._lhs, self._rhs, self._conditions, self._subsume) + return pretty_decl(self.__egg_decls__, self.decl) - -@dataclass -class BiRewrite(Rewrite): - _fn_name: ClassVar[str] = "birewrite" - - def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command: - return bindings.BiRewriteCommand( - self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite() - ) + def __repr__(self) -> str: + return str(self) @dataclass -class Fact(ABC): +class Fact: """ A query on an EGraph, either by an expression or an equivalence between multiple expressions. """ - @abstractmethod - def _to_egg_fact(self) -> bindings._Fact: - raise NotImplementedError - - @property - @abstractmethod - def __egg_decls__(self) -> Declarations: - raise NotImplementedError - - -@dataclass -class Eq(Fact): - _exprs: list[RuntimeExpr] + __egg_decls__: Declarations + fact: FactDecl def __str__(self) -> str: - first, *rest = self._exprs - args_str = ", ".join(map(str, rest)) - return f"eq({first}).to({args_str})" - - def _to_egg_fact(self) -> bindings.Eq: - return bindings.Eq([e.__egg__ for e in self._exprs]) + return pretty_decl(self.__egg_decls__, self.fact) - @cached_property - def __egg_decls__(self) -> Declarations: - return Declarations.create(*self._exprs) - - -@dataclass -class ExprFact(Fact): - _expr: RuntimeExpr - - def __str__(self) -> str: - return str(self._expr) - - def _to_egg_fact(self) -> bindings.Fact: - return bindings.Fact(self._expr.__egg__) - - @cached_property - def __egg_decls__(self) -> Declarations: - return self._expr.__egg_decls__ + def __repr__(self) -> str: + return str(self) @dataclass -class Rule(Command): - head: tuple[Action, ...] - body: tuple[Fact, ...] - name: str - ruleset: Ruleset | None - - def __str__(self) -> str: - head_str = ", ".join(map(str, self.head)) - body_str = ", ".join(map(str, self.body)) - return f"rule({body_str}).then({head_str})" - - def _to_egg_command(self, default_ruleset_name: str) -> bindings.RuleCommand: - return bindings.RuleCommand( - self.name, - self.ruleset.egg_name if self.ruleset else default_ruleset_name, - bindings.Rule( - [a._to_egg_action() for a in self.head], - [f._to_egg_fact() for f in self.body], - ), - ) - - @cached_property - def __egg_decls__(self) -> Declarations: - return Declarations.create(*self.head, *self.body) - - -class Action(Command, ABC): +class Action: """ A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking. """ - @abstractmethod - def _to_egg_action(self) -> bindings._Action: - raise NotImplementedError - - def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command: - return bindings.ActionCommand(self._to_egg_action()) - - @property - def ruleset(self) -> None | Ruleset: # type: ignore[override] - return None - - -@dataclass -class Let(Action): - _name: str - _value: RuntimeExpr - - def __str__(self) -> str: - return f"let({self._name}, {self._value})" - - def _to_egg_action(self) -> bindings.Let: - return bindings.Let(self._name, self._value.__egg__) - - @property - def __egg_decls__(self) -> Declarations: - return self._value.__egg_decls__ - - -@dataclass -class Set(Action): - """ - Similar to union, except can be used on primitive expressions, whereas union can only be used on user defined expressions. - """ - - _call: RuntimeExpr - _rhs: RuntimeExpr - - def __str__(self) -> str: - return f"set({self._call}).to({self._rhs})" - - def _to_egg_action(self) -> bindings.Set: - egg_call = self._call.__egg__ - if not isinstance(egg_call, bindings.Call): - raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}") # noqa: TRY004 - return bindings.Set( - egg_call.name, - egg_call.args, - self._rhs.__egg__, - ) - - @cached_property - def __egg_decls__(self) -> Declarations: - return Declarations.create(self._call, self._rhs) - - -@dataclass -class ExprAction(Action): - _expr: RuntimeExpr - - def __str__(self) -> str: - return str(self._expr) - - def _to_egg_action(self) -> bindings.Expr_: - return bindings.Expr_(self._expr.__egg__) - - @property - def __egg_decls__(self) -> Declarations: - return self._expr.__egg_decls__ - - -@dataclass -class Change(Action): - """ - Change a function call in an EGraph. - """ - - change: Literal["delete", "subsume"] - _call: RuntimeExpr - - def __str__(self) -> str: - return f"{self.change}({self._call})" - - def _to_egg_action(self) -> bindings.Change: - egg_call = self._call.__egg__ - if not isinstance(egg_call, bindings.Call): - raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") # noqa: TRY004 - change: bindings._Change = bindings.Delete() if self.change == "delete" else bindings.Subsume() - return bindings.Change(change, egg_call.name, egg_call.args) - - @property - def __egg_decls__(self) -> Declarations: - return self._call.__egg_decls__ - - -@dataclass -class Union_(Action): # noqa: N801 - """ - Merges two equivalence classes of two expressions. - """ - - _lhs: RuntimeExpr - _rhs: RuntimeExpr - - def __str__(self) -> str: - return f"union({self._lhs}).with_({self._rhs})" - - def _to_egg_action(self) -> bindings.Union: - return bindings.Union(self._lhs.__egg__, self._rhs.__egg__) - - @cached_property - def __egg_decls__(self) -> Declarations: - return Declarations.create(self._lhs, self._rhs) - - -@dataclass -class Panic(Action): - message: str - - def __str__(self) -> str: - return f"panic({self.message})" - - def _to_egg_action(self) -> bindings.Panic: - return bindings.Panic(self.message) - - @cached_property - def __egg_decls__(self) -> Declarations: - return Declarations() - - -@dataclass -class Run(Schedule): - """Configuration of a run""" - - # None if using default ruleset - ruleset: Ruleset | None - until: tuple[Fact, ...] - - def __str__(self) -> str: - args_str = ", ".join(map(str, [self.ruleset, *self.until])) - return f"run({args_str})" - - def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule: - return bindings.Run(self._to_egg_config(default_ruleset_name)) - - def _to_egg_config(self, default_ruleset_name: str) -> bindings.RunConfig: - return bindings.RunConfig( - self.ruleset.egg_name if self.ruleset else default_ruleset_name, - [fact._to_egg_fact() for fact in self.until] if self.until else None, - ) - - def _rulesets(self) -> Iterable[Ruleset]: - if self.ruleset: - yield self.ruleset - - @cached_property - def __egg_decls__(self) -> Declarations: - return Declarations.create(self.ruleset, *self.until) - - -@dataclass -class Saturate(Schedule): - schedule: Schedule - - def __str__(self) -> str: - return f"{self.schedule}.saturate()" - - def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule: - return bindings.Saturate(self.schedule._to_egg_schedule(default_ruleset_name)) - - def _rulesets(self) -> Iterable[Ruleset]: - return self.schedule._rulesets() - - @property - def __egg_decls__(self) -> Declarations: - return self.schedule.__egg_decls__ - - -@dataclass -class Repeat(Schedule): - length: int - schedule: Schedule - - def __str__(self) -> str: - return f"{self.schedule} * {self.length}" - - def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule: - return bindings.Repeat(self.length, self.schedule._to_egg_schedule(default_ruleset_name)) - - def _rulesets(self) -> Iterable[Ruleset]: - return self.schedule._rulesets() - - @property - def __egg_decls__(self) -> Declarations: - return self.schedule.__egg_decls__ - - -@dataclass -class Sequence(Schedule): - schedules: tuple[Schedule, ...] + __egg_decls__: Declarations + action: ActionDecl def __str__(self) -> str: - return f"sequence({', '.join(map(str, self.schedules))})" - - def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule: - return bindings.Sequence([schedule._to_egg_schedule(default_ruleset_name) for schedule in self.schedules]) + return pretty_decl(self.__egg_decls__, self.action) - def _rulesets(self) -> Iterable[Ruleset]: - for s in self.schedules: - yield from s._rulesets() - - @cached_property - def __egg_decls__(self) -> Declarations: - return Declarations.create(*self.schedules) + def __repr__(self) -> str: + return str(self) # We use these builders so that when creating these structures we can type check @@ -1841,30 +1483,41 @@ def ne(expr: EXPR) -> _NeBuilder[EXPR]: def panic(message: str) -> Action: """Raise an error with the given message.""" - return Panic(message) + return Action(Declarations(), PanicDecl(message)) def let(name: str, expr: Expr) -> Action: """Create a let binding.""" - return Let(name, to_runtime_expr(expr)) + runtime_expr = to_runtime_expr(expr) + return Action(runtime_expr.__egg_decls__, LetDecl(name, runtime_expr.__egg_typed_expr__)) def expr_action(expr: Expr) -> Action: - return ExprAction(to_runtime_expr(expr)) + runtime_expr = to_runtime_expr(expr) + return Action(runtime_expr.__egg_decls__, ExprActionDecl(runtime_expr.__egg_typed_expr__)) def delete(expr: Expr) -> Action: """Create a delete expression.""" - return Change("delete", to_runtime_expr(expr)) + runtime_expr = to_runtime_expr(expr) + typed_expr = runtime_expr.__egg_typed_expr__ + call_decl = typed_expr.expr + assert isinstance(call_decl, CallDecl), "Can only delete calls, not literals or vars" + return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "delete")) def subsume(expr: Expr) -> Action: - """Subsume an expression.""" - return Change("subsume", to_runtime_expr(expr)) + """Subsume an expression so it cannot be matched against or extracted""" + runtime_expr = to_runtime_expr(expr) + typed_expr = runtime_expr.__egg_typed_expr__ + call_decl = typed_expr.expr + assert isinstance(call_decl, CallDecl), "Can only subsume calls, not literals or vars" + return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "subsume")) def expr_fact(expr: Expr) -> Fact: - return ExprFact(to_runtime_expr(expr)) + runtime_expr = to_runtime_expr(expr) + return Fact(runtime_expr.__egg_decls__, ExprFactDecl(runtime_expr.__egg_typed_expr__)) def union(lhs: EXPR) -> _UnionBuilder[EXPR]: @@ -1891,6 +1544,11 @@ def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = No return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset) +@deprecated("This function is now a no-op, you can remove it and use actions as commands") +def action_command(action: Action) -> Action: + return action + + def var(name: str, bound: type[EXPR]) -> EXPR: """Create a new variable with the given name and type.""" return cast(EXPR, _var(name, bound)) @@ -1898,9 +1556,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR: def _var(name: str, bound: object) -> RuntimeExpr: """Create a new variable with the given name and type.""" - if not isinstance(bound, RuntimeClass | RuntimeParamaterizedClass): + if not isinstance(bound, RuntimeClass): raise TypeError(f"Unexpected type {type(bound)}") - return RuntimeExpr(bound.__egg_decls__, TypedExprDecl(class_to_ref(bound), VarDecl(name))) + return RuntimeExpr.__from_value__(bound.__egg_decls__, TypedExprDecl(bound.__egg_tp__.to_just(), VarDecl(name))) def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]: @@ -1915,15 +1573,27 @@ class _RewriteBuilder(Generic[EXPR]): ruleset: Ruleset | None subsume: bool - def to(self, rhs: EXPR, *conditions: FactLike) -> Rewrite: + def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule: lhs = to_runtime_expr(self.lhs) - rule = Rewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions), self.subsume) + facts = _fact_likes(conditions) + rhs = convert_to_same_type(rhs, lhs) + rule = RewriteOrRule( + Declarations.create(lhs, rhs, *facts, self.ruleset), + RewriteDecl( + lhs.__egg_typed_expr__.tp, + lhs.__egg_typed_expr__.expr, + rhs.__egg_typed_expr__.expr, + tuple(f.fact for f in facts), + self.subsume, + ), + ) if self.ruleset: self.ruleset.append(rule) return rule def __str__(self) -> str: - return f"rewrite({self.lhs})" + lhs = to_runtime_expr(self.lhs) + return lhs.__egg_pretty__("rewrite") @dataclass @@ -1931,15 +1601,26 @@ class _BirewriteBuilder(Generic[EXPR]): lhs: EXPR ruleset: Ruleset | None - def to(self, rhs: EXPR, *conditions: FactLike) -> Command: + def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule: lhs = to_runtime_expr(self.lhs) - rule = BiRewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions), False) + facts = _fact_likes(conditions) + rhs = convert_to_same_type(rhs, lhs) + rule = RewriteOrRule( + Declarations.create(lhs, rhs, *facts, self.ruleset), + BiRewriteDecl( + lhs.__egg_typed_expr__.tp, + lhs.__egg_typed_expr__.expr, + rhs.__egg_typed_expr__.expr, + tuple(f.fact for f in facts), + ), + ) if self.ruleset: self.ruleset.append(rule) return rule def __str__(self) -> str: - return f"birewrite({self.lhs})" + lhs = to_runtime_expr(self.lhs) + return lhs.__egg_pretty__("birewrite") @dataclass @@ -1948,52 +1629,84 @@ class _EqBuilder(Generic[EXPR]): def to(self, *exprs: EXPR) -> Fact: expr = to_runtime_expr(self.expr) - return Eq([expr] + [convert_to_same_type(e, expr) for e in exprs]) + args = [expr, *(convert_to_same_type(e, expr) for e in exprs)] + return Fact( + Declarations.create(*args), + EqDecl(expr.__egg_typed_expr__.tp, tuple(a.__egg_typed_expr__.expr for a in args)), + ) + + def __repr__(self) -> str: + return str(self) def __str__(self) -> str: - return f"eq({self.expr})" + expr = to_runtime_expr(self.expr) + return expr.__egg_pretty__("eq") @dataclass class _NeBuilder(Generic[EXPR]): - expr: EXPR + lhs: EXPR - def to(self, expr: EXPR) -> Unit: - assert isinstance(self.expr, RuntimeExpr) - args = (self.expr, convert_to_same_type(expr, self.expr)) - decls = Declarations.create(*args) - res = RuntimeExpr( - decls, - TypedExprDecl(JustTypeRef("Unit"), CallDecl(FunctionRef("!="), tuple(a.__egg_typed_expr__ for a in args))), + def to(self, rhs: EXPR) -> Unit: + lhs = to_runtime_expr(self.lhs) + rhs = convert_to_same_type(rhs, lhs) + assert isinstance(Unit, RuntimeClass) + res = RuntimeExpr.__from_value__( + Declarations.create(Unit, lhs, rhs), + TypedExprDecl( + JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__)) + ), ) return cast(Unit, res) + def __repr__(self) -> str: + return str(self) + def __str__(self) -> str: - return f"ne({self.expr})" + expr = to_runtime_expr(self.lhs) + return expr.__egg_pretty__("ne") @dataclass class _SetBuilder(Generic[EXPR]): - lhs: Expr + lhs: EXPR - def to(self, rhs: EXPR) -> Set: + def to(self, rhs: EXPR) -> Action: lhs = to_runtime_expr(self.lhs) - return Set(lhs, convert_to_same_type(rhs, lhs)) + rhs = convert_to_same_type(rhs, lhs) + lhs_expr = lhs.__egg_typed_expr__.expr + assert isinstance(lhs_expr, CallDecl), "Can only set function calls" + return Action( + Declarations.create(lhs, rhs), + SetDecl(lhs.__egg_typed_expr__.tp, lhs_expr, rhs.__egg_typed_expr__.expr), + ) + + def __repr__(self) -> str: + return str(self) def __str__(self) -> str: - return f"set_({self.lhs})" + lhs = to_runtime_expr(self.lhs) + return lhs.__egg_pretty__("set_") @dataclass class _UnionBuilder(Generic[EXPR]): - lhs: Expr + lhs: EXPR def with_(self, rhs: EXPR) -> Action: lhs = to_runtime_expr(self.lhs) - return Union_(lhs, convert_to_same_type(rhs, lhs)) + rhs = convert_to_same_type(rhs, lhs) + return Action( + Declarations.create(lhs, rhs), + UnionDecl(lhs.__egg_typed_expr__.tp, lhs.__egg_typed_expr__.expr, rhs.__egg_typed_expr__.expr), + ) + + def __repr__(self) -> str: + return str(self) def __str__(self) -> str: - return f"union({self.lhs})" + lhs = to_runtime_expr(self.lhs) + return lhs.__egg_pretty__("union") @dataclass @@ -2002,12 +1715,25 @@ class _RuleBuilder: name: str | None ruleset: Ruleset | None - def then(self, *actions: ActionLike) -> Rule: - rule = Rule(_action_likes(actions), self.facts, self.name or "", self.ruleset) + def then(self, *actions: ActionLike) -> RewriteOrRule: + actions = _action_likes(actions) + rule = RewriteOrRule( + Declarations.create(self.ruleset, *actions, *self.facts), + RuleDecl(tuple(a.action for a in actions), tuple(f.fact for f in self.facts), self.name), + ) if self.ruleset: self.ruleset.append(rule) return rule + def __str__(self) -> str: + # TODO: Figure out how to stringify rulebuilder that preserves statements + args = list(map(str, self.facts)) + if self.name is not None: + args.append(f"name={self.name}") + if ruleset is not None: + args.append(f"ruleset={self.ruleset}") + return f"rule({', '.join(args)})" + def expr_parts(expr: Expr) -> TypedExprDecl: """ @@ -2024,60 +1750,61 @@ def to_runtime_expr(expr: Expr) -> RuntimeExpr: return expr -def run(ruleset: Ruleset | None = None, *until: Fact) -> Run: +def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule: """ Create a run configuration. """ - return Run(ruleset, tuple(until)) + facts = _fact_likes(until) + return Schedule( + Thunk.fn(Declarations.create, ruleset, *facts), + RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None), + ) def seq(*schedules: Schedule) -> Schedule: """ Run a sequence of schedules. """ - return Sequence(tuple(schedules)) + return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules))) -CommandLike = Command | Expr +ActionLike: TypeAlias = Action | Expr -def _command_like(command_like: CommandLike) -> Command: - if isinstance(command_like, Expr): - return expr_action(command_like) - return command_like +def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]: + return tuple(map(_action_like, action_likes)) -CommandGenerator = Callable[..., Iterable[Rule | Rewrite]] +def _action_like(action_like: ActionLike) -> Action: + if isinstance(action_like, Expr): + return expr_action(action_like) + return action_like -def _command_generator(gen: CommandGenerator) -> Iterable[Command]: - """ - Calls the function with variables of the type and name of the arguments. - """ - # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope - # but not in the globals - current_frame = inspect.currentframe() - assert current_frame - register_frame = current_frame.f_back - assert register_frame - original_frame = register_frame.f_back - assert original_frame - hints = get_type_hints(gen, gen.__globals__, original_frame.f_locals) - args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()) - return gen(*args) +Command: TypeAlias = Action | RewriteOrRule +CommandLike: TypeAlias = ActionLike | RewriteOrRule -ActionLike = Action | Expr +def _command_like(command_like: CommandLike) -> Command: + if isinstance(command_like, RewriteOrRule): + return command_like + return _action_like(command_like) -def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]: - return tuple(map(_action_like, action_likes)) +RewriteOrRuleGenerator = Callable[..., Iterable[RewriteOrRule]] -def _action_like(action_like: ActionLike) -> Action: - if isinstance(action_like, Expr): - return expr_action(action_like) - return action_like + +def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) -> Iterable[RewriteOrRule]: + """ + Returns a thunk which will call the function with variables of the type and name of the arguments. + """ + # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope + # but not in the globals + + hints = get_type_hints(gen, gen.__globals__, frame.f_locals) + args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()] + return list(gen(*args)) # type: ignore[misc] FactLike = Fact | Expr diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py new file mode 100644 index 00000000..e138070d --- /dev/null +++ b/python/egglog/egraph_state.py @@ -0,0 +1,417 @@ +""" +Implement conversion to/from egglog. +""" + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, overload +from weakref import WeakKeyDictionary + +from typing_extensions import assert_never + +from . import bindings +from .declarations import * +from .pretty import * +from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver + +if TYPE_CHECKING: + from collections.abc import Iterable + +__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"] + +# Create a global sort for python objects, so we can store them without an e-graph instance +# Needed when serializing commands to egg commands when creating modules +GLOBAL_PY_OBJECT_SORT = bindings.PyObjectSort() + + +@dataclass +class EGraphState: + """ + State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined. + + Used for converting to/from egg and for pretty printing. + """ + + egraph: bindings.EGraph + # The decleratons we have added. + __egg_decls__: Declarations = field(default_factory=Declarations) + # Mapping of added rulesets to the added rules + rulesets: dict[str, set[RewriteOrRuleDecl]] = field(default_factory=dict) + + # Bidirectional mapping between egg function names and python callable references. + # Note that there are possibly mutliple callable references for a single egg function name, like `+` + # for both int and rational classes. + egg_fn_to_callable_refs: dict[str, set[CallableRef]] = field( + default_factory=lambda: defaultdict(set, {"!=": {FunctionRef("!=")}}) + ) + callable_ref_to_egg_fn: dict[CallableRef, str] = field(default_factory=lambda: {FunctionRef("!="): "!="}) + + # Bidirectional mapping between egg sort names and python type references. + type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict) + + # Cache of egg expressions for converting to egg + expr_to_egg_cache: WeakKeyDictionary[ExprDecl, bindings._Expr] = field(default_factory=WeakKeyDictionary) + + def copy(self) -> EGraphState: + """ + Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping. + """ + return EGraphState( + egraph=self.egraph, + __egg_decls__=self.__egg_decls__.copy(), + rulesets={k: v.copy() for k, v in self.rulesets.items()}, + egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self.egg_fn_to_callable_refs.items()}), + callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(), + type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(), + expr_to_egg_cache=self.expr_to_egg_cache.copy(), + ) + + def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule: + match schedule: + case SaturateDecl(schedule): + return bindings.Saturate(self.schedule_to_egg(schedule)) + case RepeatDecl(schedule, times): + return bindings.Repeat(times, self.schedule_to_egg(schedule)) + case SequenceDecl(schedules): + return bindings.Sequence([self.schedule_to_egg(s) for s in schedules]) + case RunDecl(ruleset_name, until): + self.ruleset_to_egg(ruleset_name) + config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until))) + return bindings.Run(config) + case _: + assert_never(schedule) + + def ruleset_to_egg(self, name: str) -> None: + """ + Registers a ruleset if it's not already registered. + """ + if name not in self.rulesets: + if name: + self.egraph.run_program(bindings.AddRuleset(name)) + rules = self.rulesets[name] = set() + else: + rules = self.rulesets[name] + for rule in self.__egg_decls__._rulesets[name].rules: + if rule in rules: + continue + self.egraph.run_program(self.command_to_egg(rule, name)) + rules.add(rule) + + def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command: + match cmd: + case ActionCommandDecl(action): + return bindings.ActionCommand(self.action_to_egg(action)) + case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions): + self.type_ref_to_egg(tp) + rewrite = bindings.Rewrite( + self.expr_to_egg(lhs), + self.expr_to_egg(rhs), + [self.fact_to_egg(c) for c in conditions], + ) + return ( + bindings.RewriteCommand(ruleset, rewrite, cmd.subsume) + if isinstance(cmd, RewriteDecl) + else bindings.BiRewriteCommand(ruleset, rewrite) + ) + case RuleDecl(head, body, name): + rule = bindings.Rule( + [self.action_to_egg(a) for a in head], + [self.fact_to_egg(f) for f in body], + ) + return bindings.RuleCommand(name or "", ruleset, rule) + case _: + assert_never(cmd) + + def action_to_egg(self, action: ActionDecl) -> bindings._Action: + match action: + case LetDecl(name, typed_expr): + return bindings.Let(name, self.typed_expr_to_egg(typed_expr)) + case SetDecl(tp, call, rhs): + self.type_ref_to_egg(tp) + call_ = self.expr_to_egg(call) + return bindings.Set(call_.name, call_.args, self.expr_to_egg(rhs)) + case ExprActionDecl(typed_expr): + return bindings.Expr_(self.typed_expr_to_egg(typed_expr)) + case ChangeDecl(tp, call, change): + self.type_ref_to_egg(tp) + call_ = self.expr_to_egg(call) + egg_change: bindings._Change + match change: + case "delete": + egg_change = bindings.Delete() + case "subsume": + egg_change = bindings.Subsume() + case _: + assert_never(change) + return bindings.Change(egg_change, call_.name, call_.args) + case UnionDecl(tp, lhs, rhs): + self.type_ref_to_egg(tp) + return bindings.Union(self.expr_to_egg(lhs), self.expr_to_egg(rhs)) + case PanicDecl(name): + return bindings.Panic(name) + case _: + assert_never(action) + + def fact_to_egg(self, fact: FactDecl) -> bindings._Fact: + match fact: + case EqDecl(tp, exprs): + self.type_ref_to_egg(tp) + return bindings.Eq([self.expr_to_egg(e) for e in exprs]) + case ExprFactDecl(typed_expr): + return bindings.Fact(self.typed_expr_to_egg(typed_expr)) + case _: + assert_never(fact) + + def callable_ref_to_egg(self, ref: CallableRef) -> str: + """ + Returns the egg function name for a callable reference, registering it if it is not already registered. + """ + if ref in self.callable_ref_to_egg_fn: + return self.callable_ref_to_egg_fn[ref] + decl = self.__egg_decls__.get_callable_decl(ref) + self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _generate_callable_egg_name(ref) + self.egg_fn_to_callable_refs[egg_name].add(ref) + match decl: + case RelationDecl(arg_types, _, _): + self.egraph.run_program(bindings.Relation(egg_name, [self.type_ref_to_egg(a) for a in arg_types])) + case ConstantDecl(tp, _): + # Use function decleration instead of constant b/c constants cannot be extracted + # https://github.com/egraphs-good/egglog/issues/334 + self.egraph.run_program( + bindings.Function(bindings.FunctionDecl(egg_name, bindings.Schema([], self.type_ref_to_egg(tp)))) + ) + case FunctionDecl(): + if not decl.builtin: + egg_fn_decl = bindings.FunctionDecl( + egg_name, + bindings.Schema( + [self.type_ref_to_egg(a.to_just()) for a in decl.arg_types], + self.type_ref_to_egg(decl.semantic_return_type.to_just()), + ), + self.expr_to_egg(decl.default) if decl.default else None, + self.expr_to_egg(decl.merge) if decl.merge else None, + [self.action_to_egg(a) for a in decl.on_merge], + decl.cost, + decl.unextractable, + ) + self.egraph.run_program(bindings.Function(egg_fn_decl)) + case _: + assert_never(decl) + return egg_name + + def type_ref_to_egg(self, ref: JustTypeRef) -> str: + """ + Returns the egg sort name for a type reference, registering it if it is not already registered. + """ + try: + return self.type_ref_to_egg_sort[ref] + except KeyError: + pass + decl = self.__egg_decls__._classes[ref.name] + self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref) + if not decl.builtin or ref.args: + self.egraph.run_program( + bindings.Sort( + egg_name, + ( + ( + self.type_ref_to_egg(JustTypeRef(ref.name)), + [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args], + ) + if ref.args + else None + ), + ) + ) + # For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because + # these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted + # even if you never use that function. + if decl.builtin: + for method in decl.class_methods: + self.callable_ref_to_egg(ClassMethodRef(ref.name, method)) + + return egg_name + + def op_mapping(self) -> dict[str, str]: + """ + Create a mapping of egglog function name to Python function name, for use in the serialized format + for better visualization. + """ + return { + k: pretty_callable_ref(self.__egg_decls__, next(iter(v))) + for k, v in self.egg_fn_to_callable_refs.items() + if len(v) == 1 + } + + def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr: + self.type_ref_to_egg(typed_expr_decl.tp) + return self.expr_to_egg(typed_expr_decl.expr) + + @overload + def expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ... + + @overload + def expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ... + + def expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: + """ + Convert an ExprDecl to an egg expression. + + Cached using weakrefs to avoid memory leaks. + """ + try: + return self.expr_to_egg_cache[expr_decl] + except KeyError: + pass + + res: bindings._Expr + match expr_decl: + case VarDecl(name): + res = bindings.Var(name) + case LitDecl(value): + l: bindings._Literal + match value: + case None: + l = bindings.Unit() + case bool(i): + l = bindings.Bool(i) + case int(i): + l = bindings.Int(i) + case float(f): + l = bindings.F64(f) + case str(s): + l = bindings.String(s) + case _: + assert_never(value) + res = bindings.Lit(l) + case CallDecl(ref, args, _): + egg_fn = self.callable_ref_to_egg(ref) + egg_args = [self.typed_expr_to_egg(a) for a in args] + res = bindings.Call(egg_fn, egg_args) + case PyObjectDecl(value): + res = GLOBAL_PY_OBJECT_SORT.store(value) + case _: + assert_never(expr_decl.expr) + + self.expr_to_egg_cache[expr_decl] = res + return res + + def exprs_from_egg( + self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef + ) -> Iterable[TypedExprDecl]: + """ + Create a function that can convert from an egg term to a typed expr. + """ + state = FromEggState(self, termdag) + return [state.from_expr(tp, term) for term in terms] + + def _get_possible_types(self, cls_name: str) -> frozenset[JustTypeRef]: + """ + Given a class name, returns all possible registered types that it can be. + """ + return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.name == cls_name) + + +def _generate_type_egg_name(ref: JustTypeRef) -> str: + """ + Generates an egg sort name for this type reference by linearizing the type. + """ + name = ref.name + if not ref.args: + return name + return f"{name}_{'_'.join(map(_generate_type_egg_name, ref.args))}" + + +def _generate_callable_egg_name(ref: CallableRef) -> str: + """ + Generates a valid egg function name for a callable reference. + """ + match ref: + case FunctionRef(name) | ConstantRef(name): + return name + case ( + MethodRef(cls_name, name) + | ClassMethodRef(cls_name, name) + | ClassVariableRef(cls_name, name) + | PropertyRef(cls_name, name) + ): + return f"{cls_name}_{name}" + case _: + assert_never(ref) + + +@dataclass +class FromEggState: + """ + Dataclass containing state used when converting from an egg term to a typed expr. + """ + + state: EGraphState + termdag: bindings.TermDag + # Cache of termdag ID to TypedExprDecl + cache: dict[int, TypedExprDecl] = field(default_factory=dict) + + @property + def decls(self) -> Declarations: + return self.state.__egg_decls__ + + def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl: + """ + Convert an egg term to a typed expr. + """ + expr_decl: ExprDecl + if isinstance(term, bindings.TermVar): + expr_decl = VarDecl(term.name) + elif isinstance(term, bindings.TermLit): + value = term.value + expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value) + elif isinstance(term, bindings.TermApp): + if term.name == "py-object": + call = bindings.termdag_term_to_expr(self.termdag, term) + expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call)) + else: + expr_decl = self.from_call(tp, term) + else: + assert_never(term) + return TypedExprDecl(tp, expr_decl) + + def from_call(self, tp: JustTypeRef, term: bindings.TermApp) -> CallDecl: + """ + Convert a call to a CallDecl. + + There could be Python call refs which match the call, so we need to find the correct one. + """ + # Find the first callable ref that matches the call + for callable_ref in self.state.egg_fn_to_callable_refs[term.name]: + # If this is a classmethod, we might need the type params that were bound for this type + # This could be multiple types if the classmethod is ambiguous, like map create. + possible_types: Iterable[JustTypeRef | None] + fn_decl = self.decls.get_callable_decl(callable_ref).to_function_decl() + if isinstance(callable_ref, ClassMethodRef): + possible_types = self.state._get_possible_types(callable_ref.class_name) + cls_name = callable_ref.class_name + else: + possible_types = [None] + cls_name = None + for possible_type in possible_types: + tcs = TypeConstraintSolver(self.decls) + if possible_type and possible_type.args: + tcs.bind_class(possible_type) + + try: + arg_types, bound_tp_params = tcs.infer_arg_types( + fn_decl.arg_types, fn_decl.semantic_return_type, fn_decl.var_arg_type, tp, cls_name + ) + except TypeConstraintError: + continue + args: list[TypedExprDecl] = [] + for a, tp in zip(term.args, arg_types, strict=False): + try: + res = self.cache[a] + except KeyError: + res = self.cache[a] = self.from_expr(tp, self.termdag.nodes[a]) + args.append(res) + return CallDecl(callable_ref, tuple(args), bound_tp_params) + raise ValueError(f"Could not find callable ref for call {term}") diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 850571f1..8f8ed930 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -24,7 +24,7 @@ # Pretend that exprs are numbers b/c sklearn does isinstance checks numbers.Integral.register(RuntimeExpr) -array_api_ruleset = ruleset() +array_api_ruleset = ruleset(name="array_api_ruleset") array_api_schedule = array_api_ruleset.saturate() @@ -36,10 +36,14 @@ def __bool__(self) -> bool: @property def bool(self) -> Bool: ... - def __or__(self, other: Boolean) -> Boolean: ... + def __or__(self, other: BooleanLike) -> Boolean: ... - def __and__(self, other: Boolean) -> Boolean: ... + def __and__(self, other: BooleanLike) -> Boolean: ... + def if_int(self, true_value: Int, false_value: Int) -> Int: ... + + +BooleanLike = Boolean | bool TRUE = constant("TRUE", Boolean) FALSE = constant("FALSE", Boolean) @@ -47,7 +51,7 @@ def __and__(self, other: Boolean) -> Boolean: ... @array_api_ruleset.register -def _bool(x: Boolean): +def _bool(x: Boolean, i: Int, j: Int): return [ rule(eq(x).to(TRUE)).then(set_(x.bool).to(Bool(True))), rule(eq(x).to(FALSE)).then(set_(x.bool).to(Bool(False))), @@ -55,82 +59,8 @@ def _bool(x: Boolean): rewrite(FALSE | x).to(x), rewrite(TRUE & x).to(x), rewrite(FALSE & x).to(FALSE), - ] - - -class DType(Expr): - float64: ClassVar[DType] - float32: ClassVar[DType] - int64: ClassVar[DType] - int32: ClassVar[DType] - object: ClassVar[DType] - bool: ClassVar[DType] - - def __eq__(self, other: DType) -> Boolean: # type: ignore[override] - ... - - -float64 = DType.float64 -float32 = DType.float32 -int32 = DType.int32 -int64 = DType.int64 - -_DTYPES = [float64, float32, int32, int64, DType.object] - -converter(type, DType, lambda x: convert(np.dtype(x), DType)) -converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload] -array_api_ruleset.register( - *(rewrite(l == r).to(TRUE if l is r else FALSE) for l, r in itertools.product(_DTYPES, repeat=2)) -) - - -class IsDtypeKind(Expr): - NULL: ClassVar[IsDtypeKind] - - @classmethod - def string(cls, s: StringLike) -> IsDtypeKind: ... - - @classmethod - def dtype(cls, d: DType) -> IsDtypeKind: ... - - @method(cost=10) - def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ... - - -# TODO: Make kind more generic to support tuples. -@function -def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ... - - -converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x)) -converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x)) -converter( - tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL -) - - -@array_api_ruleset.register -def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind): - return [ - rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE), - rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE), - rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE), - rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE), - rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE), - rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE), - rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE), - rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE), - rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE), - rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE), - rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE), - rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE), - rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE), - rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE), - rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE), - rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE), - rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE), - rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)), - rewrite(k1 | IsDtypeKind.NULL).to(k1), + rewrite(TRUE.if_int(i, j)).to(i), + rewrite(FALSE.if_int(i, j)).to(j), ] @@ -264,10 +194,13 @@ def _int(i: i64, j: i64, r: Boolean, o: Int): class Float(Expr): + # Differentiate costs of three constructors so extraction is deterministic if all three are present + @method(cost=3) def __init__(self, value: f64Like) -> None: ... def abs(self) -> Float: ... + @method(cost=2) @classmethod def rational(cls, r: Rational) -> Float: ... @@ -366,6 +299,85 @@ def some(cls, value: Int) -> OptionalInt: ... converter(Int, OptionalInt, OptionalInt.some) +class DType(Expr): + float64: ClassVar[DType] + float32: ClassVar[DType] + int64: ClassVar[DType] + int32: ClassVar[DType] + object: ClassVar[DType] + bool: ClassVar[DType] + + def __eq__(self, other: DType) -> Boolean: # type: ignore[override] + ... + + +float64 = DType.float64 +float32 = DType.float32 +int32 = DType.int32 +int64 = DType.int64 + +_DTYPES = [float64, float32, int32, int64, DType.object] + +converter(type, DType, lambda x: convert(np.dtype(x), DType)) +converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload] + + +@array_api_ruleset.register +def _(): + for l, r in itertools.product(_DTYPES, repeat=2): + yield rewrite(l == r).to(TRUE if l is r else FALSE) + + +class IsDtypeKind(Expr): + NULL: ClassVar[IsDtypeKind] + + @classmethod + def string(cls, s: StringLike) -> IsDtypeKind: ... + + @classmethod + def dtype(cls, d: DType) -> IsDtypeKind: ... + + @method(cost=10) + def __or__(self, other: IsDtypeKind) -> IsDtypeKind: ... + + +# TODO: Make kind more generic to support tuples. +@function +def isdtype(dtype: DType, kind: IsDtypeKind) -> Boolean: ... + + +converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x)) +converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x)) +converter( + tuple, IsDtypeKind, lambda x: convert(x[0], IsDtypeKind) | convert(x[1:], IsDtypeKind) if x else IsDtypeKind.NULL +) + + +@array_api_ruleset.register +def _isdtype(d: DType, k1: IsDtypeKind, k2: IsDtypeKind): + return [ + rewrite(isdtype(DType.float32, IsDtypeKind.string("integral"))).to(FALSE), + rewrite(isdtype(DType.float64, IsDtypeKind.string("integral"))).to(FALSE), + rewrite(isdtype(DType.object, IsDtypeKind.string("integral"))).to(FALSE), + rewrite(isdtype(DType.int64, IsDtypeKind.string("integral"))).to(TRUE), + rewrite(isdtype(DType.int32, IsDtypeKind.string("integral"))).to(TRUE), + rewrite(isdtype(DType.float32, IsDtypeKind.string("real floating"))).to(TRUE), + rewrite(isdtype(DType.float64, IsDtypeKind.string("real floating"))).to(TRUE), + rewrite(isdtype(DType.object, IsDtypeKind.string("real floating"))).to(FALSE), + rewrite(isdtype(DType.int64, IsDtypeKind.string("real floating"))).to(FALSE), + rewrite(isdtype(DType.int32, IsDtypeKind.string("real floating"))).to(FALSE), + rewrite(isdtype(DType.float32, IsDtypeKind.string("complex floating"))).to(FALSE), + rewrite(isdtype(DType.float64, IsDtypeKind.string("complex floating"))).to(FALSE), + rewrite(isdtype(DType.object, IsDtypeKind.string("complex floating"))).to(FALSE), + rewrite(isdtype(DType.int64, IsDtypeKind.string("complex floating"))).to(FALSE), + rewrite(isdtype(DType.int32, IsDtypeKind.string("complex floating"))).to(FALSE), + rewrite(isdtype(d, IsDtypeKind.NULL)).to(FALSE), + rewrite(isdtype(d, IsDtypeKind.dtype(d))).to(TRUE), + rewrite(isdtype(d, k1 | k2)).to(isdtype(d, k1) | isdtype(d, k2)), + rewrite(k1 | IsDtypeKind.NULL).to(k1), + ] + + class Slice(Expr): def __init__( self, diff --git a/python/egglog/exp/array_api_numba.py b/python/egglog/exp/array_api_numba.py index 66d5b3f8..5d45c2b3 100644 --- a/python/egglog/exp/array_api_numba.py +++ b/python/egglog/exp/array_api_numba.py @@ -31,7 +31,12 @@ def _std(y: NDArray, x: NDArray, i: Int): axis = OptionalIntOrTuple.some(IntOrTuple.int(i)) # https://numpy.org/doc/stable/reference/generated/numpy.std.html # "std = sqrt(mean(x)), where x = abs(a - a.mean())**2." - yield rewrite(std(x, axis), subsume=True).to(sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis))) + yield rewrite( + std(x, axis), + subsume=True, + ).to( + sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)), + ) # rewrite unique_counts to count each value one by one, since numba doesn't support np.unique(..., return_counts=True) diff --git a/python/egglog/exp/siu_examples.py b/python/egglog/exp/siu_examples.py new file mode 100644 index 00000000..45b6c9b4 --- /dev/null +++ b/python/egglog/exp/siu_examples.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import egglog + +from .array_api import Int + +# https://github.com/sklam/pyasir/blob/c363ff4f8f91177700ad4108dd5042b9b97d8289/pyasir/tests/test_fib.py + +# In progress - should be able to re-create this +# @df.func +# def fib_ir(n: pyasir.Int64) -> pyasir.Int64: +# @df.switch(n <= 1) +# def swt(n): +# @df.case(1) +# def case0(n): +# return 1 + +# @df.case(0) +# def case1(n): +# return fib_ir(n - 1) + fib_ir(n - 2) + +# yield case0 +# yield case1 + +# r = swt(n) +# return r + + +# With something like this: +@egglog.function +def fib(n: Int) -> Int: + return (n <= Int(1)).if_int( + Int(1), + fib(n - Int(1)) + fib(n - Int(2)), + ) diff --git a/python/egglog/pretty.py b/python/egglog/pretty.py new file mode 100644 index 00000000..cd2af6f6 --- /dev/null +++ b/python/egglog/pretty.py @@ -0,0 +1,418 @@ +""" +Pretty printing for declerations. +""" + +from __future__ import annotations + +from collections import Counter, defaultdict +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypeAlias + +import black +from typing_extensions import assert_never + +from .declarations import * + +if TYPE_CHECKING: + from collections.abc import Mapping + +__all__ = [ + "pretty_decl", + "pretty_callable_ref", + "BINARY_METHODS", + "UNARY_METHODS", +] +MAX_LINE_LENGTH = 110 +LINE_DIFFERENCE = 10 +BLACK_MODE = black.Mode(line_length=180) + +# Use this special character in place of the args, so that if the args are inlined +# in the viz, they will replace it +ARG_STR = "·" + +# Special methods which we might want to use as functions +# Mapping to the operator they represent for pretty printing them +# https://docs.python.org/3/reference/datamodel.html +BINARY_METHODS = { + "__lt__": "<", + "__le__": "<=", + "__eq__": "==", + "__ne__": "!=", + "__gt__": ">", + "__ge__": ">=", + # Numeric + "__add__": "+", + "__sub__": "-", + "__mul__": "*", + "__matmul__": "@", + "__truediv__": "/", + "__floordiv__": "//", + "__mod__": "%", + # TODO: Support divmod, with tuple return value + # "__divmod__": "divmod", + # TODO: Three arg power + "__pow__": "**", + "__lshift__": "<<", + "__rshift__": ">>", + "__and__": "&", + "__xor__": "^", + "__or__": "|", +} + + +UNARY_METHODS = { + "__pos__": "+", + "__neg__": "-", + "__invert__": "~", +} + +AllDecls: TypeAlias = RulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl + + +def pretty_decl( + decls: Declarations, decl: AllDecls, *, wrapping_fn: str | None = None, ruleset_name: str | None = None +) -> str: + """ + Pretty print a decleration. + + This will use re-format the result and put the expression on the last line, preceeded by the statements. + """ + traverse = TraverseContext() + traverse(decl, toplevel=True) + pretty = traverse.pretty(decls) + expr = pretty(decl, ruleset_name=ruleset_name) + if wrapping_fn: + expr = f"{wrapping_fn}({expr})" + program = "\n".join([*pretty.statements, expr]) + try: + # TODO: Try replacing with ruff for speed + # https://github.com/amyreese/ruff-api + return black.format_str(program, mode=BLACK_MODE).strip() + except black.parsing.InvalidInput: + return program + + +def pretty_callable_ref( + decls: Declarations, + ref: CallableRef, + first_arg: ExprDecl | None = None, + bound_tp_params: tuple[JustTypeRef, ...] | None = None, +) -> str: + """ + Pretty print a callable reference, using a dummy value for + the args if the function is not in the form `f(x, ...)`. + + To be used in the visualization. + """ + # Pass in three dummy args, which are the max used for any operation that + # is not a generic function call + args: list[ExprDecl] = [LitDecl(ARG_STR)] * 3 + if first_arg: + args.insert(0, first_arg) + res = PrettyContext(decls, defaultdict(lambda: 0))._call_inner( + ref, args, bound_tp_params=bound_tp_params, parens=False + ) + # Either returns a function or a function with args. If args are provided, they would just be called, + # on the function, so return them, because they are dummies + return res[0] if isinstance(res, tuple) else res + + +@dataclass +class TraverseContext: + """ + State for traversing expressions (or declerations that contain expressions), so we can know how many parents each + expression has. + """ + + # All expressions we have seen (incremented the parent counts of all children) + _seen: set[AllDecls] = field(default_factory=set) + # The number of parents for each expressions + parents: Counter[AllDecls] = field(default_factory=Counter) + + def pretty(self, decls: Declarations) -> PrettyContext: + """ + Create a pretty context from the state of this traverse context. + """ + return PrettyContext(decls, self.parents) + + def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901 + if not toplevel: + self.parents[decl] += 1 + if decl in self._seen: + return + match decl: + case RewriteDecl(_, lhs, rhs, conditions) | BiRewriteDecl(_, lhs, rhs, conditions): + self(lhs) + self(rhs) + for cond in conditions: + self(cond) + case RuleDecl(head, body, _): + for action in head: + self(action) + for fact in body: + self(fact) + case SetDecl(_, lhs, rhs) | UnionDecl(_, lhs, rhs): + self(lhs) + self(rhs) + case LetDecl(_, d) | ExprActionDecl(d) | ExprFactDecl(d): + self(d.expr) + case ChangeDecl(_, d, _) | SaturateDecl(d) | RepeatDecl(d, _) | ActionCommandDecl(d): + self(d) + case PanicDecl(_) | VarDecl(_) | LitDecl(_) | PyObjectDecl(_): + pass + case EqDecl(_, decls) | SequenceDecl(decls) | RulesetDecl(decls): + for de in decls: + self(de) + case CallDecl(_, exprs, _): + for e in exprs: + self(e.expr) + case RunDecl(_, until): + if until: + for f in until: + self(f) + case _: + assert_never(decl) + + self._seen.add(decl) + + +@dataclass +class PrettyContext: + """ + + We need to build up a list of all the expressions we are pretty printing, so that we can see who has parents and who is mutated + and create temp variables for them. + + """ + + decls: Declarations + parents: Mapping[AllDecls, int] + + # All the expressions we have saved as names + names: dict[AllDecls, str] = field(default_factory=dict) + # A list of statements assigning variables or calling destructive ops + statements: list[str] = field(default_factory=list) + # Mapping of type to the number of times we have generated a name for that type, used to generate unique names + _gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0)) + + def __call__( + self, decl: AllDecls, *, unwrap_lit: bool = False, parens: bool = False, ruleset_name: str | None = None + ) -> str: + if decl in self.names: + return self.names[decl] + expr, tp_name = self.uncached(decl, unwrap_lit=unwrap_lit, parens=parens, ruleset_name=ruleset_name) + # We use a heuristic to decide whether to name this sub-expression as a variable + # The rough goal is to reduce the number of newlines, given our line length of ~180 + # We determine it's worth making a new line for this expression if the total characters + # it would take up is > than some constant (~ line length). + line_diff: int = len(expr) - LINE_DIFFERENCE + n_parents = self.parents[decl] + if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH: + self.names[decl] = expr_name = self._name_expr(tp_name, expr, copy_identifier=False) + return expr_name + return expr + + def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: PLR0911 + match decl: + case LitDecl(value): + match value: + case None: + return "Unit()", "Unit" + case bool(b): + return str(b) if unwrap_lit else f"Bool({b})", "Bool" + case int(i): + return str(i) if unwrap_lit else f"i64({i})", "i64" + case float(f): + return str(f) if unwrap_lit else f"f64({f})", "f64" + case str(s): + return repr(s) if unwrap_lit else f"String({s!r})", "String" + assert_never(value) + case VarDecl(name): + return name, name + case CallDecl(_, _, _): + return self._call(decl, parens) + case PyObjectDecl(value): + return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject" + case ActionCommandDecl(action): + return self(action), "action" + case RewriteDecl(_, lhs, rhs, conditions) | BiRewriteDecl(_, lhs, rhs, conditions): + args = ", ".join(map(self, (rhs, *conditions))) + fn = "rewrite" if isinstance(decl, RewriteDecl) else "birewrite" + return f"{fn}({self(lhs)}).to({args})", "rewrite" + case RuleDecl(head, body, name): + l = ", ".join(map(self, body)) + if name: + l += f", name={name}" + r = ", ".join(map(self, head)) + return f"rule({l}).then({r})", "rule" + case SetDecl(_, lhs, rhs): + return f"set_({self(lhs)}).to({self(rhs)})", "action" + case UnionDecl(_, lhs, rhs): + return f"union({self(lhs)}).with_({self(rhs)})", "action" + case LetDecl(name, expr): + return f"let({name!r}, {self(expr.expr)})", "action" + case ExprActionDecl(expr): + return self(expr.expr), "action" + case ExprFactDecl(expr): + return self(expr.expr), "fact" + case ChangeDecl(_, expr, change): + return f"{change}({self(expr)})", "action" + case PanicDecl(s): + return f"panic({s!r})", "action" + case EqDecl(_, exprs): + first, *rest = exprs + return f"eq({self(first)}).to({', '.join(map(self, rest))})", "fact" + case RulesetDecl(rules): + if ruleset_name: + return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}" + args = ", ".join(map(self, rules)) + return f"ruleset({args})", "ruleset" + case SaturateDecl(schedule): + return f"{self(schedule, parens=True)}.saturate()", "schedule" + case RepeatDecl(schedule, times): + return f"{self(schedule, parens=True)} * {times}", "schedule" + case SequenceDecl(schedules): + if len(schedules) == 2: + return f"{self(schedules[0], parens=True)} + {self(schedules[1], parens=True)}", "schedule" + args = ", ".join(map(self, schedules)) + return f"seq({args})", "schedule" + case RunDecl(ruleset_name, until): + ruleset = self.decls._rulesets[ruleset_name] + ruleset_str = self(ruleset, ruleset_name=ruleset_name) + if not until: + return ruleset_str, "schedule" + args = ", ".join(map(self, until)) + return f"run({ruleset_str}, {args})", "schedule" + assert_never(decl) + + def _call( + self, + decl: CallDecl, + parens: bool, + ) -> tuple[str, str]: + """ + Pretty print the call. Also returns if it was saved as a name. + + :param parens: If true, wrap the call in parens if it is a binary method call. + """ + args = [a.expr for a in decl.args] + ref = decl.callable + # Special case != + if decl.callable == FunctionRef("!="): + l, r = self(args[0]), self(args[1]) + return f"ne({l}).to({r})", "Unit" + function_decl = self.decls.get_callable_decl(ref).to_function_decl() + # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default + n_defaults = 0 + for arg, default in zip( + reversed(args), reversed(function_decl.arg_defaults), strict=not function_decl.var_arg_type + ): + if arg != default: + break + n_defaults += 1 + if n_defaults: + args = args[:-n_defaults] + + tp_name = function_decl.semantic_return_type.name + if function_decl.mutates: + first_arg = args[0] + expr_str = self(first_arg) + # copy an identifier expression iff it has multiple parents (b/c then we can't mutate it directly) + has_multiple_parents = self.parents[first_arg] > 1 + self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents) + # Set the first arg to be the name of the mutated arg and return the name + args[0] = VarDecl(expr_name) + else: + expr_name = None + res = self._call_inner(ref, args, decl.bound_tp_params, parens) + expr = ( + f"{res[0]}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})" + if isinstance(res, tuple) + else res + ) + # If we have a name, then we mutated + if expr_name: + self.statements.append(expr) + return expr_name, tp_name + return expr, tp_name + + def _call_inner( # noqa: PLR0911 + self, ref: CallableRef, args: list[ExprDecl], bound_tp_params: tuple[JustTypeRef, ...] | None, parens: bool + ) -> tuple[str, list[ExprDecl]] | str: + """ + Pretty print the call, returning either the full function call or a tuple of the function and the args. + """ + match ref: + case FunctionRef(name): + return name, args + case ClassMethodRef(class_name, method_name): + fn_str = str(JustTypeRef(class_name, bound_tp_params or ())) + if method_name != "__init__": + fn_str += f".{method_name}" + return fn_str, args + case MethodRef(_class_name, method_name): + slf, *args = args + slf = self(slf, parens=True) + match method_name: + case _ if method_name in UNARY_METHODS: + expr = f"{UNARY_METHODS[method_name]}{slf}" + return f"({expr})" if parens else expr + case _ if method_name in BINARY_METHODS: + expr = f"{slf} {BINARY_METHODS[method_name]} {self(args[0], parens=True, unwrap_lit=True)}" + return f"({expr})" if parens else expr + case "__getitem__": + return f"{slf}[{self(args[0], unwrap_lit=True)}]" + case "__call__": + return slf, args + case "__delitem__": + return f"del {slf}[{self(args[0], unwrap_lit=True)}]" + case "__setitem__": + return f"{slf}[{self(args[0], unwrap_lit=True)}] = {self(args[1], unwrap_lit=True)}" + case _: + return f"{slf}.{method_name}", args + case ConstantRef(name): + return name + case ClassVariableRef(class_name, variable_name): + return f"{class_name}.{variable_name}" + case PropertyRef(_class_name, property_name): + return f"{self(args[0], parens=True)}.{property_name}" + assert_never(ref) + + def _generate_name(self, typ: str) -> str: + self._gen_name_types[typ] += 1 + return f"_{typ}_{self._gen_name_types[typ]}" + + def _name_expr(self, tp_name: str, expr_str: str, copy_identifier: bool) -> str: + # tp_name = + # If the thing we are naming is already a variable, we don't need to name it + if expr_str.isidentifier(): + if copy_identifier: + name = self._generate_name(tp_name) + self.statements.append(f"{name} = copy({expr_str})") + else: + name = expr_str + else: + name = self._generate_name(tp_name) + self.statements.append(f"{name} = {expr_str}") + return name + + +def _plot_line_length(expr: object): # pragma: no cover + """ + Plots the number of line lengths based on different max lengths + """ + global MAX_LINE_LENGTH, LINE_DIFFERENCE + import altair as alt + import pandas as pd + + sizes = [] + for line_length in range(40, 180, 10): + MAX_LINE_LENGTH = line_length + for diff in range(0, 40, 5): + LINE_DIFFERENCE = diff + new_l = len(str(expr).split()) + sizes.append((line_length, diff, new_l)) + + df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"]) # noqa: PD901 + + return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q") diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 97814512..c13e2253 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -11,172 +11,57 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass +from inspect import Parameter, Signature from itertools import zip_longest from typing import TYPE_CHECKING, NoReturn, TypeVar, Union, cast, get_args, get_origin -import black -import black.parsing -from typing_extensions import assert_never - -from . import bindings, config from .declarations import * -from .declarations import BINARY_METHODS, REFLECTED_BINARY_METHODS, UNARY_METHODS +from .pretty import * +from .thunk import Thunk from .type_constraint_solver import * if TYPE_CHECKING: - from collections.abc import Callable, Collection, Iterable + from collections.abc import Callable, Iterable from .egraph import Expr __all__ = [ "LIT_CLASS_NAMES", - "class_to_ref", - "resolve_literal", "resolve_callable", "resolve_type_annotation", - "convert_to_same_type", "RuntimeClass", - "RuntimeParamaterizedClass", - "RuntimeClassMethod", "RuntimeExpr", "RuntimeFunction", - "convert", - "converter", + "REFLECTED_BINARY_METHODS", ] -BLACK_MODE = black.Mode(line_length=180) - UNIT_CLASS_NAME = "Unit" UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"} LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"} +REFLECTED_BINARY_METHODS = { + "__radd__": "__add__", + "__rsub__": "__sub__", + "__rmul__": "__mul__", + "__rmatmul__": "__matmul__", + "__rtruediv__": "__truediv__", + "__rfloordiv__": "__floordiv__", + "__rmod__": "__mod__", + "__rpow__": "__pow__", + "__rlshift__": "__lshift__", + "__rrshift__": "__rshift__", + "__rand__": "__and__", + "__rxor__": "__xor__", + "__ror__": "__or__", +} + # Set this globally so we can get access to PyObject when we have a type annotation of just object. # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically. _PY_OBJECT_CLASS: RuntimeClass | None = None -## -# Converters -## - -# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target -CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {} -# Global declerations to store all convertable types so we can query if they have certain methods or not -CONVERSIONS_DECLS = Declarations() - T = TypeVar("T") -V = TypeVar("V", bound="Expr") - - -class ConvertError(Exception): - pass - - -def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: int = 1) -> None: - """ - Register a converter from some type to an egglog type. - """ - to_type_name = process_tp(to_type) - if not isinstance(to_type_name, JustTypeRef): - raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}") - _register_converter(process_tp(from_type), to_type_name, fn, cost) - - -def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None: - """ - Registers a converter from some type to an egglog type, if not already registered. - - Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered. - Also, if registering A->B and there is already D->A, then D->B will be registered. - """ - if a == b: - return - if (a, b) in CONVERSIONS and CONVERSIONS[(a, b)][0] <= cost: - return - CONVERSIONS[(a, b)] = (cost, a_b) - for (c, d), (other_cost, c_d) in list(CONVERSIONS.items()): - if b == c: - _register_converter(a, d, _ComposedConverter(a_b, c_d), cost + other_cost) - if a == d: - _register_converter(c, b, _ComposedConverter(c_d, a_b), cost + other_cost) - - -@dataclass -class _ComposedConverter: - """ - A converter which is composed of multiple converters. - - _ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x)) - - We use the dataclass instead of the lambda to make it easier to debug. - """ - - a_b: Callable - b_c: Callable - - def __call__(self, x: object) -> object: - return self.b_c(self.a_b(x)) - - def __str__(self) -> str: - return f"{self.b_c} ∘ {self.a_b}" - - -def convert(source: object, target: type[V]) -> V: - """ - Convert a source object to a target type. - """ - target_ref = class_to_ref(cast(RuntimeTypeArgType, target)) - return cast(V, resolve_literal(target_ref.to_var(), source)) - - -def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr: - """ - Convert a source object to the same type as the target. - """ - tp = target.__egg_typed_expr__.tp - return resolve_literal(tp.to_var(), source) - - -def process_tp(tp: type | RuntimeTypeArgType) -> JustTypeRef | type: - """ - Process a type before converting it, to add it to the global declerations and resolve to a ref. - """ - global CONVERSIONS_DECLS - if isinstance(tp, RuntimeClass | RuntimeParamaterizedClass): - CONVERSIONS_DECLS |= tp - return class_to_ref(tp) - return tp - - -def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef: - """ - Returns the minimum convertable type between a and b, that has a method `name`, raising a TypeError if no such type exists. - """ - a_tp = _get_tp(a) - b_tp = _get_tp(b) - a_converts_to = { - to: c - for ((from_, to), (c, _)) in CONVERSIONS.items() - if from_ == a_tp and CONVERSIONS_DECLS.has_method(to.name, name) - } - b_converts_to = { - to: c - for ((from_, to), (c, _)) in CONVERSIONS.items() - if from_ == b_tp and CONVERSIONS_DECLS.has_method(to.name, name) - } - if isinstance(a_tp, JustTypeRef): - a_converts_to[a_tp] = 0 - if isinstance(b_tp, JustTypeRef): - b_converts_to[b_tp] = 0 - common = set(a_converts_to) & set(b_converts_to) - if not common: - raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type") - return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp]) - - -def identity(x: object) -> object: - return x def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef: @@ -195,99 +80,62 @@ def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef: assert _PY_OBJECT_CLASS return resolve_type_annotation(decls, _PY_OBJECT_CLASS) if isinstance(tp, RuntimeClass): - decls |= tp - return tp.__egg_tp__.to_var() - if isinstance(tp, RuntimeParamaterizedClass): decls |= tp return tp.__egg_tp__ raise TypeError(f"Unexpected type annotation {tp}") -def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr: - arg_type = _get_tp(arg) - - # If we have any type variables, dont bother trying to resolve the literal, just return the arg - try: - tp_just = tp.to_just() - except NotImplementedError: - # If this is a var, it has to be a runtime exprssions - assert isinstance(arg, RuntimeExpr) - return arg - if arg_type == tp_just: - # If the type is an egg type, it has to be a runtime expr - assert isinstance(arg, RuntimeExpr) - return arg - # Try all parent types as well, if we are converting from a Python type - for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]: - try: - fn = CONVERSIONS[(cast(JustTypeRef | type, arg_type_instance), tp_just)][1] - except KeyError: - continue - break - else: - arg_type_str = arg_type.pretty() if isinstance(arg_type, JustTypeRef) else arg_type.__name__ - raise ConvertError(f"Cannot convert {arg_type_str} to {tp_just.pretty()}") - return fn(arg) - - -def _get_tp(x: object) -> JustTypeRef | type: - if isinstance(x, RuntimeExpr): - return x.__egg_typed_expr__.tp - tp = type(x) - # If this value has a custom metaclass, let's use that as our index instead of the type - if type(tp) != type: - return type(tp) - return tp - - ## # Runtime objects ## @dataclass -class RuntimeClass: - # Pass in a constructor to make the declarations lazy, so we can have classes reference each other in their type constructors - # This function should mutate the declerations and add to them - # Used this instead of a lazy property so we can have a reference to the decls in the class as its computing - lazy_decls: Callable[[Declarations], None] = field(repr=False) - # Cached declerations - _inner_decls: Declarations | None = field(init=False, repr=False, default=None) - __egg_name__: str +class RuntimeClass(DelayedDeclerations): + __egg_tp__: TypeRefWithVars def __post_init__(self) -> None: global _PY_OBJECT_CLASS - if self.__egg_name__ == "PyObject": + if self.__egg_tp__.name == "PyObject": _PY_OBJECT_CLASS = self - @property - def __egg_decls__(self) -> Declarations: - if self._inner_decls is None: - # Set it like this so we can have a reference to the decls in the class as its computing - self._inner_decls = Declarations() - self.lazy_decls(self._inner_decls) - return self._inner_decls + def verify(self) -> None: + if not self.__egg_tp__.args: + return + + # Raise error if we have args, but they are the wrong number + desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars + if len(self.__egg_tp__.args) != len(desired_args): + raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}") def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None: """ Create an instance of this kind by calling the __init__ classmethod """ # If this is a literal type, initializing it with a literal should return a literal - if self.__egg_name__ == "PyObject": + if self.__egg_tp__.name == "PyObject": assert len(args) == 1 - return RuntimeExpr(self.__egg_decls__, TypedExprDecl(self.__egg_tp__, PyObjectDecl(args[0]))) - if self.__egg_name__ in UNARY_LIT_CLASS_NAMES: + return RuntimeExpr.__from_value__( + self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0])) + ) + if self.__egg_tp__.name in UNARY_LIT_CLASS_NAMES: assert len(args) == 1 assert isinstance(args[0], int | float | str | bool) - return RuntimeExpr(self.__egg_decls__, TypedExprDecl(self.__egg_tp__, LitDecl(args[0]))) - if self.__egg_name__ == UNIT_CLASS_NAME: + return RuntimeExpr.__from_value__( + self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0])) + ) + if self.__egg_tp__.name == UNIT_CLASS_NAME: assert len(args) == 0 - return RuntimeExpr(self.__egg_decls__, TypedExprDecl(self.__egg_tp__, LitDecl(None))) + return RuntimeExpr.__from_value__( + self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None)) + ) - return RuntimeClassMethod(self.__egg_decls__, self.__egg_tp__, "__init__")(*args, **kwargs) + return RuntimeFunction( + Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, "__init__"), self.__egg_tp__.to_just() + )(*args, **kwargs) def __dir__(self) -> list[str]: - cls_decl = self.__egg_decls__.get_class_decl(self.__egg_name__) + cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name) possible_methods = ( list(cls_decl.class_methods) + list(cls_decl.class_variables) + list(cls_decl.preserved_methods) ) @@ -296,14 +144,19 @@ def __dir__(self) -> list[str]: possible_methods.append("__call__") return possible_methods - def __getitem__(self, args: object) -> RuntimeParamaterizedClass: + def __getitem__(self, args: object) -> RuntimeClass: + if self.__egg_tp__.args: + raise TypeError(f"Cannot index into a paramaterized class {self}") if not isinstance(args, tuple): args = (args,) decls = self.__egg_decls__.copy() - tp = TypeRefWithVars(self.__egg_name__, tuple(resolve_type_annotation(decls, arg) for arg in args)) - return RuntimeParamaterizedClass(self.__egg_decls__, tp) + tp = TypeRefWithVars(self.__egg_tp__.name, tuple(resolve_type_annotation(decls, arg) for arg in args)) + return RuntimeClass(Thunk.value(decls), tp) + + def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable: + if name == "__origin__" and self.__egg_tp__.args: + return RuntimeClass(self.__egg_decls_thunk__, TypeRefWithVars(self.__egg_tp__.name)) - def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeExpr | Callable: # Special case some names that don't exist so we can exit early without resolving decls # Important so if we take union of RuntimeClass it won't try to resolve decls if name in { @@ -314,7 +167,7 @@ def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeExpr | Callable: }: raise AttributeError - cls_decl = self.__egg_decls__.get_class_decl(self.__egg_name__) + cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name] preserved_methods = cls_decl.preserved_methods if name in preserved_methods: @@ -323,159 +176,107 @@ def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeExpr | Callable: # if this is a class variable, return an expr for it, otherwise, assume it's a method if name in cls_decl.class_variables: return_tp = cls_decl.class_variables[name] - return RuntimeExpr( - self.__egg_decls__, TypedExprDecl(return_tp, CallDecl(ClassVariableRef(self.__egg_name__, name))) + return RuntimeExpr.__from_value__( + self.__egg_decls__, + TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))), ) - return RuntimeClassMethod(self.__egg_decls__, self.__egg_tp__, name) + if name in cls_decl.class_methods: + return RuntimeFunction( + Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, name), self.__egg_tp__.to_just() + ) + msg = f"Class {self.__egg_tp__.name} has no method {name}" + if name == "__ne__": + msg += ". Did you mean to use the ne(...).to(...)?" + raise AttributeError(msg) from None def __str__(self) -> str: - return self.__egg_name__ + return str(self.__egg_tp__) # Make hashable so can go in Union def __hash__(self) -> int: - return hash((id(self.lazy_decls), self.__egg_name__)) + return hash((id(self.__egg_decls_thunk__), self.__egg_tp__)) # Support unioning like types def __or__(self, __value: type) -> object: return Union[self, __value] # noqa: UP007 - @property - def __egg_tp__(self) -> JustTypeRef: - return JustTypeRef(self.__egg_name__) - - -@dataclass -class RuntimeParamaterizedClass: - __egg_decls__: Declarations - __egg_tp__: TypeRefWithVars - - def __post_init__(self) -> None: - desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars - if len(self.__egg_tp__.args) != len(desired_args): - raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}") - - def __call__(self, *args: object) -> RuntimeExpr | None: - return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), "__init__")(*args) - - def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeClass: - # Special case so when get_type_annotations proccessed it can work - if name == "__origin__": - return RuntimeClass(self.__egg_decls__.update_other, self.__egg_tp__.name) - return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), name) - - def __str__(self) -> str: - return self.__egg_tp__.pretty() - - # Support unioning - def __or__(self, __value: type) -> object: - return Union[self, __value] # noqa: UP007 - - -# Type args can either be typevars or classes -RuntimeTypeArgType = RuntimeClass | RuntimeParamaterizedClass - - -def class_to_ref(cls: RuntimeTypeArgType) -> JustTypeRef: - if isinstance(cls, RuntimeClass): - return JustTypeRef(cls.__egg_name__) - if isinstance(cls, RuntimeParamaterizedClass): - # Currently this is used when calling methods on a parametrized class, which is only possible when we - # have actualy types currently, not typevars, currently. - return cls.__egg_tp__.to_just() - assert_never(cls) - @dataclass -class RuntimeFunction: - __egg_decls__: Declarations - __egg_name__: str - __egg_fn_ref__: FunctionRef = field(init=False) - __egg_fn_decl__: FunctionDecl = field(init=False) - - def __post_init__(self) -> None: - self.__egg_fn_ref__ = FunctionRef(self.__egg_name__) - self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_fn_ref__) +class RuntimeFunction(DelayedDeclerations): + __egg_ref__: CallableRef + # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self + __egg_bound__: JustTypeRef | RuntimeExpr | None = None def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None: - return _call(self.__egg_decls__, self.__egg_fn_ref__, self.__egg_fn_decl__, args, kwargs) - - def __str__(self) -> str: - return self.__egg_name__ - - -def _call( - decls_from_fn: Declarations, - callable_ref: CallableRef, - fn_decl: FunctionDecl, - args: Collection[object], - kwargs: dict[str, object], - bound_class: JustTypeRef | None = None, -) -> RuntimeExpr | None: - # Turn all keyword args into positional args - bound = fn_decl.to_signature(lambda expr: RuntimeExpr(decls_from_fn, expr)).bind(*args, **kwargs) - bound.apply_defaults() - assert not bound.kwargs - del args, kwargs - - upcasted_args = [ - resolve_literal(cast(TypeOrVarRef, tp), arg) - for arg, tp in zip_longest(bound.args, fn_decl.arg_types, fillvalue=fn_decl.var_arg_type) - ] - - arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args) - decls = Declarations.create(decls_from_fn, *upcasted_args) - - tcs = TypeConstraintSolver(decls) - if bound_class is not None and bound_class.args: - tcs.bind_class(bound_class) - - if fn_decl is not None: + from .conversion import resolve_literal + + if isinstance(self.__egg_bound__, RuntimeExpr): + args = (self.__egg_bound__, *args) + fn_decl = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl() + # Turn all keyword args into positional args + bound = callable_decl_to_signature(fn_decl, self.__egg_decls__).bind(*args, **kwargs) + bound.apply_defaults() + assert not bound.kwargs + del args, kwargs + + upcasted_args = [ + resolve_literal(cast(TypeOrVarRef, tp), arg) + for arg, tp in zip_longest(bound.args, fn_decl.arg_types, fillvalue=fn_decl.var_arg_type) + ] + + decls = Declarations.create(self, *upcasted_args) + + tcs = TypeConstraintSolver(decls) + bound_tp = ( + None + if self.__egg_bound__ is None + else self.__egg_bound__.__egg_typed_expr__.tp + if isinstance(self.__egg_bound__, RuntimeExpr) + else self.__egg_bound__ + ) + if bound_tp and bound_tp.args: + tcs.bind_class(bound_tp) + arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args) arg_types = [expr.tp for expr in arg_exprs] - cls_name = bound_class.name if bound_class is not None else None + cls_name = bound_tp.name if bound_tp else None return_tp = tcs.infer_return_type( - fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types, cls_name + fn_decl.arg_types, fn_decl.return_type or fn_decl.arg_types[0], fn_decl.var_arg_type, arg_types, cls_name ) - else: - return_tp = JustTypeRef("Unit") - bound_params = cast(JustTypeRef, bound_class).args if isinstance(callable_ref, ClassMethodRef) else None - expr_decl = CallDecl(callable_ref, arg_exprs, bound_params) - typed_expr_decl = TypedExprDecl(return_tp, expr_decl) - # Register return type sort in case it's a variadic generic that needs to be created - decls.register_sort(return_tp, False) - if fn_decl.mutates_first_arg: - first_arg = upcasted_args[0] - first_arg.__egg_typed_expr__ = typed_expr_decl - first_arg.__egg_decls__ = decls - return None - return RuntimeExpr(decls, TypedExprDecl(return_tp, expr_decl)) - - -@dataclass -class RuntimeClassMethod: - __egg_decls__: Declarations - __egg_tp__: JustTypeRef - __egg_method_name__: str - __egg_callable_ref__: ClassMethodRef = field(init=False) - __egg_fn_decl__: FunctionDecl = field(init=False) - - def __post_init__(self) -> None: - self.__egg_callable_ref__ = ClassMethodRef(self.class_name, self.__egg_method_name__) - try: - self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__) - except KeyError as e: - raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}") from e - - def __call__(self, *args: object, **kwargs) -> RuntimeExpr | None: - return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, kwargs, self.__egg_tp__) + bound_params = cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef) else None + expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params) + typed_expr_decl = TypedExprDecl(return_tp, expr_decl) + # If there is not return type, we are mutating the first arg + if not fn_decl.return_type: + first_arg = upcasted_args[0] + first_arg.__egg_thunk__ = Thunk.value((decls, typed_expr_decl)) + return None + return RuntimeExpr.__from_value__(decls, typed_expr_decl) def __str__(self) -> str: - return f"{self.class_name}.{self.__egg_method_name__}" - - @property - def class_name(self) -> str: - if isinstance(self.__egg_tp__, str): - return self.__egg_tp__ - return self.__egg_tp__.name + first_arg, bound_tp_params = None, None + match self.__egg_bound__: + case RuntimeExpr(_): + first_arg = self.__egg_bound__.__egg_typed_expr__.expr + case JustTypeRef(_, args): + bound_tp_params = args + return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params) + + +def callable_decl_to_signature( + decl: FunctionDecl, + decls: Declarations, +) -> Signature: + parameters = [ + Parameter( + n, + Parameter.POSITIONAL_OR_KEYWORD, + default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d)) if d else Parameter.empty, + ) + for n, d, t in zip(decl.arg_names, decl.arg_defaults, decl.arg_types, strict=True) + ] + if isinstance(decl, FunctionDecl) and decl.var_arg_type is not None: + parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL)) + return Signature(parameters) # All methods which should return NotImplemented if they fail to resolve @@ -505,63 +306,34 @@ def class_name(self) -> str: @dataclass -class RuntimeMethod: - __egg_self__: RuntimeExpr - __egg_method_name__: str - __egg_callable_ref__: MethodRef | PropertyRef = field(init=False) - __egg_fn_decl__: FunctionDecl = field(init=False, repr=False) - __egg_decls__: Declarations = field(init=False) - - def __post_init__(self) -> None: - self.__egg_decls__ = self.__egg_self__.__egg_decls__ - if self.__egg_method_name__ in self.__egg_decls__.get_class_decl(self.class_name).properties: - self.__egg_callable_ref__ = PropertyRef(self.class_name, self.__egg_method_name__) - else: - self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__) - try: - self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__) - except KeyError: - msg = f"Class {self.class_name} does not have method {self.__egg_method_name__}" - if self.__egg_method_name__ == "__ne__": - msg += ". Did you mean to use the ne(...).to(...)?" - raise AttributeError(msg) from None +class RuntimeExpr: + # Defer needing decls/expr so we can make constants that don't resolve their class types + __egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]] - def __call__(self, *args: object, **kwargs) -> RuntimeExpr | None: - args = (self.__egg_self__, *args) - try: - return _call( - self.__egg_decls__, - self.__egg_callable_ref__, - self.__egg_fn_decl__, - args, - kwargs, - self.__egg_self__.__egg_typed_expr__.tp, - ) - except ConvertError as e: - name = self.__egg_method_name__ - raise TypeError(f"Wrong types for {self.__egg_self__.__egg_typed_expr__.tp.pretty()}.{name}") from e + @classmethod + def __from_value__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr: + return cls(Thunk.value((d, e))) @property - def class_name(self) -> str: - return self.__egg_self__.__egg_typed_expr__.tp.name - + def __egg_decls__(self) -> Declarations: + return self.__egg_thunk__()[0] -@dataclass -class RuntimeExpr: - __egg_decls__: Declarations - __egg_typed_expr__: TypedExprDecl + @property + def __egg_typed_expr__(self) -> TypedExprDecl: + return self.__egg_thunk__()[1] - def __getattr__(self, name: str) -> RuntimeMethod | RuntimeExpr | Callable | None: - class_decl = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name) + def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None: + cls_name = self.__egg_class_name__ + class_decl = self.__egg_class_decl__ - preserved_methods = class_decl.preserved_methods - if name in preserved_methods: + if name in (preserved_methods := class_decl.preserved_methods): return preserved_methods[name].__get__(self) - method = RuntimeMethod(self, name) - if isinstance(method.__egg_callable_ref__, PropertyRef): - return method() - return method + if name in class_decl.methods: + return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(cls_name, name), self) + if name in class_decl.properties: + return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(cls_name, name), self)() + raise AttributeError(f"{cls_name} has no method {name}") from None def __repr__(self) -> str: """ @@ -570,18 +342,10 @@ def __repr__(self) -> str: return str(self) def __str__(self) -> str: - context = PrettyContext(self.__egg_decls__) - context.traverse_for_parents(self.__egg_typed_expr__.expr) - pretty_expr = self.__egg_typed_expr__.expr.pretty(context, parens=False) - try: - if config.SHOW_TYPES: - raise NotImplementedError - # s = f"_: {self.__egg_typed_expr__.tp.pretty()} = {pretty_expr}" - # return black.format_str(s, mode=black.FileMode()).strip() - pretty_statements = context.render(pretty_expr) - return black.format_str(pretty_statements, mode=BLACK_MODE).strip() - except black.parsing.InvalidInput: - return pretty_expr + return self.__egg_pretty__(None) + + def __egg_pretty__(self, wrapping_fn: str | None) -> str: + return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn) def _ipython_display_(self) -> None: from IPython.display import Code, display @@ -589,28 +353,32 @@ def _ipython_display_(self) -> None: display(Code(str(self), language="python")) def __dir__(self) -> Iterable[str]: - return list(self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).methods) + class_decl = self.__egg_class_decl__ + return list(class_decl.methods) + list(class_decl.properties) + list(class_decl.preserved_methods) @property - def __egg__(self) -> bindings._Expr: - return self.__egg_typed_expr__.to_egg(self.__egg_decls__) + def __egg_class_name__(self) -> str: + return self.__egg_typed_expr__.tp.name + + @property + def __egg_class_decl__(self) -> ClassDecl: + return self.__egg_decls__.get_class_decl(self.__egg_class_name__) # Have __eq__ take no NoReturn (aka Never https://docs.python.org/3/library/typing.html#typing.Never) because # we don't wany any type that MyPy thinks is an expr to be used with __eq__. # That's because we want to reserve __eq__ for domain specific equality checks, overloading this method. # To check if two exprs are equal, use the expr_eq method. - def __eq__(self, other: NoReturn) -> Expr: # type: ignore[override] - msg = "Do not use == on RuntimeExpr. Compare the __egg_typed_expr__ attribute instead for structural equality." - raise NotImplementedError(msg) + # At runtime, this will resolve if there is a defined egg function for `__eq__` + def __eq__(self, other: NoReturn) -> Expr: ... # type: ignore[override, empty-body] # Implement these so that copy() works on this object # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion def __getstate__(self) -> tuple[Declarations, TypedExprDecl]: - return (self.__egg_decls__, self.__egg_typed_expr__) + return self.__egg_thunk__() def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None: - self.__egg_decls__, self.__egg_typed_expr__ = d + self.__egg_thunk__ = Thunk.value(d) def __hash__(self) -> int: return hash(self.__egg_typed_expr__) @@ -625,12 +393,17 @@ def _special_method( __name: str = name, **kwargs: object, ) -> RuntimeExpr | None: + from .conversion import ConvertError + + class_name = self.__egg_class_name__ + class_decl = self.__egg_class_decl__ # First, try to resolve as preserved method try: - method = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).preserved_methods[__name] - return method(self, *args, **kwargs) + method = class_decl.preserved_methods[__name] except KeyError: pass + else: + return method(self, *args, **kwargs) # If this is a "partial" method meaning that it can return NotImplemented, # we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just # using the arg type of the self arg. @@ -640,7 +413,10 @@ def _special_method( return call_method_min_conversion(self, args[0], __name) except ConvertError: return NotImplemented - return RuntimeMethod(self, __name)(*args, **kwargs) + if __name in class_decl.methods: + fn = RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(class_name, __name), self) + return fn(*args, **kwargs) + raise TypeError(f"{class_name!r} object does not support {__name}") setattr(RuntimeExpr, name, _special_method) @@ -655,12 +431,14 @@ def _reflected_method(self: RuntimeExpr, other: object, __non_reflected: str = n def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None: + from .conversion import min_convertable_tp, resolve_literal + # find a minimum type that both can be converted to # This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats. min_tp = min_convertable_tp(slf, other, name) slf = resolve_literal(min_tp.to_var(), slf) other = resolve_literal(min_tp.to_var(), other) - method = RuntimeMethod(slf, name) + method = RuntimeFunction(Thunk.value(slf.__egg_decls__), MethodRef(slf.__egg_class_name__, name), slf) return method(other) @@ -680,21 +458,9 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]: """ Resolves a runtime callable into a ref """ - # TODO: Fix these typings. - ref: CallableRef - decls: Declarations - if isinstance(callable, RuntimeFunction): - ref = FunctionRef(callable.__egg_name__) - decls = callable.__egg_decls__ - elif isinstance(callable, RuntimeClassMethod): - ref = ClassMethodRef(callable.class_name, callable.__egg_method_name__) - decls = callable.__egg_decls__ - elif isinstance(callable, RuntimeMethod): - ref = MethodRef(callable.__egg_self__.__egg_typed_expr__.tp.name, callable.__egg_method_name__) - decls = callable.__egg_decls__ - elif isinstance(callable, RuntimeClass): - ref = ClassMethodRef(callable.__egg_name__, "__init__") - decls = callable.__egg_decls__ - else: - raise NotImplementedError(f"Cannot turn {callable} into a callable ref") - return (ref, decls) + match callable: + case RuntimeFunction(decls, ref, _): + return ref, decls() + case RuntimeClass(thunk, tp): + return ClassMethodRef(tp.name, "__init__"), thunk() + raise NotImplementedError(f"Cannot turn {callable} into a callable ref") diff --git a/python/egglog/thunk.py b/python/egglog/thunk.py new file mode 100644 index 00000000..5c1a903d --- /dev/null +++ b/python/egglog/thunk.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generic, TypeVar + +from typing_extensions import ParamSpec, TypeVarTuple, Unpack + +if TYPE_CHECKING: + from collections.abc import Callable + + +__all__ = ["Thunk"] + +T = TypeVar("T") +P = ParamSpec("P") +TS = TypeVarTuple("TS") + + +@dataclass +class Thunk(Generic[T, Unpack[TS]]): + """ + Cached delayed function call. + """ + + state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving[T] + + @classmethod + def fn( + cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], fallback: Callable[[], T] | None = None + ) -> Thunk[T, Unpack[TS]]: + """ + Create a thunk based on some functions and some partial args. + + If the function is called while it is being resolved recursively, will instead return the fallback, if provided. + """ + return cls(Unresolved(fn, args, fallback)) + + @classmethod + def value(cls, value: T) -> Thunk[T]: + return Thunk(Resolved(value)) + + def __call__(self) -> T: + match self.state: + case Resolved(value): + return value + case Unresolved(fn, args, fallback): + self.state = Resolving(fallback) + res = fn(*args) + self.state = Resolved(res) + return res + case Resolving(fallback): + if fallback is None: + msg = "Recursively resolving thunk without fallback" + raise ValueError(msg) + return fallback() + + +@dataclass +class Resolved(Generic[T]): + value: T + + +@dataclass +class Unresolved(Generic[T, Unpack[TS]]): + fn: Callable[[Unpack[TS]], T] + args: tuple[Unpack[TS]] + fallback: Callable[[], T] | None + + +@dataclass +class Resolving(Generic[T]): + fallback: Callable[[], T] | None diff --git a/python/egglog/type_constraint_solver.py b/python/egglog/type_constraint_solver.py index 0b475bbf..7f68223b 100644 --- a/python/egglog/type_constraint_solver.py +++ b/python/egglog/type_constraint_solver.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from collections.abc import Collection, Iterable + __all__ = ["TypeConstraintSolver", "TypeConstraintError"] @@ -38,10 +39,11 @@ def bind_class(self, ref: JustTypeRef) -> None: Bind the typevars of a class to the given types. Used for a situation like Map[int, str].create(). """ - cls_typevars = self._decls.get_class_decl(ref.name).type_vars + name = ref.name + cls_typevars = self._decls.get_class_decl(name).type_vars if len(cls_typevars) != len(ref.args): raise TypeConstraintError(f"Mismatch of typevars {cls_typevars} and {ref}") - bound_typevars = self._cls_typevar_index_to_type[ref.name] + bound_typevars = self._cls_typevar_index_to_type[name] for i, arg in enumerate(ref.args): bound_typevars[cls_typevars[i]] = arg @@ -117,6 +119,7 @@ def _infer_typevars(self, fn_arg: TypeOrVarRef, arg: JustTypeRef, cls_name: str if cls_name is None: msg = "Cannot infer typevar without class name" raise RuntimeError(msg) + class_typevars = self._cls_typevar_index_to_type[cls_name] if typevar in class_typevars: if class_typevars[typevar] != arg: diff --git a/python/tests/__snapshots__/test_array_api/TestLDA.test_optimize.py b/python/tests/__snapshots__/test_array_api/TestLDA.test_optimize.py index 6e0f306c..9817e122 100644 --- a/python/tests/__snapshots__/test_array_api/TestLDA.test_optimize.py +++ b/python/tests/__snapshots__/test_array_api/TestLDA.test_optimize.py @@ -12,7 +12,7 @@ + (TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(1)))).to_value()) + TupleValue(sum(_NDArray_2 == NDArray.scalar(Value.int(Int(2)))).to_value())) ), DType.float64, -) / NDArray.scalar(Value.float(Float(150.0))) +) / NDArray.scalar(Value.float(Float.rational(Rational(150, 1)))) _NDArray_4 = zeros(TupleInt(Int(3)) + TupleInt(Int(4)), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device)) _MultiAxisIndexKey_1 = MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(Slice())) _IndexKey_1 = IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.int(Int(0))) + _MultiAxisIndexKey_1) @@ -32,7 +32,7 @@ _NDArray_9 = square(_NDArray_8 - expand_dims(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)])))) _NDArray_10 = sqrt(sum(_NDArray_9, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_9.shape[Int(0)]))) _NDArray_11 = copy(_NDArray_10) -_NDArray_11[ndarray_index(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0))) +_NDArray_11[ndarray_index(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float.rational(Rational(1, 1)))) _TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.float(Float.rational(Rational(1, 147))))) * (_NDArray_8 / _NDArray_11), FALSE) _Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int)) _NDArray_12 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_11).T / _TupleNDArray_1[ diff --git a/python/tests/__snapshots__/test_array_api/TestLDA.test_source_optimized.py b/python/tests/__snapshots__/test_array_api/TestLDA.test_source_optimized.py index a5c6cce3..b75e4422 100644 --- a/python/tests/__snapshots__/test_array_api/TestLDA.test_source_optimized.py +++ b/python/tests/__snapshots__/test_array_api/TestLDA.test_source_optimized.py @@ -12,7 +12,7 @@ def __fn(X, y): _4 = y == np.array(2) _5 = np.sum(_4) _6 = np.array((_1, _3, _5,)).astype(np.dtype(np.float64)) - _7 = _6 / np.array(150.0) + _7 = _6 / np.array(float(150)) _8 = np.zeros((3, 4,), dtype=np.dtype(np.float64)) _9 = np.sum(X[_0], axis=0) _10 = _9 / np.array(X[_0].shape[0]) @@ -39,7 +39,7 @@ def __fn(X, y): _28 = _27 / np.array(_26.shape[0]) _29 = np.sqrt(_28) _30 = _29 == np.array(0) - _29[_30] = np.array(1.0) + _29[_30] = np.array(float(1)) _31 = _21 / _29 _32 = _17 * _31 _33 = np.linalg.svd(_32, full_matrices=False) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index cd344b10..d965ca57 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -6,11 +6,11 @@ @pytest.fixture(autouse=True) def _reset_conversions(): - import egglog.runtime + import egglog.conversion - old_conversions = copy.copy(egglog.runtime.CONVERSIONS) + old_conversions = copy.copy(egglog.conversion.CONVERSIONS) yield - egglog.runtime.CONVERSIONS = old_conversions + egglog.conversion.CONVERSIONS = old_conversions class PythonSnapshotExtension(SingleFileSnapshotExtension): diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 1cc3eccf..07552a20 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -103,49 +103,54 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any: return globals[var] -def load_source(expr): - egraph = EGraph() - fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y"))) - egraph.run(array_api_program_gen_schedule) - # cast b/c issue with it not recognizing py_object as property - cast(Any, egraph.eval(fn_program.py_object)) - assert np.allclose(res, run_lda(X_np, y_np)) - return egraph.eval(fn_program.statements) +def load_source(expr, egraph: EGraph): + with egraph: + fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y"))) + egraph.run(array_api_program_gen_schedule) + # cast b/c issue with it not recognizing py_object as property + cast(Any, egraph.eval(fn_program.py_object)) + assert np.allclose(res, run_lda(X_np, y_np)) + return egraph.eval(fn_program.statements) -@pytest.mark.benchmark(min_rounds=3) -class TestLDA: - def test_trace(self, snapshot_py, benchmark): - @benchmark - def X_r2(): - X_arr = NDArray.var("X") - assume_dtype(X_arr, X_np.dtype) - assume_shape(X_arr, X_np.shape) - assume_isfinite(X_arr) +def trace_lda(egraph: EGraph): + X_arr = NDArray.var("X") + assume_dtype(X_arr, X_np.dtype) + assume_shape(X_arr, X_np.shape) + assume_isfinite(X_arr) - y_arr = NDArray.var("y") - assume_dtype(y_arr, y_np.dtype) - assume_shape(y_arr, y_np.shape) - assume_value_one_of(y_arr, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type] + y_arr = NDArray.var("y") + assume_dtype(y_arr, y_np.dtype) + assume_shape(y_arr, y_np.shape) + assume_value_one_of(y_arr, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type] - with EGraph(): - return run_lda(X_arr, y_arr) + with egraph: + return run_lda(X_arr, y_arr) + +@pytest.mark.benchmark(min_rounds=3) +class TestLDA: + def test_trace(self, snapshot_py, benchmark): + X_r2 = benchmark(trace_lda, EGraph()) assert str(X_r2) == snapshot_py def test_optimize(self, snapshot_py, benchmark): - expr = _load_py_snapshot(self.test_trace) - simplified = benchmark(lambda: EGraph().simplify(expr, array_api_numba_schedule)) + egraph = EGraph() + expr = trace_lda(egraph) + simplified = benchmark(egraph.simplify, expr, array_api_numba_schedule) assert str(simplified) == snapshot_py @pytest.mark.xfail(reason="Original source is not working") def test_source(self, snapshot_py, benchmark): - expr = _load_py_snapshot(self.test_trace) - assert benchmark(load_source, expr) == snapshot_py + egraph = EGraph() + expr = trace_lda(egraph) + assert benchmark(load_source, expr, egraph) == snapshot_py def test_source_optimized(self, snapshot_py, benchmark): - expr = _load_py_snapshot(self.test_optimize) - assert benchmark(load_source, expr) == snapshot_py + egraph = EGraph() + expr = trace_lda(egraph) + optimized_expr = egraph.simplify(expr, array_api_numba_schedule) + assert benchmark(load_source, optimized_expr, egraph) == snapshot_py @pytest.mark.parametrize( "fn", diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 57ee0d6f..211a2839 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -361,7 +361,6 @@ def test_setitem_defaults(self): class Foo(Expr): def __init__(self) -> None: ... - def __setitem__(self, key: i64Like, value: i64Like) -> None: ... foo = Foo() @@ -478,7 +477,6 @@ def f(x: i64Like) -> i64: ... def test_upcast_self_lower_cost(): # Verifies that self will be upcasted, if that upcast has a lower cast than converting the other arg # i.e. Int(x) + NDArray(y) -> NDArray(Int(x)) + NDArray(y) instead of Int(x) + NDArray(y).to_int() - EGraph() class Int(Expr): def __init__(self, name: StringLike) -> None: ... @@ -610,3 +608,24 @@ class ZX(Expr): with pytest.raises(NotImplementedError): ZX.__egg_decls__ # type: ignore[attr-defined] + + +def test_deferred_ruleset(): + @ruleset + def rules(x: AA): + yield rewrite(first(x)).to(second(x)) + + class AA(Expr): + def __init__(self) -> None: ... + + @function + def first(x: AA) -> AA: ... + + @function + def second(x: AA) -> AA: ... + + check( + eq(first(AA())).to(second(AA())), + rules, + first(AA()), + ) diff --git a/python/tests/test_pretty.py b/python/tests/test_pretty.py new file mode 100644 index 00000000..b6217ea4 --- /dev/null +++ b/python/tests/test_pretty.py @@ -0,0 +1,149 @@ +# mypy: disable-error-code="empty-body" +from __future__ import annotations + +from copy import copy +from typing import TYPE_CHECKING, ClassVar + +import pytest + +from egglog import * + +if TYPE_CHECKING: + from egglog.runtime import RuntimeExpr + + +class A(Expr): + V: ClassVar[A] + + def __init__(self) -> None: ... + @classmethod + def cls_method(cls) -> A: ... + def method(self) -> A: ... + def __neg__(self) -> A: ... + def __add__(self, other: A) -> A: ... + def __getitem__(self, key: A) -> A: ... + def __call__(self) -> A: ... + def __delitem__(self, key: A) -> None: ... + def __setitem__(self, key: A, value: A) -> None: ... + @property + def prop(self) -> A: ... + + +@function +def f(x: A) -> A: ... + + +@function +def g() -> A: ... + + +@function +def h() -> A: ... + + +@function +def p() -> i64: ... + + +@function +def has_default(x: A = A()) -> A: ... + + +del_a = A() +del del_a[g()] + +del_del_a = copy(del_a) +del del_del_a[h()] + +del_del_a_two = copy(del_a) +del del_del_a_two[A()] + +setitem_a = A() +setitem_a[g()] = h() + +b = constant("b", A) + + +@function +def my_very_long_function_name() -> A: ... + + +long_line = my_very_long_function_name() + my_very_long_function_name() + my_very_long_function_name() + +r = ruleset(name="r") + +PARAMS = [ + # expression function calls + pytest.param(A(), "A()", id="init"), + pytest.param(f(A()), "f(A())", id="call"), + pytest.param(A.cls_method(), "A.cls_method()", id="class method"), + pytest.param(A().method(), "A().method()", id="instance method"), + pytest.param(-A(), "-A()", id="unary operator"), + pytest.param(A() + g(), "A() + g()", id="binary operator"), + pytest.param(A()[g()], "A()[g()]", id="getitem"), + pytest.param(A()(), "A()()", id="call"), + pytest.param(del_a, "_A_1 = A()\ndel _A_1[g()]\n_A_1", id="delitem"), + pytest.param(setitem_a, "_A_1 = A()\n_A_1[g()] = h()\n_A_1", id="setitem"), + pytest.param(del_a + del_a, "_A_1 = A()\ndel _A_1[g()]\n_A_1 + _A_1", id="existing de-duplicate"), + pytest.param(del_del_a, "_A_1 = A()\ndel _A_1[g()]\ndel _A_1[h()]\n_A_1", id="re-use variable"), + pytest.param( + del_del_a + del_del_a_two, + """_A_1 = A() +del _A_1[g()] +_A_2 = copy(_A_1) +del _A_2[h()] +_A_3 = copy(_A_1) +del _A_3[A()] +_A_2 + _A_3""", + id="copy name", + ), + pytest.param(b, "b", id="constant"), + pytest.param(A.V, "A.V", id="class variable"), + pytest.param(A().prop, "A().prop", id="property"), + pytest.param(ne(A()).to(g()), "ne(A()).to(g())", id="ne"), + pytest.param(has_default(A()), "has_default()", id="has default"), + pytest.param( + rewrite(long_line).to(long_line), + "_A_1 = (my_very_long_function_name() + my_very_long_function_name()) + my_very_long_function_name()\nrewrite(_A_1).to(_A_1)", + id="wrap long line", + ), + # primitives + pytest.param(Unit(), "Unit()", id="unit"), + pytest.param(Bool(True), "Bool(True)", id="bool"), + pytest.param(i64(42), "i64(42)", id="i64"), + pytest.param(f64(42.1), "f64(42.1)", id="f64"), + pytest.param(String("hello"), 'String("hello")', id="string"), + pytest.param(PyObject("hi"), 'PyObject("hi")', id="pyobject"), + pytest.param(var("x", A), "x", id="variable"), + # commands + pytest.param(rewrite(g()).to(h(), A()), "rewrite(g()).to(h(), A())", id="rewrite"), + pytest.param(rule(g()).then(h()), "rule(g()).then(h())", id="rule"), + # Actions + pytest.param(expr_action(A()), "A()", id="action"), + pytest.param(set_(p()).to(i64(1)), "set_(p()).to(i64(1))", id="set"), + pytest.param(union(g()).with_(h()), "union(g()).with_(h())", id="union"), + pytest.param(let("x", A()), 'let("x", A())', id="let"), + pytest.param(expr_action(A()), "A()", id="expr action"), + pytest.param(delete(p()), "delete(p())", id="delete"), + pytest.param(panic("oh no"), 'panic("oh no")', id="panic"), + # Fact + pytest.param(expr_fact(A()), "A()", id="expr fact"), + pytest.param(eq(g()).to(h(), A()), "eq(g()).to(h(), A())", id="eq"), + # Ruleset + pytest.param(ruleset(rewrite(g()).to(h())), "ruleset(rewrite(g()).to(h()))", id="ruleset"), + # Schedules + pytest.param(r, 'ruleset(name="r")', id="ruleset with name"), + pytest.param(r.saturate(), 'ruleset(name="r").saturate()', id="saturate"), + pytest.param(r * 10, 'ruleset(name="r") * 10', id="repeat"), + pytest.param(r + r, 'ruleset(name="r") + ruleset(name="r")', id="sequence"), + pytest.param(seq(r, r, r), 'seq(ruleset(name="r"), ruleset(name="r"), ruleset(name="r"))', id="seq"), + pytest.param(run(r, h()), 'run(ruleset(name="r"), h())', id="run"), + # Functions + pytest.param(f, "f", id="function"), + pytest.param(A().method, "A().method", id="method"), +] + + +@pytest.mark.parametrize(("x", "s"), PARAMS) +def test_str(x: RuntimeExpr, s: str) -> None: + assert str(x) == s diff --git a/python/tests/test_runtime.py b/python/tests/test_runtime.py index f3b61424..5e119ee6 100644 --- a/python/tests/test_runtime.py +++ b/python/tests/test_runtime.py @@ -4,6 +4,7 @@ from egglog.declarations import * from egglog.runtime import * +from egglog.thunk import * from egglog.type_constraint_solver import * @@ -14,8 +15,8 @@ def test_type_str(): "Map": ClassDecl(type_vars=("K", "V")), } ) - i64 = RuntimeClass(decls.update_other, "i64") - Map = RuntimeClass(decls.update_other, "Map") + i64 = RuntimeClass(Thunk.value(decls), TypeRefWithVars("i64")) + Map = RuntimeClass(Thunk.value(decls), TypeRefWithVars("Map")) assert str(i64) == "i64" assert str(Map[i64, i64]) == "Map[i64, i64]" @@ -26,19 +27,13 @@ def test_function_call(): "i64": ClassDecl(), }, _functions={ - "one": FunctionDecl( - (), - (), - (), - TypeRefWithVars("i64"), - False, - ), + "one": FunctionDecl((), (), (), TypeRefWithVars("i64")), }, ) - one = RuntimeFunction(decls, "one") + one = RuntimeFunction(Thunk.value(decls), FunctionRef("one")) assert ( one().__egg_typed_expr__ # type: ignore[union-attr] - == RuntimeExpr(decls, TypedExprDecl(JustTypeRef("i64"), CallDecl(FunctionRef("one")))).__egg_typed_expr__ + == TypedExprDecl(JustTypeRef("i64"), CallDecl(FunctionRef("one"))) ) @@ -50,37 +45,20 @@ def test_classmethod_call(): "unit": ClassDecl(), "Map": ClassDecl( type_vars=("K", "V"), - class_methods={ - "create": FunctionDecl( - (), - (), - (), - TypeRefWithVars("Map", (K, V)), - False, - ) - }, + class_methods={"create": FunctionDecl((), (), (), TypeRefWithVars("Map", (K, V)))}, ), }, - _type_ref_to_egg_sort={ - JustTypeRef("i64"): "i64", - JustTypeRef("unit"): "unit", - JustTypeRef("Map"): "Map", - }, ) - Map = RuntimeClass(decls.update_other, "Map") + Map = RuntimeClass(Thunk.value(decls), TypeRefWithVars("Map")) with pytest.raises(TypeConstraintError): Map.create() # type: ignore[operator] - i64 = RuntimeClass(decls.update_other, "i64") - unit = RuntimeClass(decls.update_other, "unit") + i64 = RuntimeClass(Thunk.value(decls), TypeRefWithVars("i64")) + unit = RuntimeClass(Thunk.value(decls), TypeRefWithVars("unit")) assert ( - Map[i64, unit].create().__egg_typed_expr__ # type: ignore[union-attr] + Map[i64, unit].create().__egg_typed_expr__ # type: ignore[union-attr, operator] == TypedExprDecl( JustTypeRef("Map", (JustTypeRef("i64"), JustTypeRef("unit"))), - CallDecl( - ClassMethodRef("Map", "create"), - (), - (JustTypeRef("i64"), JustTypeRef("unit")), - ), + CallDecl(ClassMethodRef("Map", "create"), (), (JustTypeRef("i64"), JustTypeRef("unit"))), ) ) @@ -92,52 +70,39 @@ def test_expr_special(): methods={ "__add__": FunctionDecl( (TypeRefWithVars("i64"), TypeRefWithVars("i64")), - (), + ( + "a", + "b", + ), (None, None), TypeRefWithVars("i64"), - False, ) }, class_methods={ - "__init__": FunctionDecl( - (TypeRefWithVars("i64"),), - (), - (None,), - TypeRefWithVars("i64"), - False, - ) + "__init__": FunctionDecl((TypeRefWithVars("i64"),), ("self",), (None,), TypeRefWithVars("i64")) }, ), }, ) - i64 = RuntimeClass(decls.update_other, "i64") + i64 = RuntimeClass(Thunk.value(decls), TypeRefWithVars("i64")) one = i64(1) res = one + one # type: ignore[operator] - expected_res = RuntimeExpr( - decls, - TypedExprDecl( - JustTypeRef("i64"), - CallDecl( - MethodRef("i64", "__add__"), - (TypedExprDecl(JustTypeRef("i64"), LitDecl(1)), TypedExprDecl(JustTypeRef("i64"), LitDecl(1))), - ), + assert res.__egg_typed_expr__ == TypedExprDecl( + JustTypeRef("i64"), + CallDecl( + MethodRef("i64", "__add__"), + (TypedExprDecl(JustTypeRef("i64"), LitDecl(1)), TypedExprDecl(JustTypeRef("i64"), LitDecl(1))), ), ) - assert res.__egg_typed_expr__ == expected_res.__egg_typed_expr__ def test_class_variable(): decls = Declarations( _classes={ - "i64": ClassDecl(class_variables={"one": JustTypeRef("i64")}), + "i64": ClassDecl(class_variables={"one": ConstantDecl(JustTypeRef("i64"))}), }, ) - i64 = RuntimeClass(decls.update_other, "i64") + i64 = RuntimeClass(Thunk.value(decls), TypeRefWithVars("i64")) one = i64.one assert isinstance(one, RuntimeExpr) - assert ( - one.__egg_typed_expr__ - == RuntimeExpr( - decls, TypedExprDecl(JustTypeRef("i64"), CallDecl(ClassVariableRef("i64", "one"))) - ).__egg_typed_expr__ - ) + assert one.__egg_typed_expr__ == TypedExprDecl(JustTypeRef("i64"), CallDecl(ClassVariableRef("i64", "one"))) diff --git a/src/egraph.rs b/src/egraph.rs index ea4713f4..3e62af52 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -20,20 +20,22 @@ use std::sync::Arc; pub struct EGraph { egraph: egglog::EGraph, py_object_arcsort: Option>, + cmds: Option, } #[pymethods] impl EGraph { #[new] #[pyo3( - signature = (py_object_sort=None, *, fact_directory=None, seminaive=true, terms_encoding=false), - text_signature = "(py_object_sort=None, *, fact_directory=None, seminaive=True, terms_encoding=False)" + signature = (py_object_sort=None, *, fact_directory=None, seminaive=true, terms_encoding=false, record=false), + text_signature = "(py_object_sort=None, *, fact_directory=None, seminaive=True, terms_encoding=False, record=False)" )] fn new( py_object_sort: Option, fact_directory: Option, seminaive: bool, terms_encoding: bool, + record: bool, ) -> Self { let mut egraph = egglog::EGraph::default(); egraph.fact_directory = fact_directory; @@ -50,6 +52,7 @@ impl EGraph { Self { egraph, py_object_arcsort, + cmds: if record { Some(String::new()) } else { None }, } } @@ -68,10 +71,22 @@ impl EGraph { fn run_program(&mut self, commands: Vec) -> EggResult> { let commands: Vec = commands.into_iter().map(|x| x.into()).collect(); info!("Running commands {:?}", commands); + if let Some(cmds) = &mut self.cmds { + for cmd in &commands { + cmds.push_str(&cmd.to_string()); + cmds.push('\n'); + } + } let res = self.egraph.run_program(commands)?; Ok(res) } + /// Returns the text of the commands that have been run so far, if `record` was passed. + #[pyo3(signature = ())] + fn commands(&self) -> Option { + self.cmds.clone() + } + /// Gets the last expressions extracted from the EGraph, if the last command /// was a Simplify or Extract command. #[pyo3(signature = ())]