Skip to content

Commit

Permalink
Minor enhance code style for FixedShapeTensorType
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Jun 4, 2024
1 parent 4ec1c98 commit 4ec92e2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
33 changes: 17 additions & 16 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
const std::shared_ptr<ExtensionScalar>& scalar) {
const auto ext_scalar = internal::checked_pointer_cast<ExtensionScalar>(scalar);
const auto ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(scalar->type);
const auto* ext_type =
internal::checked_cast<FixedShapeTensorType*>(scalar->type.get());
if (!is_fixed_width(*ext_type->value_type())) {
return Status::TypeError("Cannot convert non-fixed-width values to Tensor.");
}
const auto array =
const auto& array =
internal::checked_pointer_cast<const FixedSizeListScalar>(ext_scalar->value)->value;
if (array->null_count() > 0) {
return Status::Invalid("Cannot convert data with nulls to Tensor.");
Expand Down Expand Up @@ -244,7 +244,7 @@ Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
const auto buffer =
SliceBuffer(array->data()->buffers[1], start_position, size * byte_width);

return Tensor::Make(ext_type->value_type(), buffer, shape, strides, dim_names);
return Tensor::Make(value_type, buffer, shape, strides, dim_names);
}

Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor(
Expand All @@ -257,6 +257,7 @@ Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor
permutation.erase(permutation.begin());

std::vector<int64_t> cell_shape;
cell_shape.reserve(permutation.size());
for (auto i : permutation) {
cell_shape.emplace_back(tensor->shape()[i]);
}
Expand Down Expand Up @@ -337,9 +338,9 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
// To convert an array of n dimensional tensors to a n+1 dimensional tensor we
// interpret the array's length as the first dimension the new tensor.

const auto ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(this->type());
const auto value_type = ext_type->value_type();
const auto* ext_type =
internal::checked_cast<FixedShapeTensorType*>(this->type().get());
const auto& value_type = ext_type->value_type();
ARROW_RETURN_IF(
!is_fixed_width(*value_type),
Status::TypeError(value_type->ToString(), " is not valid data type for a tensor"));
Expand Down Expand Up @@ -374,11 +375,11 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
internal::Permute<int64_t>(permutation, &shape);

std::vector<int64_t> tensor_strides;
const auto fw_value_type = internal::checked_pointer_cast<FixedWidthType>(value_type);
const auto* fw_value_type = internal::checked_cast<FixedWidthType*>(value_type.get());
ARROW_RETURN_NOT_OK(
ComputeStrides(*fw_value_type.get(), shape, permutation, &tensor_strides));
ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides));

const auto raw_buffer = this->storage()->data()->child_data[0]->buffers[1];
const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1];
ARROW_ASSIGN_OR_RAISE(
const auto buffer,
SliceBufferSafe(raw_buffer, this->offset() * cell_size * value_type->byte_width()));
Expand All @@ -389,7 +390,7 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
const auto ndim = shape.size();
const size_t ndim = shape.size();
if (!permutation.empty() && ndim != permutation.size()) {
return Status::Invalid("permutation size must match shape size. Expected: ", ndim,
" Got: ", permutation.size());
Expand All @@ -402,18 +403,18 @@ Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
RETURN_NOT_OK(internal::IsPermutationValid(permutation));
}

const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
const int64_t size = std::accumulate(shape.begin(), shape.end(),
static_cast<int64_t>(1), std::multiplies<>());
return std::make_shared<FixedShapeTensorType>(value_type, static_cast<int32_t>(size),
shape, permutation, dim_names);
}

const std::vector<int64_t>& FixedShapeTensorType::strides() {
if (strides_.empty()) {
auto value_type = internal::checked_pointer_cast<FixedWidthType>(this->value_type_);
auto value_type = internal::checked_cast<FixedWidthType*>(this->value_type_.get());
std::vector<int64_t> tensor_strides;
ARROW_CHECK_OK(ComputeStrides(*value_type.get(), this->shape(), this->permutation(),
&tensor_strides));
ARROW_CHECK_OK(
ComputeStrides(*value_type, this->shape(), this->permutation(), &tensor_strides));
strides_ = tensor_strides;
}
return strides_;
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/extension/fixed_shape_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
size_t ndim() const { return shape_.size(); }

/// Shape of tensor elements
const std::vector<int64_t> shape() const { return shape_; }
const std::vector<int64_t>& shape() const { return shape_; }

/// Value type of tensor elements
const std::shared_ptr<DataType> value_type() const { return value_type_; }
const std::shared_ptr<DataType>& value_type() const { return value_type_; }

/// Strides of tensor elements. Strides state offset in bytes between adjacent
/// elements along each dimension. In case permutation is non-empty strides are
Expand Down

0 comments on commit 4ec92e2

Please sign in to comment.