Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for collective ops (#2168)
Browse files Browse the repository at this point in the history
I'm actually not sure if speculation make sense for these ops. I don't
see any indication in the spec that they could have UB except in the
cases added here, so it's likely fine. Also, I'm not sure if these ops
can have memory effects or not. They definitely involve interacting with
some kind of global state, so they have side effects. The ops weren't
marked "Pure" before this change, so there is no difference on that
front being introduced with this change.

Also I slightly updated the logic in TestUtils to delete the op when the
speculation check succeeds. Indeed before we were relying on DCE, but
these ops don't state that they don't have side effects so DCE won't
remove them.
  • Loading branch information
mlevesquedion authored Apr 9, 2024
1 parent 2651907 commit ad2a3b3
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 6 deletions.
55 changes: 55 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,24 @@ LogicalResult ReduceScatterOp::verify() {
channelId, getUseGlobalDeviceIds(), getComputation(), getResult());
}

mlir::Speculation::Speculatability ReduceScatterOp::getSpeculatability() {
auto inputType = getOperand().getType();
auto resultType = getResult().getType();
auto scatterDim = getScatterDimension();
// The actual size of the `scatterDim` depends on the number of processes,
// which is only known at runtime. If it is dynamic, there is no expectation,
// so there cannot be a mismatch. If it is static, the actual number may
// differ at runtime, leading to UB. See scatter_c8 in the spec.
if (!resultType.isDynamicDim(scatterDim))
return mlir::Speculation::NotSpeculatable;
for (size_t i : llvm::seq(resultType.getRank())) {
if (i == scatterDim) continue;
if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(i))
return mlir::Speculation::NotSpeculatable;
}
return mlir::Speculation::Speculatable;
}

//===----------------------------------------------------------------------===//
// CompatibleOperandsAndResultType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -933,6 +951,25 @@ void AllToAllOp::build(OpBuilder& odsBuilder, OperationState& odsState,
/*channel_handle=*/nullptr);
}

mlir::Speculation::Speculatability AllToAllOp::getSpeculatability() {
auto inputType = getOperand().getType();
auto resultType = getResult().getType();
auto splitDim = getSplitDimension();
auto concatDim = getConcatDimension();
// The actual size of the `splitDim` and `concatDim` depends on the number
// of processes, which is only known at runtime. If it is dynamic, there is
// no expectation, so there cannot be a mismatch. If it is static, the actual
// number may differ at runtime, leading to UB. See all_to_all_c9 in the spec.
if (!resultType.isDynamicDim(splitDim) || !resultType.isDynamicDim(concatDim))
return mlir::Speculation::NotSpeculatable;
for (size_t i : llvm::seq(resultType.getRank())) {
if (i == splitDim || i == concatDim) continue;
if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(i))
return mlir::Speculation::NotSpeculatable;
}
return mlir::Speculation::Speculatable;
}

//===----------------------------------------------------------------------===//
// AllGatherOp
//===----------------------------------------------------------------------===//
Expand All @@ -947,6 +984,24 @@ LogicalResult AllGatherOp::verify() {
getUseGlobalDeviceIds(), getResult());
}

mlir::Speculation::Speculatability AllGatherOp::getSpeculatability() {
auto inputType = getOperand().getType();
auto resultType = getResult().getType();
auto allGatherDim = getAllGatherDim();
// The actual size of the `allGatherDim` depends on the number of processes,
// which is only known at runtime. If it is dynamic, there is no expectation,
// so there cannot be a mismatch. If it is static, the actual number may
// differ at runtime, leading to UB. See all_gather_c6 in the spec.
if (!resultType.isDynamicDim(allGatherDim))
return mlir::Speculation::NotSpeculatable;
for (size_t i : llvm::seq(resultType.getRank())) {
if (i != allGatherDim && !resultType.isDynamicDim(i) &&
inputType.isDynamicDim(i))
return mlir::Speculation::NotSpeculatable;
}
return mlir::Speculation::Speculatable;
}

//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 25 additions & 6 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1317,7 +1317,7 @@ def StableHLO_WhileOp: StableHLO_Op<"while", [
}

def StableHLO_AllGatherOp : StableHLO_Op<"all_gather",
[SameOperandsAndResultElementType] /*all_gather_c6*/> {
[ConditionallySpeculatable, SameOperandsAndResultElementType] /*all_gather_c6*/> {
string summary = "AllGather operation";
string description = [{
Within each process group in the process grid, concatenates the values of the
Expand Down Expand Up @@ -1346,10 +1346,16 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather",
);
let results = (outs HLO_Tensor);
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",
[InferTensorType /*all_reduce_c6, all_reduce_c7*/]> {
[HLO_SpeculatableIfStaticDimInOutputIsStaticInInput,
InferTensorType /*all_reduce_c6, all_reduce_c7*/]> {
let summary = "AllReduce operation";
let description = [{
Within each process group in the process grid, applies a reduction function
Expand Down Expand Up @@ -1384,7 +1390,7 @@ def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",
let hasVerifier = 1;
}

def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter"> {
def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", [ConditionallySpeculatable]> {
let summary = "ReduceScatter operation";
let description = [{
Within each process group in the process grid, performs reduction, using
Expand Down Expand Up @@ -1419,10 +1425,16 @@ def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter"> {
let regions = (region SizedRegion<1>:$computation /*reduce_scatter_i6*/);
let results = (outs HLO_Tensor);
let hasVerifier = 1;

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all",
[SameOperandsAndResultElementType /*all_to_all_c9*/,
[ConditionallySpeculatable,
SameOperandsAndResultElementType /*all_to_all_c9*/,
InferTensorType /*all_to_all_c9*/]> {
let summary = "AllToAll operation";
let description = [{
Expand Down Expand Up @@ -1464,6 +1476,11 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all",
"::mlir::IntegerAttr": $concat_dimension,
"::mlir::IntegerAttr": $split_count,
"::mlir::DenseIntElementsAttr": $replica_groups)>];

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [
Expand Down Expand Up @@ -2052,7 +2069,8 @@ def StableHLO_ConcatenateOp : StableHLO_ShapedInterfaceOp<"concatenate",


def StableHLO_CollectiveBroadcastOp: StableHLO_Op<"collective_broadcast",
[HLO_CompatibleOperandsAndResultType,
[HLO_SpeculatableIfStaticDimInOutputIsStaticInInput,
HLO_CompatibleOperandsAndResultType,
SameOperandsAndResultElementType /*collective_broadcast_c3*/]> {
let summary = "CollectiveBroadcast operation";
let description = [{
Expand Down Expand Up @@ -2088,7 +2106,8 @@ def StableHLO_CollectiveBroadcastOp: StableHLO_Op<"collective_broadcast",
}

def StableHLO_CollectivePermuteOp: StableHLO_Op<"collective_permute",
[HLO_CompatibleOperandsAndResultType,
[HLO_SpeculatableIfStaticDimInOutputIsStaticInInput,
HLO_CompatibleOperandsAndResultType,
SameOperandsAndResultElementType /*collective_permute_c5*/]> {
let summary = "CollectivePermute operation";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions stablehlo/tests/TestUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ LogicalResult checkSpeculatability(PatternRewriter &rewriter, Operation *op,

if (definingOp.getSpeculatability() == spec) {
rewriter.eraseOp(op);
rewriter.eraseOp(definingOp);
return success();
}

Expand Down
Loading

0 comments on commit ad2a3b3

Please sign in to comment.