Skip to content

Commit

Permalink
Merge pull request #130 from egraphs-good/refactor
Browse files Browse the repository at this point in the history
Make rulesets delayed (and refactor)
  • Loading branch information
saulshanabrook authored Mar 28, 2024
2 parents df4e1d8 + 2b25475 commit f460d74
Show file tree
Hide file tree
Showing 24 changed files with 2,546 additions and 2,155 deletions.
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ _This project uses semantic versioning_

## UNRELEASED

- Defers adding rules in functions until they are used, so that you can use types that are not present yet.
- Removes ability to set custom default ruleset for egraph. Either just use the empty default ruleset or explicitly set it for every run
- Automatically mark Python builtin operators as preserved if they must return a real Python value
- Properly pretty print all items (rewrites, actions, exprs, etc) so that expressions are de-duplicated and state is handled correctly.

## 6.1.0 (2024-03-06)

- Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/4cc011f6b48029dd72104a38a2ca0c7657846e0b...0113af1d6476b75d4319591cc3d675f96a71cdc5)
Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,13 @@ filterwarnings = [
"ignore::numba.core.errors.NumbaPerformanceWarning",
"ignore::pytest_benchmark.logger.PytestBenchmarkWarning",
]

[tool.coverage.report]
exclude_also = [
"def __repr__",
"raise NotImplementedError",
"if TYPE_CHECKING:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod",
"assert_never\\(",
]
2 changes: 1 addition & 1 deletion python/egglog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from . import config, ipython_magic # noqa: F401
from .builtins import * # noqa: UP029
from .conversion import convert, converter # noqa: F401
from .egraph import *
from .runtime import convert, converter # noqa: F401

del ipython_magic
2 changes: 2 additions & 0 deletions python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class EGraph:
fact_directory: str | Path | None = None,
seminaive: bool = True,
terms_encoding: bool = False,
record: bool = False,
) -> None: ...
def commands(self) -> str | None: ...
def parse_program(self, __input: str, /) -> list[_Command]: ...
def run_program(self, *commands: _Command) -> list[str]: ...
def extract_report(self) -> _ExtractReport | None: ...
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union

from .conversion import converter
from .egraph import Expr, Unit, function, method
from .runtime import converter

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down
172 changes: 172 additions & 0 deletions python/egglog/conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, TypeVar, cast

from .declarations import *
from .pretty import *
from .runtime import *
from .thunk import *

if TYPE_CHECKING:
from collections.abc import Callable

from .declarations import HasDeclerations
from .egraph import Expr

__all__ = ["convert", "converter", "resolve_literal", "convert_to_same_type"]
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
# Global declerations to store all convertable types so we can query if they have certain methods or not
# Defer it as a thunk so we can register conversions without triggering type signature loading
CONVERSIONS_DECLS: Callable[[], Declarations] = Thunk.value(Declarations())

T = TypeVar("T")
V = TypeVar("V", bound="Expr")


class ConvertError(Exception):
pass


def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: int = 1) -> None:
"""
Register a converter from some type to an egglog type.
"""
to_type_name = process_tp(to_type)
if not isinstance(to_type_name, JustTypeRef):
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
_register_converter(process_tp(from_type), to_type_name, fn, cost)


def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
"""
Registers a converter from some type to an egglog type, if not already registered.
Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered.
Also, if registering A->B and there is already D->A, then D->B will be registered.
"""
if a == b:
return
if (a, b) in CONVERSIONS and CONVERSIONS[(a, b)][0] <= cost:
return
CONVERSIONS[(a, b)] = (cost, a_b)
for (c, d), (other_cost, c_d) in list(CONVERSIONS.items()):
if b == c:
_register_converter(a, d, _ComposedConverter(a_b, c_d), cost + other_cost)
if a == d:
_register_converter(c, b, _ComposedConverter(c_d, a_b), cost + other_cost)


@dataclass
class _ComposedConverter:
"""
A converter which is composed of multiple converters.
_ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x))
We use the dataclass instead of the lambda to make it easier to debug.
"""

a_b: Callable
b_c: Callable

def __call__(self, x: object) -> object:
return self.b_c(self.a_b(x))

def __str__(self) -> str:
return f"{self.b_c}{self.a_b}"


def convert(source: object, target: type[V]) -> V:
"""
Convert a source object to a target type.
"""
assert isinstance(target, RuntimeClass)
return cast(V, resolve_literal(target.__egg_tp__, source))


def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
"""
Convert a source object to the same type as the target.
"""
tp = target.__egg_typed_expr__.tp
return resolve_literal(tp.to_var(), source)


def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
"""
Process a type before converting it, to add it to the global declerations and resolve to a ref.
"""
global CONVERSIONS_DECLS
if isinstance(tp, RuntimeClass):
CONVERSIONS_DECLS = Thunk.fn(_combine_decls, CONVERSIONS_DECLS, tp)
return tp.__egg_tp__.to_just()
return tp


def _combine_decls(d: Callable[[], Declarations], x: HasDeclerations) -> Declarations:
return Declarations.create(d(), x)


def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
"""
Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
"""
decls = CONVERSIONS_DECLS()
a_tp = _get_tp(a)
b_tp = _get_tp(b)
a_converts_to = {
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
}
b_converts_to = {
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
}
if isinstance(a_tp, JustTypeRef):
a_converts_to[a_tp] = 0
if isinstance(b_tp, JustTypeRef):
b_converts_to[b_tp] = 0
common = set(a_converts_to) & set(b_converts_to)
if not common:
raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])


def identity(x: object) -> object:
return x


def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
arg_type = _get_tp(arg)

# If we have any type variables, dont bother trying to resolve the literal, just return the arg
try:
tp_just = tp.to_just()
except NotImplementedError:
# If this is a var, it has to be a runtime exprssions
assert isinstance(arg, RuntimeExpr)
return arg
if arg_type == tp_just:
# If the type is an egg type, it has to be a runtime expr
assert isinstance(arg, RuntimeExpr)
return arg
# Try all parent types as well, if we are converting from a Python type
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
try:
fn = CONVERSIONS[(cast(JustTypeRef | type, arg_type_instance), tp_just)][1]
except KeyError:
continue
break
else:
raise ConvertError(f"Cannot convert {arg_type} to {tp_just}")
return fn(arg)


def _get_tp(x: object) -> JustTypeRef | type:
if isinstance(x, RuntimeExpr):
return x.__egg_typed_expr__.tp
tp = type(x)
# If this value has a custom metaclass, let's use that as our index instead of the type
if type(tp) != type:
return type(tp)
return tp
Loading

0 comments on commit f460d74

Please sign in to comment.