diff --git a/docs/changelog.md b/docs/changelog.md index 6d86a5bb..293f285a 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -8,7 +8,7 @@ _This project uses semantic versioning_ _not yet implemented_ -This is a large breaking change that moves the function and class decorators to the top level `egglog` module, +This is a large change that moves the function and class decorators to the top level `egglog` module, from the `EGraph` and `Module` classes. Rulesets are also moved to be defined globally instead of on the `EGraph` class. The goal of this change is to remove the complexity of `Module`s and remove the need to think about what functions/classes @@ -20,11 +20,13 @@ in any rules or added in any commands. - `egraph.function` -> `egglog.function` - `egraph.relation` -> `egglog.relation` - `egraph.ruleset` -> `egglog.Ruleset` +- `egraph.Module` -> Removed The `EGraph` class can take an optional `default_ruleset` argument to set the default ruleset for the `EGraph`. Otherwise, there is a global default ruleset that is used, `egglog.Ruleset`. -This also adds support for classes with methods that are mutually recursive, by making type analysis more lazy. +For backwards compatability, the existing methods and functions are preserved, to make this easier to adopt. They will +all now raise deprication warnings. ## 5.0.0 (2024-01-16) diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index fdb33132..76e6db4c 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -933,55 +933,6 @@ def traverse_for_parents(self, expr: ExprDecl) -> None: self.traverse_for_parents(arg.expr) -# def test_expr_pretty(): -# context = PrettyContext(ModuleDeclarations(Declarations())) -# assert VarDecl("x").pretty(context) == "x" -# assert LitDecl(42).pretty(context) == "i64(42)" -# assert LitDecl("foo").pretty(context) == 'String("foo")' -# assert LitDecl(None).pretty(context) == "unit()" - -# def v(x: str) -> TypedExprDecl: -# return TypedExprDecl(JustTypeRef(""), VarDecl(x)) - -# assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(context) == "foo(x)" -# assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(context) == "foo(x, y, z)" -# assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(context) == "x + y" -# assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(context) == "x[y]" -# assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(context) == "foo(x, y)" -# assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(context) == "foo.bar(x, y)" -# assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(context) == "x(y)" -# assert ( -# CallDecl( -# ClassMethodRef("Map", "__init__"), -# (), -# (JustTypeRef("i64"), JustTypeRef("Unit")), -# ).pretty(context) -# == "Map[i64, Unit]()" -# ) - - -# def test_setitem_pretty(): -# context = PrettyContext(ModuleDeclarations(Declarations())) - -# def v(x: str) -> TypedExprDecl: -# return TypedExprDecl(JustTypeRef("typ"), VarDecl(x)) - -# final_expr = CallDecl(MethodRef("foo", "__setitem__"), (v("x"), v("y"), v("z"))).pretty(context) -# assert context.render(final_expr) == "_typ_1 = x\n_typ_1[y] = z\n_typ_1" - - -# def test_delitem_pretty(): -# context = PrettyContext(ModuleDeclarations(Declarations())) - -# def v(x: str) -> TypedExprDecl: -# return TypedExprDecl(JustTypeRef("typ"), VarDecl(x)) - -# final_expr = CallDecl(MethodRef("foo", "__delitem__"), (v("x"), v("y"))).pretty(context) -# assert context.render(final_expr) == "_typ_1 = x\ndel _typ_1[y]\n_typ_1" - - -# TODO: Multiple mutations, - ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 1c53765c..c6c2fba3 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -127,23 +127,11 @@ class _BaseModule: # _mod_decls: ModuleDeclarations = field(init=False) def __post_init__(self, modules: list[Module]) -> None: - # included_decls = [_BUILTIN_DECLS] if _BUILTIN_DECLS else [] - # # Traverse all the included modules to flatten all their dependencies and add to the included declerations for mod in modules: for child_mod in [*mod._flatted_deps, mod]: if child_mod not in self._flatted_deps: self._flatted_deps.append(child_mod) - # self._mod_decls = ModuleDeclarations(Declarations(), included_decls) - - # # TODO: Move to EGraph itself - # @abstractmethod - # def _process_commands(self, cmds: Iterable[bindings._Command]) -> None: - # """ - # Process the commands generated by this module. - # """ - # raise NotImplementedError - @overload def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]: ... @@ -689,24 +677,7 @@ def _resolve_type_annotation( @dataclass class _Builtins(_BaseModule): is_builtin: ClassVar[bool] = True - # def __post_init__(self, modules: list[Module]) -> None: - # """ - # Register these declarations as builtins, so others can use them. - # """ - # assert not modules - # super().__post_init__(modules) - # global _BUILTIN_DECLS - # if _BUILTIN_DECLS is not None: - # msg = "Builtins already initialized" - # raise RuntimeError(msg) - # _BUILTIN_DECLS = self._mod_decls._decl - # # Register != operator - # _BUILTIN_DECLS.register_callable_ref(FunctionRef("!="), "!=") - - # def _process_commands(self, cmds: Iterable[bindings._Command]) -> None: - # """ - # Commands which would have been used to create the builtins are discarded, since they are already registered. - # """ + def _register_commands(self, cmds: list[Command]) -> None: raise NotImplementedError @@ -721,89 +692,14 @@ def _register_commands(self, cmds: list[Command]) -> None: def without_rules(self) -> Module: return Module() - # _cmds: list[bindings._Command] = field(default_factory=list, repr=False) - - # @property - # def as_egglog_string(self) -> str: - # """ - # Returns the egglog string for this module. - # """ - # return "\n".join(str(c) for c in self._cmds) - - # def _process_commands(self, cmds: Iterable[bindings._Command]) -> None: - # self._cmds.extend(cmds) - - # def unextractable(self) -> Module: - # """ - # Makes a copy of this module with all functions marked as un-extractable - # """ - # return self._map_functions( - # lambda decl: bindings.FunctionDecl( - # decl.name, - # decl.schema, - # decl.default, - # decl.merge, - # decl.merge_action, - # decl.cost, - # True, - # ) - # ) - - # def increase_cost(self, x: int = 10000000) -> Module: - # """ - # Make a copy of this module with all function costs increased by x - # """ - # return self._map_functions( - # lambda decl, x=x: bindings.FunctionDecl( # type: ignore[misc] - # decl.name, - # decl.schema, - # decl.default, - # decl.merge, - # decl.merge_action, - # (decl.cost or 1) + x, - # decl.unextractable, - # ) - # ) - - # def without_rules(self) -> Module: - # """ - # Makes a copy of this module with all rules removed. - # """ - # new = copy(self) - # new._cmds = [ - # c - # for c in new._cmds - # if not isinstance(c, bindings.RuleCommand) - # and not isinstance(c, bindings.RewriteCommand) - # and not isinstance(c, bindings.BiRewriteCommand) - # ] - # return new - - # def rename_ruleset(self, new_r: str) -> Module: - # """ - # Makes a copy of this module with all default rulsets changed to the new one. - # """ - # new = copy(self) - # new._cmds = [ - # bindings.RuleCommand(c.name, new_r, c.rule) - # if isinstance(c, bindings.RuleCommand) and not c.ruleset - # else bindings.RewriteCommand(new_r, c.rewrite) - # if isinstance(c, bindings.RewriteCommand) and not c.name - # else bindings.BiRewriteCommand(new_r, c.rewrite) - # if isinstance(c, bindings.BiRewriteCommand) and not c.name - # else c - # for c in new._cmds - # ] - # new._cmds.insert(0, bindings.AddRuleset(new_r)) - # return new - - # def _map_functions(self, fn: Callable[[bindings.FunctionDecl], bindings.FunctionDecl]) -> Module: - # """ - # Returns a copy where all the functions have been mapped with the given function. - # """ - # new = copy(self) - # new._cmds = [bindings.Function(fn(c.decl)) if isinstance(c, bindings.Function) else c for c in new._cmds] - # return new + # Use identity for hash and equility, so we don't have to compare commands and compare expressions + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Module): + return NotImplemented + return self is other class GraphvizKwargs(TypedDict, total=False): @@ -1467,7 +1363,7 @@ def _to_egg_command(self) -> bindings._Command: return bindings.ActionCommand(self._to_egg_action()) @property - def ruleset(self) -> None: + def ruleset(self) -> None | Ruleset: return None diff --git a/python/tests/test_modules.py b/python/tests/test_modules.py new file mode 100644 index 00000000..87a3bf75 --- /dev/null +++ b/python/tests/test_modules.py @@ -0,0 +1,38 @@ +import pytest +# from egglog.declarations import ModuleDeclarations +from egglog.egraph import * +# from egglog.egraph import _BUILTIN_DECLS, BUILTINS + + +def test_tree_modules(): + """ + BUILTINS + / | \ + A B C + | / + D + """ + # assert _BUILTIN_DECLS + # assert BUILTINS._mod_decls == ModuleDeclarations(_BUILTIN_DECLS, []) + + A, B, C = Module(), Module(), Module() + # assert list(A._mod_decls._included_decls) == [_BUILTIN_DECLS] + + a = A.relation("a") + b = B.relation("b") + c = C.relation("c") + A.register(a()) + B.register(b()) + C.register(c()) + + D = Module([A, B]) + d = D.relation("d") + D.register(d()) + + assert D._flatted_deps == [A, B] + + egraph = EGraph([D, B]) + # assert egraph._flatted_deps == [A, B, D] + egraph.check(a(), b(), d()) + with pytest.raises(Exception): + egraph.check(c())