Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for DynamicReshape (#2231)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlevesquedion authored Apr 18, 2024
1 parent 1beb2c3 commit b911843
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 14 deletions.
23 changes: 22 additions & 1 deletion stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,8 @@ mlir::Speculation::Speculatability ConcatenateOp::getSpeculatability() {
//===----------------------------------------------------------------------===//

LogicalResult DynamicReshapeOp::verify() {
return hlo::verifyDynamicReshapeOp(getLoc(), getOutputShape(), getResult());
return hlo::verifyDynamicReshapeOp(getLoc(), getOperand(), getOutputShape(),
getResult());
}

LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
Expand All @@ -1508,6 +1509,26 @@ LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability DynamicReshapeOp::getSpeculatability() {
// If the output type's shape is fully dynamic, there is no expectation
// for the shape so the op is speculatable.
if (llvm::all_of(llvm::seq(getType().getRank()),
[this](int64_t i) { return getType().isDynamicDim(i); }))
return mlir::Speculation::Speculatable;

// If the input is static and the shape operand is constant, the output
// shape can be inferred and any mismatch will be caught statically.
// If any dimension in the input is dynamic, the number of elements may
// disagree with either the output.
// If the shape operand is not constant, it could disagree with the output,
// which has at least 1 static dimension at this point in the function.
if (getOperand().getType().hasStaticShape() &&
matchPattern(getOutputShape(), m_Constant()))
return mlir::Speculation::Speculatable;

return mlir::Speculation::NotSpeculatable;
}

//===----------------------------------------------------------------------===//
// DynamicSliceOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 7 additions & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2637,7 +2637,8 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape",
}];
}

def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [Pure]> {
def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape",
[ConditionallySpeculatable, NoMemoryEffect]> {
let summary = "DynamicReshape operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand All @@ -2659,6 +2660,11 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_ScatterOp: StableHLO_Op<"scatter",
Expand Down
19 changes: 18 additions & 1 deletion stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3780,13 +3780,30 @@ LogicalResult verifyDynamicPadOp(std::optional<Location> location,
}

LogicalResult verifyDynamicReshapeOp(std::optional<Location> location,
Value outputShape, Value result) {
Value operand, Value outputShape,
Value result) {
auto resultType = cast<ShapedType>(result.getType());
auto outputShapeType = cast<ShapedType>(outputShape.getType());
if (outputShapeType.getDimSize(0) != resultType.getRank())
return emitOptionalError(location,
"output should have a rank equal to the number of "
"elements in output_shape");

auto operandType = cast<RankedTensorType>(operand.getType());
if (SmallVector<int64_t> shape; operandType.hasStaticShape() &&
matchInts(outputShape, shape).succeeded()) {
int64_t operandCount = operandType.getNumElements();
int64_t shapeCount = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
if (operandCount != shapeCount) {
return emitOptionalError(location,
"output_shape is incompatible with input type "
"of operation: input has ",
operandCount, " elements, but output_shape has ",
shapeCount);
}
}

if (!isCompatibleForHloTypeInference(outputShape, resultType))
return emitOptionalError(
location, "output_shape is incompatible with return type of operation ",
Expand Down
3 changes: 2 additions & 1 deletion stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,8 @@ LogicalResult verifyDynamicPadOp(std::optional<Location> location,
Value interiorPadding, Value result);

LogicalResult verifyDynamicReshapeOp(std::optional<Location> location,
Value outputShape, Value result);
Value operand, Value outputShape,
Value result);

LogicalResult verifyInfeedOp(HloDialectInterface* dialect,
std::optional<Location> location,
Expand Down
31 changes: 31 additions & 0 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,37 @@ func.func @dynamic_iota(%unknown_shape: tensor<2xi32>) {

// -----

// CHECK-LABEL: func @dynamic_reshape
// CHECK-NEXT: return
func.func @dynamic_reshape(
%static_arg: tensor<4x5xf64>, %dynamic_arg: tensor<?x?xf64>,
%unknown_shape: tensor<2xi32>
) {
%constant_shape = stablehlo.constant dense<[5, 4]> : tensor<2xi32>

// Static input, constant shape
%0 = stablehlo.dynamic_reshape %static_arg, %constant_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64>
"hlo_test_speculatability.is_speculatable"(%0) : (tensor<5x4xf64>) -> ()
%1 = stablehlo.dynamic_reshape %static_arg, %constant_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%1) : (tensor<?x?xf64>) -> ()

// Dynamic input
%2 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor<?x?xf64>, tensor<2xi32>) -> tensor<5x4xf64>
"hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<5x4xf64>) -> ()
%3 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor<?x?xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%3) : (tensor<?x?xf64>) -> ()

// Unknown shape
%4 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64>
"hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<5x4xf64>) -> ()
%5 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%5) : (tensor<?x?xf64>) -> ()

return
}

// -----

// Recursively speculatable ops

// -----
Expand Down
20 changes: 10 additions & 10 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3164,18 +3164,9 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor<?xf32>, %shape: ten

// -----

func.func @dynamic_reshape_output_shape_negative_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}}
%0 = stablehlo.constant dense<[-1, 1]> : tensor<2xi64>
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32>
return %1 : tensor<1x4xf32>
}

// -----

func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}}
%0 = stablehlo.constant dense<[1, 1]> : tensor<2xi64>
%0 = stablehlo.constant dense<[2, 2]> : tensor<2xi64>
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32>
return %1 : tensor<1x4xf32>
}
Expand All @@ -3190,6 +3181,15 @@ func.func @dynamic_reshape_dynamic_output_shape(%arg0: tensor<?xf32>, %shape: te

// -----

func.func @dynamic_reshape_input_count_mismatch_shape_count(%arg0: tensor<2x5xf32>) -> tensor<?x?x?xf32> {
%0 = stablehlo.constant dense<[2, 3, 4]> : tensor<3xi32>
// expected-error@+1 {{output_shape is incompatible with input type of operation: input has 10 elements, but output_shape has 24}}
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<2x5xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}

// -----

func.func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
%0 = "stablehlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
func.return %0 : tensor<2x4xf32>
Expand Down

0 comments on commit b911843

Please sign in to comment.