Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lightning talk and bump deps #127

Merged
merged 5 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ crate-type = ["cdylib"]
pyo3 = { version = "0.20.2", features = ["extension-module"] }

# https://github.com/egraphs-good/egglog/compare/ceed816e9369570ffed9feeba157b19471dda70d...main
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "4cc011f6b48029dd72104a38a2ca0c7657846e0b" }
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "0113af1d6476b75d4319591cc3d675f96a71cdc5" }
# egglog = { git = "https://github.com/oflatt/egg-smol", branch = "oflatt-fast-terms" }
# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "38b3014b34399cc78887ede09c845b2a5d6c7d19" }
egraph-serialize = { git = "https://github.com/egraphs-good/egraph-serialize", rev = "5838c036623e91540831745b1574539e01c8cb23" }
Expand Down
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ _This project uses semantic versioning_

## UNRELEASED

- Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/4cc011f6b48029dd72104a38a2ca0c7657846e0b...0113af1d6476b75d4319591cc3d675f96a71cdc5)
- Adds subsume action

## 6.0.1 (2024-02-28)

- Upgrade dependencies, including [egglog](https://github.com/egraphs-good/egglog/compare/ceed816e9369570ffed9feeba157b19471dda70d...4cc011f6b48029dd72104a38a2ca0c7657846e0b)
Expand Down
9,538 changes: 9,538 additions & 0 deletions docs/explanation/2024_03_17_community_talk.ipynb

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions docs/explanation/2024_03_17_map.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 22 additions & 4 deletions python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,20 @@ class Fact:

_Fact: TypeAlias = Fact | Eq

##
# Change
##

@final
class Delete:
def __init__(self) -> None: ...

@final
class Subsume:
def __init__(self) -> None: ...

_Change: TypeAlias = Delete | Subsume

##
# Actions
##
Expand All @@ -168,10 +182,11 @@ class Set:
rhs: _Expr

@final
class Delete:
class Change:
change: _Change
sym: str
args: list[_Expr]
def __init__(self, sym: str, args: list[_Expr]) -> None: ...
def __init__(self, change: _Change, sym: str, args: list[_Expr]) -> None: ...

@final
class Union:
Expand All @@ -195,7 +210,7 @@ class Extract:
expr: _Expr
variants: _Expr

_Action: TypeAlias = Let | Set | Delete | Union | Panic | Expr_ | Extract
_Action: TypeAlias = Let | Set | Change | Union | Panic | Expr_ | Extract

##
# Other Structs
Expand All @@ -210,6 +225,7 @@ class FunctionDecl:
merge_action: list[_Action]
cost: int | None
unextractable: bool
ignore_viz: bool

def __init__(
self,
Expand All @@ -220,6 +236,7 @@ class FunctionDecl:
merge_action: list[_Action] = [], # noqa: B006
cost: int | None = None,
unextractable: bool = False,
ignore_viz: bool = False,
) -> None: ...

@final
Expand Down Expand Up @@ -374,7 +391,8 @@ class RewriteCommand:
# TODO: Rename to ruleset
name: str
rewrite: Rewrite
def __init__(self, name: str, rewrite: Rewrite) -> None: ...
subsume: bool
def __init__(self, name: str, rewrite: Rewrite, subsume: bool) -> None: ...

@final
class BiRewriteCommand:
Expand Down
37 changes: 23 additions & 14 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,7 @@ class Rewrite(Command):
_lhs: RuntimeExpr
_rhs: RuntimeExpr
_conditions: tuple[Fact, ...]
_subsume: bool
_fn_name: ClassVar[str] = "rewrite"

def __str__(self) -> str:
Expand All @@ -1473,7 +1474,7 @@ def __str__(self) -> str:

def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
return bindings.RewriteCommand(
self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite()
self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite(), self._subsume
)

def _to_egg_rewrite(self) -> bindings.Rewrite:
Expand All @@ -1488,7 +1489,7 @@ def __egg_decls__(self) -> Declarations:
return Declarations.create(self._lhs, self._rhs, *self._conditions)

def with_ruleset(self, ruleset: Ruleset) -> Rewrite:
return Rewrite(ruleset, self._lhs, self._rhs, self._conditions)
return Rewrite(ruleset, self._lhs, self._rhs, self._conditions, self._subsume)


@dataclass
Expand Down Expand Up @@ -1652,21 +1653,23 @@ def __egg_decls__(self) -> Declarations:


@dataclass
class Delete(Action):
class Change(Action):
"""
Remove a function call from an EGraph.
Change a function call in an EGraph.
"""

change: Literal["delete", "subsume"]
_call: RuntimeExpr

def __str__(self) -> str:
return f"delete({self._call})"
return f"{self.change}({self._call})"

def _to_egg_action(self) -> bindings.Delete:
def _to_egg_action(self) -> bindings.Change:
egg_call = self._call.__egg__
if not isinstance(egg_call, bindings.Call):
raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") # noqa: TRY004
return bindings.Delete(egg_call.name, egg_call.args)
change: bindings._Change = bindings.Delete() if self.change == "delete" else bindings.Subsume()
return bindings.Change(change, egg_call.name, egg_call.args)

@property
def __egg_decls__(self) -> Declarations:
Expand Down Expand Up @@ -1800,16 +1803,16 @@ def __egg_decls__(self) -> Declarations:

@deprecated("Use <ruleset>.register(<rewrite>) instead of passing rulesets as arguments to rewrites.")
@overload
def rewrite(lhs: EXPR, ruleset: Ruleset) -> _RewriteBuilder[EXPR]: ...
def rewrite(lhs: EXPR, ruleset: Ruleset, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ...


@overload
def rewrite(lhs: EXPR, ruleset: None = None) -> _RewriteBuilder[EXPR]: ...
def rewrite(lhs: EXPR, ruleset: None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ...


def rewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _RewriteBuilder[EXPR]:
def rewrite(lhs: EXPR, ruleset: Ruleset | None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]:
"""Rewrite the given expression to a new expression."""
return _RewriteBuilder(lhs, ruleset)
return _RewriteBuilder(lhs, ruleset, subsume)


@deprecated("Use <ruleset>.register(<birewrite>) instead of passing rulesets as arguments to birewrites.")
Expand Down Expand Up @@ -1852,7 +1855,12 @@ def expr_action(expr: Expr) -> Action:

def delete(expr: Expr) -> Action:
"""Create a delete expression."""
return Delete(to_runtime_expr(expr))
return Change("delete", to_runtime_expr(expr))


def subsume(expr: Expr) -> Action:
"""Subsume an expression."""
return Change("subsume", to_runtime_expr(expr))


def expr_fact(expr: Expr) -> Fact:
Expand Down Expand Up @@ -1905,10 +1913,11 @@ def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
class _RewriteBuilder(Generic[EXPR]):
lhs: EXPR
ruleset: Ruleset | None
subsume: bool

def to(self, rhs: EXPR, *conditions: FactLike) -> Rewrite:
lhs = to_runtime_expr(self.lhs)
rule = Rewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions))
rule = Rewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions), self.subsume)
if self.ruleset:
self.ruleset.append(rule)
return rule
Expand All @@ -1924,7 +1933,7 @@ class _BirewriteBuilder(Generic[EXPR]):

