Skip to content

Commit

Permalink
Merge pull request #129 from egraphs-good/thread-safety
Browse files Browse the repository at this point in the history
Make low level bindings thread safe
  • Loading branch information
saulshanabrook authored Mar 26, 2024
2 parents b86b36d + e2d2233 commit 2deb470
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _This project uses semantic versioning_

- Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/4cc011f6b48029dd72104a38a2ca0c7657846e0b...0113af1d6476b75d4319591cc3d675f96a71cdc5)
- Adds subsume action
- Makes all objects besides EGraphs "sendable" aka threadsafe ([#129](https://github.com/egraphs-good/egglog-python/pull/129))

## 6.0.1 (2024-02-28)

Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,8 @@ testpaths = ["python"]
python_files = ["test_*.py", "test.py"]
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
norecursedirs = ["__snapshots__"]
filterwarnings = [
"error",
"ignore::numba.core.errors.NumbaPerformanceWarning",
"ignore::pytest_benchmark.logger.PytestBenchmarkWarning",
]
26 changes: 26 additions & 0 deletions python/tests/test_bindings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import _thread
import fractions
import json
import os
Expand All @@ -7,6 +8,7 @@
import pytest

from egglog.bindings import *
from egglog.bindings import Datatype, RewriteCommand, RunSchedule


def get_egglog_folder() -> pathlib.Path:
Expand Down Expand Up @@ -206,3 +208,27 @@ def test_rational(self):
rational = Call("rational", [Lit(Int(1)), Lit(Int(2))])
egraph.run_program(ActionCommand(Expr_(Call("rational", [Lit(Int(1)), Lit(Int(2))]))))
assert egraph.eval_rational(rational) == fractions.Fraction(1, 2)


class TestThreads:
"""
Verify that objects can be accessed from multiple threads at the same time.
"""

def test_cmds(self):
cmds = (
Datatype("Math", [Variant("Add", ["Math", "Math"])]),
RewriteCommand("", Rewrite(Call("Add", [Var("a"), Var("b")]), Call("Add", [Var("b"), Var("a")])), False),
RunSchedule(Repeat(10, Run(RunConfig("")))),
)

_thread.start_new_thread(print, cmds)

@pytest.mark.xfail(reason="egraphs are unsendable")
def test_egraph(self):
_thread.start_new_thread(EGraph().run_program, (Datatype("Math", [Variant("Add", ["Math", "Math"])]),))

def test_serialized_egraph(self):
egraph = EGraph()
serialized = egraph.serialize([])
_thread.start_new_thread(print, (serialized,))
16 changes: 10 additions & 6 deletions python/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,18 +224,22 @@ def __init__(self, x: i64Like) -> None: ...


def test_modules() -> None:
m = Module()
with pytest.deprecated_call():
m = Module()

@m.class_
class Numeric(Expr):
ONE: ClassVar[Numeric]

m2 = Module()
with pytest.deprecated_call():
m2 = Module()

@m2.class_
class OtherNumeric(Expr):
@m2.method(cost=10)
def __init__(self, v: i64Like) -> None: ...
with pytest.deprecated_call():

@m2.class_
class OtherNumeric(Expr):
@m2.method(cost=10)
def __init__(self, v: i64Like) -> None: ...

egraph = EGraph([m, m2])

Expand Down
15 changes: 9 additions & 6 deletions python/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@ def test_tree_modules():
# assert _BUILTIN_DECLS
# assert BUILTINS._mod_decls == ModuleDeclarations(_BUILTIN_DECLS, [])

A, B, C = Module(), Module(), Module()
with pytest.deprecated_call():
A, B, C = Module(), Module(), Module()
# assert list(A._mod_decls._included_decls) == [_BUILTIN_DECLS]

a = A.relation("a")
b = B.relation("b")
c = C.relation("c")
with pytest.deprecated_call():
a = A.relation("a")
b = B.relation("b")
c = C.relation("c")
A.register(a())
B.register(b())
C.register(c())

D = Module([A, B])
d = D.relation("d")
with pytest.deprecated_call():
D = Module([A, B])
d = D.relation("d")
D.register(d())

assert D._flatted_deps == [A, B]
Expand Down
1 change: 0 additions & 1 deletion src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::collections::HashMap;
use pyo3::prelude::*;

#[pyclass(
unsendable,
text_signature = "(py_object_sort=None, *, fact_directory=None, seminaive=True, terms_encoding=False)"
)]
pub struct SerializedEGraph {
Expand Down
2 changes: 1 addition & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ macro_rules! convert_enums {
}
);*) => {
$($(
#[pyclass(unsendable, frozen, module="egg_smol.bindings"$(, name=$py_name)?)]
#[pyclass(frozen, module="egg_smol.bindings"$(, name=$py_name)?)]
#[derive(Clone, PartialEq, Eq$(, $trait_inner)?)]
pub struct $variant {
$(
Expand Down

0 comments on commit 2deb470

Please sign in to comment.