From bc2824f57c0d4c35dd5d758727666b1ebb3f266b Mon Sep 17 00:00:00 2001 From: Michael Levesque-Dion Date: Mon, 29 Apr 2024 09:13:23 -0700 Subject: [PATCH 1/2] Include output shape in shape mismatch error message As discussed in https://github.com/openxla/stablehlo/pull/2231#discussion_r1571148784, this will likely help with debugging shape mismatches. --- stablehlo/dialect/TypeInference.cpp | 2 +- stablehlo/tests/ops_stablehlo.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index c3b5381ef6..ee785d3c34 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3795,7 +3795,7 @@ LogicalResult verifyDynamicReshapeOp(std::optional location, if (!isCompatibleForHloTypeInference(outputShape, resultType)) return emitOptionalError( location, "output_shape is incompatible with return type of operation ", - resultType); + resultType, ": ", outputShape); return success(); } diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 1faa8ef5c2..0816912210 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -3165,7 +3165,7 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor, %shape: ten // ----- func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}} + // expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>': %0 = "stablehlo.constant"}} %0 = stablehlo.constant dense<[2, 2]> : tensor<2xi64> %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> return %1 : tensor<1x4xf32> From d4440be6c0fdf363c47a6fa4b613105c388206fc Mon Sep 17 00:00:00 2001 From: Michael Levesque-Dion Date: Mon, 29 Apr 2024 13:31:32 -0700 Subject: [PATCH 2/2] Print constant value in a nice way --- stablehlo/dialect/Base.h | 4 ++++ stablehlo/dialect/TypeInference.cpp | 16 ++++++++++++---- stablehlo/tests/ops_stablehlo.mlir | 2 +- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index 36ce89c126..1956fbf062 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -86,6 +86,10 @@ bool isCompatibleForHloTypeInference(TypeRange tp1, TypeRange tp2); // undefined behavior. bool isCompatibleForHloTypeInference(Value shape1, Type tp2); +// Returns true if the given shape, expressed as a slice of integers, is +// compatible with the given type for the purposes of HLO type inference. +bool isCompatibleForHloTypeInference(ArrayRef shape1, Type tp2); + // TODO(zhouxin) Move type inference related methods to TypeInference.cpp std::pair inferConcatenatedDimAndBound(int64_t leftSize, diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index ee785d3c34..a80a4b8877 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3792,10 +3792,18 @@ LogicalResult verifyDynamicReshapeOp(std::optional location, } } - if (!isCompatibleForHloTypeInference(outputShape, resultType)) - return emitOptionalError( - location, "output_shape is incompatible with return type of operation ", - resultType, ": ", outputShape); + if (SmallVector shape; + succeeded(matchInts(outputShape, shape)) && + !isCompatibleForHloTypeInference(shape, resultType)) { + std::string str; + llvm::raw_string_ostream os(str); + os << "["; + llvm::interleaveComma(shape, os, [&](int64_t i) { os << i; }); + os << "]"; + return emitOptionalError(location, "output_shape ", os.str(), + " is incompatible with return type of operation ", + resultType); + } return success(); } diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 0816912210..240ca8ed92 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -3165,7 +3165,7 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor, %shape: ten // ----- func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>': %0 = "stablehlo.constant"}} + // expected-error@+2 {{output_shape [2, 2] is incompatible with return type of operation 'tensor<1x4xf32>'}} %0 = stablehlo.constant dense<[2, 2]> : tensor<2xi64> %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> return %1 : tensor<1x4xf32>