def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
lhs = to_runtime_expr(self.lhs)
rule = BiRewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions))
rule = BiRewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions), False)
if self.ruleset:
self.ruleset.append(rule)
return rule
Expand Down
8 changes: 4 additions & 4 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ def _astype(x: NDArray, dtype: DType, i: i64):
]


@function(cost=500)
@function
def unique_counts(x: NDArray) -> TupleNDArray: ...


Expand Down Expand Up @@ -1016,7 +1016,7 @@ def _abs(f: Float):
]


@function(cost=100)
@function
def unique_inverse(x: NDArray) -> TupleNDArray: ...


Expand All @@ -1039,7 +1039,7 @@ def zeros(
def expand_dims(x: NDArray, axis: Int = Int(0)) -> NDArray: ...


@function(cost=100000)
@function
def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdims: Boolean = FALSE) -> NDArray: ...


Expand All @@ -1048,7 +1048,7 @@ def mean(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none, keepdim
def sqrt(x: NDArray) -> NDArray: ...


@function(cost=100000)
@function
def std(x: NDArray, axis: OptionalIntOrTuple = OptionalIntOrTuple.none) -> NDArray: ...


Expand Down
19 changes: 8 additions & 11 deletions python/egglog/exp/array_api_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@

array_api_numba_ruleset = ruleset()
array_api_numba_schedule = (array_api_ruleset + array_api_numba_ruleset).saturate()
# For these rules, we not only wanna rewrite, we also want to delete the original expression,
# For these rules, we not only wanna rewrite, we also want to subsume the original expression,
# so that the rewritten one is used, even if the original one is simpler.

# TODO: Try deleting instead if we support that in the future, and remove high cost
# https://egraphs.zulipchat.com/#narrow/stream/375765-egglog/topic/replacing.20an.20expression.20with.20delete


# Rewrite mean(x, <int>, <expand dims>) to use sum b/c numba cant do mean with axis
# https://github.com/numba/numba/issues/1269
Expand All @@ -24,8 +21,8 @@ def _mean(y: NDArray, x: NDArray, i: Int):
axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
res = sum(x, axis) / NDArray.scalar(Value.int(x.shape[i]))

yield rewrite(mean(x, axis, FALSE)).to(res)
yield rewrite(mean(x, axis, TRUE)).to(expand_dims(res, i))
yield rewrite(mean(x, axis, FALSE), subsume=True).to(res)
yield rewrite(mean(x, axis, TRUE), subsume=True).to(expand_dims(res, i))


# Rewrite std(x, <int>) to use mean and sum b/c numba cant do std with axis
Expand All @@ -34,7 +31,7 @@ def _std(y: NDArray, x: NDArray, i: Int):
axis = OptionalIntOrTuple.some(IntOrTuple.int(i))
# https://numpy.org/doc/stable/reference/generated/numpy.std.html
# "std = sqrt(mean(x)), where x = abs(a - a.mean())**2."
yield rewrite(std(x, axis)).to(sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)))
yield rewrite(std(x, axis), subsume=True).to(sqrt(mean(square(x - mean(x, axis, keepdims=TRUE)), axis)))


