From 4ec92e26f99ca6d2c91f618b8e6a20b6a08fc80d Mon Sep 17 00:00:00 2001 From: mwish Date: Tue, 4 Jun 2024 17:08:37 +0800 Subject: [PATCH] Minor enhance code style for FixedShapeTensorType --- cpp/src/arrow/extension/fixed_shape_tensor.cc | 33 ++++++++++--------- cpp/src/arrow/extension/fixed_shape_tensor.h | 4 +-- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc b/cpp/src/arrow/extension/fixed_shape_tensor.cc index 1101b08307332..1d61e929073cd 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.cc +++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc @@ -208,12 +208,12 @@ std::shared_ptr FixedShapeTensorType::MakeArray( Result> FixedShapeTensorType::MakeTensor( const std::shared_ptr& scalar) { const auto ext_scalar = internal::checked_pointer_cast(scalar); - const auto ext_type = - internal::checked_pointer_cast(scalar->type); + const auto* ext_type = + internal::checked_cast(scalar->type.get()); if (!is_fixed_width(*ext_type->value_type())) { return Status::TypeError("Cannot convert non-fixed-width values to Tensor."); } - const auto array = + const auto& array = internal::checked_pointer_cast(ext_scalar->value)->value; if (array->null_count() > 0) { return Status::Invalid("Cannot convert data with nulls to Tensor."); @@ -244,7 +244,7 @@ Result> FixedShapeTensorType::MakeTensor( const auto buffer = SliceBuffer(array->data()->buffers[1], start_position, size * byte_width); - return Tensor::Make(ext_type->value_type(), buffer, shape, strides, dim_names); + return Tensor::Make(value_type, buffer, shape, strides, dim_names); } Result> FixedShapeTensorArray::FromTensor( @@ -257,6 +257,7 @@ Result> FixedShapeTensorArray::FromTensor permutation.erase(permutation.begin()); std::vector cell_shape; + cell_shape.reserve(permutation.size()); for (auto i : permutation) { cell_shape.emplace_back(tensor->shape()[i]); } @@ -337,9 +338,9 @@ const Result> FixedShapeTensorArray::ToTensor() const { // To convert an array of n dimensional tensors to a n+1 dimensional tensor we // interpret the array's length as the first dimension the new tensor. - const auto ext_type = - internal::checked_pointer_cast(this->type()); - const auto value_type = ext_type->value_type(); + const auto* ext_type = + internal::checked_cast(this->type().get()); + const auto& value_type = ext_type->value_type(); ARROW_RETURN_IF( !is_fixed_width(*value_type), Status::TypeError(value_type->ToString(), " is not valid data type for a tensor")); @@ -374,11 +375,11 @@ const Result> FixedShapeTensorArray::ToTensor() const { internal::Permute(permutation, &shape); std::vector tensor_strides; - const auto fw_value_type = internal::checked_pointer_cast(value_type); + const auto* fw_value_type = internal::checked_cast(value_type.get()); ARROW_RETURN_NOT_OK( - ComputeStrides(*fw_value_type.get(), shape, permutation, &tensor_strides)); + ComputeStrides(*fw_value_type, shape, permutation, &tensor_strides)); - const auto raw_buffer = this->storage()->data()->child_data[0]->buffers[1]; + const auto& raw_buffer = this->storage()->data()->child_data[0]->buffers[1]; ARROW_ASSIGN_OR_RAISE( const auto buffer, SliceBufferSafe(raw_buffer, this->offset() * cell_size * value_type->byte_width())); @@ -389,7 +390,7 @@ const Result> FixedShapeTensorArray::ToTensor() const { Result> FixedShapeTensorType::Make( const std::shared_ptr& value_type, const std::vector& shape, const std::vector& permutation, const std::vector& dim_names) { - const auto ndim = shape.size(); + const size_t ndim = shape.size(); if (!permutation.empty() && ndim != permutation.size()) { return Status::Invalid("permutation size must match shape size. Expected: ", ndim, " Got: ", permutation.size()); @@ -402,18 +403,18 @@ Result> FixedShapeTensorType::Make( RETURN_NOT_OK(internal::IsPermutationValid(permutation)); } - const auto size = std::accumulate(shape.begin(), shape.end(), static_cast(1), - std::multiplies<>()); + const int64_t size = std::accumulate(shape.begin(), shape.end(), + static_cast(1), std::multiplies<>()); return std::make_shared(value_type, static_cast(size), shape, permutation, dim_names); } const std::vector& FixedShapeTensorType::strides() { if (strides_.empty()) { - auto value_type = internal::checked_pointer_cast(this->value_type_); + auto value_type = internal::checked_cast(this->value_type_.get()); std::vector tensor_strides; - ARROW_CHECK_OK(ComputeStrides(*value_type.get(), this->shape(), this->permutation(), - &tensor_strides)); + ARROW_CHECK_OK( + ComputeStrides(*value_type, this->shape(), this->permutation(), &tensor_strides)); strides_ = tensor_strides; } return strides_; diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h b/cpp/src/arrow/extension/fixed_shape_tensor.h index 3fec79b5c2a3c..20ec20a64c2d4 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.h +++ b/cpp/src/arrow/extension/fixed_shape_tensor.h @@ -67,10 +67,10 @@ class ARROW_EXPORT FixedShapeTensorType : public ExtensionType { size_t ndim() const { return shape_.size(); } /// Shape of tensor elements - const std::vector shape() const { return shape_; } + const std::vector& shape() const { return shape_; } /// Value type of tensor elements - const std::shared_ptr value_type() const { return value_type_; } + const std::shared_ptr& value_type() const { return value_type_; } /// Strides of tensor elements. Strides state offset in bytes between adjacent /// elements along each dimension. In case permutation is non-empty strides are