Skip to content

Commit

Permalink
Introduce attr that can be backed by either DenseI64ArrayAttr or I64E…
Browse files Browse the repository at this point in the history
…lementsAttr (#1887)

Port `BroadcastInDim`'s `broadcast_dimensions` as an example.

This will allow us (and downstream consumers) to port tests bit by bit.
  • Loading branch information
mlevesquedion authored Dec 15, 2023
1 parent e2e1dee commit 00a3080
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -469,15 +469,12 @@ struct HloBroadcastInDimConverter final
SmallVector<AffineExpr> dimExprs;
dimExprs.reserve(nloops);

if (broadcastOp.getBroadcastDimensions()) {
for (auto [idx, broadcastDim] : llvm::enumerate(
broadcastOp.getBroadcastDimensions().getValues<APInt>())) {
int size = broadcastDim.getSExtValue();
bool expansionNeeded =
operandShape[idx] == 1 && resultType.getShape()[size] != 1;
dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0)
: b->getAffineDimExpr(size));
}
for (auto [idx, broadcastDim] :
llvm::enumerate(broadcastOp.getBroadcastDimensions())) {
bool expansionNeeded =
operandShape[idx] == 1 && resultType.getShape()[broadcastDim] != 1;
dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0)
: b->getAffineDimExpr(broadcastDim));
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
Expand Down Expand Up @@ -566,7 +563,7 @@ struct BroadcastInDimOpToBroadcastConverter final
Location loc = op.getLoc();

SmallVector<int64_t> broadcastDimensions =
llvm::to_vector(op.getBroadcastDimensions().getValues<int64_t>());
llvm::to_vector(op.getBroadcastDimensions());

Value operand = adaptor.getOperand();
auto operandTy = llvm::cast<ShapedType>(operand.getType());
Expand Down
17 changes: 17 additions & 0 deletions stablehlo/dialect/AssemblyFormat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,23 @@ ParseResult parseDenseI64Array(OpAsmParser& parser,
return success();
}

void printI64DenseArrayOrElements1D(OpAsmPrinter& p, Operation* op,
Attribute attr) {
if (auto elems = dyn_cast<DenseIntElementsAttr>(attr)) {
printDenseI64Array(p, op, elems);
return;
}
dyn_cast<DenseI64ArrayAttr>(attr).print(p);
}

ParseResult parseI64DenseArrayOrElements1D(OpAsmParser& parser,
Attribute& attr) {
if ((attr = DenseI64ArrayAttr::parse(parser, Type{}))) {
return success();
}
return failure();
}

void printSliceRanges(OpAsmPrinter& p, Operation* op,
ArrayRef<int64_t> startIndices,
ArrayRef<int64_t> limitIndices,
Expand Down
13 changes: 13 additions & 0 deletions stablehlo/dialect/AssemblyFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,19 @@ void printDenseI64Array(OpAsmPrinter& p, Operation* op,

ParseResult parseDenseI64Array(OpAsmParser& parser, DenseIntElementsAttr& attr);

// I64DenseArrayOrElements1D - Used to print an attr that can be either
// I64ElementsAttr (DenseIntElementsAttr) or DenseI64ArrayAttr.
//
// Dense elements:
// { dense<[1, 2]> : tensor<2xi64> }
// Array:
// { array<i64: 1, 2> }
void printI64DenseArrayOrElements1D(OpAsmPrinter& p, Operation* op,
Attribute attr);

ParseResult parseI64DenseArrayOrElements1D(OpAsmParser& parser,
Attribute& attr);

// SliceRanges - Used to print multi-dimensional ranges for slice.
void printSliceRanges(OpAsmPrinter& p, Operation* op,
ArrayRef<int64_t> startIndices,
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,20 @@ def StableHLO_ConvolutionAttributes {
);
}

def I64Elements1D : And<[I64ElementsAttr.predicate, CPred<"$_self.cast<DenseIntElementsAttr>().getType().getRank() == 1">]>;

// TODO(#1578) migrate uses to DenseI64ArrayAttr and delete this attr
def I64DenseArrayOrElements1DAttr : Attr<Or<[DenseI64ArrayAttr.predicate, I64Elements1D]>, "either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr."> {
let storageType = "Attribute";
let returnType = "SmallVector<int64_t>";
let convertFromStorage = [{
[&]() -> SmallVector<int64_t> {
if (auto elems = $_self.dyn_cast<DenseIntElementsAttr>()) {
return llvm::to_vector(elems.getValues<int64_t>());
}
return llvm::to_vector($_self.cast<DenseI64ArrayAttr>().asArrayRef());
}()
}];
}

#endif // STABLEHLO_DIALECT_STABLEHLO_ATTRS
2 changes: 2 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2427,6 +2427,7 @@ using mlir::hlo::parseCustomCallTarget;
using mlir::hlo::parseDenseI64Array;
using mlir::hlo::parseDotDimensionNumbers;
using mlir::hlo::parseExponentMantissa;
using mlir::hlo::parseI64DenseArrayOrElements1D;
using mlir::hlo::parsePairwiseOpType;
using mlir::hlo::parseSameOperandsAndResultType;
using mlir::hlo::parseSelectOpType;
Expand All @@ -2439,6 +2440,7 @@ using mlir::hlo::printCustomCallTarget;
using mlir::hlo::printDenseI64Array;
using mlir::hlo::printDotDimensionNumbers;
using mlir::hlo::printExponentMantissa;
using mlir::hlo::printI64DenseArrayOrElements1D;
using mlir::hlo::printPairwiseOpType;
using mlir::hlo::printSameOperandsAndResultType;
using mlir::hlo::printSelectOpType;
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1873,15 +1873,15 @@ def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim",
}];
let arguments = (ins
HLO_Tensor:$operand /*broadcast_in_dim_i1*/,
BroadcastDimAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/
I64DenseArrayOrElements1DAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/
);

