Skip to content

Commit

Permalink
Add helper to verify output shape matches result type (#2269)
Browse files Browse the repository at this point in the history
And use the helper to verify shape operands in DynamicBroadcastInDim,
DynamicIota and DynamicReshape.

Also, reorganize the isCompatibleForHloTypeInference logic a little bit
to avoid duplication.

This is a follow-up/generalization of
#2264.
  • Loading branch information
mlevesquedion authored May 3, 2024
1 parent 40e70d2 commit ab92ade
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 33 deletions.
7 changes: 2 additions & 5 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ bool isCompatibleForHloTypeInference(TypeRange tp1, TypeRange tp2) {
}

bool isCompatibleForHloTypeInference(ArrayRef<int64_t> shape1, Type tp2) {
if (llvm::any_of(shape1, [&](int64_t x) { return x < 0; })) return false;
auto stp2 = dyn_cast<ShapedType>(tp2);
if (!stp2) return false;
return isCompatibleForHloTypeInference(
Expand All @@ -141,11 +142,7 @@ bool isCompatibleForHloTypeInference(ArrayRef<int64_t> shape1, Type tp2) {
bool isCompatibleForHloTypeInference(Value shape1, Type tp2) {
SmallVector<int64_t> shapeVec1;
if (!succeeded(matchInts(shape1, shapeVec1))) return true;
if (llvm::any_of(shapeVec1, [&](int64_t x) { return x < 0; })) return false;
auto stp2 = dyn_cast<ShapedType>(tp2);
if (!stp2) return false;
auto tp1 = RankedTensorType::get(shapeVec1, stp2.getElementType());
return isCompatibleForHloTypeInference(tp1, tp2);
return isCompatibleForHloTypeInference(shapeVec1, tp2);
}

LogicalResult matchInt(Value value, int64_t& result) {
Expand Down
58 changes: 35 additions & 23 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Attributes.h"
Expand All @@ -66,9 +67,11 @@ limitations under the License.
namespace mlir {
namespace hlo {
namespace {

//===----------------------------------------------------------------------===//
// Utils for quantization specific verifications
//===----------------------------------------------------------------------===//

template <typename T>
bool allQuantized(ArrayRef<Type> typeRange) {
return llvm::all_of(
Expand Down Expand Up @@ -468,6 +471,27 @@ LogicalResult verifyAddOp(std::optional<Location> location, Operation* op,
return success();
}

// If the shape operand is constant, checks that it is compatible with the
// result's shape. Emits an error if the shapes are incompatible.
LogicalResult verifyShapeOperandIsCompatibleWithResultType(
std::optional<Location> loc, Value shapeOperand, Type resultType) {
if (SmallVector<int64_t> shape;
succeeded(matchInts(shapeOperand, shape)) &&
!isCompatibleForHloTypeInference(shape, resultType)) {
std::string str;
llvm::raw_string_ostream os(str);
llvm::interleaveComma(shape, os, [&](int64_t i) { os << i; });
return emitOptionalError(loc, "output shape [", os.str(),
"] is incompatible with return type of operation ",
resultType);
}
return success();
}

//===----------------------------------------------------------------------===//
// Verifiers
//===----------------------------------------------------------------------===//

LogicalResult verifyTransposeOp(std::optional<Location> location,
Type operandType, ArrayRef<int64_t> permutation,
Type resultType) {
Expand Down Expand Up @@ -3760,26 +3784,23 @@ LogicalResult verifyDynamicBroadcastInDimOp(
}
}

if (!isCompatibleForHloTypeInference(outputDimensions, resultType))
return emitOptionalError(
location,
"output_dimensions are incompatible with return type of operation ",
resultType);
if (failed(verifyShapeOperandIsCompatibleWithResultType(
location, outputDimensions, resultType)))
return failure();

return success();
}

LogicalResult verifyDynamicIotaOp(std::optional<Location> location,
Value outputShape, int64_t iotaDimension,
Value result) {
auto shape = cast<ShapedType>(result.getType());
auto resultType = cast<ShapedType>(result.getType());

if (!isCompatibleForHloTypeInference(outputShape, shape))
return emitOptionalError(
location, "output_shape is incompatible with return type of operation ",
result.getType());
if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape,
resultType)))
return failure();

if (iotaDimension >= shape.getRank() || iotaDimension < 0)
if (iotaDimension >= resultType.getRank() || iotaDimension < 0)
return emitOptionalError(
location,
"iota dimension cannot go beyond the output rank or be negative.");
Expand Down Expand Up @@ -3850,18 +3871,9 @@ LogicalResult verifyDynamicReshapeOp(std::optional<Location> location,
}
}

if (SmallVector<int64_t> shape;
succeeded(matchInts(outputShape, shape)) &&
!isCompatibleForHloTypeInference(shape, resultType)) {
std::string str;
llvm::raw_string_ostream os(str);
os << "[";
llvm::interleaveComma(shape, os, [&](int64_t i) { os << i; });
os << "]";
return emitOptionalError(location, "output_shape ", os.str(),
" is incompatible with return type of operation ",
resultType);
}
if (failed(verifyShapeOperandIsCompatibleWithResultType(location, outputShape,
resultType)))
return failure();
return success();
}

Expand Down
58 changes: 53 additions & 5 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ func.func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape
// -----

func.func @dynamic_broadcast_in_dim_output_dimensions_negative_size(%arg0: tensor<4xf32>) -> tensor<3x4xf32> {
// @expected-error@+2 {{output_dimensions are incompatible with return type of operation 'tensor<3x4xf32>'}}
// @expected-error@+2 {{output shape [-1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}}
%0 = stablehlo.constant dense<[-1, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
Expand All @@ -1035,14 +1035,30 @@ func.func @dynamic_broadcast_in_dim_output_dimensions_negative_size(%arg0: tenso
// -----

func.func @dynamic_broadcast_in_dim_output_dimensions_mismatching_size(%arg0: tensor<4xf32>) -> tensor<3x4xf32> {
// @expected-error@+2 {{output_dimensions are incompatible with return type of operation 'tensor<3x4xf32>'}}
// @expected-error@+2 {{output shape [1, 4] is incompatible with return type of operation 'tensor<3x4xf32>'}}
%0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
}

// -----

func.func @dynamic_broadcast_in_dim_output_dimensions_match_result(%arg0: tensor<4xf32>) -> tensor<3x4xf32> {
%0 = stablehlo.constant dense<[3, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<3x4xf32>
return %1 : tensor<3x4xf32>
}

// -----

func.func @dynamic_broadcast_in_dim_output_dimensions_compatible_with_result(%arg0: tensor<4xf32>) -> tensor<?x?xf32> {
%0 = stablehlo.constant dense<[3, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// -----

func.func @dynamic_broadcast_in_dim_negative_size(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> {
// expected-error@+1 {{broadcast_dimensions contains invalid value -1 for result with rank 3}}
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: -1>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
Expand Down Expand Up @@ -3174,14 +3190,30 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor<?xf32>, %shape: ten
// -----

func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-error@+2 {{output_shape [2, 2] is incompatible with return type of operation 'tensor<1x4xf32>'}}
// expected-error@+2 {{output shape [2, 2] is incompatible with return type of operation 'tensor<1x4xf32>'}}
%0 = stablehlo.constant dense<[2, 2]> : tensor<2xi64>
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32>
return %1 : tensor<1x4xf32>
}

// -----

func.func @dynamic_reshape_output_shape_matches_result(%arg0: tensor<4xf32>) -> tensor<1x4xf32> {
%0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32>
return %1 : tensor<1x4xf32>
}

// -----

func.func @dynamic_reshape_output_shape_compatible_with_result(%arg0: tensor<4xf32>) -> tensor<?x?xf32> {
%0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// -----

func.func @dynamic_reshape_dynamic_output_shape(%arg0: tensor<?xf32>, %shape: tensor<?xindex>) -> tensor<1x4xf32> {
// expected-error@+1 {{op operand #1 must be statically shaped}}
%0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<1x4xf32>
Expand Down Expand Up @@ -5431,7 +5463,7 @@ func.func @dynamic_iota_invalid_iota_dimension_too_big() -> tensor<?xf32> {
// -----

func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> {
// @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}}
// @expected-error@+2 {{output shape [-1] is incompatible with return type of operation 'tensor<4xf32>'}}
%0 = stablehlo.constant dense<[-1]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32>
func.return %1 : tensor<4xf32>
Expand All @@ -5440,14 +5472,30 @@ func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> {
// -----

func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> {
// @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}}
// @expected-error@+2 {{output shape [1] is incompatible with return type of operation 'tensor<4xf32>'}}
%0 = stablehlo.constant dense<[1]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32>
func.return %1 : tensor<4xf32>
}

// -----

func.func @dynamic_iota_output_shape_matches_result() -> tensor<4xf32> {
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32>
func.return %1 : tensor<4xf32>
}

// -----

func.func @dynamic_iota_output_shape_compatible_with_result() -> tensor<?xf32> {
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<?xf32>
func.return %1 : tensor<?xf32>
}

// -----

func.func @first(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
func.return %arg0 : tensor<f32>
}
Expand Down

0 comments on commit ab92ade

Please sign in to comment.