# rewrite unique_counts to count each value one by one, since numba doesn't support np.unique(..., return_counts=True)
Expand All @@ -49,11 +46,11 @@ def count_values(x: NDArray, values: NDArray) -> TupleValue:
def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value):
return [
# The unique counts are the count of all the unique values
rewrite(unique_counts(x)[Int(1)]).to(NDArray.vector(count_values(x, unique_values(x)))),
rewrite(count_values(x, NDArray.vector(TupleValue(v) + tv))).to(
rewrite(unique_counts(x)[Int(1)], subsume=True).to(NDArray.vector(count_values(x, unique_values(x)))),
rewrite(count_values(x, NDArray.vector(TupleValue(v) + tv)), subsume=True).to(
TupleValue(sum(x == NDArray.scalar(v)).to_value()) + count_values(x, NDArray.vector(tv))
),
rewrite(count_values(x, NDArray.vector(TupleValue(v)))).to(
rewrite(count_values(x, NDArray.vector(TupleValue(v))), subsume=True).to(
TupleValue(sum(x == NDArray.scalar(v)).to_value()),
),
]
Expand All @@ -64,7 +61,7 @@ def _unique_counts(x: NDArray, c: NDArray, tv: TupleValue, v: Value):
def _unique_inverse(x: NDArray, i: Int):
return [
# Creating a mask array of when the unique inverse is a value is the same as a mask array for when the value is that index of the unique values
rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i))).to(
rewrite(unique_inverse(x)[Int(1)] == NDArray.scalar(Value.int(i)), subsume=True).to(
x == NDArray.scalar(unique_values(x).index(TupleInt(i)))
),
]
2 changes: 1 addition & 1 deletion python/tests/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_run_rules(self):
egraph = EGraph()
egraph.run_program(
Datatype("Math", [Variant("Add", ["Math", "Math"])]),
RewriteCommand("", Rewrite(Call("Add", [Var("a"), Var("b")]), Call("Add", [Var("b"), Var("a")]))),
RewriteCommand("", Rewrite(Call("Add", [Var("a"), Var("b")]), Call("Add", [Var("b"), Var("a")])), False),
RunSchedule(Repeat(10, Run(RunConfig("")))),
)

