Skip to content

Commit

Permalink
Use matchInts to reduce boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Levesque-Dion committed Apr 18, 2024
1 parent 96447ff commit cdcde6b
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3790,14 +3790,11 @@ LogicalResult verifyDynamicReshapeOp(std::optional<Location> location,
"elements in output_shape");

auto operandType = cast<RankedTensorType>(operand.getType());
if (DenseIntElementsAttr shape;
operandType.hasStaticShape() &&
matchPattern(outputShape, m_Constant(&shape))) {
if (SmallVector<int64_t> shape; operandType.hasStaticShape() &&
matchInts(outputShape, shape).succeeded()) {
int64_t operandCount = operandType.getNumElements();
auto shapeValues = shape.getValues<APInt>();
int64_t shapeCount = std::accumulate(
shapeValues.begin(), shapeValues.end(), 1,
[](int64_t lhs, APInt rhs) { return lhs * rhs.getSExtValue(); });
int64_t shapeCount = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
if (operandCount != shapeCount) {
return emitOptionalError(location,
"output_shape is incompatible with input type "
Expand Down

0 comments on commit cdcde6b

Please sign in to comment.