Skip to content

Commit

Permalink
Merge main into refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Mar 27, 2024
2 parents c941440 + df4e1d8 commit c4975be
Show file tree
Hide file tree
Showing 23 changed files with 9,719 additions and 64 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "egglog-python"
version = "6.0.0"
version = "6.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -12,7 +12,7 @@ crate-type = ["cdylib"]
pyo3 = { version = "0.20.2", features = ["extension-module"] }

# https://github.com/egraphs-good/egglog/compare/ceed816e9369570ffed9feeba157b19471dda70d...main
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "4cc011f6b48029dd72104a38a2ca0c7657846e0b" }
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "0113af1d6476b75d4319591cc3d675f96a71cdc5" }
# egglog = { git = "https://github.com/oflatt/egg-smol", branch = "oflatt-fast-terms" }
# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "38b3014b34399cc78887ede09c845b2a5d6c7d19" }
egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", rev = "5838c036623e91540831745b1574539e01c8cb23" }
Expand Down
13 changes: 11 additions & 2 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@

_This project uses semantic versioning_

## Unreleased
## UNRELEASED

- Defers adding rules in functions until they are used, so that you can use types that are not present yet.
- Removes ability to set custom default ruleset for egraph. Either just use the empty default ruleset or explicitly set it for every run
- Automatically mark Python builtin operators as preserved if they must return a real Python value

## 6.1.0 (2024-03-06)

- Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/4cc011f6b48029dd72104a38a2ca0c7657846e0b...0113af1d6476b75d4319591cc3d675f96a71cdc5)
- Adds subsume action
- Makes all objects besides EGraphs "sendable" aka threadsafe ([#129](https://github.com/egraphs-good/egglog-python/pull/129))

## 6.0.1 (2024-02-28)

- Upgrade dependencies, including [egglog](https://github.com/egraphs-good/egglog/compare/ceed816e9369570ffed9feeba157b19471dda70d...4cc011f6b48029dd72104a38a2ca0c7657846e0b)
- Fix bug where saturate wasn't properly getting translated.
- Automatically mark Python builtin operators as preserved if they must return a real Python value

## 6.0.0 (2024-02-06)

Expand Down
9,538 changes: 9,538 additions & 0 deletions docs/explanation/2024_03_17_community_talk.ipynb

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions docs/explanation/2024_03_17_map.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ testpaths = ["python"]
python_files = ["test_*.py", "test.py"]
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
norecursedirs = ["__snapshots__"]
filterwarnings = [
"error",
"ignore::numba.core.errors.NumbaPerformanceWarning",
"ignore::pytest_benchmark.logger.PytestBenchmarkWarning",
]

[tool.coverage.report]
exclude_also = [
Expand Down
26 changes: 22 additions & 4 deletions python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ class Fact:

_Fact: TypeAlias = Fact | Eq

##
# Change
##

@final
class Delete:
def __init__(self) -> None: ...

@final
class Subsume:
def __init__(self) -> None: ...

_Change: TypeAlias = Delete | Subsume

##
# Actions
##
Expand All @@ -170,10 +184,11 @@ class Set:
rhs: _Expr

@final
class Delete:
class Change:
change: _Change
sym: str
args: list[_Expr]
def __init__(self, sym: str, args: list[_Expr]) -> None: ...
def __init__(self, change: _Change, sym: str, args: list[_Expr]) -> None: ...

@final
class Union:
Expand All @@ -197,7 +212,7 @@ class Extract:
expr: _Expr
variants: _Expr

_Action: TypeAlias = Let | Set | Delete | Union | Panic | Expr_ | Extract
_Action: TypeAlias = Let | Set | Change | Union | Panic | Expr_ | Extract

##
# Other Structs
Expand All @@ -212,6 +227,7 @@ class FunctionDecl:
merge_action: list[_Action]
cost: int | None
unextractable: bool
ignore_viz: bool

def __init__(
self,
Expand All @@ -222,6 +238,7 @@ class FunctionDecl:
merge_action: list[_Action] = [], # noqa: B006
cost: int | None = None,
unextractable: bool = False,
ignore_viz: bool = False,
) -> None: ...

@final
Expand Down Expand Up @@ -376,7 +393,8 @@ class RewriteCommand:
# TODO: Rename to ruleset
name: str
rewrite: Rewrite
def __init__(self, name: str, rewrite: Rewrite) -> None: ...
subsume: bool
def __init__(self, name: str, rewrite: Rewrite, subsume: bool) -> None: ...

@final
class BiRewriteCommand:
Expand Down
10 changes: 6 additions & 4 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Protocol, TypeAlias, Union, runtime_checkable
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_checkable

from typing_extensions import Self, assert_never

Expand Down Expand Up @@ -57,7 +57,7 @@
"LetDecl",
"SetDecl",
"ExprActionDecl",
"DeleteDecl",
"ChangeDecl",
"UnionDecl",
"PanicDecl",
"ActionDecl",
Expand Down Expand Up @@ -553,9 +553,10 @@ class ExprActionDecl:


@dataclass(frozen=True)
class DeleteDecl:
class ChangeDecl:
tp: JustTypeRef
call: CallDecl
change: Literal["delete", "subsume"]


@dataclass(frozen=True)
Expand All @@ -570,7 +571,7 @@ class PanicDecl:
msg: str


ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | DeleteDecl | UnionDecl | PanicDecl
ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl


##
Expand All @@ -584,6 +585,7 @@ class RewriteDecl:
lhs: ExprDecl
rhs: ExprDecl
conditions: tuple[FactDecl, ...]
subsume: bool


@dataclass(frozen=True)
Expand Down
22 changes: 17 additions & 5 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"let",
"constant",
"delete",
"subsume",
"union",
"set_",
"rule",
Expand Down Expand Up @@ -1444,16 +1445,16 @@ def __repr__(self) -> str:

@deprecated("Use <ruleset>.register(<rewrite>) instead of passing rulesets as arguments to rewrites.")
@overload
def rewrite(lhs: EXPR, ruleset: Ruleset) -> _RewriteBuilder[EXPR]: ...
def rewrite(lhs: EXPR, ruleset: Ruleset, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ...


@overload
def rewrite(lhs: EXPR, ruleset: None = None) -> _RewriteBuilder[EXPR]: ...
def rewrite(lhs: EXPR, ruleset: None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ...


def rewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _RewriteBuilder[EXPR]:
def rewrite(lhs: EXPR, ruleset: Ruleset | None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]:
"""Rewrite the given expression to a new expression."""
return _RewriteBuilder(lhs, ruleset)
return _RewriteBuilder(lhs, ruleset, subsume)


@deprecated("Use <ruleset>.register(<birewrite>) instead of passing rulesets as arguments to birewrites.")
Expand Down Expand Up @@ -1502,7 +1503,16 @@ def delete(expr: Expr) -> Action:
typed_expr = runtime_expr.__egg_typed_expr__
call_decl = typed_expr.expr
assert isinstance(call_decl, CallDecl), "Can only delete calls, not literals or vars"
return Action(runtime_expr.__egg_decls__, DeleteDecl(typed_expr.tp, call_decl))
return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "delete"))


def subsume(expr: Expr) -> Action:
"""Subsume an expression so it cannot be matched against or extracted"""
runtime_expr = to_runtime_expr(expr)
typed_expr = runtime_expr.__egg_typed_expr__
call_decl = typed_expr.expr
assert isinstance(call_decl, CallDecl), "Can only subsume calls, not literals or vars"
return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "subsume"))


def expr_fact(expr: Expr) -> Fact:
Expand Down Expand Up @@ -1561,6 +1571,7 @@ def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
class _RewriteBuilder(Generic[EXPR]):
lhs: EXPR
ruleset: Ruleset | None
subsume: bool

def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
lhs = to_runtime_expr(self.lhs)
Expand All @@ -1573,6 +1584,7 @@ def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
lhs.__egg_typed_expr__.expr,
rhs.__egg_typed_expr__.expr,
tuple(f.fact for f in facts),
self.subsume,
),
)
if self.ruleset:
Expand Down
14 changes: 11 additions & 3 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
[self.fact_to_egg(c) for c in conditions],
)
return (
bindings.RewriteCommand(ruleset, rewrite)
bindings.RewriteCommand(ruleset, rewrite, cmd.subsume)
if isinstance(cmd, RewriteDecl)
else bindings.BiRewriteCommand(ruleset, rewrite)
)
Expand All @@ -132,10 +132,18 @@ def action_to_egg(self, action: ActionDecl) -> bindings._Action:
return bindings.Set(call_.name, call_.args, self.expr_to_egg(rhs))
case ExprActionDecl(typed_expr):
return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
case DeleteDecl(tp, call):
case ChangeDecl(tp, call, change):
self.type_ref_to_egg(tp)
call_ = self.expr_to_egg(call)
return bindings.Delete(call_.name, call_.args)
egg_change: bindings._Change
match change:
case "delete":
egg_change = bindings.Delete()
case "subsume":
egg_change = bindings.Subsume()
case _:
assert_never(change)
return bindings.Change(egg_change, call_.name, call_.args)
case UnionDecl(tp, lhs, rhs):
self.type_ref_to_egg(tp)
return bindings.Union(self.expr_to_egg(lhs), self.expr_to_egg(rhs))
Expand Down
1 change: 1 addition & 0 deletions python/egglog/examples/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Resolution theorem proving.
===========================
"""

from __future__ import annotations

from typing import ClassVar
Expand Down
1 change: 1 addition & 0 deletions python/egglog/examples/schedule_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Schedule demo
=============
"""

