From 0435e16c50f2a9eb7393a58d1a0a1066f6f45fb9 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Mon, 20 Jan 2025 06:13:50 -0500 Subject: [PATCH] Fix race condition involving wrapper lookup (#865) There's a race condition between wrapper lookup and wrapper deallocation where a Python wrapper may be returned that's in the process of being deallocated. This commit fixes the issue (see #864 for further details). --- src/nb_type.cpp | 74 ++++++++++++++++++++++++++++++++++++++----- tests/test_thread.cpp | 5 +++ tests/test_thread.py | 15 ++++++++- 3 files changed, 85 insertions(+), 9 deletions(-) diff --git a/src/nb_type.cpp b/src/nb_type.cpp index 94e07d56..f8a27450 100644 --- a/src/nb_type.cpp +++ b/src/nb_type.cpp @@ -40,6 +40,62 @@ static PyObject **nb_weaklist_ptr(PyObject *self) { return weaklistoffset ? (PyObject **) ((uint8_t *) self + weaklistoffset) : nullptr; } +static void nb_enable_try_inc_ref(PyObject *obj) noexcept { +#if 0 && defined(Py_GIL_DISABLED) && PY_VERSION_HEX >= 0x030E00A5 + PyUnstable_EnableTryIncRef(obj); +#elif defined(Py_GIL_DISABLED) + // Since this is called during object construction, we know that we have + // the only reference to the object and can use a non-atomic write. + assert(obj->ob_ref_shared == 0); + obj->ob_ref_shared = _Py_REF_MAYBE_WEAKREF; +#else + (void) obj; +#endif +} + +static bool nb_try_inc_ref(PyObject *obj) noexcept { +#if 0 && defined(Py_GIL_DISABLED) && PY_VERSION_HEX >= 0x030E00A5 + return PyUnstable_TryIncRef(obj); +#elif defined(Py_GIL_DISABLED) + // See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761 + uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local); + local += 1; + if (local == 0) { + // immortal + return true; + } + if (_Py_IsOwnedByCurrentThread(obj)) { + _Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local); +#ifdef Py_REF_DEBUG + _Py_INCREF_IncRefTotal(); +#endif + return true; + } + Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared); + for (;;) { + // If the shared refcount is zero and the object is either merged + // or may not have weak references, then we cannot incref it. + if (shared == 0 || shared == _Py_REF_MERGED) { + return false; + } + + if (_Py_atomic_compare_exchange_ssize( + &obj->ob_ref_shared, &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) { +#ifdef Py_REF_DEBUG + _Py_INCREF_IncRefTotal(); +#endif + return true; + } + } +#else + if (Py_REFCNT(obj) > 0) { + Py_INCREF(obj); + return true; + } + return false; +#endif +} + static PyGetSetDef inst_getset[] = { { "__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr }, { nullptr, nullptr, nullptr, nullptr, nullptr } @@ -98,6 +154,7 @@ PyObject *inst_new_int(PyTypeObject *tp, PyObject * /* args */, self->clear_keep_alive = 0; self->intrusive = intrusive; self->unused = 0; + nb_enable_try_inc_ref((PyObject *)self); // Update hash table that maps from C++ to Python instance nb_shard &shard = internals->shard((void *) payload); @@ -163,6 +220,7 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) { self->clear_keep_alive = 0; self->intrusive = intrusive; self->unused = 0; + nb_enable_try_inc_ref((PyObject *)self); nb_shard &shard = internals->shard(value); lock_shard guard(shard); @@ -1766,16 +1824,16 @@ PyObject *nb_type_put(const std::type_info *cpp_type, PyTypeObject *tp = Py_TYPE(seq.inst); if (nb_type_data(tp)->type == cpp_type) { - Py_INCREF(seq.inst); - return seq.inst; + if (nb_try_inc_ref(seq.inst)) + return seq.inst; } if (!lookup_type()) return nullptr; if (PyType_IsSubtype(tp, td->type_py)) { - Py_INCREF(seq.inst); - return seq.inst; + if (nb_try_inc_ref(seq.inst)) + return seq.inst; } if (seq.next == nullptr) @@ -1852,8 +1910,8 @@ PyObject *nb_type_put_p(const std::type_info *cpp_type, const std::type_info *p = nb_type_data(tp)->type; if (p == cpp_type || p == cpp_type_p) { - Py_INCREF(seq.inst); - return seq.inst; + if (nb_try_inc_ref(seq.inst)) + return seq.inst; } if (!lookup_type()) @@ -1861,8 +1919,8 @@ PyObject *nb_type_put_p(const std::type_info *cpp_type, if (PyType_IsSubtype(tp, td->type_py) || (td_p && PyType_IsSubtype(tp, td_p->type_py))) { - Py_INCREF(seq.inst); - return seq.inst; + if (nb_try_inc_ref(seq.inst)) + return seq.inst; } if (seq.next == nullptr) diff --git a/tests/test_thread.cpp b/tests/test_thread.cpp index 230e1285..54181b7b 100644 --- a/tests/test_thread.cpp +++ b/tests/test_thread.cpp @@ -12,6 +12,8 @@ struct Counter { } }; +struct GlobalData {} global_data; + nb::ft_mutex mutex; NB_MODULE(test_thread_ext, m) { @@ -34,4 +36,7 @@ NB_MODULE(test_thread_ext, m) { nb::ft_lock_guard guard(mutex); c.inc(); }, "counter"); + + nb::class_(m, "GlobalData") + .def_static("get", [] { return &global_data; }, nb::rv_policy::reference); } diff --git a/tests/test_thread.py b/tests/test_thread.py index 832b2fb6..1dd05af9 100644 --- a/tests/test_thread.py +++ b/tests/test_thread.py @@ -1,5 +1,5 @@ import test_thread_ext as t -from test_thread_ext import Counter +from test_thread_ext import Counter, GlobalData from common import parallelize def test01_object_creation(n_threads=8): @@ -75,3 +75,16 @@ def f(): parallelize(f, n_threads=n_threads) assert c.value == n * n_threads + + +def test_06_global_wrapper(n_threads=8): + # Check wrapper lookup racing with wrapper deallocation + n = 10000 + def f(): + for i in range(n): + GlobalData.get() + GlobalData.get() + GlobalData.get() + GlobalData.get() + + parallelize(f, n_threads=n_threads)