Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for remaining dynamic ops
Browse files Browse the repository at this point in the history
Included ops:
- DynamicPad
- RealDynamicSlice
- DynamicConv
- DynamicGather

I refactored the logic to check speculatability for shaped ops to enable
reuse and allow ops with more than one shape-related operand to be
checked.

This should be the last change adding new speculatability
implementations. All the other ops are either done, pure, or
deprecated. I will confirm this shortly by reviewing the entire opset
and making sure everything is covered (I have been tracking progress in
a personal document).
  • Loading branch information
Michael Levesque-Dion committed Apr 29, 2024
1 parent 0a91535 commit 5932d82
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 26 deletions.
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,5 +611,22 @@ bool isSplatArray(ArrayRef<int64_t> arr, int64_t val) {
[val](int64_t x) { return x == val; });
}

mlir::Speculation::Speculatability getShapedSpeculatability(
Operation* op, int64_t shapeCount) {
// If all inputs are static and the shape-related operands are constant
// then any relationship between the input, the shapes and the output can be
// verified statically.
bool allInputsStatic = llvm::all_of(op->getOperandTypes(), [](Type t) {
return cast<ShapedType>(t).hasStaticShape();
});
bool allShapesConstant = llvm::all_of(llvm::seq(shapeCount), [&](int64_t i) {
return matchPattern(op->getOperand(op->getNumOperands() - 1 - i),
m_Constant());
});
return allInputsStatic && allShapesConstant
? mlir::Speculation::Speculatable
: mlir::Speculation::NotSpeculatable;
}

} // namespace hlo
} // namespace mlir
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,13 @@ void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) {
}
} // namespace bytecode

// Determines the speculatability for a shaped operation `op` with `shapeCount`
// shape operands. The last `count` operands are assumed to be shape operands.
// To be speculatable, such an op must either have a fully dynamic result type
// or have only static inputs and constant shape operands.
mlir::Speculation::Speculatability getShapedSpeculatability(Operation *op,
int64_t shapeCount);

namespace OpTrait {

template <typename ConcreteType>
Expand Down Expand Up @@ -470,6 +477,16 @@ struct RecursivelySpeculatableIfAllInputsStaticImplTrait
}
};

template <typename ConcreteType>
struct SpeculatableIfAllInputsStaticAndShapeConstantImplTrait
: public mlir::OpTrait::TraitBase<
ConcreteType,
SpeculatableIfAllInputsStaticAndShapeConstantImplTrait> {
mlir::Speculation::Speculatability getSpeculatability() {
return getShapedSpeculatability(this->getOperation(), 1);
}
};

} // namespace OpTrait
} // namespace hlo
} // namespace mlir
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,13 @@ def HLO_RecursivelySpeculatableIfAllInputsStaticImplTrait
def HLO_RecursivelySpeculatableIfAllInputsStatic : TraitList<[
ConditionallySpeculatable, HLO_RecursivelySpeculatableIfAllInputsStaticImplTrait]>;

def HLO_SpeculatableIfAllInputsStaticAndShapeConstantImplTrait
: HLO_NativeOpTrait<"SpeculatableIfAllInputsStaticAndShapeConstantImplTrait">;

// This trait is the same as HLO_SpeculatableIfAllInputsStatic, but for ops that
// take a shape as their last operand. Such ops are speculatable if either the
// output is dynamic or all inputs are static and the shape is constant.
def HLO_SpeculatableIfAllInputsStaticAndShapeConstant : TraitList<[
ConditionallySpeculatable, HLO_SpeculatableIfAllInputsStaticAndShapeConstantImplTrait]>;

#endif // STABLEHLO_DIALECT_BASE
20 changes: 8 additions & 12 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1532,18 +1532,6 @@ LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability DynamicReshapeOp::getSpeculatability() {
// If the input is static and the shape operand is constant, the output
// shape can be inferred and any mismatch will be caught statically.
// If any dimension in the input is dynamic, or if the shape is not known,
// the number of elements may disagree at runtime.
if (getOperand().getType().hasStaticShape() &&
matchPattern(getOutputShape(), m_Constant()))
return mlir::Speculation::Speculatable;

return mlir::Speculation::NotSpeculatable;
}

//===----------------------------------------------------------------------===//
// DynamicSliceOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1613,6 +1601,10 @@ LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability RealDynamicSliceOp::getSpeculatability() {
return hlo::getShapedSpeculatability(getOperation(), /*count=*/3);
}

//===----------------------------------------------------------------------===//
// InfeedOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2160,6 +2152,10 @@ LogicalResult DynamicPadOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability DynamicPadOp::getSpeculatability() {
return hlo::getShapedSpeculatability(getOperation(), /*count=*/3);
}

//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 20 additions & 11 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [Condit
$output_shape `,` `dim` `=` $iota_dimension attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
Expand Down Expand Up @@ -2653,7 +2653,7 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape",
}

def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape",
[ConditionallySpeculatable, NoMemoryEffect]> {
[HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> {
let summary = "DynamicReshape operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand All @@ -2675,11 +2675,6 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape",
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_ScatterOp: StableHLO_Op<"scatter",
Expand Down Expand Up @@ -3375,7 +3370,8 @@ def StableHLO_ReducePrecisionOp : StableHLO_Op<"reduce_precision",

def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp<
"real_dynamic_slice",
[Pure, AllElementTypesMatch<["operand", "result"]>,
[ConditionallySpeculatable, NoMemoryEffect,
AllElementTypesMatch<["operand", "result"]>,
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let summary = "RealDynamicSlice operation";
let description = [{
Expand Down Expand Up @@ -3403,10 +3399,16 @@ def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp<
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad",
[Pure, AllElementTypesMatch<["operand", "padding_value", "result"]>,
[ConditionallySpeculatable, NoMemoryEffect,
AllElementTypesMatch<["operand", "padding_value", "result"]>,
AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
let summary = "DynamicPad operation";
let description = [{
Expand Down Expand Up @@ -3440,10 +3442,16 @@ def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad",
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather",
[InferTensorTypeWithReify, Pure]> {
[HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect,
InferTensorTypeWithReify]> {
let summary = "DynamicGather operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand Down Expand Up @@ -3476,7 +3484,8 @@ def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather",
let results = (outs HLO_Tensor);
}

def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", [Pure]> {
def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv",
[HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> {
let summary = "DynamicConv operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand Down
Loading

0 comments on commit 5932d82

Please sign in to comment.