Skip to content

Commit

Permalink
Use ref rather than ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Jun 4, 2024
1 parent 8c132a8 commit bd092e7
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,44 +207,44 @@ std::shared_ptr<Array> FixedShapeTensorType::MakeArray(

Result<std::shared_ptr<Tensor>> FixedShapeTensorType::MakeTensor(
const std::shared_ptr<ExtensionScalar>& scalar) {
const auto* ext_scalar = internal::checked_cast<const ExtensionScalar*>(scalar.get());
const auto* ext_type =
internal::checked_cast<FixedShapeTensorType*>(scalar->type.get());
if (!is_fixed_width(*ext_type->value_type())) {
const auto& ext_scalar = internal::checked_cast<const ExtensionScalar&>(*scalar);
const auto& ext_type =
internal::checked_cast<const FixedShapeTensorType&>(*scalar->type);
if (!is_fixed_width(*ext_type.value_type())) {
return Status::TypeError("Cannot convert non-fixed-width values to Tensor.");
}
const auto& array =
internal::checked_cast<const FixedSizeListScalar*>(ext_scalar->value.get())->value;
internal::checked_cast<const FixedSizeListScalar*>(ext_scalar.value.get())->value;
if (array->null_count() > 0) {
return Status::Invalid("Cannot convert data with nulls to Tensor.");
}
const auto* value_type =
internal::checked_cast<const FixedWidthType*>(ext_type->value_type().get());
const auto byte_width = value_type->byte_width();
const auto& value_type =
internal::checked_cast<const FixedWidthType&>(*ext_type.value_type());
const auto byte_width = value_type.byte_width();

std::vector<int64_t> permutation = ext_type->permutation();
std::vector<int64_t> permutation = ext_type.permutation();
if (permutation.empty()) {
permutation.resize(ext_type->ndim());
permutation.resize(ext_type.ndim());
std::iota(permutation.begin(), permutation.end(), 0);
}

std::vector<int64_t> shape = ext_type->shape();
std::vector<int64_t> shape = ext_type.shape();
internal::Permute<int64_t>(permutation, &shape);

std::vector<std::string> dim_names = ext_type->dim_names();
std::vector<std::string> dim_names = ext_type.dim_names();
if (!dim_names.empty()) {
internal::Permute<std::string>(permutation, &dim_names);
}

std::vector<int64_t> strides;
RETURN_NOT_OK(ComputeStrides(*value_type, shape, permutation, &strides));
RETURN_NOT_OK(ComputeStrides(value_type, shape, permutation, &strides));
const auto start_position = array->offset() * byte_width;
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
const auto buffer =
SliceBuffer(array->data()->buffers[1], start_position, size * byte_width);

return Tensor::Make(ext_type->value_type(), buffer, shape, strides, dim_names);
return Tensor::Make(ext_type.value_type(), buffer, shape, strides, dim_names);
}

Result<std::shared_ptr<FixedShapeTensorArray>> FixedShapeTensorArray::FromTensor(
Expand Down Expand Up @@ -339,9 +339,9 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
// To convert an array of n dimensional tensors to a n+1 dimensional tensor we
// interpret the array's length as the first dimension the new tensor.

const auto* ext_type =
internal::checked_cast<FixedShapeTensorType*>(this->type().get());
const auto& value_type = ext_type->value_type();
const auto& ext_type =
internal::checked_cast<const FixedShapeTensorType&>(*this->type());
const auto& value_type = ext_type.value_type();
ARROW_RETURN_IF(
!is_fixed_width(*value_type),
Status::TypeError(value_type->ToString(), " is not valid data type for a tensor"));
Expand All @@ -352,24 +352,24 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
// will get permutation index 0 and remaining values from ext_type->permutation() need
// to be shifted to fill the [1, ndim+1) range. Computed permutation will be used to
// generate the new tensor's shape, strides and dim_names.
std::vector<int64_t> permutation = ext_type->permutation();
std::vector<int64_t> permutation = ext_type.permutation();
if (permutation.empty()) {
permutation.resize(ext_type->ndim() + 1);
permutation.resize(ext_type.ndim() + 1);
std::iota(permutation.begin(), permutation.end(), 0);
} else {
for (auto i = 0; i < static_cast<int64_t>(ext_type->ndim()); i++) {
for (auto i = 0; i < static_cast<int64_t>(ext_type.ndim()); i++) {
permutation[i] += 1;
}
permutation.insert(permutation.begin(), 1, 0);
}

std::vector<std::string> dim_names = ext_type->dim_names();
std::vector<std::string> dim_names = ext_type.dim_names();
if (!dim_names.empty()) {
dim_names.insert(dim_names.begin(), 1, "");
internal::Permute<std::string>(permutation, &dim_names);
}

std::vector<int64_t> shape = ext_type->shape();
std::vector<int64_t> shape = ext_type.shape();
auto cell_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
shape.insert(shape.begin(), 1, this->length());
Expand Down

0 comments on commit bd092e7

Please sign in to comment.