Skip to content

Commit

Permalink
Merge pull request #109 from egraphs-good/change
Browse files Browse the repository at this point in the history
ChangeFix module handling and rephrase
  • Loading branch information
saulshanabrook authored Jan 30, 2024
2 parents dea4390 + a50344c commit ff934ab
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 165 deletions.
6 changes: 4 additions & 2 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ _This project uses semantic versioning_

_not yet implemented_

This is a large breaking change that moves the function and class decorators to the top level `egglog` module,
This is a large change that moves the function and class decorators to the top level `egglog` module,
from the `EGraph` and `Module` classes. Rulesets are also moved to be defined globally instead of on the `EGraph` class.

The goal of this change is to remove the complexity of `Module`s and remove the need to think about what functions/classes
Expand All @@ -20,11 +20,13 @@ in any rules or added in any commands.
- `egraph.function` -> `egglog.function`
- `egraph.relation` -> `egglog.relation`
- `egraph.ruleset` -> `egglog.Ruleset`
- `egraph.Module` -> Removed

The `EGraph` class can take an optional `default_ruleset` argument to set the default ruleset for the `EGraph`. Otherwise,
there is a global default ruleset that is used, `egglog.Ruleset`.

This also adds support for classes with methods that are mutually recursive, by making type analysis more lazy.
For backwards compatability, the existing methods and functions are preserved, to make this easier to adopt. They will
all now raise deprication warnings.

## 5.0.0 (2024-01-16)

Expand Down
49 changes: 0 additions & 49 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,55 +933,6 @@ def traverse_for_parents(self, expr: ExprDecl) -> None:
self.traverse_for_parents(arg.expr)


# def test_expr_pretty():
# context = PrettyContext(ModuleDeclarations(Declarations()))
# assert VarDecl("x").pretty(context) == "x"
# assert LitDecl(42).pretty(context) == "i64(42)"
# assert LitDecl("foo").pretty(context) == 'String("foo")'
# assert LitDecl(None).pretty(context) == "unit()"

# def v(x: str) -> TypedExprDecl:
# return TypedExprDecl(JustTypeRef(""), VarDecl(x))

# assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(context) == "foo(x)"
# assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(context) == "foo(x, y, z)"
# assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(context) == "x + y"
# assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(context) == "x[y]"
# assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(context) == "foo(x, y)"
# assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(context) == "foo.bar(x, y)"
# assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(context) == "x(y)"
# assert (
# CallDecl(
# ClassMethodRef("Map", "__init__"),
# (),
# (JustTypeRef("i64"), JustTypeRef("Unit")),
# ).pretty(context)
# == "Map[i64, Unit]()"
# )


# def test_setitem_pretty():
# context = PrettyContext(ModuleDeclarations(Declarations()))

# def v(x: str) -> TypedExprDecl:
# return TypedExprDecl(JustTypeRef("typ"), VarDecl(x))

# final_expr = CallDecl(MethodRef("foo", "__setitem__"), (v("x"), v("y"), v("z"))).pretty(context)
# assert context.render(final_expr) == "_typ_1 = x\n_typ_1[y] = z\n_typ_1"


# def test_delitem_pretty():
# context = PrettyContext(ModuleDeclarations(Declarations()))

# def v(x: str) -> TypedExprDecl:
# return TypedExprDecl(JustTypeRef("typ"), VarDecl(x))

# final_expr = CallDecl(MethodRef("foo", "__delitem__"), (v("x"), v("y"))).pretty(context)
# assert context.render(final_expr) == "_typ_1 = x\ndel _typ_1[y]\n_typ_1"


# TODO: Multiple mutations,

ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl


Expand Down
124 changes: 10 additions & 114 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,11 @@ class _BaseModule:
# _mod_decls: ModuleDeclarations = field(init=False)

def __post_init__(self, modules: list[Module]) -> None:
# included_decls = [_BUILTIN_DECLS] if _BUILTIN_DECLS else []
# # Traverse all the included modules to flatten all their dependencies and add to the included declerations
for mod in modules:
for child_mod in [*mod._flatted_deps, mod]:
if child_mod not in self._flatted_deps:
self._flatted_deps.append(child_mod)

# self._mod_decls = ModuleDeclarations(Declarations(), included_decls)

# # TODO: Move to EGraph itself
# @abstractmethod
# def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
# """
# Process the commands generated by this module.
# """
# raise NotImplementedError

