From e2d22336a0206e1c2ecdf65ba87fde01ee3bf8ba Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 26 Mar 2024 15:51:10 -0400 Subject: [PATCH] Make all objects besides egraphs threadsafe --- docs/changelog.md | 1 + python/tests/test_bindings.py | 15 ++++++++++++--- src/serialize.rs | 1 - src/utils.rs | 2 +- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 4975e8b2..9db2d725 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,7 @@ _This project uses semantic versioning_ - Upgrade [egglog](https://github.com/egraphs-good/egglog/compare/4cc011f6b48029dd72104a38a2ca0c7657846e0b...0113af1d6476b75d4319591cc3d675f96a71cdc5) - Adds subsume action +- Makes all objects besides EGraphs "sendable" aka threadsafe ([#129](https://github.com/egraphs-good/egglog-python/pull/129)) ## 6.0.1 (2024-02-28) diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index c4877b64..bc669e52 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -8,6 +8,7 @@ import pytest from egglog.bindings import * +from egglog.bindings import Datatype, RewriteCommand, RunSchedule def get_egglog_folder() -> pathlib.Path: @@ -214,12 +215,20 @@ class TestThreads: Verify that objects can be accessed from multiple threads at the same time. """ - def test_run_program(self): + def test_cmds(self): cmds = ( Datatype("Math", [Variant("Add", ["Math", "Math"])]), RewriteCommand("", Rewrite(Call("Add", [Var("a"), Var("b")]), Call("Add", [Var("b"), Var("a")])), False), RunSchedule(Repeat(10, Run(RunConfig("")))), ) - _thread.start_new_thread(EGraph().run_program, cmds) - _thread.start_new_thread(EGraph().run_program, cmds) + _thread.start_new_thread(print, cmds) + + @pytest.mark.xfail(reason="egraphs are unsendable") + def test_egraph(self): + _thread.start_new_thread(EGraph().run_program, (Datatype("Math", [Variant("Add", ["Math", "Math"])]),)) + + def test_serialized_egraph(self): + egraph = EGraph() + serialized = egraph.serialize([]) + _thread.start_new_thread(print, (serialized,)) diff --git a/src/serialize.rs b/src/serialize.rs index 405263dd..63d46373 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; use pyo3::prelude::*; #[pyclass( - unsendable, text_signature = "(py_object_sort=None, *, fact_directory=None, seminaive=True, terms_encoding=False)" )] pub struct SerializedEGraph { diff --git a/src/utils.rs b/src/utils.rs index 201b5001..b5409f6e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -38,7 +38,7 @@ macro_rules! convert_enums { } );*) => { $($( - #[pyclass(unsendable, frozen, module="egg_smol.bindings"$(, name=$py_name)?)] + #[pyclass(frozen, module="egg_smol.bindings"$(, name=$py_name)?)] #[derive(Clone, PartialEq, Eq$(, $trait_inner)?)] pub struct $variant { $(