Skip to content

Commit

Permalink
Use explicit template type to avoid Windows build errors (#2222)
Browse files Browse the repository at this point in the history
Co-authored-by: Kevin Gleason <gleasonk@google.com>
  • Loading branch information
mlevesquedion and GleasonK authored Apr 16, 2024
1 parent e87d2f4 commit fd0c20a
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "stablehlo/reference/Ops.h"

#include <algorithm>
#include <cstdint>

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
Expand Down Expand Up @@ -1428,23 +1429,23 @@ Tensor evalConvolutionOp(
inputSpatialDimensions.end());
lhsPermutation.push_back(inputFeatureDimension);

auto lhsWindowDimensions =
concatAndPermute(lhs.getShape()[inputBatchDimension],
extractElements(rhs.getShape(), kernelSpatialDimensions),
lhs.getShape()[inputFeatureDimension], lhsPermutation);
auto lhsWindowDimensions = concatAndPermute<int64_t>(
lhs.getShape()[inputBatchDimension],
extractElements(rhs.getShape(), kernelSpatialDimensions),
lhs.getShape()[inputFeatureDimension], lhsPermutation);

auto lhsWindowStrides =
concatAndPermute(1L, llvm::to_vector(windowStrides), 1L, lhsPermutation);
auto lhsWindowStrides = concatAndPermute<int64_t>(
1L, llvm::to_vector(windowStrides), 1L, lhsPermutation);

auto lhsBaseDilations =
concatAndPermute(0L, Sizes(lhsDilation) - 1, 0L, lhsPermutation);
concatAndPermute<int64_t>(0L, Sizes(lhsDilation) - 1, 0L, lhsPermutation);

auto lhsWindowDilations =
concatAndPermute(1L, llvm::to_vector(rhsDilation), 1L, lhsPermutation);
auto lhsWindowDilations = concatAndPermute<int64_t>(
1L, llvm::to_vector(rhsDilation), 1L, lhsPermutation);

Sizes lhsPaddingLow, lhsPaddingHigh;
for (auto paddingPair : concatAndPermute({0, 0}, llvm::to_vector(padding),
{0, 0}, lhsPermutation)) {
for (auto paddingPair : concatAndPermute<std::pair<int64_t, int64_t>>(
{0, 0}, llvm::to_vector(padding), {0, 0}, lhsPermutation)) {
lhsPaddingLow.push_back(paddingPair.first);
lhsPaddingHigh.push_back(paddingPair.second);
}
Expand All @@ -1461,8 +1462,8 @@ Tensor evalConvolutionOp(
for (; outputSpatialIndexIt != outputSpatialIndexItEnd;
++outputSpatialIndexIt) {
Sizes lhsWindowStart;
for (auto [i, offset] : llvm::enumerate(
concatAndPermute(0L, *outputSpatialIndexIt, 0L, lhsPermutation)))
for (auto [i, offset] : llvm::enumerate(concatAndPermute<int64_t>(
0L, *outputSpatialIndexIt, 0L, lhsPermutation)))
lhsWindowStart.push_back(lhsWindowStrides[i] * offset);

Sizes limitIndices;
Expand Down Expand Up @@ -1507,9 +1508,9 @@ Tensor evalConvolutionOp(
for (auto dotProductIt = dotProduct.index_begin();
dotProductIt != dotProduct.index_end();
++dotProductIt, ++resultNonSpatialIt) {
Index resultIndex(
concatAndPermute((*resultNonSpatialIt)[0], *outputSpatialIndexIt,
(*resultNonSpatialIt)[1], resultPermutation));
Index resultIndex(concatAndPermute<int64_t>(
(*resultNonSpatialIt)[0], *outputSpatialIndexIt,
(*resultNonSpatialIt)[1], resultPermutation));
result.set(resultIndex, dotProduct.get(*dotProductIt));
}
}
Expand Down

0 comments on commit fd0c20a

Please sign in to comment.