diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 295d89d351b93..b862aa5e9046c 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2693,6 +2693,12 @@ cdef extern from "arrow/extension_type.h" namespace "arrow": c_string extension_name() shared_ptr[CDataType] storage_type() + c_string Serialize() + CResult[shared_ptr[CDataType]] Deserialize(shared_ptr[CDataType] storage_type, + const c_string & serialized_data) + + bint ExtensionEquals(CExtensionType other) + @staticmethod shared_ptr[CArray] WrapArray(shared_ptr[CDataType] ext_type, shared_ptr[CArray] storage) diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index fa7ece5bc24c7..8b5ad9103e283 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1115,9 +1115,16 @@ def test_cpp_extension_in_python(tmpdir): assert uuid_type.extension_name == "uuid" assert uuid_type.storage_type == pa.binary(16) + array_cls = uuid_type.__arrow_ext_class__() + scalar_cls = uuid_type.__arrow_ext_scalar_class__() + assert array_cls.__name__ == "ExtensionArray(uuid)" + assert scalar_cls.__name__ == "ExtensionScalar(uuid)" + array = mod._make_uuid_array() assert array.type == uuid_type + assert isinstance(array, array_cls) assert array.to_pylist() == [b'abcdefghijklmno0', b'0onmlkjihgfedcba'] + assert isinstance(array[0], scalar_cls) assert array[0].as_py() == b'abcdefghijklmno0' assert array[1].as_py() == b'0onmlkjihgfedcba' @@ -1127,3 +1134,12 @@ def test_cpp_extension_in_python(tmpdir): reconstructed_array = batch.column(0) assert reconstructed_array.type == uuid_type assert reconstructed_array == array + + storage = pa.array([b'0onmlkjihgfedcba']*4, pa.binary(16)) + ext_array = pa.ExtensionArray.from_storage(uuid_type, storage) + assert isinstance(ext_array, array_cls) + assert len(ext_array) == 4 + + for i in range(4): + assert isinstance(ext_array[i], scalar_cls) + assert ext_array[i].as_py() == b'0onmlkjihgfedcba' diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 29b397c04255e..197a87fcef278 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -20,6 +20,7 @@ from cpython.pycapsule cimport PyCapsule_CheckExact, PyCapsule_GetPointer import atexit from collections.abc import Mapping import re +from threading import Lock import sys import warnings @@ -831,22 +832,62 @@ cdef class BaseExtensionType(DataType): """ Concrete base class for extension types. """ + _cache_lock = Lock() + _ext_array_cache = {} + _ext_scalar_cache = {} cdef void init(self, const shared_ptr[CDataType]& type) except *: DataType.init(self, type) self.ext_type = type.get() + def __arrow_ext_serialize__(self): + """Serialized representation of metadata to reconstruct the type object.""" + return self.ext_type.Serialize() + + @classmethod + def __arrow_ext_deserialize__(self, storage_type, serialized): + """Return an extension type instance from the storage type and serialized metadata.""" + return self.ext_type.Deserialize(storage_type, serialized) + + def __eq__(self, other): + cdef: + const CExtensionType * c_other_ext + BaseExtensionType base_type + + if not isinstance(other, BaseExtensionType): + return False + + base_type = other + c_other_ext = base_type.ext_type + return deref(self.ext_type).ExtensionEquals(deref(c_other_ext)) + def __arrow_ext_class__(self): """ The associated array extension class """ - return ExtensionArray + key = (self.extension_name, self.storage_type) + + with self._cache_lock: + try: + return self._ext_array_cache[key] + except KeyError: + self._ext_array_cache[key] = cls = type( + f"ExtensionArray({key[0]})", (ExtensionArray,), {}) + return cls def __arrow_ext_scalar_class__(self): """ - The associated scalar class + The associated scalar extension class """ - return ExtensionScalar + key = (self.extension_name, self.storage_type) + + with self._cache_lock: + try: + return self._ext_scalar_cache[key] + except KeyError: + self._ext_scalar_cache[key] = cls = type( + f"ExtensionScalar({key[0]})", (ExtensionScalar,), {}) + return cls @property def extension_name(self):