Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for remaining dynamic ops
Browse files Browse the repository at this point in the history
Included ops:
- DynamicPad
- RealDynamicSlice
- DynamicConv
- DynamicGather

I refactored the logic to check speculatability for shaped ops to enable
reuse and allow ops with more than one shape-related operand to be
checked.

This should be the last change adding new speculatability
implementations. All the other ops are either done, pure, or
deprecated. I will confirm this shortly by reviewing the entire opset
and making sure everything is covered (I have been tracking progress in
a personal document).
  • Loading branch information
Michael Levesque-Dion committed Apr 19, 2024
1 parent 3729fde commit f81720a
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 23 deletions.
24 changes: 24 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,5 +611,29 @@ bool isSplatArray(ArrayRef<int64_t> arr, int64_t val) {
[val](int64_t x) { return x == val; });
}

mlir::Speculation::Speculatability getShapedSpeculatability(Operation* op,
int64_t count) {
auto resultType = cast<ShapedType>(op->getResult(0).getType());
// The result type's shape is fully dynamic, so there cannot be a mismatch
// with the output shape operand at runtime (the type has no expectations).
if (llvm::all_of(llvm::seq(resultType.getRank()),
[&](int64_t i) { return resultType.isDynamicDim(i); }))
return mlir::Speculation::Speculatable;

// If all inputs are static and the shape-related operands are constant
// then any relationship between the input and the shapes can be
// verified statically. Shapes must be static due to ODS constraints.
bool allInputsStatic = llvm::all_of(op->getOperandTypes(), [](Type t) {
return cast<ShapedType>(t).hasStaticShape();
});
bool allShapesConstant = llvm::all_of(llvm::seq(count), [&](int64_t i) {
return matchPattern(op->getOperand(op->getNumOperands() - 1 - i),
m_Constant());
});
return allInputsStatic && allShapesConstant
? mlir::Speculation::Speculatable
: mlir::Speculation::NotSpeculatable;
}

} // namespace hlo
} // namespace mlir
26 changes: 8 additions & 18 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,13 @@ void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) {
}
} // namespace bytecode

// Determines the speculatability for a shaped operation `op` with `count`
// shape operands. The last `count` operands are assumed to be shape operands.
// To be speculatable, such an op must either have a fully dynamic result type
// or have only static inputs and constant shape operands.
mlir::Speculation::Speculatability getShapedSpeculatability(Operation *op,
int64_t count);

namespace OpTrait {

template <typename ConcreteType>
Expand Down Expand Up @@ -456,24 +463,7 @@ struct SpeculatableIfAllInputsStaticAndShapeConstantImplTrait
SpeculatableIfAllInputsStaticAndShapeConstantImplTrait> {
mlir::Speculation::Speculatability getSpeculatability() {
auto op = this->getOperation();
auto resultType = cast<ShapedType>(op->getResult(0).getType());
// The result type's shape is fully dynamic, so there cannot be a mismatch
// with the output shape operand at runtime (the type has no expectations).
if (llvm::all_of(llvm::seq(resultType.getRank()),
[&](int64_t i) { return resultType.isDynamicDim(i); }))
return mlir::Speculation::Speculatable;

// If all inputs are static and the output shape (last operand) is constant,
// then any relationship between the input and the output shape can be
// verified statically. The shape operand is known to be static due to ODS
// constraints.
bool allInputsStatic = llvm::all_of(op->getOperandTypes(), [](Type t) {
return cast<ShapedType>(t).hasStaticShape();
});
if (allInputsStatic &&
matchPattern(op->getOperand(op->getNumOperands() - 1), m_Constant()))
return mlir::Speculation::Speculatable;
return mlir::Speculation::NotSpeculatable;
return getShapedSpeculatability(op, 1);
}
};

Expand Down
23 changes: 23 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,21 @@ LogicalResult DynamicGatherOp::inferReturnTypeComponents(
adaptor.getDimensionNumbers().getIndexVectorDim(), inferredReturnShapes);
}

mlir::Speculation::Speculatability DynamicGatherOp::getSpeculatability() {
// When indices_are_sorted is true, if the start_indices are not sorted, the
// behavior is undefined.
// A possible improvement would be to check if the start_indices are constant
// and if they are sorted, do not return NotSpeculatable. However, such a
// check could be somewhat costly and has unclear ROI.
if (getIndicesAreSorted()) return mlir::Speculation::NotSpeculatable;
bool allOperandsStatic = llvm::all_of(
this->getOperation()->getOperandTypes(),
[](Type t) { return cast<RankedTensorType>(t).hasStaticShape(); });
return allOperandsStatic && matchPattern(getSliceSizes(), m_Constant())
? mlir::Speculation::Speculatable
: mlir::Speculation::NotSpeculatable;
}

//===----------------------------------------------------------------------===//
// GetDimensionSizeOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1557,6 +1572,10 @@ LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability RealDynamicSliceOp::getSpeculatability() {
return hlo::getShapedSpeculatability(getOperation(), /*count=*/3);
}

//===----------------------------------------------------------------------===//
// InfeedOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2085,6 +2104,10 @@ LogicalResult DynamicPadOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability DynamicPadOp::getSpeculatability() {
return hlo::getShapedSpeculatability(getOperation(), /*count=*/3);
}

//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
Expand Down
28 changes: 23 additions & 5 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1345,7 +1345,7 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather",
let results = (outs HLO_Tensor);
let hasVerifier = 1;

let extraClassDeclaration = [{
let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
Expand Down Expand Up @@ -3348,7 +3348,8 @@ def StableHLO_ReducePrecisionOp : StableHLO_Op<"reduce_precision",

def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp<
"real_dynamic_slice",
[Pure, AllElementTypesMatch<["operand", "result"]>,
[ConditionallySpeculatable, NoMemoryEffect,
AllElementTypesMatch<["operand", "result"]>,
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let summary = "RealDynamicSlice operation";
let description = [{
Expand Down Expand Up @@ -3376,10 +3377,16 @@ def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp<
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad",
[Pure, AllElementTypesMatch<["operand", "padding_value", "result"]>,
[ConditionallySpeculatable, NoMemoryEffect,
AllElementTypesMatch<["operand", "padding_value", "result"]>,
AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
let summary = "DynamicPad operation";
let description = [{
Expand Down Expand Up @@ -3413,10 +3420,15 @@ def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad",
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather",
[InferTensorTypeWithReify, Pure]> {
[ConditionallySpeculatable, NoMemoryEffect, InferTensorTypeWithReify]> {
let summary = "DynamicGather operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand Down Expand Up @@ -3447,9 +3459,15 @@ def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather",
DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted
);
let results = (outs HLO_Tensor);

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", [Pure]> {
def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv",
[HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> {
let summary = "DynamicConv operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand Down
Loading

0 comments on commit f81720a

Please sign in to comment.