diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index 36ce89c126..1956fbf062 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -86,6 +86,10 @@ bool isCompatibleForHloTypeInference(TypeRange tp1, TypeRange tp2); // undefined behavior. bool isCompatibleForHloTypeInference(Value shape1, Type tp2); +// Returns true if the given shape, expressed as a slice of integers, is +// compatible with the given type for the purposes of HLO type inference. +bool isCompatibleForHloTypeInference(ArrayRef shape1, Type tp2); + // TODO(zhouxin) Move type inference related methods to TypeInference.cpp std::pair inferConcatenatedDimAndBound(int64_t leftSize, diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index c3b5381ef6..a80a4b8877 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3792,10 +3792,18 @@ LogicalResult verifyDynamicReshapeOp(std::optional location, } } - if (!isCompatibleForHloTypeInference(outputShape, resultType)) - return emitOptionalError( - location, "output_shape is incompatible with return type of operation ", - resultType); + if (SmallVector shape; + succeeded(matchInts(outputShape, shape)) && + !isCompatibleForHloTypeInference(shape, resultType)) { + std::string str; + llvm::raw_string_ostream os(str); + os << "["; + llvm::interleaveComma(shape, os, [&](int64_t i) { os << i; }); + os << "]"; + return emitOptionalError(location, "output_shape ", os.str(), + " is incompatible with return type of operation ", + resultType); + } return success(); } diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 1faa8ef5c2..240ca8ed92 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -3165,7 +3165,7 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor, %shape: ten // ----- func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}} + // expected-error@+2 {{output_shape [2, 2] is incompatible with return type of operation 'tensor<1x4xf32>'}} %0 = stablehlo.constant dense<[2, 2]> : tensor<2xi64> %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> return %1 : tensor<1x4xf32>