From 341eb7b0533f61063b1c4b892841eeb5db59b2b5 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Mon, 5 Feb 2024 07:05:18 -0800 Subject: [PATCH] snapshot of cl/602231983 --- clif/python/gen.py | 21 +++---- clif/python/runtime.cc | 55 +++++++++++++++++++ clif/python/runtime.h | 8 +++ .../python_multiple_inheritance_test.py | 15 +++++ 4 files changed, 89 insertions(+), 10 deletions(-) diff --git a/clif/python/gen.py b/clif/python/gen.py index 7f60c826..d6c659d7 100644 --- a/clif/python/gen.py +++ b/clif/python/gen.py @@ -408,7 +408,7 @@ def TypeObject(ht_qualname, tracked_slot_groups, 'static int tp_init_impl(PyObject* self, PyObject* args, PyObject* kw);' ) yield ( - 'static int tp_init_intercepted(' + 'static int tp_init_with_safety_checks(' 'PyObject* self, PyObject* args, PyObject* kw);' ) if not iterator: @@ -446,7 +446,11 @@ def TypeObject(ht_qualname, tracked_slot_groups, yield I+'Py_END_ALLOW_THREADS' if not iterator and enable_instance_dict: yield I+'Py_CLEAR(%s(self)->instance_dict);' % _Cast(wname) - yield I+'Py_TYPE(self)->tp_free(self);' + yield I+'PyTypeObject* type_self = Py_TYPE(self);' + yield I+'type_self->tp_free(self);' + yield '#if PY_VERSION_HEX >= 0x03080000 // python/cpython#79991 (BPO 35810)' + yield I+'Py_DECREF((PyObject*) type_self);' + yield '#endif' yield '}' if not iterator: # Use delete for static types (not derived), allocated with tp_alloc_impl. @@ -562,7 +566,7 @@ def TypeObject(ht_qualname, tracked_slot_groups, yield '}' yield '' yield ( - 'static int tp_init_intercepted(' + 'static int tp_init_with_safety_checks(' 'PyObject* self, PyObject* args, PyObject* kw) {' ) yield I+'DCHECK(PyType_Check(self) == 0);' @@ -574,7 +578,6 @@ def TypeObject(ht_qualname, tracked_slot_groups, yield I+'int status = (*derived_tp_init->second)(self, args, kw);' yield I+'if (status == 0 &&' yield I+' reinterpret_cast(self)->cpp.get() == nullptr) {' - yield I+' Py_DECREF(self);' yield I+' PyErr_Format(PyExc_TypeError,' yield I+' "%s.__init__() must be called when"' yield I+' " overriding __init__", wrapper_Type->tp_name);' @@ -601,18 +604,16 @@ def TypeObject(ht_qualname, tracked_slot_groups, ' PyObject* kwds) {' ) if ctor: - yield I+'if (type->tp_init != tp_init_impl &&' - yield I+' derived_tp_init_registry->count(type) == 0) {' - yield I+I+'(*derived_tp_init_registry)[type] = type->tp_init;' - yield I+I+'type->tp_init = tp_init_intercepted;' - yield I+'}' + yield I + I + 'return tp_new_impl_with_tp_init_safety_checks(' + yield I + I + ' type, args, kwds, derived_tp_init_registry,' + yield I + I + ' tp_init_impl, tp_init_with_safety_checks);' else: yield I + 'if (type->tp_init != Clif_PyType_Inconstructible) {' yield I + ' clif::SetErrorWrappedTypeCannotBeUsedAsBase(' yield I + ' wrapper_Type, type);' yield I + ' return nullptr;' yield I + '}' - yield I+'return PyType_GenericNew(type, args, kwds);' + yield I + 'return PyType_GenericNew(type, args, kwds);' yield '}' diff --git a/clif/python/runtime.cc b/clif/python/runtime.cc index 425abb41..624ad285 100644 --- a/clif/python/runtime.cc +++ b/clif/python/runtime.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include // NOLINT(build/c++11) @@ -29,6 +30,7 @@ #include "absl/log/check.h" #include "absl/log/log.h" #include "clif/python/pickle_support.h" +#include "clif/python/stltypes.h" extern "C" int Clif_PyType_Inconstructible(PyObject* self, PyObject* a, PyObject* kw) { @@ -558,4 +560,57 @@ PyObject* ModuleCreateAndSetPyClifCodeGenMode(PyModuleDef* module_def) { return module; } +namespace { + +extern "C" PyObject* FunctionCapsuleOneArgPyCFunction(PyObject* cap, + PyObject* arg) { + using function_type = std::function; + void* fp = PyCapsule_GetPointer(cap, typeid(function_type).name()); + if (fp == nullptr) { + return nullptr; + } + (*static_cast(fp))(arg); + Py_RETURN_NONE; +} + +static PyMethodDef FunctionCapsuleOneArgPyMethodDef = { + "", FunctionCapsuleOneArgPyCFunction, METH_O, nullptr}; + +} // namespace + +PyObject* tp_new_impl_with_tp_init_safety_checks( + PyTypeObject* type, PyObject* args, PyObject* kwds, + derived_tp_init_registry_type* derived_tp_init_registry, + initproc tp_init_impl, initproc tp_init_with_safety_checks) { + if (type->tp_init != tp_init_impl && + type->tp_init != tp_init_with_safety_checks && + derived_tp_init_registry->count(type) == 0) { + PyObject* wr_cb_fc = FunctionCapsule( + std::function([type, derived_tp_init_registry](PyObject* wr) { + CHECK_EQ(PyWeakref_CheckRef(wr), 1); + auto num_erased = derived_tp_init_registry->erase(type); + CHECK_EQ(num_erased, 1); + Py_DECREF(wr); + })); + if (wr_cb_fc == nullptr) { + return nullptr; + } + PyObject* wr_cb = + PyCFunction_New(&FunctionCapsuleOneArgPyMethodDef, wr_cb_fc); + Py_DECREF(wr_cb_fc); + if (wr_cb == nullptr) { + return nullptr; + } + PyObject* wr = PyWeakref_NewRef((PyObject*)type, wr_cb); + Py_DECREF(wr_cb); + if (wr == nullptr) { + return nullptr; + } + CHECK_NE(wr, Py_None); + (*derived_tp_init_registry)[type] = type->tp_init; + type->tp_init = tp_init_with_safety_checks; + } + return PyType_GenericNew(type, args, kwds); +} + } // namespace clif diff --git a/clif/python/runtime.h b/clif/python/runtime.h index 48b48006..164c94e0 100644 --- a/clif/python/runtime.h +++ b/clif/python/runtime.h @@ -216,6 +216,14 @@ void SetIsNotConvertibleError(PyObject* py_obj, const char* cpp_type); PyObject* ModuleCreateAndSetPyClifCodeGenMode(PyModuleDef* module_def); +using derived_tp_init_registry_type = + std::unordered_map; + +PyObject* tp_new_impl_with_tp_init_safety_checks( + PyTypeObject* type, PyObject* args, PyObject* kwds, + derived_tp_init_registry_type* derived_tp_init_registry, + initproc tp_init_impl, initproc tp_init_with_safety_checks); + } // namespace clif #endif // CLIF_PYTHON_RUNTIME_H_ diff --git a/clif/testing/python/python_multiple_inheritance_test.py b/clif/testing/python/python_multiple_inheritance_test.py index a1f82f9d..386859c4 100644 --- a/clif/testing/python/python_multiple_inheritance_test.py +++ b/clif/testing/python/python_multiple_inheritance_test.py @@ -84,6 +84,21 @@ def testPCExplicitInitMissingSuper(self, derived_type): " overriding __init__", ) + def testDerivedTpInitRegistryWeakrefBasedCleanup(self): + def NestedFunction(i): + class NestedClass(tm.CppBase): + + def __init__(self, value): + super().__init__(value + 3) + + d1 = NestedClass(i + 7) + d2 = NestedClass(i + 8) + return (d1.get_base_value(), d2.get_base_value()) + + for _ in range(100): + self.assertEqual(NestedFunction(0), (10, 11)) + self.assertEqual(NestedFunction(3), (13, 14)) + if __name__ == "__main__": absltest.main()