from __future__ import annotations

from egglog import *
Expand Down
8 changes: 4 additions & 4 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def _astype(x: NDArray, dtype: DType, i: i64):
]


@function(cost=500)
@function
def unique_counts(x: NDArray) -> TupleNDArray: ...


Expand Down Expand Up @@ -1028,7 +1028,7 @@ def _abs(f: Float):
]


@function(cost=100)
@function
def unique_inverse(x: NDArray) -> TupleNDArray: ...


Expand All @@ -1051,7 +1051,7 @@ def zeros(
def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: ...


@function(cost=100000)
@function
def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ...


Expand All @@ -1060,7 +1060,7 @@ def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdim
def sqrt(x: NDArray) -> NDArray: ...


@function(cost=100000)
@function
def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ...


Expand Down
18 changes: 8 additions & 10 deletions python/egglog/exp/array_api_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@

array_api_numba_ruleset = ruleset()
array_api_numba_schedule = (array_api_ruleset + array_api_numba_ruleset).saturate()
# For these rules, we not only wanna rewrite, we also want to delete the original expression,
# For these rules, we not only wanna rewrite, we also want to subsume the original expression,
# so that the rewritten one is used, even if the original one is simpler.

# TODO: Try deleting instead if we support that in the future, and remove high cost
# https://egraphs.zulipchat.com/#narrow/stream/375765-egglog/topic/replacing.20an.20expression.20with.20delete


# Rewrite mean(x, <int>, <expand dims>) to use sum b/c numba cant do mean with axis
# https://github.com/numba/numba/issues/1269
Expand All @@ -24,8 +21,8 @@ def _mean(y: NDArray, x: NDArray, i: Int):
axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
res = sum(x, axis) / NDArray.scalar(Value.int(x.shape[i]))

yield rewrite(mean(x, axis, FALSE)).to(res)
yield rewrite(mean(x, axis, TRUE)).to(expand_dims(res, i))
yield rewrite(mean(x, axis, FALSE), subsume=True).to(res)
yield rewrite(mean(x, axis, TRUE), subsume=True).to(expand_dims(res, i))


# Rewrite std(x, <int>) to use mean and sum b/c numba cant do std with axis
Expand All @@ -36,6 +33,7 @@ def _std(y: NDArray, x: NDArray, i: Int):
# "std = sqrt(mean(x)), where x = abs(a - a.mean())**2."
yield rewrite(
std(x, axis),
subsume=True,
).to(
sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)),
)
Expand All @@ -53,11 +51,11 @@ def count_values(x: NDArray, values: NDArray) -> TupleValue:
def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value):
return [
# The unique counts are the count of all the unique values
rewrite(unique_counts(x)[Int(1)]).to(NDArray.vector(count_values(x, unique_values(x)))),
rewrite(count_values(x, NDArray.vector(TupleValue(v) + tv))).to(
rewrite(unique_counts(x)[Int(1)], subsume=True).to(NDArray.vector(count_values(x, unique_values(x)))),
rewrite(count_values(x, NDArray.vector(TupleValue(v) + tv)), subsume=True).to(
TupleValue(sum(x == NDArray.scalar(v)).to_value()) + count_values(x, NDArray.vector(tv))
),
rewrite(count_values(x, NDArray.vector(TupleValue(v)))).to(
rewrite(count_values(x, NDArray.vector(TupleValue(v))), subsume=True).to(
TupleValue(sum(x == NDArray.scalar(v)).to_value()),
),
]
Expand All @@ -68,7 +66,7 @@ def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value):
def _unique_inverse(x: NDArray, i: Int):
return [
# Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i))).to(
rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i)), subsume=True).to(
x == NDArray.scalar(unique_values(x).index(TupleInt(i)))
),
]
1 change: 1 addition & 0 deletions python/egglog/exp/program_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
Builds up imperative string expressions from a functional expression.
"""

from __future__ import annotations

from typing import Union
Expand Down
Loading

0 comments on commit c4975be

Please sign in to comment.