diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 3ef6608623..9710fbbf34 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "stablehlo/reference/Ops.h" #include +#include #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" @@ -1428,23 +1429,23 @@ Tensor evalConvolutionOp( inputSpatialDimensions.end()); lhsPermutation.push_back(inputFeatureDimension); - auto lhsWindowDimensions = - concatAndPermute(lhs.getShape()[inputBatchDimension], - extractElements(rhs.getShape(), kernelSpatialDimensions), - lhs.getShape()[inputFeatureDimension], lhsPermutation); + auto lhsWindowDimensions = concatAndPermute( + lhs.getShape()[inputBatchDimension], + extractElements(rhs.getShape(), kernelSpatialDimensions), + lhs.getShape()[inputFeatureDimension], lhsPermutation); - auto lhsWindowStrides = - concatAndPermute(1L, llvm::to_vector(windowStrides), 1L, lhsPermutation); + auto lhsWindowStrides = concatAndPermute( + 1L, llvm::to_vector(windowStrides), 1L, lhsPermutation); auto lhsBaseDilations = - concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, lhsPermutation); + concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, lhsPermutation); - auto lhsWindowDilations = - concatAndPermute(1L, llvm::to_vector(rhsDilation), 1L, lhsPermutation); + auto lhsWindowDilations = concatAndPermute( + 1L, llvm::to_vector(rhsDilation), 1L, lhsPermutation); Sizes lhsPaddingLow, lhsPaddingHigh; - for (auto paddingPair : concatAndPermute({0, 0}, llvm::to_vector(padding), - {0, 0}, lhsPermutation)) { + for (auto paddingPair : concatAndPermute>( + {0, 0}, llvm::to_vector(padding), {0, 0}, lhsPermutation)) { lhsPaddingLow.push_back(paddingPair.first); lhsPaddingHigh.push_back(paddingPair.second); } @@ -1461,8 +1462,8 @@ Tensor evalConvolutionOp( for (; outputSpatialIndexIt != outputSpatialIndexItEnd; ++outputSpatialIndexIt) { Sizes lhsWindowStart; - for (auto [i, offset] : llvm::enumerate( - concatAndPermute(0L, *outputSpatialIndexIt, 0L, lhsPermutation))) + for (auto [i, offset] : llvm::enumerate(concatAndPermute( + 0L, *outputSpatialIndexIt, 0L, lhsPermutation))) lhsWindowStart.push_back(lhsWindowStrides[i] * offset); Sizes limitIndices; @@ -1507,9 +1508,9 @@ Tensor evalConvolutionOp( for (auto dotProductIt = dotProduct.index_begin(); dotProductIt != dotProduct.index_end(); ++dotProductIt, ++resultNonSpatialIt) { - Index resultIndex( - concatAndPermute((*resultNonSpatialIt)[0], *outputSpatialIndexIt, - (*resultNonSpatialIt)[1], resultPermutation)); + Index resultIndex(concatAndPermute( + (*resultNonSpatialIt)[0], *outputSpatialIndexIt, + (*resultNonSpatialIt)[1], resultPermutation)); result.set(resultIndex, dotProduct.get(*dotProductIt)); } }