From d216a0f6117c1a34b77058c7d81702a1aa1cc793 Mon Sep 17 00:00:00 2001 From: Dustin Spicuzza Date: Sun, 31 Jan 2021 21:57:45 -0500 Subject: [PATCH] Attach python lifetime to shared_ptr passed to C++ - Reference cycles are possible as a result, but shared_ptr is already susceptible to this in C++ --- include/pybind11/cast.h | 39 ++++++++++++++++++++++++++++++++++- tests/test_smart_ptr.cpp | 32 +++++++++++++++++++++++++++++ tests/test_smart_ptr.py | 44 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 1 deletion(-) diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 72b3a87fd7e..f6f9e02897f 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -10,6 +10,7 @@ #pragma once +#include "gil.h" #include "pytypes.h" #include "detail/typeid.h" #include "detail/descr.h" @@ -1524,6 +1525,42 @@ struct holder_helper { static auto get(const T &p) -> decltype(p.get()) { return p.get(); } }; +/// Another helper class for holders that helps construct derivative holders from +/// the original holder +template +struct holder_retriever { + static auto get_derivative_holder(const value_and_holder &v_h) -> decltype(v_h.template holder()) { + return v_h.template holder(); + } +}; + +template +struct holder_retriever> { + struct shared_ptr_deleter { + // Note: deleter destructor fails on MSVC 2015 and GCC 4.8, so we manually + // call dec_ref here instead + handle ref; + void operator()(T *) { + gil_scoped_acquire gil; + ref.dec_ref(); + } + }; + + static auto get_derivative_holder(const value_and_holder &v_h) -> std::shared_ptr { + // The shared_ptr is always given to C++ code, so construct a new shared_ptr + // that is given a custom deleter. The custom deleter increments the python + // reference count to bind the python instance lifetime with the lifetime + // of the shared_ptr. + // + // This enables things like passing the last python reference of a subclass to a + // C++ function without the python reference dying. + // + // Reference cycles will cause a leak, but this is a limitation of shared_ptr + return std::shared_ptr((T*)v_h.value_ptr(), + shared_ptr_deleter{handle((PyObject*)v_h.inst).inc_ref()}); + } +}; + /// Type caster for holder types like std::shared_ptr, etc. /// The SFINAE hook is provided to help work around the current lack of support /// for smart-pointer interoperability. Please consider it an implementation @@ -1566,7 +1603,7 @@ struct copyable_holder_caster : public type_caster_base { bool load_value(value_and_holder &&v_h) { if (v_h.holder_constructed()) { value = v_h.value_ptr(); - holder = v_h.template holder(); + holder = holder_retriever::get_derivative_holder(v_h); return true; } else { throw cast_error("Unable to cast from non-held to held instance (T& to Holder) " diff --git a/tests/test_smart_ptr.cpp b/tests/test_smart_ptr.cpp index 59996edeb4d..e9a039bb3cd 100644 --- a/tests/test_smart_ptr.cpp +++ b/tests/test_smart_ptr.cpp @@ -397,4 +397,36 @@ TEST_SUBMODULE(smart_ptr, m) { list.append(py::cast(e)); return list; }); + + // For testing whether a python subclass of a C++ object dies when the + // last python reference is lost + struct SpBase { + // returns true if the base virtual function is called + virtual bool is_base_used() { return true; } + + SpBase() = default; + SpBase(const SpBase&) = delete; + virtual ~SpBase() = default; + }; + + struct PySpBase : SpBase { + bool is_base_used() override { PYBIND11_OVERRIDE(bool, SpBase, is_base_used); } + }; + + struct SpBaseTester { + std::shared_ptr get_object() { return m_obj; } + void set_object(std::shared_ptr obj) { m_obj = obj; } + bool is_base_used() { return m_obj->is_base_used(); } + std::shared_ptr m_obj; + }; + + py::class_, PySpBase>(m, "SpBase") + .def(py::init<>()) + .def("is_base_used", &SpBase::is_base_used); + + py::class_(m, "SpBaseTester") + .def(py::init<>()) + .def("get_object", &SpBaseTester::get_object) + .def("set_object", &SpBaseTester::set_object) + .def("is_base_used", &SpBaseTester::is_base_used); } diff --git a/tests/test_smart_ptr.py b/tests/test_smart_ptr.py index 85f61a32236..29e2968f064 100644 --- a/tests/test_smart_ptr.py +++ b/tests/test_smart_ptr.py @@ -316,3 +316,47 @@ def test_shared_ptr_gc(): pytest.gc_collect() for i, v in enumerate(el.get()): assert i == v.value() + + +def test_shared_ptr_cpp_arg(): + import weakref + + class PyChild(m.SpBase): + def is_base_used(self): + return False + + tester = m.SpBaseTester() + + obj = PyChild() + objref = weakref.ref(obj) + + tester.set_object(obj) + del obj + pytest.gc_collect() + + # python reference is still around since C++ has it now + assert objref() is not None + assert tester.is_base_used() is False + assert tester.get_object() is objref() + + +def test_shared_ptr_arg_identity(): + import weakref + + tester = m.SpBaseTester() + + obj = m.SpBase() + objref = weakref.ref(obj) + + tester.set_object(obj) + del obj + pytest.gc_collect() + + # python reference is still around since C++ has it + assert objref() is not None + assert tester.get_object() is objref() + + # python reference disappears once the C++ object releases it + tester.set_object(None) + pytest.gc_collect() + assert objref() is None