Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for {Dynamic,}BroadcastInDim
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Levesque-Dion committed Apr 22, 2024
1 parent ac8ed4c commit 9d9c6fb
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 2 deletions.
22 changes: 22 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,28 @@ LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability
DynamicBroadcastInDimOp::getSpeculatability() {
auto operandType = getOperand().getType();

// If input is dynamic, not speculatable.
if (!operandType.hasStaticShape())
return mlir::Speculation::NotSpeculatable;

// If input is broadcastable (all 1's) and result is dynamic, speculatable.
auto resultDynamic =
llvm::all_of(llvm::seq(getType().getRank()),
[this](int64_t i) { return getType().isDynamicDim(i); });
if (operandType.getNumElements() == 1 && resultDynamic)
return mlir::Speculation::Speculatable;

// If shape is known, speculatable.
if (matchPattern(getOutputDimensions(), m_Constant()))
return mlir::Speculation::Speculatable;

return mlir::Speculation::NotSpeculatable;
}

//===----------------------------------------------------------------------===//
// ClampOp
//===----------------------------------------------------------------------===//
Expand Down
11 changes: 9 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1912,7 +1912,8 @@ def StableHLO_BroadcastOp : StableHLO_ShapedInterfaceOp<"broadcast",
}

def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim",
[Pure, HLO_CompatibleOperandsAndResultElementType /*broadcast_in_dim_c1*/]> {
[HLO_SpeculatableIfAllInputsStatic, NoMemoryEffect,
HLO_CompatibleOperandsAndResultElementType /*broadcast_in_dim_c1*/]> {
let summary = "BroadcastInDim operation";
let description = [{
Expands the dimensions and/or rank of an input tensor by duplicating the
Expand Down Expand Up @@ -1942,7 +1943,8 @@ def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim",
}

def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp<
"dynamic_broadcast_in_dim", [Pure]> {
"dynamic_broadcast_in_dim",
[ConditionallySpeculatable, NoMemoryEffect]> {
let summary = "DynamicBroadcastInDim operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand Down Expand Up @@ -1984,6 +1986,11 @@ def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp<
$operand `,` $output_dimensions `,` `dims` `=` custom<DenseI64Array>($broadcast_dimensions)
attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

// Note: There is no HLO_CallOp because the standard call operation mlir::func::CallOp
Expand Down
17 changes: 17 additions & 0 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3714,6 +3714,23 @@ LogicalResult verifyDynamicBroadcastInDimOp(
" does not refer to a "
"valid operand dimension");

if (SmallVector<int64_t> shape;
operandType.hasStaticShape() &&
matchInts(outputDimensions, shape).succeeded()) {
for (int64_t i = 0; i != bcastDimensionsSize; ++i) {
auto dimIndex = broadcastDimensions[i];
if (!operandType.isDynamicDim(i)) {
auto dimSize = operandType.getDimSize(i);
auto shapeDimSize = shape[dimIndex];
if (dimSize != 1 && dimSize != shapeDimSize)
return emitOptionalError(
location, "size of operand dimension ", i, " (", dimSize,
") is not equal to 1 or value of shape at index ", dimIndex, " (",
shapeDimSize, ")");
}
}
}

if (!isCompatibleForHloTypeInference(outputDimensions, resultType))
return emitOptionalError(
location,
Expand Down
46 changes: 46 additions & 0 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,16 @@ func.func @tanh(%static_arg: tensor<2xf64>, %dynamic_arg: tensor<?xf64>) {

// -----

// CHECK-LABEL: func @broadcast_in_dim
// CHECK-NEXT: return
func.func @broadcast_in_dim(%static_arg: tensor<1x1xf64>, %dynamic_arg: tensor<?x?xf64>) {
%0 = stablehlo.broadcast_in_dim %static_arg, dims = [0, 1] : (tensor<1x1xf64>) -> tensor<3x3xf64>
"hlo_test_speculatability.is_speculatable"(%0) : (tensor<3x3xf64>) -> ()
%1 = stablehlo.broadcast_in_dim %dynamic_arg, dims = [0, 1] : (tensor<?x?xf64>) -> tensor<3x3xf64>
"hlo_test_speculatability.is_not_speculatable"(%1) : (tensor<3x3xf64>) -> ()
return
}

// CHECK-LABEL: func @pad
// CHECK-NEXT: return
func.func @pad(%static_arg: tensor<2xf64>, %dynamic_arg: tensor<?xf64>, %padding_value: tensor<f64>) {
Expand Down Expand Up @@ -1676,6 +1686,42 @@ func.func @scatter(

// -----

// CHECK-LABEL: func @dynamic_broadcast_in_dim
// CHECK-NEXT: return
func.func @dynamic_broadcast_in_dim(
%static_arg_0: tensor<1x1xf64>, %static_arg_1: tensor<1x5xf64>,
%dynamic_arg: tensor<?x?xf64>, %unknown_shape: tensor<2xi32>
) {
%constant_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi32>

// Static input, constant shape
%0 = stablehlo.dynamic_broadcast_in_dim %static_arg_0, %constant_shape, dims = [0, 1] : (tensor<1x1xf64>, tensor<2xi32>) -> tensor<4x5xf64>
"hlo_test_speculatability.is_speculatable"(%0) : (tensor<4x5xf64>) -> ()
%1 = stablehlo.dynamic_broadcast_in_dim %static_arg_0, %constant_shape, dims = [0, 1] : (tensor<1x1xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%1) : (tensor<?x?xf64>) -> ()

// Dynamic input
%2 = stablehlo.dynamic_broadcast_in_dim %dynamic_arg, %constant_shape, dims = [0, 1] : (tensor<?x?xf64>, tensor<2xi32>) -> tensor<4x5xf64>
"hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<4x5xf64>) -> ()
%3 = stablehlo.dynamic_broadcast_in_dim %dynamic_arg, %constant_shape, dims = [0, 1] : (tensor<?x?xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_not_speculatable"(%3) : (tensor<?x?xf64>) -> ()

// Unknown shape, but all dimensions are 1 so must be broadcastable
%4 = stablehlo.dynamic_broadcast_in_dim %static_arg_0, %unknown_shape, dims = [0, 1] : (tensor<1x1xf64>, tensor<2xi32>) -> tensor<4x5xf64>
"hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<4x5xf64>) -> ()
%5 = stablehlo.dynamic_broadcast_in_dim %static_arg_0, %unknown_shape, dims = [0, 1] : (tensor<1x1xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%5) : (tensor<?x?xf64>) -> ()

// Unknown shape, but not all dimensions are 1
%6 = stablehlo.dynamic_broadcast_in_dim %static_arg_1, %unknown_shape, dims = [0, 1] : (tensor<1x5xf64>, tensor<2xi32>) -> tensor<4x5xf64>
"hlo_test_speculatability.is_not_speculatable"(%6) : (tensor<4x5xf64>) -> ()
%7 = stablehlo.dynamic_broadcast_in_dim %static_arg_1, %unknown_shape, dims = [0, 1] : (tensor<1x5xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_not_speculatable"(%7) : (tensor<?x?xf64>) -> ()
return
}

// -----

// CHECK-LABEL: func @dynamic_iota
// CHECK-NEXT: return
func.func @dynamic_iota(%unknown_shape: tensor<2xi32>) {
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,15 @@ func.func @dynamic_broadcast_in_dim_dynamic_output_shape(%arg0: tensor<?x?xi32>,

// -----

func.func @dynamic_broadcast_in_dim_input_mismatch_with_shape(%arg0: tensor<1x3xi32>) {
%shape = stablehlo.constant dense<[2, 1, 1]> : tensor<3xi32>
// expected-error@+1 {{size of operand dimension 1 (3) is not equal to 1 or value of shape at index 2 (1)}}
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<1x3xi32>, tensor<3xi32>) -> tensor<?x?x?xi32>
return
}

// -----

// CHECK-LABEL: func @broadcast_in_dim
func.func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> {
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32>
Expand Down

0 comments on commit 9d9c6fb

Please sign in to comment.