diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 14318908..28591fa3 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -9,7 +9,7 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.11" + python: "3.12" # # You can also specify other tool versions: # # nodejs: "16" rust: "1.70" diff --git a/Cargo.lock b/Cargo.lock index c88c9686..fc56effe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -312,7 +312,7 @@ dependencies = [ [[package]] name = "egglog-python" -version = "4.0.1" +version = "5.0.0" dependencies = [ "egglog", "egraph-serialize", diff --git a/docs/changelog.md b/docs/changelog.md index b2395fb5..6d86a5bb 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,28 @@ _This project uses semantic versioning_ ## UNRELEASED +### Auto Register Function and Class Definitions + +_not yet implemented_ + +This is a large breaking change that moves the function and class decorators to the top level `egglog` module, +from the `EGraph` and `Module` classes. Rulesets are also moved to be defined globally instead of on the `EGraph` class. + +The goal of this change is to remove the complexity of `Module`s and remove the need to think about what functions/classes +need to be registered for each `EGraph`. Instead, we will implicitly register and functions/classes that are used +in any rules or added in any commands. + +- `egraph.class_` -> Simply subclass from `egglog.Expr` +- `egraph.method` -> `egglog.method` +- `egraph.function` -> `egglog.function` +- `egraph.relation` -> `egglog.relation` +- `egraph.ruleset` -> `egglog.Ruleset` + +The `EGraph` class can take an optional `default_ruleset` argument to set the default ruleset for the `EGraph`. Otherwise, +there is a global default ruleset that is used, `egglog.Ruleset`. + +This also adds support for classes with methods that are mutually recursive, by making type analysis more lazy. + ## 5.0.0 (2024-01-16) - Move egglog `!=` function to be called with `ne(x).to(y)` instead of `x != y` so that user defined expressions diff --git a/docs/conf.py b/docs/conf.py index 1bb85f9d..aef67543 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,7 +26,13 @@ # Myst ## -myst_enable_extensions = ["attrs_inline", "smartquotes", "strikethrough", "html_image"] +myst_enable_extensions = [ + # "attrs_inline", + "smartquotes", + "strikethrough", + "html_image", + "deflist", +] myst_fence_as_directive = ["mermaid"] ## @@ -38,7 +44,7 @@ output_dir = cwd / "presentations" subprocess.run( - [ # noqa: S603,S607 + [ # noqa: S607, S603 "jupyter", "nbconvert", str(presentation_file), diff --git a/docs/new_reference.md b/docs/new_reference.md new file mode 100644 index 00000000..c2728993 --- /dev/null +++ b/docs/new_reference.md @@ -0,0 +1,93 @@ +# Reference Documentation + +```{module} egglog + +``` + +_TODO: this isn't done yet_ +This is a definitive reference of `egglog` module and the concepts in it. + +## Terms + +Ruleset +: A colleciton of rules + +Rule +: Updates an EGraph by matching on a number of facts and then running a number of actions. Any variables in the facts can be used in the actions. + +Fact +: A query on an EGraph, either by an expression or an equivalence between multiple expressions. + +Action +: A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking. +Union +: Merges two equivalence classes of two expressions. +Set +: Similar to union, except can be used on primitive expressions, whereas union can only be used on user defined expressions. +Delete +: Remove a function call from an EGraph. + +Schedule +: A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met + +EGraph ([](egglog.EGraph)) +: An equivalence relation over a set of expressions. +: A collection of expressions where each expression is part of a distinct equivalence class. +: Can run actions, check facts, run schedules, or extract minimal cost expressions. + +Expression ([](egglog.Expr)) +: Either a function called with some number of argument expressions or a literal integer, float, or string, with a particular type. + +Function ([](egglog.function)) +: Defined by a unique name and a typing relation that will specify the return type based on the types of the argument expressions. +: These can either be builtin functions, which are implemented in Rust, or user defined function which have types for each argument and the return type. +: Relations ([](egglog.relation)), constants ([](egglog.constant)), methods ([](egglog.method)), classmethods, and class variables are all syntactic sugar for defining functions. + +Type (called a "sort" in the rust bindings) +: A uniquely named entity, that is used to constraint the composition of functions. +: Can be either a primitive type if it is defined in Rust or a user defined type. + +## Classes + +```{class} EGraph + +``` + +https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#module-sphinx.ext.autodoc +https://www.sphinx-doc.org/en/master/usage/domains/python.html#signatures +https://myst-parser.readthedocs.io/en/latest/syntax/roles-and-directives.html + +```{class} Expr + +Subclass `Expr` to create a new type. Only direct subclasses are supported, subclasses of subclasses are not for now. + + +``` + +```{decorator} function + +``` + +```{decorator} method + +Any method can be decorated with this to customize it's behavior. This is only supported in classes which subclass [](egglog.Expr). + +``` + +```{function} relation + +Creates a function whose return type is [](egglog.Unit) and whose default value is `[](egglog.Unit)`. + +``` + +```{function} constant + +A "constant" is implemented as the instantiation of a value that takes no args. +This creates a function with `name` and return type `tp` and returns a value of it being called. + +``` + +```{class} Unit + +Primitive type with only one possible value. +``` diff --git a/docs/reference.md b/docs/reference.md index 0333133b..fa0521f2 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -22,4 +22,5 @@ reference/high-level reference/egglog-translation reference/python-integration reference/bindings +new_reference ``` diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index 66e90b43..53a365b5 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -105,7 +105,7 @@ assert egraph.eval(evalled) == 3 ### Simpler Eval -Instead of using the above low level primitive for evaluating, there is a higher level wrapper function, `eval_fn`. +Instead of using the above low level primitive for evaluating, there is a higher level wrapper function, `py_eval_fn`. It takes in a Python function and converts it to a function of PyObjects, by using `py_eval` under the hood. @@ -115,7 +115,7 @@ The above code code be re-written like this: def my_add(a, b): return a + b -evalled = eval_fn(lambda a: my_add(a, 2))(1) +evalled = py_eval_fn(lambda a: my_add(a, 2))(1) assert egraph.eval(evalled) == 3 ``` diff --git a/pyproject.toml b/pyproject.toml index 67b92b5c..b85f8ffa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = ["typing-extensions", "black", "graphviz"] [project.optional-dependencies] -array = ["scikit-learn", "array_api_compat", 'numba; python_version<"3.12"'] +array = ["scikit-learn", "array_api_compat", "numba==0.59.0rc1", "llvmlite==0.42.0rc1"] dev = ["pre-commit", "ruff", "mypy", "anywidget[dev]", "egglog[docs,test]"] test = ["pytest", "mypy", "syrupy", "egglog[array]"] diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index fc950b9c..fdb33132 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -4,13 +4,12 @@ from __future__ import annotations -import itertools from collections import defaultdict from dataclasses import dataclass, field from inspect import Parameter, Signature -from typing import TYPE_CHECKING, TypeAlias +from typing import TYPE_CHECKING, Protocol, TypeAlias, Union, runtime_checkable -from typing_extensions import assert_never +from typing_extensions import Self, assert_never from . import bindings @@ -20,7 +19,8 @@ __all__ = [ "Declarations", - "ModuleDeclarations", + "DeclerationsLike", + "upcast_decleratioons", "JustTypeRef", "ClassTypeVarRef", "TypeRefWithVars", @@ -100,6 +100,30 @@ } +@runtime_checkable +class HasDeclerations(Protocol): + @property + def __egg_decls__(self) -> Declarations: + ... + + +DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"] + + +def upcast_decleratioons(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]: + d = [] + for l in declerations_like: + if l is None: + continue + if isinstance(l, HasDeclerations): + d.append(l.__egg_decls__) + elif isinstance(l, Declarations): + d.append(l) + else: + assert_never(l) + return d + + @dataclass class Declarations: _functions: dict[str, FunctionDecl] = field(default_factory=dict) @@ -116,6 +140,74 @@ class Declarations: _egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict) _type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict) + # Mapping from egg name (of sort or function) to command to create it. + _cmds: dict[str, bindings._Command] = field(default_factory=dict) + + def __post_init__(self) -> None: + if "!=" not in self._egg_fn_to_callable_refs: + self.register_callable_ref(FunctionRef("!="), "!=") + + @classmethod + def create(cls, *others: DeclerationsLike) -> Declarations: + others = upcast_decleratioons(others) + if not others: + return Declarations() + first, *rest = others + new = first.copy() + new.update(*rest) + return new + + def copy(self) -> Declarations: + return Declarations( + _functions=self._functions.copy(), + _classes=self._classes.copy(), + _constants=self._constants.copy(), + _egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self._egg_fn_to_callable_refs.items()}), + _callable_ref_to_egg_fn=self._callable_ref_to_egg_fn.copy(), + _egg_sort_to_type_ref=self._egg_sort_to_type_ref.copy(), + _type_ref_to_egg_sort=self._type_ref_to_egg_sort.copy(), + _cmds=self._cmds.copy(), + ) + + def __deepcopy__(self, memo: dict) -> Declarations: + return self.copy() + + def add_cmd(self, name: str, cmd: bindings._Command) -> None: + self._cmds[name] = cmd + + def list_cmds(self) -> list[bindings._Command]: + return list(self._cmds.values()) + + def update(self, *others: DeclerationsLike) -> None: + for other in others: + self |= other + + def __or__(self, other: DeclerationsLike) -> Declarations: + result = Declarations() + result |= self + result |= other + return result + + def __ior__(self, other: DeclerationsLike) -> Self: + if other is None: + return self + if isinstance(other, HasDeclerations): + other = other.__egg_decls__ + # If cmds are == skip unioning for time savings + # if set(self._cmds) == set(other._cmds) and self.record_cmds and other.record_cmds: + # return self + + self._functions |= other._functions + self._classes |= other._classes + self._constants |= other._constants + self._egg_sort_to_type_ref |= other._egg_sort_to_type_ref + self._type_ref_to_egg_sort |= other._type_ref_to_egg_sort + self._cmds |= other._cmds + self._callable_ref_to_egg_fn |= other._callable_ref_to_egg_fn + for egg_fn, callable_refs in other._egg_fn_to_callable_refs.items(): + self._egg_fn_to_callable_refs[egg_fn] |= callable_refs + return self + def set_function_decl(self, ref: FunctionCallableRef, decl: FunctionDecl) -> None: """ Sets a function declaration for the given callable reference. @@ -164,26 +256,6 @@ def register_callable_ref(self, ref: CallableRef, egg_name: str) -> None: self._callable_ref_to_egg_fn[ref] = egg_name self._egg_fn_to_callable_refs[egg_name].add(ref) - def get_function_decl(self, ref: FunctionCallableRef) -> FunctionDecl: - match ref: - case FunctionRef(name): - return self._functions[name] - case MethodRef(class_name, method_name): - return self._classes[class_name].methods[method_name] - case ClassMethodRef(class_name, method_name): - return self._classes[class_name].class_methods[method_name] - case PropertyRef(class_name, property_name): - return self._classes[class_name].properties[property_name] - assert_never(ref) - - def get_constant_type(self, ref: ConstantCallableRef) -> JustTypeRef: - match ref: - case ConstantRef(name): - return self._constants[name] - case ClassVariableRef(class_name, variable_name): - return self._classes[class_name].class_variables[variable_name] - assert_never(ref) - def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]: return self._egg_fn_to_callable_refs[egg_name] @@ -193,120 +265,58 @@ def get_egg_fn(self, ref: CallableRef) -> str: def get_egg_sort(self, ref: JustTypeRef) -> str: return self._type_ref_to_egg_sort[ref] - def op_mapping(self) -> dict[str, str]: - return {k: str(next(iter(v))) for k, v in self._egg_fn_to_callable_refs.items() if len(v) == 1} - - -@dataclass -class ModuleDeclarations: - """ - A set of working declerations for a module. - """ - - # The modules declarations we have, which we can edit - _decl: Declarations - # A list of other declarations we can use, but not edit - _included_decls: list[Declarations] = field(default_factory=list, repr=False) - def op_mapping(self) -> dict[str, str]: """ Create a mapping of egglog function name to Python function name, for use in the serialized format for better visualization. """ - mapping = self._decl.op_mapping() - for decl in self._included_decls: - mapping.update(decl.op_mapping()) - return mapping - - @classmethod - def parent_decl(cls, a: ModuleDeclarations, b: ModuleDeclarations) -> ModuleDeclarations: - """ - Returns the declerations which has the other as a child. - """ - if b._decl in a.all_decls: - return a - if a._decl in b.all_decls: - return b - msg = "No parent decl found" - raise ValueError(msg) - - @property - def all_decls(self) -> Iterable[Declarations]: - return itertools.chain([self._decl], self._included_decls) + return {k: str(next(iter(v))) for k, v in self._egg_fn_to_callable_refs.items() if len(v) == 1} def has_method(self, class_name: str, method_name: str) -> bool | None: """ Returns whether the given class has the given method, or None if we cant find the class. """ - for decl in self.all_decls: - if class_name in decl._classes: - return method_name in decl._classes[class_name].methods + if class_name in self._classes: + return method_name in self._classes[class_name].methods return None def get_function_decl(self, ref: CallableRef) -> FunctionDecl: - if isinstance(ref, ClassVariableRef | ConstantRef): - for decls in self.all_decls: - try: - return decls.get_constant_type(ref).to_constant_function_decl() - except KeyError: - pass - raise KeyError(f"Constant {ref} not found") - if isinstance(ref, FunctionRef | MethodRef | ClassMethodRef | PropertyRef): - for decls in self.all_decls: - try: - return decls.get_function_decl(ref) - except KeyError: - pass - raise KeyError(f"Function {ref} not found") + match ref: + case ConstantRef(name): + return self._constants[name].to_constant_function_decl() + case ClassVariableRef(class_name, variable_name): + return self._classes[class_name].class_variables[variable_name].to_constant_function_decl() + case FunctionRef(name): + return self._functions[name] + case MethodRef(class_name, method_name): + return self._classes[class_name].methods[method_name] + case ClassMethodRef(class_name, method_name): + return self._classes[class_name].class_methods[method_name] + case PropertyRef(class_name, property_name): + return self._classes[class_name].properties[property_name] assert_never(ref) - def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]: - return itertools.chain.from_iterable(decls.get_callable_refs(egg_name) for decls in self.all_decls) - - def get_egg_fn(self, ref: CallableRef) -> str: - for decls in self.all_decls: - try: - return decls.get_egg_fn(ref) - except KeyError: - pass - raise KeyError(f"Callable ref {ref!r} not found") - - def get_egg_sort(self, ref: JustTypeRef) -> str: - for decls in self.all_decls: - try: - return decls.get_egg_sort(ref) - except KeyError: - pass - raise KeyError(f"Type {ref} not found") - def get_class_decl(self, name: str) -> ClassDecl: - for decls in self.all_decls: - try: - return decls._classes[name] - except KeyError: - pass - raise KeyError(f"Class {name} not found") + return self._classes[name] def get_registered_class_args(self, cls_name: str) -> tuple[JustTypeRef, ...]: """ Given a class name, returns the first typevar regsisted with args of that class. """ - for decl in self.all_decls: - for tp in decl._type_ref_to_egg_sort: - if tp.name == cls_name and tp.args: - return tp.args + for tp in self._type_ref_to_egg_sort: + if tp.name == cls_name and tp.args: + return tp.args return () - def register_class(self, name: str, n_type_vars: int, egg_sort: str | None) -> Iterable[bindings._Command]: + def register_class(self, name: str, n_type_vars: int, builtin: bool, egg_sort: str | None) -> None: # Register class first - if name in self._decl._classes: + if name in self._classes: raise ValueError(f"Class {name} already registered") decl = ClassDecl(n_type_vars=n_type_vars) - self._decl._classes[name] = decl - _egg_sort, cmds = self.register_sort(JustTypeRef(name), egg_sort) - return cmds + self._classes[name] = decl + self.register_sort(JustTypeRef(name), builtin, egg_sort) - def register_sort(self, ref: JustTypeRef, egg_name: str | None = None) -> tuple[str, Iterable[bindings._Command]]: + def register_sort(self, ref: JustTypeRef, builtin: bool, egg_name: str | None = None) -> str: """ Register a sort with the given name. If no name is given, one is generated. @@ -318,13 +328,27 @@ def register_sort(self, ref: JustTypeRef, egg_name: str | None = None) -> tuple[ except KeyError: pass else: - return (egg_sort, []) + return egg_sort egg_name = egg_name or ref.generate_egg_name() - if egg_name in self._decl._egg_sort_to_type_ref: + if egg_name in self._egg_sort_to_type_ref: raise ValueError(f"Sort {egg_name} is already registered.") - self._decl._egg_sort_to_type_ref[egg_name] = ref - self._decl._type_ref_to_egg_sort[ref] = egg_name - return egg_name, ref.to_commands(self) + self._egg_sort_to_type_ref[egg_name] = ref + self._type_ref_to_egg_sort[ref] = egg_name + if not builtin: + self.add_cmd( + egg_name, + bindings.Sort( + egg_name, + ( + self.get_egg_sort(JustTypeRef(ref.name)), + [bindings.Var(self.register_sort(arg, False)) for arg in ref.args], + ) + if ref.args + else None, + ), + ) + + return egg_name def register_function_callable( self, @@ -336,30 +360,61 @@ def register_function_callable( merge: ExprDecl | None, merge_action: list[bindings._Action], unextractable: bool, + builtin: bool, is_relation: bool = False, - ) -> Iterable[bindings._Command]: + ) -> None: """ Registers a callable with the given egg name. The callable's function needs to be registered first. """ egg_name = egg_name or ref.generate_egg_name() - self._decl.register_callable_ref(ref, egg_name) - self._decl.set_function_decl(ref, fn_decl) - return fn_decl.to_commands(self, egg_name, cost, default, merge, merge_action, is_relation, unextractable) - - def register_constant_callable( - self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: str | None - ) -> Iterable[bindings._Command]: - egg_function = ref.generate_egg_name() - self._decl.register_callable_ref(ref, egg_function) - self._decl.set_constant_type(ref, type_ref) - # Create a function decleartion for a constant function. This is similar to how egglog compiles - # the `declare` command. - return FunctionDecl((), (), (), type_ref.to_var(), False).to_commands(self, egg_name or ref.generate_egg_name()) + self.register_callable_ref(ref, egg_name) + self.set_function_decl(ref, fn_decl) + + # Skip generating the cmds if we don't want to record them, like for the builtins + if builtin: + return + + if fn_decl.var_arg_type is not None: + msg = "egglog does not support variable arguments yet." + raise NotImplementedError(msg) + # Remove all vars from the type refs, raising an errory if we find one, + # since we cannot create egg functions with vars + arg_sorts = [self.register_sort(a.to_just(), False) for a in fn_decl.arg_types] + cmd: bindings._Command + if is_relation: + assert not default + assert not merge + assert not merge_action + assert not cost + cmd = bindings.Relation(egg_name, arg_sorts) + else: + egg_fn_decl = bindings.FunctionDecl( + egg_name, + bindings.Schema(arg_sorts, self.register_sort(fn_decl.return_type.to_just(), False)), + default.to_egg(self) if default else None, + merge.to_egg(self) if merge else None, + merge_action, + cost, + unextractable, + ) + cmd = bindings.Function(egg_fn_decl) + self.add_cmd(egg_name, cmd) + + def register_constant_callable(self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: str | None) -> None: + egg_name = egg_name or ref.generate_egg_name() + self.register_callable_ref(ref, egg_name) + self.set_constant_type(ref, type_ref) + egg_sort = self.register_sort(type_ref, False) + # self.add_cmd(egg_name, bindings.Declare(egg_name, self.get_egg_sort(type_ref))) + # Use function decleration instead of constant b/c constants cannot be extracted + # https://github.com/egraphs-good/egglog/issues/334 + fn_decl = bindings.FunctionDecl(egg_name, bindings.Schema([], egg_sort)) + self.add_cmd(egg_name, bindings.Function(fn_decl)) def register_preserved_method(self, class_: str, method: str, fn: Callable) -> None: - self._decl._classes[class_].preserved_methods[method] = fn + self._classes[class_].preserved_methods[method] = fn # Have two different types of type refs, one that can include vars recursively and one that cannot. @@ -379,18 +434,6 @@ def generate_egg_name(self) -> str: args = "_".join(a.generate_egg_name() for a in self.args) return f"{self.name}_{args}" - def to_commands(self, mod_decls: ModuleDeclarations) -> Iterable[bindings._Command]: - """ - Returns commands to register this as a sort, as well as for any of its arguments. - """ - egg_name = mod_decls.get_egg_sort(self) - arg_sorts: list[bindings._Expr] = [] - for arg in self.args: - egg_sort, cmds = mod_decls.register_sort(arg) - arg_sorts.append(bindings.Var(egg_sort)) - yield from cmds - yield bindings.Sort(egg_name, (self.name, arg_sorts) if arg_sorts else None) - def to_var(self) -> TypeRefWithVars: return TypeRefWithVars(self.name, tuple(a.to_var() for a in self.args)) @@ -487,9 +530,6 @@ class ClassMethodRef: class_name: str method_name: str - def to_egg(self, decls: Declarations) -> str: - return decls.get_egg_fn(self) - def generate_egg_name(self) -> str: return f"{self.class_name}_{self.method_name}" @@ -568,50 +608,6 @@ def to_signature(self, transform_default: Callable[[TypedExprDecl], object]) -> parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL)) return Signature(parameters) - def to_commands( - self, - mod_decls: ModuleDeclarations, - egg_name: str, - cost: int | None = None, - default: ExprDecl | None = None, - merge: ExprDecl | None = None, - merge_action: list[bindings._Action] | None = None, - is_relation: bool = False, - unextractable: bool = False, - ) -> Iterable[bindings._Command]: - if merge_action is None: - merge_action = [] - if self.var_arg_type is not None: - msg = "egglog does not support variable arguments yet." - raise NotImplementedError(msg) - arg_sorts: list[str] = [] - for a in self.arg_types: - # Remove all vars from the type refs, raising an errory if we find one, - # since we cannot create egg functions with vars - arg_sort, cmds = mod_decls.register_sort(a.to_just()) - yield from cmds - arg_sorts.append(arg_sort) - return_sort, cmds = mod_decls.register_sort(self.return_type.to_just()) - yield from cmds - if is_relation: - assert not default - assert not merge - assert not merge_action - assert not cost - assert return_sort == "Unit" - yield bindings.Relation(egg_name, arg_sorts) - return - egg_fn_decl = bindings.FunctionDecl( - egg_name, - bindings.Schema(arg_sorts, return_sort), - default.to_egg(mod_decls) if default else None, - merge.to_egg(mod_decls) if merge else None, - merge_action, - cost, - unextractable, - ) - yield bindings.Function(egg_fn_decl) - @dataclass(frozen=True) class VarDecl: @@ -622,7 +618,7 @@ def from_egg(cls, var: bindings.Var) -> TypedExprDecl: msg = "Cannot turn var into egg type because typing unknown." raise NotImplementedError(msg) - def to_egg(self, _decls: ModuleDeclarations) -> bindings.Var: + def to_egg(self, _decls: Declarations) -> bindings.Var: return bindings.Var(self.name) def pretty(self, context: PrettyContext, **kwargs) -> str: @@ -644,7 +640,7 @@ def __hash__(self) -> int: def from_egg(cls, egraph: bindings.EGraph, call: bindings.Call) -> TypedExprDecl: return TypedExprDecl(JustTypeRef("PyObject"), cls(egraph.eval_py_object(call))) - def to_egg(self, _decls: ModuleDeclarations) -> bindings._Expr: + def to_egg(self, _decls: Declarations) -> bindings._Expr: return GLOBAL_PY_OBJECT_SORT.store(self.value) def pretty(self, context: PrettyContext, **kwargs) -> str: @@ -674,7 +670,7 @@ def from_egg(cls, lit: bindings.Lit) -> TypedExprDecl: return TypedExprDecl(JustTypeRef("Unit"), cls(None)) assert_never(lit.value) - def to_egg(self, _decls: ModuleDeclarations) -> bindings.Lit: + def to_egg(self, _decls: Declarations) -> bindings.Lit: if self.value is None: return bindings.Lit(bindings.Unit()) if isinstance(self.value, bool): @@ -735,31 +731,41 @@ def __eq__(self, other: object) -> bool: return hash(self) == hash(other) @classmethod - def from_egg(cls, egraph: bindings.EGraph, mod_decls: ModuleDeclarations, call: bindings.Call) -> TypedExprDecl: + def from_egg(cls, egraph: bindings.EGraph, decls: Declarations, call: bindings.Call) -> TypedExprDecl: + """ + Convert an egg expression into a typed expression by using the declerations. + + For use in extract + """ from .type_constraint_solver import TypeConstraintSolver - results = tuple(TypedExprDecl.from_egg(egraph, mod_decls, a) for a in call.args) + results = tuple(TypedExprDecl.from_egg(egraph, decls, a) for a in call.args) arg_types = tuple(r.tp for r in results) # Find the first callable ref that matches the call - for callable_ref in mod_decls.get_callable_refs(call.name): + for callable_ref in decls.get_callable_refs(call.name): # If this is a classmethod, we might need the type params that were bound for this type # egglog currently only allows one instantiated type of any generic sort to be used in any program # So we just lookup what args were registered for this sort if isinstance(callable_ref, ClassMethodRef): - cls_args = mod_decls.get_registered_class_args(callable_ref.class_name) + cls_args = decls.get_registered_class_args(callable_ref.class_name) tcs = TypeConstraintSolver.from_type_parameters(cls_args) else: tcs = TypeConstraintSolver() - fn_decl = mod_decls.get_function_decl(callable_ref) + fn_decl = decls.get_function_decl(callable_ref) return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types) return TypedExprDecl(return_tp, cls(callable_ref, tuple(results))) raise ValueError(f"Could not find callable ref for call {call}") - def to_egg(self, mod_decls: ModuleDeclarations) -> bindings.Call: + def to_egg(self, decls: Declarations) -> bindings._Expr: """Convert a Call to an egg Call.""" - egg_fn = mod_decls.get_egg_fn(self.callable) - return bindings.Call(egg_fn, [a.to_egg(mod_decls) for a in self.args]) + # This was removed when we replaced declerations constants with our b/c of unextractable constants + # # If this is a constant, then emit it just as a var, not as a call + # if isinstance(self.callable, ConstantRef | ClassVariableRef): + # decls.get_egg_fn + # return bindings.Var(egg_fn) + egg_fn = decls.get_egg_fn(self.callable) + return bindings.Call(egg_fn, [a.to_egg(decls) for a in self.args]) def pretty(self, context: PrettyContext, parens: bool = True, **kwargs) -> str: # noqa: C901 """ @@ -770,7 +776,10 @@ def pretty(self, context: PrettyContext, parens: bool = True, **kwargs) -> str: if self in context.names: return context.names[self] ref, args = self.callable, [a.expr for a in self.args] - function_decl = context.mod_decls.get_function_decl(ref) + # Special case != + if ref == FunctionRef("!="): + return f"ne({args[0].pretty(context, parens=False, unwrap_lit=False)}).to({args[1].pretty(context, parens=False, unwrap_lit=False)})" + function_decl = context.decls.get_function_decl(ref) # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default n_defaults = 0 for arg, default in zip( @@ -882,7 +891,7 @@ def _pretty_call(context: PrettyContext, fn: str, args: Iterable[ExprDecl]) -> s @dataclass class PrettyContext: - mod_decls: ModuleDeclarations + decls: Declarations # List of statements of "context" setting variable for the expr statements: list[str] = field(default_factory=list) @@ -924,51 +933,51 @@ def traverse_for_parents(self, expr: ExprDecl) -> None: self.traverse_for_parents(arg.expr) -def test_expr_pretty(): - context = PrettyContext(ModuleDeclarations(Declarations())) - assert VarDecl("x").pretty(context) == "x" - assert LitDecl(42).pretty(context) == "i64(42)" - assert LitDecl("foo").pretty(context) == 'String("foo")' - assert LitDecl(None).pretty(context) == "unit()" +# def test_expr_pretty(): +# context = PrettyContext(ModuleDeclarations(Declarations())) +# assert VarDecl("x").pretty(context) == "x" +# assert LitDecl(42).pretty(context) == "i64(42)" +# assert LitDecl("foo").pretty(context) == 'String("foo")' +# assert LitDecl(None).pretty(context) == "unit()" - def v(x: str) -> TypedExprDecl: - return TypedExprDecl(JustTypeRef(""), VarDecl(x)) +# def v(x: str) -> TypedExprDecl: +# return TypedExprDecl(JustTypeRef(""), VarDecl(x)) - assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(context) == "foo(x)" - assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(context) == "foo(x, y, z)" - assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(context) == "x + y" - assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(context) == "x[y]" - assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(context) == "foo(x, y)" - assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(context) == "foo.bar(x, y)" - assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(context) == "x(y)" - assert ( - CallDecl( - ClassMethodRef("Map", "__init__"), - (), - (JustTypeRef("i64"), JustTypeRef("Unit")), - ).pretty(context) - == "Map[i64, Unit]()" - ) +# assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(context) == "foo(x)" +# assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(context) == "foo(x, y, z)" +# assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(context) == "x + y" +# assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(context) == "x[y]" +# assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(context) == "foo(x, y)" +# assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(context) == "foo.bar(x, y)" +# assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(context) == "x(y)" +# assert ( +# CallDecl( +# ClassMethodRef("Map", "__init__"), +# (), +# (JustTypeRef("i64"), JustTypeRef("Unit")), +# ).pretty(context) +# == "Map[i64, Unit]()" +# ) -def test_setitem_pretty(): - context = PrettyContext(ModuleDeclarations(Declarations())) +# def test_setitem_pretty(): +# context = PrettyContext(ModuleDeclarations(Declarations())) - def v(x: str) -> TypedExprDecl: - return TypedExprDecl(JustTypeRef("typ"), VarDecl(x)) +# def v(x: str) -> TypedExprDecl: +# return TypedExprDecl(JustTypeRef("typ"), VarDecl(x)) - final_expr = CallDecl(MethodRef("foo", "__setitem__"), (v("x"), v("y"), v("z"))).pretty(context) - assert context.render(final_expr) == "_typ_1 = x\n_typ_1[y] = z\n_typ_1" +# final_expr = CallDecl(MethodRef("foo", "__setitem__"), (v("x"), v("y"), v("z"))).pretty(context) +# assert context.render(final_expr) == "_typ_1 = x\n_typ_1[y] = z\n_typ_1" -def test_delitem_pretty(): - context = PrettyContext(ModuleDeclarations(Declarations())) +# def test_delitem_pretty(): +# context = PrettyContext(ModuleDeclarations(Declarations())) - def v(x: str) -> TypedExprDecl: - return TypedExprDecl(JustTypeRef("typ"), VarDecl(x)) +# def v(x: str) -> TypedExprDecl: +# return TypedExprDecl(JustTypeRef("typ"), VarDecl(x)) - final_expr = CallDecl(MethodRef("foo", "__delitem__"), (v("x"), v("y"))).pretty(context) - assert context.render(final_expr) == "_typ_1 = x\ndel _typ_1[y]\n_typ_1" +# final_expr = CallDecl(MethodRef("foo", "__delitem__"), (v("x"), v("y"))).pretty(context) +# assert context.render(final_expr) == "_typ_1 = x\ndel _typ_1[y]\n_typ_1" # TODO: Multiple mutations, @@ -982,7 +991,7 @@ class TypedExprDecl: expr: ExprDecl @classmethod - def from_egg(cls, egraph: bindings.EGraph, mod_decls: ModuleDeclarations, expr: bindings._Expr) -> TypedExprDecl: + def from_egg(cls, egraph: bindings.EGraph, decls: Declarations, expr: bindings._Expr) -> TypedExprDecl: if isinstance(expr, bindings.Var): return VarDecl.from_egg(expr) if isinstance(expr, bindings.Lit): @@ -990,10 +999,10 @@ def from_egg(cls, egraph: bindings.EGraph, mod_decls: ModuleDeclarations, expr: if isinstance(expr, bindings.Call): if expr.name == "py-object": return PyObjectDecl.from_egg(egraph, expr) - return CallDecl.from_egg(egraph, mod_decls, expr) + return CallDecl.from_egg(egraph, decls, expr) assert_never(expr) - def to_egg(self, decls: ModuleDeclarations) -> bindings._Expr: + def to_egg(self, decls: Declarations) -> bindings._Expr: return self.expr.to_egg(decls) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 66b45bf6..1c53765c 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable from contextvars import ContextVar, Token -from copy import copy, deepcopy +from copy import deepcopy from dataclasses import InitVar, dataclass, field from inspect import Parameter, currentframe, signature from types import FunctionType @@ -37,7 +37,7 @@ from .declarations import * from .ipython_magic import IN_IPYTHON from .runtime import * -from .runtime import _resolve_callable, class_to_ref, convert_to_same_type +from .runtime import _resolve_callable, _resolve_literal, class_to_ref, convert_to_same_type if TYPE_CHECKING: import ipywidgets @@ -96,13 +96,13 @@ } -_BUILTIN_DECLS: Declarations | None = None - ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"} +_PY_OBJECT_CLASS: RuntimeClass | None = None + @dataclass -class _BaseModule(ABC): +class _BaseModule: """ Base Module which provides methods to register sorts, expressions, actions etc. @@ -112,28 +112,37 @@ class _BaseModule(ABC): - Module: Stores a list of commands and additional declerations """ - # Any modules you want to depend on + is_builtin: ClassVar[bool] = False + + # TODO: If we want to preserve existing semantics, then we use the module to find the default schedules + # and add them to the + modules: InitVar[list[Module]] = [] # noqa: RUF008 - # All dependencies flattened + + # TODO: Move commands to Decleraration instance. Pass in is_builtins to declerations so we can skip adding commands for those. Pass in from module, set as argument of module and subclcass + + # Any modules you want to depend on + # # All dependencies flattened _flatted_deps: list[Module] = field(init=False, default_factory=list) - _mod_decls: ModuleDeclarations = field(init=False) + # _mod_decls: ModuleDeclarations = field(init=False) def __post_init__(self, modules: list[Module]) -> None: - included_decls = [_BUILTIN_DECLS] if _BUILTIN_DECLS else [] - # Traverse all the included modules to flatten all their dependencies and add to the included declerations + # included_decls = [_BUILTIN_DECLS] if _BUILTIN_DECLS else [] + # # Traverse all the included modules to flatten all their dependencies and add to the included declerations for mod in modules: for child_mod in [*mod._flatted_deps, mod]: if child_mod not in self._flatted_deps: self._flatted_deps.append(child_mod) - included_decls.append(child_mod._mod_decls._decl) - self._mod_decls = ModuleDeclarations(Declarations(), included_decls) - @abstractmethod - def _process_commands(self, cmds: Iterable[bindings._Command]) -> None: - """ - Process the commands generated by this module. - """ - raise NotImplementedError + # self._mod_decls = ModuleDeclarations(Declarations(), included_decls) + + # # TODO: Move to EGraph itself + # @abstractmethod + # def _process_commands(self, cmds: Iterable[bindings._Command]) -> None: + # """ + # Process the commands generated by this module. + # """ + # raise NotImplementedError @overload def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]: @@ -147,6 +156,7 @@ def class_(self, *args, **kwargs) -> Any: """ Registers a class. """ + # Get locals and globals from parent frame so we can infer types from it. frame = currentframe() assert frame prev_frame = frame.f_back @@ -158,7 +168,7 @@ def class_(self, *args, **kwargs) -> Any: assert len(args) == 1 return self._class(args[0], prev_frame.f_locals, prev_frame.f_globals) - def _class( # noqa: PLR0912 + def _class( # noqa: PLR0912, C901 self, cls: type[Expr], hint_locals: dict[str, Any], @@ -168,7 +178,12 @@ def _class( # noqa: PLR0912 """ Registers a class. """ + global _PY_OBJECT_CLASS + decls = Declarations() cls_name = cls.__name__ + runtime_class = RuntimeClass(decls, cls_name) + if cls_name == "PyObject": + _PY_OBJECT_CLASS = runtime_class # Get all the methods from the class cls_dict: dict[str, Any] = { k: v for k, v in cls.__dict__.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod) @@ -176,7 +191,7 @@ def _class( # noqa: PLR0912 parameters: list[TypeVar] = cls_dict.pop("__parameters__", []) n_type_vars = len(parameters) - self._process_commands(self._mod_decls.register_class(cls_name, n_type_vars, egg_sort)) + decls.register_class(cls_name, n_type_vars, self.is_builtin, egg_sort) # The type ref of self is paramterized by the type vars slf_type_ref = TypeRefWithVars(cls_name, tuple(ClassTypeVarRef(i) for i in range(n_type_vars))) @@ -186,7 +201,7 @@ def _class( # noqa: PLR0912 for k, v in get_type_hints(cls, globalns=hint_globals, localns=hint_locals).items(): if v.__origin__ == ClassVar: (inner_tp,) = v.__args__ - self._register_constant(ClassVariableRef(cls_name, k), inner_tp, None, (cls, cls_name)) + self._register_constant(decls, ClassVariableRef(cls_name, k), inner_tp, None, (cls, cls_name)) else: msg = "The only supported annotations on class attributes are class vars" raise NotImplementedError(msg) @@ -207,7 +222,7 @@ def _class( # noqa: PLR0912 mutates_first_arg = method.mutates_self unextractable = method.unextractable if method.preserve: - self._mod_decls.register_preserved_method(cls_name, method_name, fn) + decls.register_preserved_method(cls_name, method_name, fn) continue else: fn = method @@ -237,6 +252,7 @@ def _class( # noqa: PLR0912 else MethodRef(cls_name, method_name) ) self._register_function( + decls, ref, egg_fn, fn, @@ -253,13 +269,13 @@ def _class( # noqa: PLR0912 # Otherwise, this might be a Map in which case pass in the original cls so that we # can do Map[T, V] on it, which is not allowed on the runtime class cls_type_and_name=( - RuntimeClass(self._mod_decls, cls_name) if cls_name in {"i64", "String"} else cls, + RuntimeClass(decls, cls_name) if cls_name in {"i64", "String"} else cls, cls_name, ), unextractable=unextractable, ) - - return RuntimeClass(self._mod_decls, cls_name) + # self._process_commands(decls.list_cmds()) + return runtime_class # We seperate the function and method overloads to make it simpler to know if we are modifying a function or method, # So that we can add the functions eagerly to the registry and wait on the methods till we process the class. @@ -379,8 +395,10 @@ def _function( Uncurried version of function decorator """ name = fn.__name__ + decls = Declarations() # Save function decleartion self._register_function( + decls, FunctionRef(name), egg_fn, fn, @@ -392,11 +410,13 @@ def _function( mutates_first_arg, unextractable=unextractable, ) + # self._process_commands(decls.list_cmds()) # Return a runtime function which will act like the decleration - return RuntimeFunction(self._mod_decls, name) + return RuntimeFunction(decls, name) - def _register_function( # noqa: C901, PLR0912 + def _register_function( self, + decls: Declarations, ref: FunctionCallableRef, egg_name: str | None, fn: object, @@ -412,6 +432,7 @@ def _register_function( # noqa: C901, PLR0912 first_arg: Literal["cls"] | TypeOrVarRef | None = None, cls_typevars: list[TypeVar] | None = None, is_init: bool = False, + # We need this for very weird case around typevar identity, I forget the details :( cls_type_and_name: tuple[type | RuntimeClass, str] | None = None, unextractable: bool = False, ) -> None: @@ -427,42 +448,32 @@ def _register_function( # noqa: C901, PLR0912 hints = get_type_hints(fn, hint_globals, hint_locals) params = list(signature(fn).parameters.values()) - arg_names = tuple(t.name for t in params) - arg_defaults = tuple(expr_parts(p.default).expr if p.default is not Parameter.empty else None for p in params) + # If this is an init function, or a classmethod, remove the first arg name if is_init or first_arg == "cls": - arg_names = arg_names[1:] - arg_defaults = arg_defaults[1:] - # Remove first arg if this is a classmethod or a method, since it won't have an annotation - if first_arg is not None: - first, *params = params - if first.annotation != Parameter.empty: - raise ValueError(f"First arg of a method must not have an annotation, not {first.annotation}") - - # Check that all the params are positional or keyword, and that there is only one var arg at the end - found_var_arg = False - for param in params: - if found_var_arg: - msg = "Can only have a single var arg at the end" - raise ValueError(msg) - kind = param.kind - if kind == Parameter.VAR_POSITIONAL: - found_var_arg = True - elif kind != Parameter.POSITIONAL_OR_KEYWORD: - raise ValueError(f"Can only register functions with positional or keyword args, not {param.kind}") + params = params[1:] - if found_var_arg: + if _last_param_variable(params): *params, var_arg_param = params # For now, we don't use the variable arg name - arg_names = arg_names[:-1] - arg_defaults = arg_defaults[:-1] - var_arg_type = self._resolve_type_annotation(hints[var_arg_param.name], cls_typevars, cls_type_and_name) + var_arg_type = _resolve_type_annotation(decls, hints[var_arg_param.name], cls_typevars, cls_type_and_name) else: var_arg_type = None - arg_types = tuple(self._resolve_type_annotation(hints[t.name], cls_typevars, cls_type_and_name) for t in params) - # If the first arg is a self, and this not an __init__ fn, add this as a typeref - if isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars) and not is_init: - arg_types = (first_arg, *arg_types) + arg_types = tuple( + first_arg + # If the first arg is a self, and this not an __init__ fn, add this as a typeref + if i == 0 and isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars) and not is_init + else _resolve_type_annotation(decls, hints[t.name], cls_typevars, cls_type_and_name) + for i, t in enumerate(params) + ) + + # Resolve all default values as arg types + arg_defaults = [ + _resolve_literal(t, p.default) if p.default is not Parameter.empty else None + for (t, p) in zip(arg_types, params, strict=True) + ] + + decls.update(*arg_defaults) # If this is an init fn use the first arg as the return type if is_init: @@ -474,94 +485,52 @@ def _register_function( # noqa: C901, PLR0912 elif mutates_first_arg: return_type = arg_types[0] else: - return_type = self._resolve_type_annotation(hints["return"], cls_typevars, cls_type_and_name) + return_type = _resolve_type_annotation(decls, hints["return"], cls_typevars, cls_type_and_name) - default_decl = None if default is None else default.__egg_typed_expr__.expr - merge_decl = ( + decls |= default + merged = ( None if merge is None else merge( - RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))), - RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))), - ).__egg_typed_expr__.expr + RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))), + RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))), + ) ) + decls |= merged + merge_action = ( [] if on_merge is None else _action_likes( on_merge( - RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))), - RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))), + RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))), + RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))), ) ) ) + decls.update(*merge_action) fn_decl = FunctionDecl( return_type=return_type, var_arg_type=var_arg_type, arg_types=arg_types, - arg_names=arg_names, - arg_defaults=arg_defaults, + arg_names=tuple(t.name for t in params), + arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults), mutates_first_arg=mutates_first_arg, ) - self._process_commands( - self._mod_decls.register_function_callable( - ref, - fn_decl, - egg_name, - cost, - default_decl, - merge_decl, - [a._to_egg_action(self._mod_decls) for a in merge_action], - unextractable, - ) + decls.register_function_callable( + ref, + fn_decl, + egg_name, + cost, + None if default is None else default.__egg_typed_expr__.expr, + merged.__egg_typed_expr__.expr if merged is not None else None, + [a._to_egg_action() for a in merge_action], + unextractable, + self.is_builtin, ) - def _resolve_type_annotation( - self, - tp: object, - cls_typevars: list[TypeVar], - cls_type_and_name: tuple[type | RuntimeClass, str] | None, - ) -> TypeOrVarRef: - if isinstance(tp, TypeVar): - return ClassTypeVarRef(cls_typevars.index(tp)) - # If there is a union, then we assume the first item is the type we want, and the others are types that can be converted to that type. - if get_origin(tp) == Union: - first, *_rest = get_args(tp) - return self._resolve_type_annotation(first, cls_typevars, cls_type_and_name) - # If the type is `object` then this is assumed to be a PyObjetLike, i.e. converted into a PyObject - if tp == object: - return TypeRefWithVars("PyObject") - # from .builtins import PyObject - - # tp = PyObject - # If this is the type for the class, use the class name - if cls_type_and_name and tp == cls_type_and_name[0]: - return TypeRefWithVars(cls_type_and_name[1]) - - # If this is the class for this method and we have a paramaterized class, recurse - if cls_type_and_name and isinstance(tp, _GenericAlias) and tp.__origin__ == cls_type_and_name[0]: - return TypeRefWithVars( - cls_type_and_name[1], - tuple(self._resolve_type_annotation(a, cls_typevars, cls_type_and_name) for a in tp.__args__), - ) - - if isinstance(tp, RuntimeClass | RuntimeParamaterizedClass): - return class_to_ref(tp).to_var() - raise TypeError(f"Unexpected type annotation {tp}") - - def register(self, command_or_generator: CommandLike | CommandGenerator, *commands: CommandLike) -> None: - """ - Registers any number of rewrites or rules. - """ - if isinstance(command_or_generator, FunctionType): - assert not commands - commands = tuple(_command_generator(command_or_generator)) - else: - commands = (cast(CommandLike, command_or_generator), *commands) - self._process_commands(_command_like(command)._to_egg_command(self._mod_decls) for command in commands) - def ruleset(self, name: str) -> Ruleset: - self._process_commands([bindings.AddRuleset(name)]) + # self._process_commands([bindings.AddRuleset(name)]) return Ruleset(name) # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value @@ -591,11 +560,13 @@ def relation(self, name: str, /, *tps: type, egg_fn: str | None = None) -> Calla """ Defines a relation, which is the same as a function which returns unit. """ - arg_types = tuple(self._resolve_type_annotation(cast(object, tp), [], None) for tp in tps) + decls = Declarations() + decls |= cast(RuntimeClass, Unit) + arg_types = tuple(_resolve_type_annotation(decls, cast(object, tp), [], None) for tp in tps) fn_decl = FunctionDecl( arg_types, None, tuple(None for _ in tps), TypeRefWithVars("Unit"), mutates_first_arg=False ) - commands = self._mod_decls.register_function_callable( + decls.register_function_callable( FunctionRef(name), fn_decl, egg_fn, @@ -604,30 +575,28 @@ def relation(self, name: str, /, *tps: type, egg_fn: str | None = None) -> Calla merge=None, merge_action=[], unextractable=False, + builtin=False, is_relation=True, ) - self._process_commands(commands) - return cast(Callable[..., Unit], RuntimeFunction(self._mod_decls, name)) - - def input(self, fn: Callable[..., String], path: str) -> None: - """ - Loads a CSV file and sets it as *input, output of the function. - """ - fn_name = self._mod_decls.get_egg_fn(_resolve_callable(fn)) - self._process_commands([bindings.Input(fn_name, path)]) + # self._process_commands(decls.list_cmds()) + return cast(Callable[..., Unit], RuntimeFunction(decls, name)) def constant(self, name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR: """ + Defines a named constant of a certain type. This is the same as defining a nullary function with a high cost. + # TODO: Rename as declare to match eggglog? """ ref = ConstantRef(name) - type_ref = self._register_constant(ref, tp, egg_name, None) - return cast(EXPR, RuntimeExpr(self._mod_decls, TypedExprDecl(type_ref, CallDecl(ref)))) + decls = Declarations() + type_ref = self._register_constant(decls, ref, tp, egg_name, None) + return cast(EXPR, RuntimeExpr(decls, TypedExprDecl(type_ref, CallDecl(ref)))) def _register_constant( self, + decls: Declarations, ref: ConstantRef | ClassVariableRef, tp: object, egg_name: str | None, @@ -636,100 +605,179 @@ def _register_constant( """ Register a constant, returning its typeref(). """ - type_ref = self._resolve_type_annotation(tp, [], cls_type_and_name).to_just() - self._process_commands(self._mod_decls.register_constant_callable(ref, type_ref, egg_name)) + type_ref = _resolve_type_annotation(decls, tp, [], cls_type_and_name).to_just() + decls.register_constant_callable(ref, type_ref, egg_name) return type_ref - def let(self, name: str, expr: EXPR) -> EXPR: + def register(self, command_or_generator: CommandLike | CommandGenerator, *command_likes: CommandLike) -> None: """ - Define a new expression in the egraph and return a reference to it. + Registers any number of rewrites or rules. """ - typed_expr = expr_parts(expr) - self._process_commands([bindings.ActionCommand(bindings.Let(name, typed_expr.to_egg(self._mod_decls)))]) - return cast(EXPR, RuntimeExpr(self._mod_decls, TypedExprDecl(typed_expr.tp, VarDecl(name)))) + if isinstance(command_or_generator, FunctionType): + assert not command_likes + command_likes = tuple(_command_generator(command_or_generator)) + else: + command_likes = (cast(CommandLike, command_or_generator), *command_likes) + + self._register_commands(list(map(_command_like, command_likes))) + + @abstractmethod + def _register_commands(self, cmds: list[Command]) -> None: + raise NotImplementedError + + +def _last_param_variable(params: list[Parameter]) -> bool: + """ + Checks if the last paramater is a variable arg. + + Raises an error if any of the other params are not positional or keyword. + """ + found_var_arg = False + for param in params: + if found_var_arg: + msg = "Can only have a single var arg at the end" + raise ValueError(msg) + kind = param.kind + if kind == Parameter.VAR_POSITIONAL: + found_var_arg = True + elif kind != Parameter.POSITIONAL_OR_KEYWORD: + raise ValueError(f"Can only register functions with positional or keyword args, not {param.kind}") + return found_var_arg + + +def _resolve_type_annotation( + decls: Declarations, + tp: object, + cls_typevars: list[TypeVar], + cls_type_and_name: tuple[type | RuntimeClass, str] | None, +) -> TypeOrVarRef: + """ + Resolves a type object into a type reference. + + The cls_typevars should be a list of type variables that were defined for that type, i.e. for class Dict(Generic[T, K]), they would be [T, K] + + The cls_type_and_name is the type of the current class being traversed, in case it hasn't been added yet. + """ + if isinstance(tp, TypeVar): + return ClassTypeVarRef(cls_typevars.index(tp)) + # If there is a union, then we assume the first item is the type we want, and the others are types that can be converted to that type. + if get_origin(tp) == Union: + first, *_rest = get_args(tp) + return _resolve_type_annotation(decls, first, cls_typevars, cls_type_and_name) + # If this is the type for the class, use the class name + if cls_type_and_name and tp == cls_type_and_name[0]: + return TypeRefWithVars(cls_type_and_name[1]) + + # If the type is `object` then this is assumed to be a PyObjetLike, i.e. converted into a PyObject + if tp == object: + assert _PY_OBJECT_CLASS + return _resolve_type_annotation(decls, _PY_OBJECT_CLASS, [], None) + + # If this is the class for this method and we have a paramaterized class, recurse + if cls_type_and_name and isinstance(tp, _GenericAlias) and tp.__origin__ == cls_type_and_name[0]: + return TypeRefWithVars( + cls_type_and_name[1], + tuple(_resolve_type_annotation(decls, a, cls_typevars, cls_type_and_name) for a in tp.__args__), + ) + + if isinstance(tp, RuntimeClass | RuntimeParamaterizedClass): + decls |= tp + return class_to_ref(tp).to_var() + raise TypeError(f"Unexpected type annotation {tp}") @dataclass class _Builtins(_BaseModule): - def __post_init__(self, modules: list[Module]) -> None: - """ - Register these declarations as builtins, so others can use them. - """ - assert not modules - super().__post_init__(modules) - global _BUILTIN_DECLS - if _BUILTIN_DECLS is not None: - msg = "Builtins already initialized" - raise RuntimeError(msg) - _BUILTIN_DECLS = self._mod_decls._decl - # Register != operator - _BUILTIN_DECLS.register_callable_ref(FunctionRef("!="), "!=") - - def _process_commands(self, cmds: Iterable[bindings._Command]) -> None: - """ - Commands which would have been used to create the builtins are discarded, since they are already registered. - """ + is_builtin: ClassVar[bool] = True + # def __post_init__(self, modules: list[Module]) -> None: + # """ + # Register these declarations as builtins, so others can use them. + # """ + # assert not modules + # super().__post_init__(modules) + # global _BUILTIN_DECLS + # if _BUILTIN_DECLS is not None: + # msg = "Builtins already initialized" + # raise RuntimeError(msg) + # _BUILTIN_DECLS = self._mod_decls._decl + # # Register != operator + # _BUILTIN_DECLS.register_callable_ref(FunctionRef("!="), "!=") + + # def _process_commands(self, cmds: Iterable[bindings._Command]) -> None: + # """ + # Commands which would have been used to create the builtins are discarded, since they are already registered. + # """ + def _register_commands(self, cmds: list[Command]) -> None: + raise NotImplementedError @dataclass class Module(_BaseModule): - _cmds: list[bindings._Command] = field(default_factory=list, repr=False) + cmds: list[Command] = field(default_factory=list) - @property - def as_egglog_string(self) -> str: - """ - Returns the egglog string for this module. - """ - return "\n".join(str(c) for c in self._cmds) + def _register_commands(self, cmds: list[Command]) -> None: + self.cmds.extend(cmds) - def _process_commands(self, cmds: Iterable[bindings._Command]) -> None: - self._cmds.extend(cmds) + def without_rules(self) -> Module: + return Module() - def unextractable(self) -> Module: - """ - Makes a copy of this module with all functions marked as un-extractable - """ - return self._map_functions( - lambda decl: bindings.FunctionDecl( - decl.name, - decl.schema, - decl.default, - decl.merge, - decl.merge_action, - decl.cost, - True, - ) - ) + # _cmds: list[bindings._Command] = field(default_factory=list, repr=False) - def increase_cost(self, x: int = 10000000) -> Module: - """ - Make a copy of this module with all function costs increased by x - """ - return self._map_functions( - lambda decl, x=x: bindings.FunctionDecl( # type: ignore[misc] - decl.name, - decl.schema, - decl.default, - decl.merge, - decl.merge_action, - (decl.cost or 1) + x, - decl.unextractable, - ) - ) + # @property + # def as_egglog_string(self) -> str: + # """ + # Returns the egglog string for this module. + # """ + # return "\n".join(str(c) for c in self._cmds) - def without_rules(self) -> Module: - """ - Makes a copy of this module with all rules removed. - """ - new = copy(self) - new._cmds = [ - c - for c in new._cmds - if not isinstance(c, bindings.RuleCommand) - and not isinstance(c, bindings.RewriteCommand) - and not isinstance(c, bindings.BiRewriteCommand) - ] - return new + # def _process_commands(self, cmds: Iterable[bindings._Command]) -> None: + # self._cmds.extend(cmds) + + # def unextractable(self) -> Module: + # """ + # Makes a copy of this module with all functions marked as un-extractable + # """ + # return self._map_functions( + # lambda decl: bindings.FunctionDecl( + # decl.name, + # decl.schema, + # decl.default, + # decl.merge, + # decl.merge_action, + # decl.cost, + # True, + # ) + # ) + + # def increase_cost(self, x: int = 10000000) -> Module: + # """ + # Make a copy of this module with all function costs increased by x + # """ + # return self._map_functions( + # lambda decl, x=x: bindings.FunctionDecl( # type: ignore[misc] + # decl.name, + # decl.schema, + # decl.default, + # decl.merge, + # decl.merge_action, + # (decl.cost or 1) + x, + # decl.unextractable, + # ) + # ) + + # def without_rules(self) -> Module: + # """ + # Makes a copy of this module with all rules removed. + # """ + # new = copy(self) + # new._cmds = [ + # c + # for c in new._cmds + # if not isinstance(c, bindings.RuleCommand) + # and not isinstance(c, bindings.RewriteCommand) + # and not isinstance(c, bindings.BiRewriteCommand) + # ] + # return new # def rename_ruleset(self, new_r: str) -> Module: # """ @@ -749,13 +797,13 @@ def without_rules(self) -> Module: # new._cmds.insert(0, bindings.AddRuleset(new_r)) # return new - def _map_functions(self, fn: Callable[[bindings.FunctionDecl], bindings.FunctionDecl]) -> Module: - """ - Returns a copy where all the functions have been mapped with the given function. - """ - new = copy(self) - new._cmds = [bindings.Function(fn(c.decl)) if isinstance(c, bindings.Function) else c for c in new._cmds] - return new + # def _map_functions(self, fn: Callable[[bindings.FunctionDecl], bindings.FunctionDecl]) -> Module: + # """ + # Returns a copy where all the functions have been mapped with the given function. + # """ + # new = copy(self) + # new._cmds = [bindings.Function(fn(c.decl)) if isinstance(c, bindings.Function) else c for c in new._cmds] + # return new class GraphvizKwargs(TypedDict, total=False): @@ -765,6 +813,32 @@ class GraphvizKwargs(TypedDict, total=False): split_primitive_outputs: bool +@dataclass +class _EGraphState: + """ + State of the EGraph declerations, so we know what to + """ + + # The decleratons we have added. The _cmds represent all the symbols we have added + decls: Declarations = field(default_factory=Declarations) + # List of rulesets already added, so we don't re-add them if they are passed again + + added_rulesets: set[str] = field(default_factory=set) + + def add_decls(self, new_decls: Declarations) -> Iterable[bindings._Command]: + new_cmds = [v for k, v in new_decls._cmds.items() if k not in self.decls._cmds] + self.decls |= new_decls + return new_cmds + + def add_rulesets(self, new_rulesets: dict[str, list[bindings._Command]]) -> Iterable[bindings._Command]: + new_cmds = [] + for k, v in new_rulesets.items(): + if k not in self.added_rulesets: + new_cmds.extend(v) + self.added_rulesets.add(k) + return new_cmds + + @dataclass class EGraph(_BaseModule): """ @@ -775,25 +849,44 @@ class EGraph(_BaseModule): save_egglog_string: InitVar[bool] = False _egraph: bindings.EGraph = field(repr=False, init=False) - # The current declarations which have been pushed to the stack - _decl_stack: list[Declarations] = field(default_factory=list, repr=False) - _token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False) _egglog_string: str | None = field(default=None, repr=False, init=False) + _state: _EGraphState = field(default_factory=_EGraphState, repr=False) + # For pushing/popping with egglog + _state_stack: list[_EGraphState] = field(default_factory=list, repr=False) + # For storing the global "current" egraph + _token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False) def __post_init__(self, modules: list[Module], seminaive: bool, save_egglog_string: bool) -> None: super().__post_init__(modules) self._egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive) for m in self._flatted_deps: - self._process_commands(m._cmds) + self._add_decls(*m.cmds) + self._register_commands(m.cmds) if save_egglog_string: self._egglog_string = "" + def _register_commands(self, commands: list[Command]) -> None: + for c in commands: + if c.ruleset: + self._add_schedule(run(c.ruleset)) + + self._add_decls(*commands) + self._process_commands(command._to_egg_command() for command in commands) + def _process_commands(self, commands: Iterable[bindings._Command]) -> None: commands = list(commands) self._egraph.run_program(*commands) if isinstance(self._egglog_string, str): self._egglog_string += "\n".join(str(c) for c in commands) + "\n" + def _add_decls(self, *decls: DeclerationsLike) -> None: + for d in upcast_decleratioons(decls): + self._process_commands(self._state.add_decls(d)) + + def _add_schedule(self, schedule: Schedule) -> None: + self._add_decls(schedule) + self._process_commands(self._state.add_rulesets(schedule._rulesets())) + @property def as_egglog_string(self) -> str: """ @@ -815,7 +908,7 @@ def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source: kwargs.setdefault("split_primitive_outputs", True) n_inline = kwargs.pop("n_inline_leaves", 0) serialized = self._egraph.serialize(**kwargs) # type: ignore[misc] - serialized.map_ops(self._mod_decls.op_mapping()) + serialized.map_ops(self._state.decls.op_mapping()) for _ in range(n_inline): serialized.inline_leaves() original = serialized.to_dot() @@ -872,6 +965,23 @@ def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None: else: graphviz.render(view=True, format="svg", quiet=True) + def input(self, fn: Callable[..., String], path: str) -> None: + """ + Loads a CSV file and sets it as *input, output of the function. + """ + ref, decls = _resolve_callable(fn) + fn_name = decls.get_egg_fn(ref) + self._process_commands(decls.list_cmds()) + self._process_commands([bindings.Input(fn_name, path)]) + + def let(self, name: str, expr: EXPR) -> EXPR: + """ + Define a new expression in the egraph and return a reference to it. + """ + self._register_commands([let(name, expr)]) + expr = to_runtime_expr(expr) + return cast(EXPR, RuntimeExpr(expr.__egg_decls__, TypedExprDecl(expr.__egg_typed_expr__.tp, VarDecl(name)))) + @overload def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ... @@ -886,19 +996,22 @@ def simplify( """ Simplifies the given expression. """ - if isinstance(limit_or_schedule, int): - limit_or_schedule = run(ruleset, *until) * limit_or_schedule - typed_expr = expr_parts(expr) - egg_expr = typed_expr.to_egg(self._mod_decls) - self._process_commands([bindings.Simplify(egg_expr, limit_or_schedule._to_egg_schedule(self._mod_decls))]) + schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule + del limit_or_schedule + expr = to_runtime_expr(expr) + self._add_decls(expr) + self._add_schedule(schedule) + + # decls = Declarations.create(expr, schedule) + self._process_commands([bindings.Simplify(expr.__egg__, schedule._to_egg_schedule())]) extract_report = self._egraph.extract_report() if not isinstance(extract_report, bindings.Best): msg = "No extract report saved" raise ValueError(msg) # noqa: TRY004 new_typed_expr = TypedExprDecl.from_egg( - self._egraph, self._mod_decls, bindings.termdag_term_to_expr(extract_report.termdag, extract_report.term) + self._egraph, self._state.decls, bindings.termdag_term_to_expr(extract_report.termdag, extract_report.term) ) - return cast(EXPR, RuntimeExpr(self._mod_decls, new_typed_expr)) + return cast(EXPR, RuntimeExpr(self._state.decls.copy(), new_typed_expr)) def include(self, path: str) -> None: """ @@ -930,7 +1043,8 @@ def run( return self._run_schedule(limit_or_schedule) def _run_schedule(self, schedule: Schedule) -> bindings.RunReport: - self._process_commands([bindings.RunSchedule(schedule._to_egg_schedule(self._mod_decls))]) + self._add_schedule(schedule) + self._process_commands([bindings.RunSchedule(schedule._to_egg_schedule())]) run_report = self._egraph.run_report() if not run_report: msg = "No run report saved" @@ -950,7 +1064,9 @@ def check_fail(self, *facts: FactLike) -> None: self._process_commands([bindings.Fail(self._facts_to_check(facts))]) def _facts_to_check(self, facts: Iterable[FactLike]) -> bindings.Check: - egg_facts = [f._to_egg_fact(self._mod_decls) for f in _fact_likes(facts)] + facts = _fact_likes(facts) + self._add_decls(*facts) + egg_facts = [f._to_egg_fact() for f in _fact_likes(facts)] return bindings.Check(egg_facts) @overload @@ -965,18 +1081,16 @@ def extract(self, expr: EXPR, include_cost: bool = False) -> EXPR | tuple[EXPR, """ Extract the lowest cost expression from the egraph. """ - typed_expr = expr_parts(expr) - egg_expr = typed_expr.to_egg(self._mod_decls) - extract_report = self._run_extract(egg_expr, 0) + assert isinstance(expr, RuntimeExpr) + self._add_decls(expr) + extract_report = self._run_extract(expr.__egg__, 0) if not isinstance(extract_report, bindings.Best): msg = "No extract report saved" raise ValueError(msg) # noqa: TRY004 new_typed_expr = TypedExprDecl.from_egg( - self._egraph, self._mod_decls, bindings.termdag_term_to_expr(extract_report.termdag, extract_report.term) + self._egraph, self._state.decls, bindings.termdag_term_to_expr(extract_report.termdag, extract_report.term) ) - if new_typed_expr.tp != typed_expr.tp: - raise RuntimeError(f"Type mismatch: {new_typed_expr.tp} != {typed_expr.tp}") - res = cast(EXPR, RuntimeExpr(self._mod_decls, new_typed_expr)) + res = cast(EXPR, RuntimeExpr(self._state.decls.copy(), new_typed_expr)) if include_cost: return res, extract_report.cost return res @@ -985,19 +1099,20 @@ def extract_multiple(self, expr: EXPR, n: int) -> list[EXPR]: """ Extract multiple expressions from the egraph. """ - typed_expr = expr_parts(expr) - egg_expr = typed_expr.to_egg(self._mod_decls) - extract_report = self._run_extract(egg_expr, n) + assert isinstance(expr, RuntimeExpr) + self._add_decls(expr) + + extract_report = self._run_extract(expr.__egg__, n) if not isinstance(extract_report, bindings.Variants): msg = "Wrong extract report type" raise ValueError(msg) # noqa: TRY004 new_exprs = [ TypedExprDecl.from_egg( - self._egraph, self._mod_decls, bindings.termdag_term_to_expr(extract_report.termdag, term) + self._egraph, self._state.decls, bindings.termdag_term_to_expr(extract_report.termdag, term) ) for term in extract_report.terms ] - return [cast(EXPR, RuntimeExpr(self._mod_decls, expr)) for expr in new_exprs] + return [cast(EXPR, RuntimeExpr(self._state.decls.copy(), expr)) for expr in new_exprs] def _run_extract(self, expr: bindings._Expr, n: int) -> bindings._ExtractReport: self._process_commands([bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n))))]) @@ -1012,15 +1127,15 @@ def push(self) -> None: Push the current state of the egraph, so that it can be popped later and reverted back. """ self._process_commands([bindings.Push(1)]) - self._decl_stack.append(self._mod_decls._decl) - self._decls = deepcopy(self._mod_decls._decl) + self._state_stack.append(self._state) + self._state = deepcopy(self._state) def pop(self) -> None: """ Pop the current state of the egraph, reverting back to the previous state. """ self._process_commands([bindings.Pop(1)]) - self._mod_decls._decl = self._decl_stack.pop() + self._state = self._state_stack.pop() def __enter__(self) -> Self: """ @@ -1060,8 +1175,9 @@ def eval(self, expr: Expr) -> object: """ Evaluates the given expression (which must be a primitive type), returning the result. """ - typed_expr = expr_parts(expr) - egg_expr = typed_expr.to_egg(self._mod_decls) + assert isinstance(expr, RuntimeExpr) + typed_expr = expr.__egg_typed_expr__ + egg_expr = expr.__egg__ match typed_expr.tp: case JustTypeRef("i64"): return self._egraph.eval_i64(egg_expr) @@ -1144,6 +1260,18 @@ class _ExprMetaclass(type): Used to override isistance checks, so that runtime expressions are instances of Expr at runtime. """ + # def __new__( + # cls: type[_ExprMetaclass], + # name: str, + # bases: tuple[type, ...], + # namespace: dict[str, Any], + # egg_name: str | None = None, + # ) -> Self: + # for attr_name, attr_value in attrs.items(): + # if isinstance(attr_value, _WrappedMethod): + # attrs[attr_name] = attr_value.fn + # return super().__new__(cls, name, bases, attrs) + def __instancecheck__(cls, instance: object) -> bool: return isinstance(instance, RuntimeExpr) @@ -1173,9 +1301,22 @@ def __init__(self) -> None: ... -@dataclass(frozen=True) +@dataclass class Ruleset: name: str + __egg_decls__: Declarations = field(default_factory=Declarations, repr=False) + _cmds: list[bindings._Command] = field(default_factory=list, repr=False) + + def __post_init__(self) -> None: + if self.name: + self._cmds.append(bindings.AddRuleset(self.name)) + + def append(self, rule: Rule | Rewrite) -> None: + """ + Register a rule with the ruleset. + """ + self._cmds.append(rule._to_egg_command()) + self.__egg_decls__ |= rule def _ruleset_name(ruleset: Ruleset | None) -> str: @@ -1191,8 +1332,15 @@ class Command(ABC): Anything that can be passed to the `register` function in a Module is a Command. """ + ruleset: Ruleset | None + + @property + @abstractmethod + def __egg_decls__(self) -> Declarations: + raise NotImplementedError + @abstractmethod - def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command: + def _to_egg_command(self) -> bindings._Command: raise NotImplementedError @abstractmethod @@ -1202,7 +1350,7 @@ def __str__(self) -> str: @dataclass class Rewrite(Command): - _ruleset: str + ruleset: Ruleset | None _lhs: RuntimeExpr _rhs: RuntimeExpr _conditions: tuple[Fact, ...] @@ -1212,23 +1360,27 @@ def __str__(self) -> str: args_str = ", ".join(map(str, [self._rhs, *self._conditions])) return f"{self._fn_name}({self._lhs}).to({args_str})" - def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command: - return bindings.RewriteCommand(self._ruleset, self._to_egg_rewrite(mod_decls)) + def _to_egg_command(self) -> bindings._Command: + return bindings.RewriteCommand(_ruleset_name(self.ruleset), self._to_egg_rewrite()) - def _to_egg_rewrite(self, mod_decls: ModuleDeclarations) -> bindings.Rewrite: + def _to_egg_rewrite(self) -> bindings.Rewrite: return bindings.Rewrite( - self._lhs.__egg_typed_expr__.expr.to_egg(mod_decls), - self._rhs.__egg_typed_expr__.expr.to_egg(mod_decls), - [c._to_egg_fact(mod_decls) for c in self._conditions], + self._lhs.__egg_typed_expr__.expr.to_egg(self._lhs.__egg_decls__), + self._rhs.__egg_typed_expr__.expr.to_egg(self._rhs.__egg_decls__), + [c._to_egg_fact() for c in self._conditions], ) + @property + def __egg_decls__(self) -> Declarations: + return Declarations.create(self._lhs, self._rhs, *self._conditions) + @dataclass class BiRewrite(Rewrite): _fn_name: ClassVar[str] = "birewrite" - def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command: - return bindings.BiRewriteCommand(self._ruleset, self._to_egg_rewrite(mod_decls)) + def _to_egg_command(self) -> bindings._Command: + return bindings.BiRewriteCommand(_ruleset_name(self.ruleset), self._to_egg_rewrite()) @dataclass @@ -1238,7 +1390,12 @@ class Fact(ABC): """ @abstractmethod - def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings._Fact: + def _to_egg_fact(self) -> bindings._Fact: + raise NotImplementedError + + @property + @abstractmethod + def __egg_decls__(self) -> Declarations: raise NotImplementedError @@ -1251,8 +1408,12 @@ def __str__(self) -> str: args_str = ", ".join(map(str, rest)) return f"eq({first}).to({args_str})" - def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings.Eq: - return bindings.Eq([e.__egg_typed_expr__.expr.to_egg(mod_decls) for e in self._exprs]) + def _to_egg_fact(self) -> bindings.Eq: + return bindings.Eq([e.__egg__ for e in self._exprs]) + + @property + def __egg_decls__(self) -> Declarations: + return Declarations.create(*self._exprs) @dataclass @@ -1262,8 +1423,12 @@ class ExprFact(Fact): def __str__(self) -> str: return str(self._expr) - def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings.Fact: - return bindings.Fact(self._expr.__egg_typed_expr__.expr.to_egg(mod_decls)) + def _to_egg_fact(self) -> bindings.Fact: + return bindings.Fact(self._expr.__egg__) + + @property + def __egg_decls__(self) -> Declarations: + return self._expr.__egg_decls__ @dataclass @@ -1271,31 +1436,39 @@ class Rule(Command): head: tuple[Action, ...] body: tuple[Fact, ...] name: str - ruleset: str + ruleset: Ruleset | None def __str__(self) -> str: head_str = ", ".join(map(str, self.head)) body_str = ", ".join(map(str, self.body)) return f"rule({body_str}).then({head_str})" - def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings.RuleCommand: + def _to_egg_command(self) -> bindings.RuleCommand: return bindings.RuleCommand( self.name, - self.ruleset, + _ruleset_name(self.ruleset), bindings.Rule( - [a._to_egg_action(mod_decls) for a in self.head], - [f._to_egg_fact(mod_decls) for f in self.body], + [a._to_egg_action() for a in self.head], + [f._to_egg_fact() for f in self.body], ), ) + @property + def __egg_decls__(self) -> Declarations: + return Declarations.create(*self.head, *self.body) + class Action(Command, ABC): @abstractmethod - def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings._Action: + def _to_egg_action(self) -> bindings._Action: raise NotImplementedError - def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command: - return bindings.ActionCommand(self._to_egg_action(mod_decls)) + def _to_egg_command(self) -> bindings._Command: + return bindings.ActionCommand(self._to_egg_action()) + + @property + def ruleset(self) -> None: + return None @dataclass @@ -1306,8 +1479,12 @@ class Let(Action): def __str__(self) -> str: return f"let({self._name}, {self._value})" - def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Let: - return bindings.Let(self._name, self._value.__egg_typed_expr__.expr.to_egg(mod_decls)) + def _to_egg_action(self) -> bindings.Let: + return bindings.Let(self._name, self._value.__egg__) + + @property + def __egg_decls__(self) -> Declarations: + return self._value.__egg_decls__ @dataclass @@ -1318,16 +1495,20 @@ class Set(Action): def __str__(self) -> str: return f"set({self._call}).to({self._rhs})" - def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Set: - egg_call = self._call.__egg_typed_expr__.expr.to_egg(mod_decls) + def _to_egg_action(self) -> bindings.Set: + egg_call = self._call.__egg__ if not isinstance(egg_call, bindings.Call): raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}") # noqa: TRY004 return bindings.Set( egg_call.name, egg_call.args, - self._rhs.__egg_typed_expr__.expr.to_egg(mod_decls), + self._rhs.__egg__, ) + @property + def __egg_decls__(self) -> Declarations: + return Declarations.create(self._call, self._rhs) + @dataclass class ExprAction(Action): @@ -1336,8 +1517,12 @@ class ExprAction(Action): def __str__(self) -> str: return str(self._expr) - def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Expr_: - return bindings.Expr_(self._expr.__egg_typed_expr__.expr.to_egg(mod_decls)) + def _to_egg_action(self) -> bindings.Expr_: + return bindings.Expr_(self._expr.__egg__) + + @property + def __egg_decls__(self) -> Declarations: + return self._expr.__egg_decls__ @dataclass @@ -1347,12 +1532,16 @@ class Delete(Action): def __str__(self) -> str: return f"delete({self._call})" - def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Delete: - egg_call = self._call.__egg_typed_expr__.expr.to_egg(mod_decls) + def _to_egg_action(self) -> bindings.Delete: + egg_call = self._call.__egg__ if not isinstance(egg_call, bindings.Call): raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") # noqa: TRY004 return bindings.Delete(egg_call.name, egg_call.args) + @property + def __egg_decls__(self) -> Declarations: + return self._call.__egg_decls__ + @dataclass class Union_(Action): # noqa: N801 @@ -1362,10 +1551,12 @@ class Union_(Action): # noqa: N801 def __str__(self) -> str: return f"union({self._lhs}).with_({self._rhs})" - def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Union: - return bindings.Union( - self._lhs.__egg_typed_expr__.expr.to_egg(mod_decls), self._rhs.__egg_typed_expr__.expr.to_egg(mod_decls) - ) + def _to_egg_action(self) -> bindings.Union: + return bindings.Union(self._lhs.__egg__, self._rhs.__egg__) + + @property + def __egg_decls__(self) -> Declarations: + return Declarations.create(self._lhs, self._rhs) @dataclass @@ -1375,9 +1566,13 @@ class Panic(Action): def __str__(self) -> str: return f"panic({self.message})" - def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Panic: + def _to_egg_action(self) -> bindings.Panic: return bindings.Panic(self.message) + @property + def __egg_decls__(self) -> Declarations: + return Declarations() + class Schedule(ABC): def __mul__(self, length: int) -> Schedule: @@ -1403,7 +1598,19 @@ def __str__(self) -> str: raise NotImplementedError @abstractmethod - def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule: + def _to_egg_schedule(self) -> bindings._Schedule: + raise NotImplementedError + + @abstractmethod + def _rulesets(self) -> dict[str, list[bindings._Command]]: + """ + Mapping of all the rulesets used to commands. + """ + raise NotImplementedError + + @property + @abstractmethod + def __egg_decls__(self) -> Declarations: raise NotImplementedError @@ -1411,22 +1618,33 @@ def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule: class Run(Schedule): """Configuration of a run""" - ruleset: str + # None if using default ruleset + ruleset: Ruleset | None until: tuple[Fact, ...] def __str__(self) -> str: args_str = ", ".join(map(str, [self.ruleset, *self.until])) return f"run({args_str})" - def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule: - return bindings.Run(self._to_egg_config(mod_decls)) + def _to_egg_schedule(self) -> bindings._Schedule: + return bindings.Run(self._to_egg_config()) - def _to_egg_config(self, mod_decls: ModuleDeclarations) -> bindings.RunConfig: + def _to_egg_config(self) -> bindings.RunConfig: return bindings.RunConfig( - self.ruleset, - [fact._to_egg_fact(mod_decls) for fact in self.until] if self.until else None, + _ruleset_name(self.ruleset), [fact._to_egg_fact() for fact in self.until] if self.until else None ) + def _rulesets(self) -> dict[str, list[bindings._Command]]: + if not self.ruleset: + return {} + return {self.ruleset.name: self.ruleset._cmds} + + @property + def __egg_decls__(self) -> Declarations: + decls = self.ruleset.__egg_decls__.copy() if self.ruleset else Declarations() + decls.update(*self.until) + return decls + @dataclass class Saturate(Schedule): @@ -1435,8 +1653,15 @@ class Saturate(Schedule): def __str__(self) -> str: return f"{self.schedule}.saturate()" - def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule: - return bindings.Saturate(self.schedule._to_egg_schedule(mod_decls)) + def _to_egg_schedule(self) -> bindings._Schedule: + return bindings.Saturate(self.schedule._to_egg_schedule()) + + def _rulesets(self) -> dict[str, list[bindings._Command]]: + return self.schedule._rulesets() + + @property + def __egg_decls__(self) -> Declarations: + return self.schedule.__egg_decls__ @dataclass @@ -1447,8 +1672,15 @@ class Repeat(Schedule): def __str__(self) -> str: return f"{self.schedule} * {self.length}" - def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule: - return bindings.Repeat(self.length, self.schedule._to_egg_schedule(mod_decls)) + def _to_egg_schedule(self) -> bindings._Schedule: + return bindings.Repeat(self.length, self.schedule._to_egg_schedule()) + + def _rulesets(self) -> dict[str, list[bindings._Command]]: + return self.schedule._rulesets() + + @property + def __egg_decls__(self) -> Declarations: + return self.schedule.__egg_decls__ @dataclass @@ -1458,8 +1690,17 @@ class Sequence(Schedule): def __str__(self) -> str: return f"sequence({', '.join(map(str, self.schedules))})" - def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule: - return bindings.Sequence([schedule._to_egg_schedule(mod_decls) for schedule in self.schedules]) + def _to_egg_schedule(self) -> bindings._Schedule: + return bindings.Sequence([schedule._to_egg_schedule() for schedule in self.schedules]) + + def _rulesets(self) -> dict[str, list[bindings._Command]]: + return {k: v for d in self.schedules for k, v in d._rulesets().items()} + + @property + def __egg_decls__(self) -> Declarations: + decls = Declarations() + decls.update(*self.schedules) + return decls # We use these builders so that when creating these structures we can type check @@ -1549,12 +1790,10 @@ class _RewriteBuilder(Generic[EXPR]): def to(self, rhs: EXPR, *conditions: FactLike) -> Command: lhs = to_runtime_expr(self.lhs) - return Rewrite( - _ruleset_name(self.ruleset), - lhs, - convert_to_same_type(rhs, lhs), - _fact_likes(conditions), - ) + rule = Rewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions)) + if self.ruleset: + self.ruleset.append(rule) + return rule def __str__(self) -> str: return f"rewrite({self.lhs})" @@ -1567,12 +1806,10 @@ class _BirewriteBuilder(Generic[EXPR]): def to(self, rhs: EXPR, *conditions: FactLike) -> Command: lhs = to_runtime_expr(self.lhs) - return BiRewrite( - _ruleset_name(self.ruleset), - lhs, - convert_to_same_type(rhs, lhs), - _fact_likes(conditions), - ) + rule = BiRewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions)) + if self.ruleset: + self.ruleset.append(rule) + return rule def __str__(self) -> str: return f"birewrite({self.lhs})" @@ -1595,20 +1832,14 @@ class _NeBuilder(Generic[EXPR]): expr: EXPR def to(self, expr: EXPR) -> Unit: - l_expr = cast(RuntimeExpr, self.expr) - return cast( - Unit, - RuntimeExpr( - BUILTINS._mod_decls, - TypedExprDecl( - JustTypeRef("Unit"), - CallDecl( - FunctionRef("!="), - (l_expr.__egg_typed_expr__, convert_to_same_type(expr, l_expr).__egg_typed_expr__), - ), - ), - ), + assert isinstance(self.expr, RuntimeExpr) + args = (self.expr, convert_to_same_type(expr, self.expr)) + decls = Declarations.create(*args) + res = RuntimeExpr( + decls, + TypedExprDecl(JustTypeRef("Unit"), CallDecl(FunctionRef("!="), tuple(a.__egg_typed_expr__ for a in args))), ) + return cast(Unit, res) def __str__(self) -> str: return f"ne({self.expr})" @@ -1645,7 +1876,10 @@ class _RuleBuilder: ruleset: Ruleset | None def then(self, *actions: ActionLike) -> Command: - return Rule(_action_likes(actions), self.facts, self.name or "", _ruleset_name(self.ruleset)) + rule = Rule(_action_likes(actions), self.facts, self.name or "", self.ruleset) + if self.ruleset: + self.ruleset.append(rule) + return rule def expr_parts(expr: Expr) -> TypedExprDecl: @@ -1667,7 +1901,7 @@ def run(ruleset: Ruleset | None = None, *until: Fact) -> Run: """ Create a run configuration. """ - return Run(_ruleset_name(ruleset), tuple(until)) + return Run(ruleset, tuple(until)) def seq(*schedules: Schedule) -> Schedule: diff --git a/python/egglog/examples/matrix.py b/python/egglog/examples/matrix.py index d96a7dfa..ffbb1264 100644 --- a/python/egglog/examples/matrix.py +++ b/python/egglog/examples/matrix.py @@ -127,17 +127,17 @@ def kron(a: Matrix, b: Matrix) -> Matrix: # type: ignore[empty-body] egraph.register( # demand rows and columns when we multiply matrices rule(eq(C).to(A @ B)).then( - let("demand1", A.ncols()), - let("demand2", A.nrows()), - let("demand3", B.ncols()), - let("demand4", B.nrows()), + A.ncols(), + A.nrows(), + B.ncols(), + B.nrows(), ), # demand rows and columns when we take the kronecker product rule(eq(C).to(kron(A, B))).then( - let("demand1", A.ncols()), - let("demand2", A.nrows()), - let("demand3", B.ncols()), - let("demand4", B.nrows()), + A.ncols(), + A.nrows(), + B.ncols(), + B.nrows(), ), ) diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index a8e933a1..76af0f8f 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -85,10 +85,7 @@ def __eq__(self, other: DType) -> Boolean: # type: ignore[override] converter(type, DType, lambda x: convert(np.dtype(x), DType)) converter(type(np.dtype), DType, lambda x: getattr(DType, x.name)) # type: ignore[call-overload] array_api_module.register( - *( - rewrite(l == r).to(TRUE if expr_parts(l) == expr_parts(r) else FALSE) - for l, r in itertools.product(_DTYPES, repeat=2) - ) + *(rewrite(l == r).to(TRUE if l is r else FALSE) for l, r in itertools.product(_DTYPES, repeat=2)) ) diff --git a/python/egglog/exp/array_api_numba.py b/python/egglog/exp/array_api_numba.py index 4f053e5a..63468b6f 100644 --- a/python/egglog/exp/array_api_numba.py +++ b/python/egglog/exp/array_api_numba.py @@ -5,8 +5,6 @@ from __future__ import annotations -import operator - from egglog import * from egglog.exp.array_api import * @@ -71,30 +69,3 @@ def _unique_inverse(x: NDArray, i: Int): x == NDArray.scalar(unique_values(x).index(TupleInt(i))) ), ] - - -# Inline these changes until this PR is released to add suport for checking dtypes equal -# https://github.com/numba/numba/pull/9249 -try: - from llvmlite import ir - from numba.core import types - from numba.core.imputils import impl_ret_untracked, lower_builtin - from numba.core.typing.templates import AbstractTemplate, infer_global, signature -except ImportError: - pass -else: - - @infer_global(operator.eq) - class DtypeEq(AbstractTemplate): - def generic(self, args, kws): # noqa: ANN201, ANN001 - [lhs, rhs] = args - if isinstance(lhs, types.DType) and isinstance(rhs, types.DType): - return signature(types.boolean, lhs, rhs) - return None - - @lower_builtin(operator.eq, types.DType, types.DType) - def const_eq_impl(context, builder, sig, args): # noqa: ANN201, ANN001 - arg1, arg2 = sig.args - val = 1 if arg1 == arg2 else 0 - res = ir.Constant(ir.IntType(1), val) - return impl_ret_untracked(context, builder, sig.return_type, res) diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index 19f5a9a7..d478b8ab 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -126,7 +126,11 @@ def _py_object(p: Program, expr: String, statements: String, g: PyObject): # When we evaluate a program, we first want to compile to a string yield rule(p.eval_py_object(g)).then(p.compile()) # Then we want to evaluate the statements/expr - yield rule(p.eval_py_object(g), eq(p.statements).to(statements), eq(p.expr).to(expr)).then( + yield rule( + p.eval_py_object(g), + eq(p.statements).to(statements), + eq(p.expr).to(expr), + ).then( set_(p.py_object).to( py_eval( "l['___res']", diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index c0b0bd51..77fa1e37 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -19,7 +19,7 @@ import black.parsing from typing_extensions import assert_never -from . import bindings, config # noqa: F401 +from . import bindings, config from .declarations import * from .declarations import BINARY_METHODS, REFLECTED_BINARY_METHODS, UNARY_METHODS from .type_constraint_solver import * @@ -53,6 +53,8 @@ # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {} +# Global declerations to store all convertable types so we can query if they have certain methods or not +CONVERSIONS_DECLS = Declarations() T = TypeVar("T") V = TypeVar("V", bound="Expr") @@ -128,22 +130,28 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr: def process_tp(tp: type | RuntimeTypeArgType) -> JustTypeRef | type: + global CONVERSIONS_DECLS if isinstance(tp, RuntimeClass | RuntimeParamaterizedClass): + CONVERSIONS_DECLS |= tp return class_to_ref(tp) return tp -def min_convertable_tp(decls: ModuleDeclarations, a: object, b: object, name: str) -> JustTypeRef: +def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef: """ Returns the minimum convertable type between a and b, that has a method `name`, raising a TypeError if no such type exists. """ a_tp = _get_tp(a) b_tp = _get_tp(b) a_converts_to = { - to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name) + to: c + for ((from_, to), (c, _)) in CONVERSIONS.items() + if from_ == a_tp and CONVERSIONS_DECLS.has_method(to.name, name) } b_converts_to = { - to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name) + to: c + for ((from_, to), (c, _)) in CONVERSIONS.items() + if from_ == b_tp and CONVERSIONS_DECLS.has_method(to.name, name) } if isinstance(a_tp, JustTypeRef): a_converts_to[a_tp] = 0 @@ -203,7 +211,7 @@ def _get_tp(x: object) -> JustTypeRef | type: @dataclass class RuntimeClass: - __egg_decls__: ModuleDeclarations + __egg_decls__: Declarations __egg_name__: str def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None: @@ -265,7 +273,7 @@ def __hash__(self) -> int: @dataclass class RuntimeParamaterizedClass: - __egg_decls__: ModuleDeclarations + __egg_decls__: Declarations # Note that this will never be a typevar because we don't use RuntimeParamaterizedClass for maps on their own methods # which is the only time we define function which take typevars __egg_tp__: JustTypeRef @@ -299,7 +307,7 @@ def class_to_ref(cls: RuntimeTypeArgType) -> JustTypeRef: @dataclass class RuntimeFunction: - __egg_decls__: ModuleDeclarations + __egg_decls__: Declarations __egg_name__: str __egg_fn_ref__: FunctionRef = field(init=False) __egg_fn_decl__: FunctionDecl = field(init=False) @@ -316,36 +324,25 @@ def __str__(self) -> str: def _call( - decls: ModuleDeclarations, + decls_from_fn: Declarations, callable_ref: CallableRef, - # Not included if this is the != method - fn_decl: FunctionDecl | None, + fn_decl: FunctionDecl, args: Collection[object], kwargs: dict[str, object], bound_params: tuple[JustTypeRef, ...] | None = None, ) -> RuntimeExpr | None: # Turn all keyword args into positional args + bound = fn_decl.to_signature(lambda expr: RuntimeExpr(decls_from_fn, expr)).bind(*args, **kwargs) + bound.apply_defaults() + assert not bound.kwargs + del args, kwargs - if fn_decl: - bound = fn_decl.to_signature(lambda expr: RuntimeExpr(decls, expr)).bind(*args, **kwargs) - bound.apply_defaults() - assert not bound.kwargs - args = bound.args - mutates_first_arg = fn_decl.mutates_first_arg - else: - assert not kwargs - mutates_first_arg = False - upcasted_args: list[RuntimeExpr] - if fn_decl is not None: - upcasted_args = [ - _resolve_literal(cast(TypeOrVarRef, tp), arg) - for arg, tp in zip_longest(args, fn_decl.arg_types, fillvalue=fn_decl.var_arg_type) - ] - else: - upcasted_args = cast("list[RuntimeExpr]", args) - arg_decls = tuple(arg.__egg_typed_expr__ for arg in upcasted_args) + upcasted_args = [ + _resolve_literal(cast(TypeOrVarRef, tp), arg) + for arg, tp in zip_longest(bound.args, fn_decl.arg_types, fillvalue=fn_decl.var_arg_type) + ] - arg_types = [decl.tp for decl in arg_decls] + arg_decls = tuple(arg.__egg_typed_expr__ for arg in upcasted_args) if bound_params is not None: tcs = TypeConstraintSolver.from_type_parameters(bound_params) @@ -353,13 +350,16 @@ def _call( tcs = TypeConstraintSolver() if fn_decl is not None: + arg_types = [decl.tp for decl in arg_decls] return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types) else: return_tp = JustTypeRef("Unit") - expr_decl = CallDecl(callable_ref, arg_decls, bound_params) typed_expr_decl = TypedExprDecl(return_tp, expr_decl) - if mutates_first_arg: + decls = Declarations.create(decls_from_fn, *upcasted_args) + # Register return type sort in case it's a variadic generic that needs to be created + decls.register_sort(return_tp, False) + if fn_decl.mutates_first_arg: first_arg = upcasted_args[0] first_arg.__egg_typed_expr__ = typed_expr_decl first_arg.__egg_decls__ = decls @@ -369,7 +369,7 @@ def _call( @dataclass class RuntimeClassMethod: - __egg_decls__: ModuleDeclarations + __egg_decls__: Declarations # Either a string if it isn't bound or a tp if it s __egg_tp__: JustTypeRef | str __egg_method_name__: str @@ -428,15 +428,17 @@ class RuntimeMethod: __egg_self__: RuntimeExpr __egg_method_name__: str __egg_callable_ref__: MethodRef | PropertyRef = field(init=False) - __egg_fn_decl__: FunctionDecl | None = field(init=False) + __egg_fn_decl__: FunctionDecl = field(init=False) + __egg_decls__: Declarations = field(init=False) def __post_init__(self) -> None: - if self.__egg_method_name__ in self.__egg_self__.__egg_decls__.get_class_decl(self.class_name).properties: + self.__egg_decls__ = self.__egg_self__.__egg_decls__ + if self.__egg_method_name__ in self.__egg_decls__.get_class_decl(self.class_name).properties: self.__egg_callable_ref__ = PropertyRef(self.class_name, self.__egg_method_name__) else: self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__) try: - self.__egg_fn_decl__ = self.__egg_self__.__egg_decls__.get_function_decl(self.__egg_callable_ref__) + self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__) except KeyError as e: msg = f"Class {self.class_name} does not have method {self.__egg_method_name__}" if self.__egg_method_name__ == "__ne__": @@ -446,7 +448,7 @@ def __post_init__(self) -> None: def __call__(self, *args: object, **kwargs) -> RuntimeExpr | None: args = (self.__egg_self__, *args) try: - return _call(self.__egg_self__.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, kwargs) + return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, kwargs) except ConvertError as e: name = self.__egg_method_name__ raise TypeError(f"Wrong types for {self.__egg_self__.__egg_typed_expr__.tp.pretty()}.{name}") from e @@ -458,7 +460,7 @@ def class_name(self) -> str: @dataclass class RuntimeExpr: - __egg_decls__: ModuleDeclarations + __egg_decls__: Declarations __egg_typed_expr__: TypedExprDecl def __getattr__(self, name: str) -> RuntimeMethod | RuntimeExpr | Callable | None: @@ -501,6 +503,10 @@ def _ipython_display_(self) -> None: def __dir__(self) -> Iterable[str]: return list(self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).methods) + @property + def __egg__(self) -> bindings._Expr: + return self.__egg_typed_expr__.to_egg(self.__egg_decls__) + # Have __eq__ take no NoReturn (aka Never https://docs.python.org/3/library/typing.html#typing.Never) because # we don't wany any type that MyPy thinks is an expr to be used with __eq__. # That's because we want to reserve __eq__ for domain specific equality checks, overloading this method. @@ -512,10 +518,10 @@ def __eq__(self, other: NoReturn) -> Expr: # type: ignore[override] # Implement these so that copy() works on this object # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion - def __getstate__(self) -> tuple[ModuleDeclarations, TypedExprDecl]: + def __getstate__(self) -> tuple[Declarations, TypedExprDecl]: return (self.__egg_decls__, self.__egg_typed_expr__) - def __setstate__(self, d: tuple[ModuleDeclarations, TypedExprDecl]) -> None: + def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None: self.__egg_decls__, self.__egg_typed_expr__ = d def __hash__(self) -> int: @@ -561,29 +567,15 @@ def _reflected_method(self: RuntimeExpr, other: object, __non_reflected: str = n def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None: - # Use the mod decls that is most general between the args, if both of them are expressions - mod_decls = get_general_decls(slf, other) # find a minimum type that both can be converted to # This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats. - min_tp = min_convertable_tp(mod_decls, slf, other, name) + min_tp = min_convertable_tp(slf, other, name) slf = _resolve_literal(min_tp.to_var(), slf) other = _resolve_literal(min_tp.to_var(), other) method = RuntimeMethod(slf, name) return method(other) -def get_general_decls(a: object, b: object) -> ModuleDeclarations: - """ - Returns the more general module declerations between the two, if both are expressions. - """ - if isinstance(a, RuntimeExpr) and isinstance(b, RuntimeExpr): - return ModuleDeclarations.parent_decl(a.__egg_decls__, b.__egg_decls__) - if isinstance(a, RuntimeExpr): - return a.__egg_decls__ - assert isinstance(b, RuntimeExpr) - return b.__egg_decls__ - - for name in ["__bool__", "__len__", "__complex__", "__int__", "__float__", "__iter__", "__index__"]: def _preserved_method(self: RuntimeExpr, __name: str = name): @@ -596,16 +588,25 @@ def _preserved_method(self: RuntimeExpr, __name: str = name): setattr(RuntimeExpr, name, _preserved_method) -def _resolve_callable(callable: object) -> CallableRef: +def _resolve_callable(callable: object) -> tuple[CallableRef, Declarations]: """ Resolves a runtime callable into a ref """ + # TODO: Fix these typings. + ref: CallableRef + decls: Declarations if isinstance(callable, RuntimeFunction): - return FunctionRef(callable.__egg_name__) - if isinstance(callable, RuntimeClassMethod): - return ClassMethodRef(callable.class_name, callable.__egg_method_name__) - if isinstance(callable, RuntimeMethod): - return MethodRef(callable.__egg_self__.__egg_typed_expr__.tp.name, callable.__egg_method_name__) - if isinstance(callable, RuntimeClass): - return ClassMethodRef(callable.__egg_name__, "__init__") - raise NotImplementedError(f"Cannot turn {callable} into a callable ref") + ref = FunctionRef(callable.__egg_name__) + decls = callable.__egg_decls__ + elif isinstance(callable, RuntimeClassMethod): + ref = ClassMethodRef(callable.class_name, callable.__egg_method_name__) + decls = callable.__egg_decls__ + elif isinstance(callable, RuntimeMethod): + ref = MethodRef(callable.__egg_self__.__egg_typed_expr__.tp.name, callable.__egg_method_name__) + decls = callable.__egg_decls__ + elif isinstance(callable, RuntimeClass): + ref = ClassMethodRef(callable.__egg_name__, "__init__") + decls = callable.__egg_decls__ + else: + raise NotImplementedError(f"Cannot turn {callable} into a callable ref") + return (ref, decls) diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 081aaa3c..e13fac1f 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -31,6 +31,8 @@ def test_unwrap_lit(self): assert str(i64(1) + 1) == "i64(1) + 1" assert str(i64(1).max(2)) == "i64(1).max(2)" + def test_ne(self): + assert str(ne(i64(1)).to(i64(2))) == "ne(i64(1)).to(i64(2))" def test_eqsat_basic(): egraph = EGraph() @@ -153,21 +155,26 @@ def foo() -> i64: def test_constants(): egraph = EGraph() - one = egraph.constant("one", i64) - egraph.register(set_(one).to(i64(1))) - egraph.check(eq(one).to(i64(1))) + @egraph.class_ + class A(Expr): + pass + one = egraph.constant("one", A) + two = egraph.constant("two", A) + + egraph.register(union(one).with_(two)) + egraph.check(eq(one).to(two)) def test_class_vars(): egraph = EGraph() @egraph.class_ - class Numeric(Expr): - ONE: ClassVar[i64] - - egraph.register(set_(Numeric.ONE).to(i64(1))) - egraph.check(eq(Numeric.ONE).to(i64(1))) + class A(Expr): + ONE: ClassVar[A] + two = egraph.constant("two", A) + egraph.register(union(A.ONE).with_(two)) + egraph.check(eq(A.ONE).to(two)) def test_simplify_constant(): egraph = EGraph() @@ -210,9 +217,6 @@ def test_relation(): def test_variable_args(): egraph = EGraph() - # Create dummy function with type so its registered - egraph.relation("_", Set[i64]) - egraph.check(Set(i64(1), i64(2)).contains(i64(1))) @@ -514,6 +518,15 @@ def test_rewrite_upcasts(): rewrite(i64(1)).to(0) # type: ignore +def test_function_default_upcasts(): + egraph = EGraph() + + @egraph.function + def f(x: i64Like) -> i64: + ... + + assert expr_parts(f(1)) == expr_parts(f(i64(1))) + def test_upcast_self_lower_cost(): # Verifies that self will be upcasted, if that upcast has a lower cast than converting the other arg # i.e. Int(x) + NDArray(y) -> NDArray(Int(x)) + NDArray(y) instead of Int(x) + NDArray(y).to_int() @@ -563,16 +576,16 @@ def test_eval(): assert egraph.eval(PyObject((1, 2))) == (1, 2) -def test_egglog_string(): - egraph = EGraph(save_egglog_string=True) - egraph.register((i64(1))) - assert egraph.as_egglog_string +# def test_egglog_string(): +# egraph = EGraph(save_egglog_string=True) +# egraph.register((i64(1))) +# assert egraph.as_egglog_string -def test_no_egglog_string(): - egraph = EGraph() - egraph.register((i64(1))) - with pytest.raises(ValueError): - egraph.as_egglog_string +# def test_no_egglog_string(): +# egraph = EGraph() +# egraph.register((i64(1))) +# with pytest.raises(ValueError): +# egraph.as_egglog_string diff --git a/python/tests/test_modules.py b/python/tests/test_modules.py deleted file mode 100644 index d5497ab5..00000000 --- a/python/tests/test_modules.py +++ /dev/null @@ -1,38 +0,0 @@ -import pytest -from egglog.declarations import ModuleDeclarations -from egglog.egraph import * -from egglog.egraph import _BUILTIN_DECLS, BUILTINS - - -def test_tree_modules(): - """ - BUILTINS - / | \ - A B C - | / - D - """ - assert _BUILTIN_DECLS - assert BUILTINS._mod_decls == ModuleDeclarations(_BUILTIN_DECLS, []) - - A, B, C = Module(), Module(), Module() - assert list(A._mod_decls._included_decls) == [_BUILTIN_DECLS] - - a = A.relation("a") - b = B.relation("b") - c = C.relation("c") - A.register(a()) - B.register(b()) - C.register(c()) - - D = Module([A, B]) - d = D.relation("d") - D.register(d()) - - assert D._flatted_deps == [A, B] - - egraph = EGraph([D, B]) - assert egraph._flatted_deps == [A, B, D] - egraph.check(a(), b(), d()) - with pytest.raises(Exception): - egraph.check(c()) diff --git a/python/tests/test_runtime.py b/python/tests/test_runtime.py index 10263320..7a2ddf91 100644 --- a/python/tests/test_runtime.py +++ b/python/tests/test_runtime.py @@ -6,14 +6,12 @@ def test_type_str(): - decls = ModuleDeclarations( - Declarations( + decls = Declarations( _classes={ "i64": ClassDecl(), "Map": ClassDecl(n_type_vars=2), } ) - ) i64 = RuntimeClass(decls, "i64") Map = RuntimeClass(decls, "Map") assert str(i64) == "i64" @@ -21,8 +19,7 @@ def test_type_str(): def test_function_call(): - decls = ModuleDeclarations( - Declarations( + decls = Declarations( _classes={ "i64": ClassDecl(), }, @@ -36,7 +33,6 @@ def test_function_call(): ), }, ) - ) one = RuntimeFunction(decls, "one") assert ( one().__egg_typed_expr__ # type: ignore @@ -48,8 +44,7 @@ def test_classmethod_call(): from pytest import raises K, V = ClassTypeVarRef(0), ClassTypeVarRef(1) - decls = ModuleDeclarations( - Declarations( + decls = Declarations( _classes={ "i64": ClassDecl(), "unit": ClassDecl(), @@ -65,9 +60,13 @@ def test_classmethod_call(): ) }, ), + }, + _type_ref_to_egg_sort={ + JustTypeRef("i64"): "i64", + JustTypeRef("unit"): "unit", + JustTypeRef("Map"): "Map", } ) - ) Map = RuntimeClass(decls, "Map") with raises(TypeConstraintError): Map.create() # type: ignore @@ -75,23 +74,19 @@ def test_classmethod_call(): unit = RuntimeClass(decls, "unit") assert ( Map[i64, unit].create().__egg_typed_expr__ # type: ignore - == RuntimeExpr( - decls, - TypedExprDecl( + == TypedExprDecl( JustTypeRef("Map", (JustTypeRef("i64"), JustTypeRef("unit"))), CallDecl( ClassMethodRef("Map", "create"), (), (JustTypeRef("i64"), JustTypeRef("unit")), ), - ), - ).__egg_typed_expr__ + ) ) def test_expr_special(): - decls = ModuleDeclarations( - Declarations( + decls = Declarations( _classes={ "i64": ClassDecl( methods={ @@ -115,7 +110,6 @@ def test_expr_special(): ), }, ) - ) i64 = RuntimeClass(decls, "i64") one = i64(1) # type: ignore res = one + one # type: ignore @@ -133,13 +127,11 @@ def test_expr_special(): def test_class_variable(): - decls = ModuleDeclarations( - Declarations( + decls = Declarations( _classes={ "i64": ClassDecl(class_variables={"one": JustTypeRef("i64")}), }, ) - ) i64 = RuntimeClass(decls, "i64") one = i64.one assert isinstance(one, RuntimeExpr)