Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for DynamicReshape
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Levesque-Dion committed Apr 17, 2024
1 parent 19e3142 commit facd6f8
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
20 changes: 20 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,26 @@ LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability DynamicReshapeOp::getSpeculatability() {
// If the output type's shape is fully dynamic, there is no expectation
// for the shape so the op is speculatable.
if (llvm::all_of(llvm::seq(getType().getRank()),
[this](int64_t i) { return getType().isDynamicDim(i); }))
return mlir::Speculation::Speculatable;

// If the input is static and the shape operand is constant, the output
// shape can be inferred and any mismatch will be caught statically.
// If any dimension in the input is dynamic, the number of elements may
// disagree with either the output.
// If the shape operand is not constant, it could disagree with the output,
// which has at least 1 static dimension at this point in the function.
if (getOperand().getType().hasStaticShape() &&
matchPattern(getOutputShape(), m_Constant()))
return mlir::Speculation::Speculatable;

return mlir::Speculation::NotSpeculatable;
}

//===----------------------------------------------------------------------===//
// DynamicSliceOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 7 additions & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2637,7 +2637,8 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape",
}];
}

def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [Pure]> {
def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape",
[ConditionallySpeculatable, NoMemoryEffect]> {
let summary = "DynamicReshape operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand All @@ -2659,6 +2660,11 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_ScatterOp: StableHLO_Op<"scatter",
Expand Down
36 changes: 36 additions & 0 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,42 @@ func.func @dynamic_iota(%unknown_shape: tensor<2xi32>) {

// -----

// CHECK-LABEL: func @dynamic_reshape
// CHECK-NEXT: return
func.func @dynamic_reshape(
%static_arg: tensor<4x5xf64>, %dynamic_arg: tensor<?x?xf64>,
%unknown_shape: tensor<2xi32>
) {
%constant_shape = stablehlo.constant dense<[5, 4]> : tensor<2xi32>

// Static input, constant shape
%0 = stablehlo.dynamic_reshape %static_arg, %constant_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64>
"hlo_test_speculatability.is_speculatable"(%0) : (tensor<5x4xf64>) -> ()
%1 = stablehlo.dynamic_reshape %static_arg, %constant_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%1) : (tensor<?x?xf64>) -> ()

%2 = stablehlo.dynamic_reshape %static_arg, %constant_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64>
"hlo_test_speculatability.is_speculatable"(%2) : (tensor<5x4xf64>) -> ()
%3 = stablehlo.dynamic_reshape %static_arg, %constant_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%3) : (tensor<?x?xf64>) -> ()

// Dynamic input
%4 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor<?x?xf64>, tensor<2xi32>) -> tensor<5x4xf64>
"hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<5x4xf64>) -> ()
%5 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor<?x?xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%5) : (tensor<?x?xf64>) -> ()

// Unknown shape
%6 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64>
"hlo_test_speculatability.is_not_speculatable"(%6) : (tensor<5x4xf64>) -> ()
%7 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%7) : (tensor<?x?xf64>) -> ()

return
}

// -----

// Recursively speculatable ops

// -----
Expand Down

0 comments on commit facd6f8

Please sign in to comment.