diff --git a/docs/changelog.md b/docs/changelog.md index 4975e8b2..9db2d725 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,7 @@ _This project uses semantic versioning_ - Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/4cc011f6b48029dd72104a38a2ca0c7657846e0b...0113af1d6476b75d4319591cc3d675f96a71cdc5) - Adds subsume action +- Makes all objects besides EGraphs "sendable" aka threadsafe ([#129](https://github.com/egraphs-good/egglog-python/pull/129)) ## 6.0.1 (2024-02-28) diff --git a/pyproject.toml b/pyproject.toml index 73bf7cbf..ed7f5445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,3 +209,8 @@ testpaths = ["python"] python_files = ["test_*.py", "test.py"] markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] norecursedirs = ["__snapshots__"] +filterwarnings = [ + "error", + "ignore::numba.core.errors.NumbaPerformanceWarning", + "ignore::pytest_benchmark.logger.PytestBenchmarkWarning", +] diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index bca80c85..bc669e52 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -1,3 +1,4 @@ +import _thread import fractions import json import os @@ -7,6 +8,7 @@ import pytest from egglog.bindings import * +from egglog.bindings import Datatype, RewriteCommand, RunSchedule def get_egglog_folder() -> pathlib.Path: @@ -206,3 +208,27 @@ def test_rational(self): rational = Call("rational", [Lit(Int(1)), Lit(Int(2))]) egraph.run_program(ActionCommand(Expr_(Call("rational", [Lit(Int(1)), Lit(Int(2))])))) assert egraph.eval_rational(rational) == fractions.Fraction(1, 2) + + +class TestThreads: + """ + Verify that objects can be accessed from multiple threads at the same time. + """ + + def test_cmds(self): + cmds = ( + Datatype("Math", [Variant("Add", ["Math", "Math"])]), + RewriteCommand("", Rewrite(Call("Add", [Var("a"), Var("b")]), Call("Add", [Var("b"), Var("a")])), False), + RunSchedule(Repeat(10, Run(RunConfig("")))), + ) + + _thread.start_new_thread(print, cmds) + + @pytest.mark.xfail(reason="egraphs are unsendable") + def test_egraph(self): + _thread.start_new_thread(EGraph().run_program, (Datatype("Math", [Variant("Add", ["Math", "Math"])]),)) + + def test_serialized_egraph(self): + egraph = EGraph() + serialized = egraph.serialize([]) + _thread.start_new_thread(print, (serialized,)) diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 90e32863..57ee0d6f 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -224,18 +224,22 @@ def __init__(self, x: i64Like) -> None: ... def test_modules() -> None: - m = Module() + with pytest.deprecated_call(): + m = Module() @m.class_ class Numeric(Expr): ONE: ClassVar[Numeric] - m2 = Module() + with pytest.deprecated_call(): + m2 = Module() - @m2.class_ - class OtherNumeric(Expr): - @m2.method(cost=10) - def __init__(self, v: i64Like) -> None: ... + with pytest.deprecated_call(): + + @m2.class_ + class OtherNumeric(Expr): + @m2.method(cost=10) + def __init__(self, v: i64Like) -> None: ... egraph = EGraph([m, m2]) diff --git a/python/tests/test_modules.py b/python/tests/test_modules.py index 93c067d1..85b5c38b 100644 --- a/python/tests/test_modules.py +++ b/python/tests/test_modules.py @@ -17,18 +17,21 @@ def test_tree_modules(): # assert _BUILTIN_DECLS # assert BUILTINS._mod_decls == ModuleDeclarations(_BUILTIN_DECLS, []) - A, B, C = Module(), Module(), Module() + with pytest.deprecated_call(): + A, B, C = Module(), Module(), Module() # assert list(A._mod_decls._included_decls) == [_BUILTIN_DECLS] - a = A.relation("a") - b = B.relation("b") - c = C.relation("c") + with pytest.deprecated_call(): + a = A.relation("a") + b = B.relation("b") + c = C.relation("c") A.register(a()) B.register(b()) C.register(c()) - D = Module([A, B]) - d = D.relation("d") + with pytest.deprecated_call(): + D = Module([A, B]) + d = D.relation("d") D.register(d()) assert D._flatted_deps == [A, B] diff --git a/src/serialize.rs b/src/serialize.rs index 405263dd..63d46373 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; use pyo3::prelude::*; #[pyclass( - unsendable, text_signature = "(py_object_sort=None, *, fact_directory=None, seminaive=True, terms_encoding=False)" )] pub struct SerializedEGraph { diff --git a/src/utils.rs b/src/utils.rs index 201b5001..b5409f6e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -38,7 +38,7 @@ macro_rules! convert_enums { } );*) => { $($( - #[pyclass(unsendable, frozen, module="egg_smol.bindings"$(, name=$py_name)?)] + #[pyclass(frozen, module="egg_smol.bindings"$(, name=$py_name)?)] #[derive(Clone, PartialEq, Eq$(, $trait_inner)?)] pub struct $variant { $(