From 188a3c4654a4706d5d58f6a3a4346a2b74305537 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 26 Mar 2024 15:30:17 -0400 Subject: [PATCH 1/3] Add failing test for sending objects to threads --- pyproject.toml | 2 +- python/tests/test_bindings.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 73bf7cbf..2d193ee1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -204,7 +204,7 @@ exclude = ["__snapshots__", "_build", "^conftest.py$"] python-source = "python" [tool.pytest.ini_options] -addopts = ["--import-mode=importlib"] +addopts = ["--import-mode=importlib", "-Werror"] testpaths = ["python"] python_files = ["test_*.py", "test.py"] markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index bca80c85..c4877b64 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -1,3 +1,4 @@ +import _thread import fractions import json import os @@ -206,3 +207,19 @@ def test_rational(self): rational = Call("rational", [Lit(Int(1)), Lit(Int(2))]) egraph.run_program(ActionCommand(Expr_(Call("rational", [Lit(Int(1)), Lit(Int(2))])))) assert egraph.eval_rational(rational) == fractions.Fraction(1, 2) + + +class TestThreads: + """ + Verify that objects can be accessed from multiple threads at the same time. + """ + + def test_run_program(self): + cmds = ( + Datatype("Math", [Variant("Add", ["Math", "Math"])]), + RewriteCommand("", Rewrite(Call("Add", [Var("a"), Var("b")]), Call("Add", [Var("b"), Var("a")])), False), + RunSchedule(Repeat(10, Run(RunConfig("")))), + ) + + _thread.start_new_thread(EGraph().run_program, cmds) + _thread.start_new_thread(EGraph().run_program, cmds) From ae3757e8b63c6dc93936b8e507d45c9a3852b7f1 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 26 Mar 2024 15:43:31 -0400 Subject: [PATCH 2/3] Fix extraneous error on other pytest warningsfilterwarnings --- pyproject.toml | 7 ++++++- python/tests/test_high_level.py | 16 ++++++++++------ python/tests/test_modules.py | 15 +++++++++------ 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2d193ee1..ed7f5445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -204,8 +204,13 @@ exclude = ["__snapshots__", "_build", "^conftest.py$"] python-source = "python" [tool.pytest.ini_options] -addopts = ["--import-mode=importlib", "-Werror"] +addopts = ["--import-mode=importlib"] testpaths = ["python"] python_files = ["test_*.py", "test.py"] markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] norecursedirs = ["__snapshots__"] +filterwarnings = [ + "error", + "ignore::numba.core.errors.NumbaPerformanceWarning", + "ignore::pytest_benchmark.logger.PytestBenchmarkWarning", +] diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 90e32863..57ee0d6f 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -224,18 +224,22 @@ def __init__(self, x: i64Like) -> None: ... def test_modules() -> None: - m = Module() + with pytest.deprecated_call(): + m = Module() @m.class_ class Numeric(Expr): ONE: ClassVar[Numeric] - m2 = Module() + with pytest.deprecated_call(): + m2 = Module() - @m2.class_ - class OtherNumeric(Expr): - @m2.method(cost=10) - def __init__(self, v: i64Like) -> None: ... + with pytest.deprecated_call(): + + @m2.class_ + class OtherNumeric(Expr): + @m2.method(cost=10) + def __init__(self, v: i64Like) -> None: ... egraph = EGraph([m, m2]) diff --git a/python/tests/test_modules.py b/python/tests/test_modules.py index 93c067d1..85b5c38b 100644 --- a/python/tests/test_modules.py +++ b/python/tests/test_modules.py @@ -17,18 +17,21 @@ def test_tree_modules(): # assert _BUILTIN_DECLS # assert BUILTINS._mod_decls == ModuleDeclarations(_BUILTIN_DECLS, []) - A, B, C = Module(), Module(), Module() + with pytest.deprecated_call(): + A, B, C = Module(), Module(), Module() # assert list(A._mod_decls._included_decls) == [_BUILTIN_DECLS] - a = A.relation("a") - b = B.relation("b") - c = C.relation("c") + with pytest.deprecated_call(): + a = A.relation("a") + b = B.relation("b") + c = C.relation("c") A.register(a()) B.register(b()) C.register(c()) - D = Module([A, B]) - d = D.relation("d") + with pytest.deprecated_call(): + D = Module([A, B]) + d = D.relation("d") D.register(d()) assert D._flatted_deps == [A, B] From e2d22336a0206e1c2ecdf65ba87fde01ee3bf8ba Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 26 Mar 2024 15:51:10 -0400 Subject: [PATCH 3/3] Make all objects besides egraphs threadsafe --- docs/changelog.md | 1 + python/tests/test_bindings.py | 15 ++++++++++++--- src/serialize.rs | 1 - src/utils.rs | 2 +- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 4975e8b2..9db2d725 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,7 @@ _This project uses semantic versioning_ - Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/4cc011f6b48029dd72104a38a2ca0c7657846e0b...0113af1d6476b75d4319591cc3d675f96a71cdc5) - Adds subsume action +- Makes all objects besides EGraphs "sendable" aka threadsafe ([#129](https://github.com/egraphs-good/egglog-python/pull/129)) ## 6.0.1 (2024-02-28) diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index c4877b64..bc669e52 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -8,6 +8,7 @@ import pytest from egglog.bindings import * +from egglog.bindings import Datatype, RewriteCommand, RunSchedule def get_egglog_folder() -> pathlib.Path: @@ -214,12 +215,20 @@ class TestThreads: Verify that objects can be accessed from multiple threads at the same time. """ - def test_run_program(self): + def test_cmds(self): cmds = ( Datatype("Math", [Variant("Add", ["Math", "Math"])]), RewriteCommand("", Rewrite(Call("Add", [Var("a"), Var("b")]), Call("Add", [Var("b"), Var("a")])), False), RunSchedule(Repeat(10, Run(RunConfig("")))), ) - _thread.start_new_thread(EGraph().run_program, cmds) - _thread.start_new_thread(EGraph().run_program, cmds) + _thread.start_new_thread(print, cmds) + + @pytest.mark.xfail(reason="egraphs are unsendable") + def test_egraph(self): + _thread.start_new_thread(EGraph().run_program, (Datatype("Math", [Variant("Add", ["Math", "Math"])]),)) + + def test_serialized_egraph(self): + egraph = EGraph() + serialized = egraph.serialize([]) + _thread.start_new_thread(print, (serialized,)) diff --git a/src/serialize.rs b/src/serialize.rs index 405263dd..63d46373 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; use pyo3::prelude::*; #[pyclass( - unsendable, text_signature = "(py_object_sort=None, *, fact_directory=None, seminaive=True, terms_encoding=False)" )] pub struct SerializedEGraph { diff --git a/src/utils.rs b/src/utils.rs index 201b5001..b5409f6e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -38,7 +38,7 @@ macro_rules! convert_enums { } );*) => { $($( - #[pyclass(unsendable, frozen, module="egg_smol.bindings"$(, name=$py_name)?)] + #[pyclass(frozen, module="egg_smol.bindings"$(, name=$py_name)?)] #[derive(Clone, PartialEq, Eq$(, $trait_inner)?)] pub struct $variant { $(