diff --git a/clif/python/BUILD b/clif/python/BUILD index ece0314..aa91ff3 100644 --- a/clif/python/BUILD +++ b/clif/python/BUILD @@ -52,7 +52,9 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":pickle_support_lib", + "//net/proto2/proto:descriptor_cc_proto", "//third_party/protobuf/io", + "//third_party/pybind11_abseil/compat:py_base_utilities", "@com_google_absl//absl/base:config", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", diff --git a/clif/python/pyproto.cc b/clif/python/pyproto.cc index 872c4ba..8f8d29c 100644 --- a/clif/python/pyproto.cc +++ b/clif/python/pyproto.cc @@ -22,13 +22,16 @@ headers are included. #include +#include #include #include "absl/log/check.h" +#include "absl/log/log.h" #include "clif/python/runtime.h" #include "google/protobuf/io/coded_stream.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/dynamic_message.h" +#include "third_party/pybind11_abseil/compat/py_base_utilities.h" namespace { @@ -95,6 +98,138 @@ PyObject* GetMessageName(PyObject* py) { } return fn; } + +class ClifDescriptorDatabase : public proto2::DescriptorDatabase { + public: + ClifDescriptorDatabase() { + using pybind11_abseil::compat::py_base_utilities:: + PyExcFetchMaybeErrOccurred; + PyObject* descriptor_pool = + PyImport_ImportModule("google.protobuf.descriptor_pool"); + if (descriptor_pool == nullptr) { + LOG(FATAL) << "Failed to import google.protobuf.descriptor_pool module: " + << PyExcFetchMaybeErrOccurred().FlatMessage(); + } + + pool_ = PyObject_CallMethod(descriptor_pool, "Default", nullptr); + if (pool_ == nullptr) { + LOG(FATAL) << "Failed to get python Default pool: " + << PyExcFetchMaybeErrOccurred().FlatMessage(); + } + Py_DECREF(descriptor_pool); + }; + + ~ClifDescriptorDatabase() { + // Objects of this class are meant to be `static`ally initialized and + // never destroyed. This is a commonly used approach, because the order + // in which destructors of static objects run is unpredictable. In + // particular, it is possible that the Python interpreter may have been + // finalized already. + DLOG(FATAL) << "MEANT TO BE UNREACHABLE."; + }; + + bool FindFileByName(const std::string& filename, + proto2::FileDescriptorProto* output) override { + PyObject* pyfile_name = + PyUnicode_FromStringAndSize(filename.data(), filename.size()); + if (pyfile_name == nullptr) { + // Ideally this would be raise from. + PyErr_Format(PyExc_TypeError, "Fail to convert proto file name"); + return false; + } + + PyObject* pyfile = + PyObject_CallMethod(pool_, "FindFileByName", "O", pyfile_name); + Py_DECREF(pyfile_name); + if (pyfile == nullptr) { + // Ideally this would be raise from. + PyErr_Format(PyExc_TypeError, "Default python pool fail to find %s", + filename.data()); + return false; + } + + PyObject* pyfile_serialized = + PyObject_GetAttrString(pyfile, "serialized_pb"); + Py_DECREF(pyfile); + if (pyfile_serialized == nullptr) { + // Ideally this would be raise from. + PyErr_Format(PyExc_TypeError, + "Python file has no attribute 'serialized_pb'"); + return false; + } + + bool ok = output->ParseFromArray( + reinterpret_cast(PyBytes_AS_STRING(pyfile_serialized)), + PyBytes_GET_SIZE(pyfile_serialized)); + if (!ok) { + LOG(ERROR) << "Failed to parse descriptor for " << filename; + } + Py_DECREF(pyfile_serialized); + return ok; + } + + bool FindFileContainingSymbol(const std::string& symbol_name, + proto2::FileDescriptorProto* output) override { + return false; + } + + bool FindFileContainingExtension( + const std::string& containing_type, int field_number, + proto2::FileDescriptorProto* output) override { + return false; + } + + PyObject* pool() { return pool_; } + + private: + PyObject* pool_; +}; + +const proto2::Descriptor* FindMessageDescriptor( + PyObject* pyfile, const char* descritor_full_name) { + static auto* database = new ClifDescriptorDatabase(); + static auto* pool = new proto2::DescriptorPool(database); + PyObject* pyfile_name = PyObject_GetAttrString(pyfile, "name"); + if (pyfile_name == nullptr) { + // Ideally this would be raise from. + PyErr_Format(PyExc_TypeError, "FileDescriptor has no attribute 'name'"); + return nullptr; + } + PyObject* pyfile_pool = PyObject_GetAttrString(pyfile, "pool"); + if (pyfile_pool == nullptr) { + Py_DECREF(pyfile_name); + // Ideally this would be raise from. + PyErr_Format(PyExc_TypeError, "FileDescriptor has no attribute 'pool'"); + return nullptr; + } + bool is_from_generated_pool = database->pool() == pyfile_pool; + Py_DECREF(pyfile_pool); + const char* pyfile_name_char_ptr = PyUnicode_AsUTF8(pyfile_name); + if (pyfile_name_char_ptr == nullptr) { + Py_DECREF(pyfile_name); + // Ideally this would be raise from. + PyErr_Format(PyExc_TypeError, + "FileDescriptor 'name' PyUnicode_AsUTF8() failure."); + return nullptr; + } + if (!is_from_generated_pool) { + PyErr_Format(PyExc_TypeError, "%s is not from generated pool", + pyfile_name_char_ptr); + Py_DECREF(pyfile_name); + return nullptr; + } + pool->FindFileByName(pyfile_name_char_ptr); + Py_DECREF(pyfile_name); + + return pool->FindMessageTypeByName(descritor_full_name); +} + +proto2::DynamicMessageFactory* GetFactory() { + static proto2::DynamicMessageFactory* factory = + new proto2::DynamicMessageFactory; + return factory; +} + } // namespace proto bool Internal_Clif_PyObjAs(PyObject* py, std::unique_ptr<::proto2::Message>* c, @@ -110,15 +245,42 @@ bool Internal_Clif_PyObjAs(PyObject* py, std::unique_ptr<::proto2::Message>* c, } const proto2::Descriptor* d = dp->FindMessageTypeByName( PyUnicode_AsUTF8(fn)); + proto2::Message* m; if (d == nullptr) { - PyErr_Format(PyExc_TypeError, "DESCRIPTOR.full_name %s not found", - PyUnicode_AsUTF8(fn)); - Py_DECREF(fn); - return false; + PyObject* pyd = PyObject_GetAttrString(py, "DESCRIPTOR"); + if (pyd == nullptr) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Format(PyExc_TypeError, "'%s' %s has no attribute 'DESCRIPTOR'", + ClassName(py), ClassType(py)); + } + return false; + } + + PyObject* pyfile = PyObject_GetAttrString(pyd, "file"); + Py_DECREF(pyd); + if (pyfile == nullptr) { + PyErr_Format(PyExc_TypeError, "'%s.DESCRIPTOR' has no attribute 'file'", + ClassName(py)); + return false; + } + + const char* descritor_full_name = PyUnicode_AsUTF8(fn); + if (descritor_full_name == nullptr) { + PyErr_Format(PyExc_ValueError, "Fail to convert descriptor full name"); + } + + d = proto::FindMessageDescriptor(pyfile, descritor_full_name); + Py_DECREF(pyfile); + if (d == nullptr) { + PyErr_Format(PyExc_ValueError, "Fail to find descriptor %s.", + descritor_full_name); + return false; + } + m = proto::GetFactory()->GetPrototype(d)->New(); + } else { + m = proto2::MessageFactory::generated_factory()->GetPrototype(d)->New(); } Py_DECREF(fn); - proto2::Message* m = proto2::MessageFactory::generated_factory()-> - GetPrototype(d)->New(); if (m == nullptr) { PyErr_SetNone(PyExc_MemoryError); return false;