From 3652ad2bbd02d32a9b35f751d34e64c4f9fef914 Mon Sep 17 00:00:00 2001 From: Michael Levesque-Dion Date: Tue, 30 Apr 2024 11:31:21 -0700 Subject: [PATCH 1/3] Add helper to verify output shape matches result type And use the helper to verify shape operands in DynamicBroadcastInDim, DynamicIota and DynamicReshape. Also, reorganize the isCompatibleForHloTypeInference logic a little bit to avoid duplication. --- stablehlo/dialect/Base.cpp | 7 +--- stablehlo/dialect/TypeInference.cpp | 59 ++++++++++++++++++----------- stablehlo/tests/ops_stablehlo.mlir | 10 ++--- 3 files changed, 43 insertions(+), 33 deletions(-) diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index e5cf93f97d..de6a222094 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -132,6 +132,7 @@ bool isCompatibleForHloTypeInference(TypeRange tp1, TypeRange tp2) { } bool isCompatibleForHloTypeInference(ArrayRef shape1, Type tp2) { + if (llvm::any_of(shape1, [&](int64_t x) { return x < 0; })) return false; auto stp2 = dyn_cast(tp2); if (!stp2) return false; return isCompatibleForHloTypeInference( @@ -141,11 +142,7 @@ bool isCompatibleForHloTypeInference(ArrayRef shape1, Type tp2) { bool isCompatibleForHloTypeInference(Value shape1, Type tp2) { SmallVector shapeVec1; if (!succeeded(matchInts(shape1, shapeVec1))) return true; - if (llvm::any_of(shapeVec1, [&](int64_t x) { return x < 0; })) return false; - auto stp2 = dyn_cast(tp2); - if (!stp2) return false; - auto tp1 = RankedTensorType::get(shapeVec1, stp2.getElementType()); - return isCompatibleForHloTypeInference(tp1, tp2); + return isCompatibleForHloTypeInference(shapeVec1, tp2); } LogicalResult matchInt(Value value, int64_t& result) { diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 7d9e905133..61c9b0c815 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -44,6 +44,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/Regex.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/IR/Attributes.h" @@ -66,9 +67,11 @@ limitations under the License. namespace mlir { namespace hlo { namespace { + //===----------------------------------------------------------------------===// // Utils for quantization specific verifications //===----------------------------------------------------------------------===// + template bool allQuantized(ArrayRef typeRange) { return llvm::all_of( @@ -468,6 +471,28 @@ LogicalResult verifyAddOp(std::optional location, Operation* op, return success(); } +// If the shape operand is constant, checks that it matches the result. +// If not, emits an error. +LogicalResult verifyShapeOperandMatchesResultType(std::optional loc, + Value shapeOperand, + Type resultType) { + if (SmallVector shape; + succeeded(matchInts(shapeOperand, shape)) && + !isCompatibleForHloTypeInference(shape, resultType)) { + std::string str; + llvm::raw_string_ostream os(str); + llvm::interleaveComma(shape, os, [&](int64_t i) { os << i; }); + return emitOptionalError(loc, "output shape [", os.str(), + "] is incompatible with return type of operation ", + resultType); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Verifiers +//===----------------------------------------------------------------------===// + LogicalResult verifyTransposeOp(std::optional location, Type operandType, ArrayRef permutation, Type resultType) { @@ -3718,11 +3743,9 @@ LogicalResult verifyDynamicBroadcastInDimOp( } } - if (!isCompatibleForHloTypeInference(outputDimensions, resultType)) - return emitOptionalError( - location, - "output_dimensions are incompatible with return type of operation ", - resultType); + if (failed(verifyShapeOperandMatchesResultType(location, outputDimensions, + resultType))) + return failure(); return success(); } @@ -3730,14 +3753,13 @@ LogicalResult verifyDynamicBroadcastInDimOp( LogicalResult verifyDynamicIotaOp(std::optional location, Value outputShape, int64_t iotaDimension, Value result) { - auto shape = cast(result.getType()); + auto resultType = cast(result.getType()); - if (!isCompatibleForHloTypeInference(outputShape, shape)) - return emitOptionalError( - location, "output_shape is incompatible with return type of operation ", - result.getType()); + if (failed(verifyShapeOperandMatchesResultType(location, outputShape, + resultType))) + return failure(); - if (iotaDimension >= shape.getRank() || iotaDimension < 0) + if (iotaDimension >= resultType.getRank() || iotaDimension < 0) return emitOptionalError( location, "iota dimension cannot go beyond the output rank or be negative."); @@ -3808,18 +3830,9 @@ LogicalResult verifyDynamicReshapeOp(std::optional location, } } - if (SmallVector shape; - succeeded(matchInts(outputShape, shape)) && - !isCompatibleForHloTypeInference(shape, resultType)) { - std::string str; - llvm::raw_string_ostream os(str); - os << "["; - llvm::interleaveComma(shape, os, [&](int64_t i) { os << i; }); - os << "]"; - return emitOptionalError(location, "output_shape ", os.str(), - " is incompatible with return type of operation ", - resultType); - } + if (failed(verifyShapeOperandMatchesResultType(location, outputShape, + resultType))) + return failure(); return success(); } diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index dba67cdc26..e62adcf0fb 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -1026,7 +1026,7 @@ func.func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape // ----- func.func @dynamic_broadcast_in_dim_output_dimensions_negative_size(%arg0: tensor<4xf32>) -> tensor<3x4xf32> { - // @expected-error@+2 {{output_dimensions are incompatible with return type of operation 'tensor<3x4xf32>'}} + // @expected-error@+2 {{output shape [-1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}} %0 = stablehlo.constant dense<[-1, 4]> : tensor<2xi64> %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32> return %1 : tensor<3x4xf32> @@ -1035,7 +1035,7 @@ func.func @dynamic_broadcast_in_dim_output_dimensions_negative_size(%arg0: tenso // ----- func.func @dynamic_broadcast_in_dim_output_dimensions_mismatching_size(%arg0: tensor<4xf32>) -> tensor<3x4xf32> { - // @expected-error@+2 {{output_dimensions are incompatible with return type of operation 'tensor<3x4xf32>'}} + // @expected-error@+2 {{output shape [1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}} %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32> return %1 : tensor<3x4xf32> @@ -3174,7 +3174,7 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor, %shape: ten // ----- func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-error@+2 {{output_shape [2, 2] is incompatible with return type of operation 'tensor<1x4xf32>'}} + // expected-error@+2 {{output shape [2, 2] is incompatible with return type of operation 'tensor<1x4xf32>'}} %0 = stablehlo.constant dense<[2, 2]> : tensor<2xi64> %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> return %1 : tensor<1x4xf32> @@ -5431,7 +5431,7 @@ func.func @dynamic_iota_invalid_iota_dimension_too_big() -> tensor { // ----- func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> { - // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}} + // @expected-error@+2 {{output shape [-1] is incompatible with return type of operation 'tensor<4xf32>'}} %0 = stablehlo.constant dense<[-1]> : tensor<1xi64> %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32> func.return %1 : tensor<4xf32> @@ -5440,7 +5440,7 @@ func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> { // ----- func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> { - // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}} + // @expected-error@+2 {{output shape [1] is incompatible with return type of operation 'tensor<4xf32>'}} %0 = stablehlo.constant dense<[1]> : tensor<1xi64> %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32> func.return %1 : tensor<4xf32> From 635755f904d6e3a99b27670ebaf9a86b44952fde Mon Sep 17 00:00:00 2001 From: Michael Levesque-Dion Date: Fri, 3 May 2024 15:04:32 -0700 Subject: [PATCH 2/3] Address review comments --- stablehlo/dialect/TypeInference.cpp | 12 ++++---- stablehlo/tests/ops_stablehlo.mlir | 48 +++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 61c1a28519..09b2ca2d6f 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -471,9 +471,9 @@ LogicalResult verifyAddOp(std::optional location, Operation* op, return success(); } -// If the shape operand is constant, checks that it matches the result. -// If not, emits an error. -LogicalResult verifyShapeOperandMatchesResultType(std::optional loc, +// If the shape operand is constant, checks that it is compatible with the +// result's shape. Emits an error if the shapes are incompatible. +LogicalResult verifyShapeOperandIsCompatibleWithResultType(std::optional loc, Value shapeOperand, Type resultType) { if (SmallVector shape; @@ -3785,7 +3785,7 @@ LogicalResult verifyDynamicBroadcastInDimOp( } } - if (failed(verifyShapeOperandMatchesResultType(location, outputDimensions, + if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputDimensions, resultType))) return failure(); @@ -3797,7 +3797,7 @@ LogicalResult verifyDynamicIotaOp(std::optional location, Value result) { auto resultType = cast(result.getType()); - if (failed(verifyShapeOperandMatchesResultType(location, outputShape, + if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape, resultType))) return failure(); @@ -3872,7 +3872,7 @@ LogicalResult verifyDynamicReshapeOp(std::optional location, } } - if (failed(verifyShapeOperandMatchesResultType(location, outputShape, + if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape, resultType))) return failure(); return success(); diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index e62adcf0fb..abd22f23f2 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -1043,6 +1043,22 @@ func.func @dynamic_broadcast_in_dim_output_dimensions_mismatching_size(%arg0: te // ----- +func.func @dynamic_broadcast_in_dim_output_dimensions_match_result(%arg0: tensor<4xf32>) -> tensor<3x4xf32> { + %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi64> + %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32> + return %1 : tensor<3x4xf32> +} + +// ----- + +func.func @dynamic_broadcast_in_dim_output_dimensions_compatible_with_result(%arg0: tensor<4xf32>) -> tensor { + %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi64> + %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor + return %1 : tensor +} + +// ----- + func.func @dynamic_broadcast_in_dim_negative_size(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { // expected-error@+1 {{broadcast_dimensions contains invalid value -1 for result with rank 3}} %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> @@ -3182,6 +3198,22 @@ func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) - // ----- +func.func @dynamic_reshape_output_shape_matches_result(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { + %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> + %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> + return %1 : tensor<1x4xf32> +} + +// ----- + +func.func @dynamic_reshape_output_shape_compatible_with_result(%arg0: tensor<4xf32>) -> tensor { + %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> + %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor + return %1 : tensor +} + +// ----- + func.func @dynamic_reshape_dynamic_output_shape(%arg0: tensor, %shape: tensor) -> tensor<1x4xf32> { // expected-error@+1 {{op operand #1 must be statically shaped}} %0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor) -> tensor<1x4xf32> @@ -5448,6 +5480,22 @@ func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> { // ----- +func.func @dynamic_iota_output_shape_matches_result() -> tensor<4xf32> { + %0 = stablehlo.constant dense<[4]> : tensor<1xi64> + %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32> + func.return %1 : tensor<4xf32> +} + +// ----- + +func.func @dynamic_iota_output_shape_compatible_with_result() -> tensor { + %0 = stablehlo.constant dense<[4]> : tensor<1xi64> + %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor + func.return %1 : tensor +} + +// ----- + func.func @first(%arg0: tensor, %arg1: tensor) -> tensor { func.return %arg0 : tensor } From 3697e4624537af0c869be31c81e8bca5c80b40d5 Mon Sep 17 00:00:00 2001 From: Michael Levesque-Dion Date: Fri, 3 May 2024 15:10:44 -0700 Subject: [PATCH 3/3] Fix formatting --- stablehlo/dialect/TypeInference.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 09b2ca2d6f..fabc132483 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -473,9 +473,8 @@ LogicalResult verifyAddOp(std::optional location, Operation* op, // If the shape operand is constant, checks that it is compatible with the // result's shape. Emits an error if the shapes are incompatible. -LogicalResult verifyShapeOperandIsCompatibleWithResultType(std::optional loc, - Value shapeOperand, - Type resultType) { +LogicalResult verifyShapeOperandIsCompatibleWithResultType( + std::optional loc, Value shapeOperand, Type resultType) { if (SmallVector shape; succeeded(matchInts(shapeOperand, shape)) && !isCompatibleForHloTypeInference(shape, resultType)) { @@ -3785,8 +3784,8 @@ LogicalResult verifyDynamicBroadcastInDimOp( } } - if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputDimensions, - resultType))) + if (failed(verifyShapeOperandIsCompatibleWithResultType( + location, outputDimensions, resultType))) return failure(); return success(); @@ -3798,7 +3797,7 @@ LogicalResult verifyDynamicIotaOp(std::optional location, auto resultType = cast(result.getType()); if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape, - resultType))) + resultType))) return failure(); if (iotaDimension >= resultType.getRank() || iotaDimension < 0) @@ -3873,7 +3872,7 @@ LogicalResult verifyDynamicReshapeOp(std::optional location, } if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape, - resultType))) + resultType))) return failure(); return success(); }