let results = (outs HLO_StaticShapeTensor);

let hasVerifier = 1;

let assemblyFormat = [{
$operand `,` `dims` `=` custom<DenseI64Array>($broadcast_dimensions)
$operand `,` `dims` `=` custom<I64DenseArrayOrElements1D>($broadcast_dimensions)
attr-dict `:` functional-type(operands, results)
}];
}
Expand Down
16 changes: 5 additions & 11 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3198,7 +3198,7 @@ LogicalResult verifyBitcastConvertOp(std::optional<Location> location,

LogicalResult verifyBroadcastInDimOp(std::optional<Location> location,
Value operand,
DenseIntElementsAttr broadcastDimensions,
ArrayRef<int64_t> broadcastDimensions,
Value result) {
auto operandType = operand.getType().dyn_cast<RankedTensorType>();
if (!operandType) {
Expand All @@ -3207,29 +3207,23 @@ LogicalResult verifyBroadcastInDimOp(std::optional<Location> location,
return success();
}

auto dimensionsType = broadcastDimensions.getType();
auto dimensionsRank = dimensionsType.getRank();
// broadcast_in_dim_i2
if (dimensionsRank != 1)
return emitOptionalError(location, "broadcast_dimensions has rank ",
dimensionsRank, " instead of rank 1");
// broadcast_in_dim_c2
auto dimensionsSize = dimensionsType.getNumElements();
auto dimensionsSize = broadcastDimensions.size();
auto operandRank = operandType.getRank();
if (dimensionsSize != operandRank)
if (static_cast<int64_t>(dimensionsSize) != operandRank)
return emitOptionalError(location, "broadcast_dimensions size (",
dimensionsSize, ") does not match operand rank (",
operandRank, ")");

auto dimensions = llvm::to_vector(broadcastDimensions.getValues<int64_t>());
auto dimensions = llvm::to_vector(broadcastDimensions);
// broadcast_in_dim_c4
if (hasDuplicates(dimensions))
return emitOptionalError(location,
"broadcast_dimensions should not have duplicates");

auto resultType = result.getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank();
for (int i = 0; i != dimensionsSize; ++i) {
for (size_t i = 0; i != dimensionsSize; ++i) {
auto dimIndex = dimensions[i];
// broadcast_in_dim_c3
if (dimIndex < 0 || dimIndex >= resultRank)
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ LogicalResult verifyBitcastConvertOp(std::optional<Location> location,

LogicalResult verifyBroadcastInDimOp(std::optional<Location> location,
Value operand,
DenseIntElementsAttr broadcastDimensions,
ArrayRef<int64_t> broadcastDimensions,
Value result);

LogicalResult verifyCollectiveBroadcastOp(std::optional<Location> location,
Expand Down
21 changes: 20 additions & 1 deletion stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ func.func @broadcast_in_dim_c5(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
// -----

func.func @broadcast_in_dim_i2(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}}
// expected-error@+1 {{failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr}}
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
func.return %0 : tensor<1x2x3xi32>
}
Expand Down Expand Up @@ -5639,3 +5639,22 @@ func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> {
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32>
func.return %1 : tensor<4xf32>
}

// Tests for I64DenseArrayOrElementsAttr.

// -----

// CHECK-LABEL: func @broadcast_in_dim_elements
func.func @broadcast_in_dim_elements(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> {
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32>
func.return %0 : tensor<1x2x2xi32>
}

// -----

// CHECK-LABEL: func @broadcast_in_dim_dense_array
func.func @broadcast_in_dim_dense_array(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> {
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32>
func.return %0 : tensor<1x2x2xi32>
}

4 changes: 4 additions & 0 deletions stablehlo/transforms/VhloLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,10 @@ SpecialResult convertSpecial(const OpConversionPattern<VhloOpTy>& pattern,
vhloName == "strides")
return convertDenseArray(vhloName, vhloAttr, stablehloAttrs);
}
if constexpr (std::is_same<VhloOpTy, vhlo::BroadcastInDimOpV1>::value) {
if (vhloName == "broadcast_dimensions")
return convertDenseArray(vhloName, vhloAttr, stablehloAttrs);
}
return notSpecial();
}

Expand Down

0 comments on commit 00a3080

Please sign in to comment.