From 51709ac7a5a6399c8ee5030414fc29c4a98a6607 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 6 Mar 2023 16:26:16 +0200 Subject: [PATCH] Further expose C++ Extension Types in Python --- python/pyarrow/includes/libarrow.pxd | 6 ++ python/pyarrow/public-api.pxi | 3 +- python/pyarrow/tests/test_extension_type.py | 16 ++++++ python/pyarrow/types.pxi | 61 ++++++++++++++++++++- 4 files changed, 83 insertions(+), 3 deletions(-) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 295d89d351b93..b862aa5e9046c 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2693,6 +2693,12 @@ cdef extern from "arrow/extension_type.h" namespace "arrow": c_string extension_name() shared_ptr[CDataType] storage_type() + c_string Serialize() + CResult[shared_ptr[CDataType]] Deserialize(shared_ptr[CDataType] storage_type, + const c_string & serialized_data) + + bint ExtensionEquals(CExtensionType other) + @staticmethod shared_ptr[CArray] WrapArray(shared_ptr[CDataType] ext_type, shared_ptr[CArray] storage) diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 1849ecab096ca..43ac26fdce6e0 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -117,7 +117,8 @@ cdef api object pyarrow_wrap_data_type( if cpy_ext_type != nullptr: return cpy_ext_type.GetInstance() else: - out = BaseExtensionType.__new__(BaseExtensionType) + cls = get_cpp_extension_type(ext_type) + out = cls.__new__(cls) else: out = DataType.__new__(DataType) diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index fa7ece5bc24c7..8b5ad9103e283 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1115,9 +1115,16 @@ def test_cpp_extension_in_python(tmpdir): assert uuid_type.extension_name == "uuid" assert uuid_type.storage_type == pa.binary(16) + array_cls = uuid_type.__arrow_ext_class__() + scalar_cls = uuid_type.__arrow_ext_scalar_class__() + assert array_cls.__name__ == "ExtensionArray(uuid)" + assert scalar_cls.__name__ == "ExtensionScalar(uuid)" + array = mod._make_uuid_array() assert array.type == uuid_type + assert isinstance(array, array_cls) assert array.to_pylist() == [b'abcdefghijklmno0', b'0onmlkjihgfedcba'] + assert isinstance(array[0], scalar_cls) assert array[0].as_py() == b'abcdefghijklmno0' assert array[1].as_py() == b'0onmlkjihgfedcba' @@ -1127,3 +1134,12 @@ def test_cpp_extension_in_python(tmpdir): reconstructed_array = batch.column(0) assert reconstructed_array.type == uuid_type assert reconstructed_array == array + + storage = pa.array([b'0onmlkjihgfedcba']*4, pa.binary(16)) + ext_array = pa.ExtensionArray.from_storage(uuid_type, storage) + assert isinstance(ext_array, array_cls) + assert len(ext_array) == 4 + + for i in range(4): + assert isinstance(ext_array[i], scalar_cls) + assert ext_array[i].as_py() == b'0onmlkjihgfedcba' diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 29b397c04255e..c72f5803db432 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -20,6 +20,7 @@ from cpython.pycapsule cimport PyCapsule_CheckExact, PyCapsule_GetPointer import atexit from collections.abc import Mapping import re +from threading import Lock import sys import warnings @@ -831,7 +832,6 @@ cdef class BaseExtensionType(DataType): """ Concrete base class for extension types. """ - cdef void init(self, const shared_ptr[CDataType]& type) except *: DataType.init(self, type) self.ext_type = type.get() @@ -844,7 +844,7 @@ cdef class BaseExtensionType(DataType): def __arrow_ext_scalar_class__(self): """ - The associated scalar class + The associated scalar extension class """ return ExtensionScalar @@ -902,6 +902,63 @@ cdef class BaseExtensionType(DataType): self.sp_type, ( storage).sp_chunked_array)) +cdef dict _cpp_extension_type = {} + + +cdef get_cpp_extension_type(const CExtensionType * ext_type): + """ + Generates and caches a Python Extension Type wrapping a C++ Extension Type. + """ + ext_name = frombytes(deref(ext_type).extension_name()) + + try: + return _cpp_extension_type[ext_name] + except KeyError: + acls = type(f"ExtensionArray({ext_name})", (ExtensionArray,), {}) + scls = type(f"ExtensionScalar({ext_name})", (ExtensionScalar,), {}) + + _cpp_extension_type[ext_name] = tcls = type( + f"ExtensionType({ext_name})", + (CppExtensionType,), + { + "__arrow_ext_class__": lambda s: acls, + "__arrow_ext_scalar_class__": lambda s: scls + }) + + return tcls + + +cdef class CppExtensionType(BaseExtensionType): + cdef void init(self, const shared_ptr[CDataType]& data_type) except *: + BaseExtensionType.init(self, data_type) + + def __arrow_ext_serialize__(self): + """Serialized representation of metadata to reconstruct the type object.""" + return self.ext_type.Serialize() + + @classmethod + def __arrow_ext_deserialize__(self, storage_type, serialized): + """Return an extension type instance from the storage type and serialized metadata.""" + return self.ext_type.Deserialize(storage_type, serialized) + + def __repr__(self): + return self.__class__.__name__ + + __str__ = __repr__ + + def __eq__(self, other): + cdef: + const CExtensionType * c_other_ext + BaseExtensionType base_type + + if not isinstance(other, BaseExtensionType): + return False + + base_type = other + c_other_ext = base_type.ext_type + return deref(self.ext_type).ExtensionEquals(deref(c_other_ext)) + + cdef class ExtensionType(BaseExtensionType): """ Concrete base class for Python-defined extension types.