Expand Down
33 changes: 23 additions & 10 deletions src/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ convert_enums!(
f -> egglog::ast::Fact::Fact((&f.expr).into()),
egglog::ast::Fact::Fact(e) => Fact { expr: e.into() }
};
egglog::ast::Change: "{:?}" => _Change {
Delete()
_d -> egglog::ast::Change::Delete,
egglog::ast::Change::Delete => Delete {};
Subsume()
_d -> egglog::ast::Change::Subsume,
egglog::ast::Change::Subsume => Subsume {}
};
egglog::ast::Action: "{}" => Action {
Let(lhs: String, rhs: Expr)
d -> egglog::ast::Action::Let((), (&d.lhs).into(), (&d.rhs).into()),
Expand All @@ -58,9 +66,10 @@ convert_enums!(
args: a.iter().map(|e| e.into()).collect(),
rhs: e.into()
};
Delete(sym: String, args: Vec<Expr>)
d -> egglog::ast::Action::Delete((), (&d.sym).into(), d.args.iter().map(|e| e.into()).collect()),
egglog::ast::Action::Delete(_, n, a) => Delete {
Change(change: _Change, sym: String, args: Vec<Expr>)
d -> egglog::ast::Action::Change((), (&d.change).into(), (&d.sym).into(), d.args.iter().map(|e| e.into()).collect()),
egglog::ast::Action::Change(_, c, n, a) => Change {
change: c.into(),
sym: n.to_string(),
args: a.iter().map(|e| e.into()).collect()
};
Expand Down Expand Up @@ -162,11 +171,12 @@ convert_enums!(
ruleset: ruleset.to_string(),
rule: rule.into()
};
RewriteCommand(name: String, rewrite: Rewrite)
r -> egglog::ast::Command::Rewrite((&r.name).into(), (&r.rewrite).into()),
egglog::ast::Command::Rewrite(name, r) => RewriteCommand {
RewriteCommand(name: String, rewrite: Rewrite, subsume: bool)
r -> egglog::ast::Command::Rewrite((&r.name).into(), (&r.rewrite).into(), r.subsume),
egglog::ast::Command::Rewrite(name, r, subsume) => RewriteCommand {
name: name.to_string(),
rewrite: r.into()
rewrite: r.into(),
subsume: *subsume
};
BiRewriteCommand(name: String, rewrite: Rewrite)
r -> egglog::ast::Command::BiRewrite((&r.name).into(), (&r.rewrite).into()),
Expand Down Expand Up @@ -303,7 +313,8 @@ convert_struct!(
merge: Option<Expr> = None,
merge_action: Vec<Action> = Vec::new(),
cost: Option<usize> = None,
unextractable: bool = false
unextractable: bool = false,
ignore_viz: bool = false
)
f -> egglog::ast::FunctionDecl {
name: (&f.name).into(),
Expand All @@ -312,7 +323,8 @@ convert_struct!(
merge: f.merge.as_ref().map(|e| e.into()),
merge_action: egglog::ast::GenericActions(f.merge_action.iter().map(|a| a.into()).collect()),
cost: f.cost,
unextractable: f.unextractable
unextractable: f.unextractable,
ignore_viz: f.ignore_viz
},
f -> FunctionDecl {
name: f.name.to_string(),
Expand All @@ -321,7 +333,8 @@ convert_struct!(
merge: f.merge.as_ref().map(|e| e.into()),
merge_action: f.merge_action.0.iter().map(|a| a.into()).collect(),
cost: f.cost,
unextractable: f.unextractable
unextractable: f.unextractable,
ignore_viz: f.ignore_viz
};
egglog::ast::Variant: "{:?}" => Variant(
name: String,
Expand Down
2 changes: 2 additions & 0 deletions stubtest_allow
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
.*egglog.bindings.CheckProof.__init__.*
.*egglog.bindings.PrintOverallStatistics.__init__.*
.*egglog.bindings.PyObjectSort.__init__.*
.*egglog.bindings.Delete.__init__.*
.*egglog.bindings.Subsume.__init__.*
Loading