Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stablehlo] Add batching dims to stablehlo.gather and stablehlo.scatter #2259

Merged
merged 50 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
6b2e42a
[RFC] Add batching dims to `stablehlo.gather` and `stable.scatter` sp…
tomnatan30 Mar 12, 2024
84a4fda
[RFC] Add batching dims to `stablehlo.gather` and `stable.scatter` sp…
tomnatan30 Mar 12, 2024
64a2039
resolve conflicts
tomnatan30 Mar 12, 2024
d32cfcc
additional name fixes
tomnatan30 Mar 12, 2024
382d5c5
typo
tomnatan30 Mar 12, 2024
672b9cc
fix arrows in image
tomnatan30 Mar 12, 2024
35c50a7
Merge branch 'openxla:main' into gather-scatter-batch-spec
tomnatan30 Apr 25, 2024
038166b
[stablehlo] Add batching dims to stablehlo.gather and stablehlo.scatter
tomnatan30 Apr 26, 2024
39da3ab
[stablehlo] Add batching dims to stablehlo.gather and stablehlo.scatter
tomnatan30 Apr 26, 2024
f9e21a9
Revert "[stablehlo] Add batching dims to stablehlo.gather and stableh…
tomnatan30 Apr 26, 2024
36a311f
update stablehlo python tests
tomnatan30 Apr 26, 2024
5420453
Merge branch 'openxla:main' into gather-scatter-batch-impl
tomnatan30 May 1, 2024
dd466a4
fix clang format in Ops.h/cpp
tomnatan30 May 1, 2024
554b206
fix clang format and cmake failure in TypeInference.cpp
tomnatan30 May 1, 2024
797a84e
Merge branch 'openxla:main' into gather-scatter-batch-impl
tomnatan30 May 2, 2024
ef929e8
fix cmake failure in StablehloModule.cpp
tomnatan30 May 2, 2024
05fa434
fix conflict in verify_scatter.mlir
tomnatan30 May 7, 2024
b7fcd28
verify new attributes for DynamicGatherOp
tomnatan30 May 7, 2024
2bfcf01
patching compat
GleasonK May 7, 2024
9e02518
address review comments
tomnatan30 May 7, 2024
3a52251
change new version to 20
GleasonK May 7, 2024
ccb43b6
minor changes
tomnatan30 May 7, 2024
bade61e
resolve additional review comments
tomnatan30 May 7, 2024
25d1152
fix merge conflicts
tomnatan30 May 7, 2024
c9a8e3f
fix build issues
tomnatan30 May 7, 2024
39607c1
fix minor issue and generate serialized stablehlo_legalize_to_vhlo file
tomnatan30 May 7, 2024
1ed7fc2
fix clang format
tomnatan30 May 7, 2024
72fd5fa
fix whitespaces
tomnatan30 May 7, 2024
784918d
additional minor fixes
tomnatan30 May 7, 2024
6c0ab52
additional minor fixes
tomnatan30 May 7, 2024
7761e47
verification message change
tomnatan30 May 7, 2024
046647e
Merge branch 'openxla:main' into gather-scatter-batch-impl
tomnatan30 May 8, 2024
ded51ea
fix scatter downgrade patterns
GleasonK May 8, 2024
320397e
minor verification message change
tomnatan30 May 8, 2024
71a8948
minor fix
tomnatan30 May 8, 2024
a702720
Fix lint in BUILD.bazel
ghpvnist May 8, 2024
883e443
Update table formatting and fix typos (#2300)
ghpvnist May 8, 2024
3f12b29
dynamic_reshape op spec (#2284)
abhigunj May 8, 2024
bcec619
Change uses of deprecated method a.cast to cast<>(a) (#2301)
abhigunj May 8, 2024
f905d7a
Remove dynamic_{iota,reshape} from list of ops to spec (#2303)
May 8, 2024
ddb5d2a
dynamic_iota op : match ODS description with the spec (#2305)
abhigunj May 8, 2024
8a5120d
Merge branch 'main' into gather-scatter-batch-impl
tomnatan30 May 9, 2024
69e3bc4
fix merge conflicts
tomnatan30 May 15, 2024
c309cf4
resolve review comments
tomnatan30 May 15, 2024
a802f22
restore stablehlo_legalize_to_vhlo.0_20_0.mlir
tomnatan30 May 15, 2024
44929ec
add test for dynamic_gather c12
tomnatan30 May 15, 2024
c62a491
fix whitespaces
tomnatan30 May 15, 2024
41969bf
fix deprecated dyn_cast
tomnatan30 May 15, 2024
28c9f9b
Merge branch 'openxla:main' into gather-scatter-batch-impl
tomnatan30 May 15, 2024
6827859
remove extra new line
tomnatan30 May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 189 additions & 89 deletions docs/spec.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion rfcs/20240311-gather-scatter-batching-dims.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ The following diagram shows how elements in `updates...` map on elements in
`updates...` indices and explains in detail which `results...` indices they
correspond to.

![](images/20240311-gather-scatter-batching-dims/scatter.svg)
![scatter](images/20240311-gather-scatter-batching-dims/scatter.svg)

More formally, for all `update_index` in `index_space(updates[0])`:

Expand Down
4 changes: 4 additions & 0 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def StableHLO_ScatterDimensionNumbers : AttrDef<StableHLO_Dialect, "ScatterDimen
let parameters = (ins
StableHLO_Dims:$updateWindowDims,
StableHLO_Dims:$insertedWindowDims,
StableHLO_Dims:$inputBatchingDims,
StableHLO_Dims:$scatterIndicesBatchingDims,
StableHLO_Dims:$scatterDimsToOperandDims,
"int64_t":$indexVectorDim
);
Expand All @@ -61,6 +63,8 @@ def StableHLO_GatherDimensionNumbers : AttrDef<StableHLO_Dialect, "GatherDimensi
let parameters = (ins
StableHLO_Dims:$offsetDims,
StableHLO_Dims:$collapsedSliceDims,
StableHLO_Dims:$operandBatchingDims,
StableHLO_Dims:$startIndicesBatchingDims,
StableHLO_Dims:$startIndexMap,
"int64_t":$indexVectorDim
);
Expand Down
23 changes: 16 additions & 7 deletions stablehlo/dialect/StablehloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,25 +491,30 @@ GatherDimensionNumbersAttr
StablehloBytecodeInterface::readGatherDimensionNumbersAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
llvm::SmallVector<int64_t> offsetDims, collapsedSliceDims, startIndexMap;
llvm::SmallVector<int64_t> offsetDims, collapsedSliceDims,
operandBatchingDims, startIndicesBatchingDims, startIndexMap;
int64_t indexVectorDim;

if (failed(reader.readSignedVarInts(offsetDims)) ||
failed(reader.readSignedVarInts(collapsedSliceDims)) ||
failed(reader.readSignedVarInts(operandBatchingDims)) ||
failed(reader.readSignedVarInts(startIndicesBatchingDims)) ||
failed(reader.readSignedVarInts(startIndexMap)) ||
failed(reader.readSignedVarInt(indexVectorDim)))
return GatherDimensionNumbersAttr();

return GatherDimensionNumbersAttr::get(getContext(), offsetDims,
collapsedSliceDims, startIndexMap,
indexVectorDim);
return GatherDimensionNumbersAttr::get(
getContext(), offsetDims, collapsedSliceDims, operandBatchingDims,
startIndicesBatchingDims, startIndexMap, indexVectorDim);
}

void StablehloBytecodeInterface::write(GatherDimensionNumbersAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(stablehlo_encoding::kGatherDimensionNumbers);
writer.writeSignedVarInts(attr.getOffsetDims());
writer.writeSignedVarInts(attr.getCollapsedSliceDims());
writer.writeSignedVarInts(attr.getOperandBatchingDims());
writer.writeSignedVarInts(attr.getStartIndicesBatchingDims());
writer.writeSignedVarInts(attr.getStartIndexMap());
writer.writeSignedVarInt(attr.getIndexVectorDim());
}
Expand Down Expand Up @@ -599,25 +604,29 @@ StablehloBytecodeInterface::readScatterDimensionNumbersAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
llvm::SmallVector<int64_t> updateWindowDims, insertedWindowDims,
scatterDimsToOperandDims;
inputBatchingDims, scatterIndicesBatchingDims, scatterDimsToOperandDims;
int64_t indexVectorDim;

if (failed(reader.readSignedVarInts(updateWindowDims)) ||
failed(reader.readSignedVarInts(insertedWindowDims)) ||
failed(reader.readSignedVarInts(inputBatchingDims)) ||
failed(reader.readSignedVarInts(scatterIndicesBatchingDims)) ||
failed(reader.readSignedVarInts(scatterDimsToOperandDims)) ||
failed(reader.readSignedVarInt(indexVectorDim)))
return ScatterDimensionNumbersAttr();

return ScatterDimensionNumbersAttr::get(
getContext(), updateWindowDims, insertedWindowDims,
scatterDimsToOperandDims, indexVectorDim);
getContext(), updateWindowDims, insertedWindowDims, inputBatchingDims,
scatterIndicesBatchingDims, scatterDimsToOperandDims, indexVectorDim);
}

void StablehloBytecodeInterface::write(ScatterDimensionNumbersAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(stablehlo_encoding::kScatterDimensionNumbersAttr);
writer.writeSignedVarInts(attr.getUpdateWindowDims());
writer.writeSignedVarInts(attr.getInsertedWindowDims());
writer.writeSignedVarInts(attr.getInputBatchingDims());
writer.writeSignedVarInts(attr.getScatterIndicesBatchingDims());
writer.writeSignedVarInts(attr.getScatterDimsToOperandDims());
writer.writeSignedVarInt(attr.getIndexVectorDim());
}
Expand Down
39 changes: 31 additions & 8 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands,
hlo::reifyGatherDimSizes(resultRank, getStartIndicesDim, getSliceDim,
op->getDimensionNumbers().getOffsetDims(),
op->getDimensionNumbers().getCollapsedSliceDims(),
op->getDimensionNumbers().getStartIndexMap(),
op->getDimensionNumbers().getOperandBatchingDims(),
op->getDimensionNumbers().getIndexVectorDim(),
shapeValues);

Expand Down Expand Up @@ -784,6 +784,8 @@ LogicalResult GatherOp::inferReturnTypeComponents(
location, adaptor.getOperand(), adaptor.getStartIndices(),
adaptor.getDimensionNumbers().getOffsetDims(),
adaptor.getDimensionNumbers().getCollapsedSliceDims(),
adaptor.getDimensionNumbers().getOperandBatchingDims(),
adaptor.getDimensionNumbers().getStartIndicesBatchingDims(),
adaptor.getDimensionNumbers().getStartIndexMap(),
adaptor.getDimensionNumbers().getIndexVectorDim(),
adaptor.getSliceSizes(), inferredReturnShapes);
Expand Down Expand Up @@ -823,6 +825,8 @@ LogicalResult DynamicGatherOp::inferReturnTypeComponents(
location, adaptor.getOperand(), adaptor.getStartIndices(),
adaptor.getSliceSizes(), adaptor.getDimensionNumbers().getOffsetDims(),
adaptor.getDimensionNumbers().getCollapsedSliceDims(),
adaptor.getDimensionNumbers().getOperandBatchingDims(),
adaptor.getDimensionNumbers().getStartIndicesBatchingDims(),
adaptor.getDimensionNumbers().getStartIndexMap(),
adaptor.getDimensionNumbers().getIndexVectorDim(), inferredReturnShapes);
}
Expand Down Expand Up @@ -2510,6 +2514,8 @@ LogicalResult ScatterOp::verify() {
getLoc(), getInputs(), getScatterIndices(), getUpdates(),
getScatterDimensionNumbers().getUpdateWindowDims(),
getScatterDimensionNumbers().getInsertedWindowDims(),
getScatterDimensionNumbers().getInputBatchingDims(),
getScatterDimensionNumbers().getScatterIndicesBatchingDims(),
getScatterDimensionNumbers().getScatterDimsToOperandDims(),
getScatterDimensionNumbers().getIndexVectorDim(), getUpdateComputation());
}
Expand Down Expand Up @@ -2801,6 +2807,9 @@ void ScatterDimensionNumbersAttr::print(AsmPrinter& printer) const {
printStruct(printer, "scatter",
std::make_pair("update_window_dims", getUpdateWindowDims()),
std::make_pair("inserted_window_dims", getInsertedWindowDims()),
std::make_pair("input_batching_dims", getInputBatchingDims()),
std::make_pair("scatter_indices_batching_dims",
getScatterIndicesBatchingDims()),
std::make_pair("scatter_dims_to_operand_dims",
getScatterDimsToOperandDims()),
std::make_pair("index_vector_dim", getIndexVectorDim()));
Expand All @@ -2809,15 +2818,20 @@ Attribute ScatterDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
if (failed(parser.parseLess())) return {};
SmallVector<int64_t> updateWindowDims;
SmallVector<int64_t> insertedWindowDims;
SmallVector<int64_t> inputBatchingDims;
SmallVector<int64_t> scatterIndicesBatchingDims;
SmallVector<int64_t> scatterDimsToOperandDims;
int64_t indexVectorDim = 0;

if (failed(parseStruct(
parser,
{"update_window_dims", "inserted_window_dims",
"scatter_dims_to_operand_dims", "index_vector_dim"},
{"update_window_dims", "inserted_window_dims", "input_batching_dims",
"scatter_indices_batching_dims", "scatter_dims_to_operand_dims",
"index_vector_dim"},
{[&]() { return parseDims(parser, updateWindowDims); },
[&]() { return parseDims(parser, insertedWindowDims); },
[&]() { return parseDims(parser, inputBatchingDims); },
[&]() { return parseDims(parser, scatterIndicesBatchingDims); },
[&]() { return parseDims(parser, scatterDimsToOperandDims); },
[&]() { return parser.parseInteger(indexVectorDim); }}))) {
parser.emitError(parser.getCurrentLocation())
Expand All @@ -2827,13 +2841,17 @@ Attribute ScatterDimensionNumbersAttr::parse(AsmParser& parser, Type type) {

return ScatterDimensionNumbersAttr::get(
parser.getContext(), updateWindowDims, insertedWindowDims,
scatterDimsToOperandDims, indexVectorDim);
inputBatchingDims, scatterIndicesBatchingDims, scatterDimsToOperandDims,
indexVectorDim);
}

// Custom printer and parser for GatherDimensionNumbersAttr.
void GatherDimensionNumbersAttr::print(AsmPrinter& printer) const {
printStruct(printer, "gather", std::make_pair("offset_dims", getOffsetDims()),
std::make_pair("collapsed_slice_dims", getCollapsedSliceDims()),
std::make_pair("operand_batching_dims", getOperandBatchingDims()),
std::make_pair("start_indices_batching_dims",
getStartIndicesBatchingDims()),
std::make_pair("start_index_map", getStartIndexMap()),
std::make_pair("index_vector_dim", getIndexVectorDim()));
}
Expand All @@ -2843,25 +2861,30 @@ Attribute GatherDimensionNumbersAttr::parse(AsmParser& parser, Type type) {

SmallVector<int64_t> offsetDims;
SmallVector<int64_t> collapsedSliceDims;
SmallVector<int64_t> operandBatchingDims;
SmallVector<int64_t> startIndicesBatchingDims;
SmallVector<int64_t> startIndexMap;
int64_t indexVectorDim = 0;

if (failed(parseStruct(
parser,
{"offset_dims", "collapsed_slice_dims", "start_index_map",
{"offset_dims", "collapsed_slice_dims", "operand_batching_dims",
"start_indices_batching_dims", "start_index_map",
"index_vector_dim"},
{[&]() { return parseDims(parser, offsetDims); },
[&]() { return parseDims(parser, collapsedSliceDims); },
[&]() { return parseDims(parser, operandBatchingDims); },
[&]() { return parseDims(parser, startIndicesBatchingDims); },
[&]() { return parseDims(parser, startIndexMap); },
[&]() { return parser.parseInteger(indexVectorDim); }}))) {
parser.emitError(parser.getCurrentLocation())
<< "failed parsing gather dimension numbers attribute";
return {};
}

return GatherDimensionNumbersAttr::get(parser.getContext(), offsetDims,
collapsedSliceDims, startIndexMap,
indexVectorDim);
return GatherDimensionNumbersAttr::get(
parser.getContext(), offsetDims, collapsedSliceDims, operandBatchingDims,
startIndicesBatchingDims, startIndexMap, indexVectorDim);
}

// Custom printer and parser for DotDimensionNumbersAttr.
Expand Down
48 changes: 26 additions & 22 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2551,8 +2551,8 @@ def StableHLO_FftOp: StableHLO_Op<"fft",

def StableHLO_GatherOp: StableHLO_Op<"gather",
[ConditionallySpeculatable, NoMemoryEffect,
InferTensorTypeWithReify /*gather_c13*/,
AllElementTypesMatch<["operand", "result"]> /*gather_c14*/]> {
InferTensorTypeWithReify /*gather_c22*/,
AllElementTypesMatch<["operand", "result"]> /*gather_c23*/]> {
let summary = "Gather operation";
let description = [{
Gathers slices from `operand` tensor from offsets specified in
Expand All @@ -2565,22 +2565,24 @@ def StableHLO_GatherOp: StableHLO_Op<"gather",
```mlir
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
slice_sizes = array<i64: 1, 2, 2>,
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi64>
```
}];

let arguments = (ins
HLO_Tensor:$operand /*gather_i1*/,
HLO_IntTensor:$start_indices /*gather_i2*/,
StableHLO_GatherDimensionNumbers:$dimension_numbers /*gather_i3, gather_i4, gather_i5, gather_i6*/,
GenericDenseI64ArrayAttr:$slice_sizes /*gather_i7*/,
DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted /*gather_i8*/
StableHLO_GatherDimensionNumbers:$dimension_numbers /*gather_i3...gather_i8*/,
GenericDenseI64ArrayAttr:$slice_sizes /*gather_i9*/,
DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted /*gather_i10*/
);

let results = (outs HLO_Tensor:$result);
Expand Down Expand Up @@ -2716,8 +2718,8 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape",
def StableHLO_ScatterOp: StableHLO_Op<"scatter",
[ConditionallySpeculatable, RecursiveMemoryEffects,
SameVariadicOperandSize /*scatter_c5*/,
DeclareOpInterfaceMethods<InferTypeOpInterface> /*scatter_c16,
scater_c17*/]> {
DeclareOpInterfaceMethods<InferTypeOpInterface> /*scatter_c24,
scater_c25*/]> {
let summary = "Scatter operation";
let description = [{
Produces `results` tensors which are equal to `inputs` tensors except that
Expand All @@ -2735,25 +2737,27 @@ def StableHLO_ScatterOp: StableHLO_Op<"scatter",
stablehlo.return %0 : tensor<i64>
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
tomnatan30 marked this conversation as resolved.
Show resolved Hide resolved
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
```
}];
let arguments = (ins
Variadic<HLO_Tensor>:$inputs, /*scatter_i1*/
RankedTensorOf<[AnyInteger, Index]>:$scatter_indices, /*scatter_i2*/
Variadic<HLO_Tensor>:$updates, /*scatter_i3*/
StableHLO_ScatterDimensionNumbers:$scatter_dimension_numbers, /*scatter_i4...scatter_i7*/
DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted, /*scatter_i8*/
DefaultValuedOptionalAttr<BoolAttr, "false">:$unique_indices /*scatter_i9*/
StableHLO_ScatterDimensionNumbers:$scatter_dimension_numbers, /*scatter_i4...scatter_i9*/
DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted, /*scatter_i10*/
DefaultValuedOptionalAttr<BoolAttr, "false">:$unique_indices /*scatter_i11*/
);

let regions = (region SizedRegion<1>:$update_computation /*scatter_i10*/);
let regions = (region SizedRegion<1>:$update_computation /*scatter_i12*/);

let results = (outs Variadic<HLO_Tensor>);

Expand Down
Loading
Loading