-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GH-33984: [C++][Python] DLPack implementation for Arrow Arrays (produ…
…cer) (#38472) ### Rationale for this change DLPack is selected for Array API protocol so it is important to have it implemented for Arrow/PyArrow Arrays also. This is possible for primitive type arrays (int, uint and float) with no validity buffer. Device support is not in scope of this PR (CPU only). ### What changes are included in this PR? - `ExportArray` and `ExportDevice` methods on Arrow C++ Arrays - `__dlpack__` method on the base PyArrow Array class exposing `ExportArray` method - `__dlpack_device__` method on the base PyArrow Array class exposing `ExportDevice` method ### Are these changes tested? Yes, tests are added to `dlpack_test.cc` and `test_array.py`. ### Are there any user-facing changes? No. * Closes: #33984 Lead-authored-by: AlenkaF <frim.alenka@gmail.com> Co-authored-by: Alenka Frim <AlenkaF@users.noreply.github.com> Co-authored-by: Antoine Pitrou <antoine@python.org> Co-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com> Signed-off-by: Antoine Pitrou <antoine@python.org>
- Loading branch information
1 parent
3e182f2
commit 6c326db
Showing
15 changed files
with
982 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
#include "arrow/c/dlpack.h" | ||
|
||
#include "arrow/array/array_base.h" | ||
#include "arrow/c/dlpack_abi.h" | ||
#include "arrow/device.h" | ||
#include "arrow/type.h" | ||
#include "arrow/type_traits.h" | ||
|
||
namespace arrow::dlpack { | ||
|
||
namespace { | ||
|
||
Result<DLDataType> GetDLDataType(const DataType& type) { | ||
DLDataType dtype; | ||
dtype.lanes = 1; | ||
dtype.bits = type.bit_width(); | ||
switch (type.id()) { | ||
case Type::INT8: | ||
case Type::INT16: | ||
case Type::INT32: | ||
case Type::INT64: | ||
dtype.code = DLDataTypeCode::kDLInt; | ||
return dtype; | ||
case Type::UINT8: | ||
case Type::UINT16: | ||
case Type::UINT32: | ||
case Type::UINT64: | ||
dtype.code = DLDataTypeCode::kDLUInt; | ||
return dtype; | ||
case Type::HALF_FLOAT: | ||
case Type::FLOAT: | ||
case Type::DOUBLE: | ||
dtype.code = DLDataTypeCode::kDLFloat; | ||
return dtype; | ||
case Type::BOOL: | ||
// DLPack supports byte-packed boolean values | ||
return Status::TypeError("Bit-packed boolean data type not supported by DLPack."); | ||
default: | ||
return Status::TypeError("DataType is not compatible with DLPack spec: ", | ||
type.ToString()); | ||
} | ||
} | ||
|
||
struct ManagerCtx { | ||
std::shared_ptr<ArrayData> array; | ||
DLManagedTensor tensor; | ||
}; | ||
|
||
} // namespace | ||
|
||
Result<DLManagedTensor*> ExportArray(const std::shared_ptr<Array>& arr) { | ||
// Define DLDevice struct nad check if array type is supported | ||
// by the DLPack protocol at the same time. Raise TypeError if not. | ||
// Supported data types: int, uint, float with no validity buffer. | ||
ARROW_ASSIGN_OR_RAISE(auto device, ExportDevice(arr)) | ||
|
||
// Define the DLDataType struct | ||
const DataType& type = *arr->type(); | ||
std::shared_ptr<ArrayData> data = arr->data(); | ||
ARROW_ASSIGN_OR_RAISE(auto dlpack_type, GetDLDataType(type)); | ||
|
||
// Create ManagerCtx that will serve as the owner of the DLManagedTensor | ||
std::unique_ptr<ManagerCtx> ctx(new ManagerCtx); | ||
|
||
// Define the data pointer to the DLTensor | ||
// If array is of length 0, data pointer should be NULL | ||
if (arr->length() == 0) { | ||
ctx->tensor.dl_tensor.data = NULL; | ||
} else { | ||
const auto data_offset = data->offset * type.byte_width(); | ||
ctx->tensor.dl_tensor.data = | ||
const_cast<uint8_t*>(data->buffers[1]->data() + data_offset); | ||
} | ||
|
||
ctx->tensor.dl_tensor.device = device; | ||
ctx->tensor.dl_tensor.ndim = 1; | ||
ctx->tensor.dl_tensor.dtype = dlpack_type; | ||
ctx->tensor.dl_tensor.shape = const_cast<int64_t*>(&data->length); | ||
ctx->tensor.dl_tensor.strides = NULL; | ||
ctx->tensor.dl_tensor.byte_offset = 0; | ||
|
||
ctx->array = std::move(data); | ||
ctx->tensor.manager_ctx = ctx.get(); | ||
ctx->tensor.deleter = [](struct DLManagedTensor* self) { | ||
delete reinterpret_cast<ManagerCtx*>(self->manager_ctx); | ||
}; | ||
return &ctx.release()->tensor; | ||
} | ||
|
||
Result<DLDevice> ExportDevice(const std::shared_ptr<Array>& arr) { | ||
// Check if array is supported by the DLPack protocol. | ||
if (arr->null_count() > 0) { | ||
return Status::TypeError("Can only use DLPack on arrays with no nulls."); | ||
} | ||
const DataType& type = *arr->type(); | ||
if (type.id() == Type::BOOL) { | ||
return Status::TypeError("Bit-packed boolean data type not supported by DLPack."); | ||
} | ||
if (!is_integer(type.id()) && !is_floating(type.id())) { | ||
return Status::TypeError("DataType is not compatible with DLPack spec: ", | ||
type.ToString()); | ||
} | ||
|
||
// Define DLDevice struct | ||
DLDevice device; | ||
if (arr->data()->buffers[1]->device_type() == DeviceAllocationType::kCPU) { | ||
device.device_id = 0; | ||
device.device_type = DLDeviceType::kDLCPU; | ||
return device; | ||
} else { | ||
return Status::NotImplemented( | ||
"DLPack support is implemented only for buffers on CPU device."); | ||
} | ||
} | ||
|
||
} // namespace arrow::dlpack |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
#pragma once | ||
|
||
#include "arrow/array/array_base.h" | ||
#include "arrow/c/dlpack_abi.h" | ||
|
||
namespace arrow::dlpack { | ||
|
||
/// \brief Export Arrow array as DLPack tensor. | ||
/// | ||
/// DLMangedTensor is produced as defined by the DLPack protocol, | ||
/// see https://dmlc.github.io/dlpack/latest/. | ||
/// | ||
/// Data types for which the protocol is supported are | ||
/// integer and floating-point data types. | ||
/// | ||
/// DLPack protocol only supports arrays with one contiguous | ||
/// memory region which means Arrow Arrays with validity buffers | ||
/// are not supported. | ||
/// | ||
/// \param[in] arr Arrow array | ||
/// \return DLManagedTensor struct | ||
ARROW_EXPORT | ||
Result<DLManagedTensor*> ExportArray(const std::shared_ptr<Array>& arr); | ||
|
||
/// \brief Get DLDevice with enumerator specifying the | ||
/// type of the device data is stored on and index of the | ||
/// device which is 0 by default for CPU. | ||
/// | ||
/// \param[in] arr Arrow array | ||
/// \return DLDevice struct | ||
ARROW_EXPORT | ||
Result<DLDevice> ExportDevice(const std::shared_ptr<Array>& arr); | ||
|
||
} // namespace arrow::dlpack |
Oops, something went wrong.