@overload
def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]:
...
Expand Down Expand Up @@ -689,24 +677,7 @@ def _resolve_type_annotation(
@dataclass
class _Builtins(_BaseModule):
is_builtin: ClassVar[bool] = True
# def __post_init__(self, modules: list[Module]) -> None:
# """
# Register these declarations as builtins, so others can use them.
# """
# assert not modules
# super().__post_init__(modules)
# global _BUILTIN_DECLS
# if _BUILTIN_DECLS is not None:
# msg = "Builtins already initialized"
# raise RuntimeError(msg)
# _BUILTIN_DECLS = self._mod_decls._decl
# # Register != operator
# _BUILTIN_DECLS.register_callable_ref(FunctionRef("!="), "!=")

# def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
# """
# Commands which would have been used to create the builtins are discarded, since they are already registered.
# """

def _register_commands(self, cmds: list[Command]) -> None:
raise NotImplementedError

Expand All @@ -721,89 +692,14 @@ def _register_commands(self, cmds: list[Command]) -> None:
def without_rules(self) -> Module:
return Module()

# _cmds: list[bindings._Command] = field(default_factory=list, repr=False)

# @property
# def as_egglog_string(self) -> str:
# """
# Returns the egglog string for this module.
# """
# return "\n".join(str(c) for c in self._cmds)

# def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
# self._cmds.extend(cmds)

# def unextractable(self) -> Module:
# """
# Makes a copy of this module with all functions marked as un-extractable
# """
# return self._map_functions(
# lambda decl: bindings.FunctionDecl(
# decl.name,
# decl.schema,
# decl.default,
# decl.merge,
# decl.merge_action,
# decl.cost,
# True,
# )
# )

# def increase_cost(self, x: int = 10000000) -> Module:
# """
# Make a copy of this module with all function costs increased by x
# """
# return self._map_functions(
# lambda decl, x=x: bindings.FunctionDecl( # type: ignore[misc]
# decl.name,
# decl.schema,
# decl.default,
# decl.merge,
# decl.merge_action,
# (decl.cost or 1) + x,
# decl.unextractable,
# )
# )

# def without_rules(self) -> Module:
# """
# Makes a copy of this module with all rules removed.
# """
# new = copy(self)
# new._cmds = [
# c
# for c in new._cmds
# if not isinstance(c, bindings.RuleCommand)
# and not isinstance(c, bindings.RewriteCommand)
# and not isinstance(c, bindings.BiRewriteCommand)
# ]
# return new

# def rename_ruleset(self, new_r: str) -> Module:
# """
# Makes a copy of this module with all default rulsets changed to the new one.
# """
# new = copy(self)
# new._cmds = [
# bindings.RuleCommand(c.name, new_r, c.rule)
# if isinstance(c, bindings.RuleCommand) and not c.ruleset
# else bindings.RewriteCommand(new_r, c.rewrite)
# if isinstance(c, bindings.RewriteCommand) and not c.name
# else bindings.BiRewriteCommand(new_r, c.rewrite)
# if isinstance(c, bindings.BiRewriteCommand) and not c.name
# else c
# for c in new._cmds
# ]
# new._cmds.insert(0, bindings.AddRuleset(new_r))
# return new

# def _map_functions(self, fn: Callable[[bindings.FunctionDecl], bindings.FunctionDecl]) -> Module:
# """
# Returns a copy where all the functions have been mapped with the given function.
# """
# new = copy(self)
# new._cmds = [bindings.Function(fn(c.decl)) if isinstance(c, bindings.Function) else c for c in new._cmds]
# return new
# Use identity for hash and equility, so we don't have to compare commands and compare expressions
def __hash__(self) -> int:
return id(self)

def __eq__(self, other: object) -> bool:
if not isinstance(other, Module):
return NotImplemented
return self is other


class GraphvizKwargs(TypedDict, total=False):
Expand Down Expand Up @@ -1467,7 +1363,7 @@ def _to_egg_command(self) -> bindings._Command:
return bindings.ActionCommand(self._to_egg_action())

@property
def ruleset(self) -> None:
def ruleset(self) -> None | Ruleset:
return None


Expand Down
38 changes: 38 additions & 0 deletions python/tests/test_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
# from egglog.declarations import ModuleDeclarations
from egglog.egraph import *
# from egglog.egraph import _BUILTIN_DECLS, BUILTINS


def test_tree_modules():
"""
BUILTINS
/ | \
A B C
| /
D
"""
# assert _BUILTIN_DECLS
# assert BUILTINS._mod_decls == ModuleDeclarations(_BUILTIN_DECLS, [])

A, B, C = Module(), Module(), Module()
# assert list(A._mod_decls._included_decls) == [_BUILTIN_DECLS]

a = A.relation("a")
b = B.relation("b")
c = C.relation("c")
A.register(a())
B.register(b())
C.register(c())

D = Module([A, B])
d = D.relation("d")
D.register(d())

assert D._flatted_deps == [A, B]

egraph = EGraph([D, B])
# assert egraph._flatted_deps == [A, B, D]
egraph.check(a(), b(), d())
with pytest.raises(Exception):
egraph.check(c())

0 comments on commit ff934ab

Please sign in to comment.