diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index c669c9d1b9..432b3b5a5a 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -469,15 +469,12 @@ struct HloBroadcastInDimConverter final SmallVector dimExprs; dimExprs.reserve(nloops); - if (broadcastOp.getBroadcastDimensions()) { - for (auto [idx, broadcastDim] : llvm::enumerate( - broadcastOp.getBroadcastDimensions().getValues())) { - int size = broadcastDim.getSExtValue(); - bool expansionNeeded = - operandShape[idx] == 1 && resultType.getShape()[size] != 1; - dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0) - : b->getAffineDimExpr(size)); - } + for (auto [idx, broadcastDim] : + llvm::enumerate(broadcastOp.getBroadcastDimensions())) { + bool expansionNeeded = + operandShape[idx] == 1 && resultType.getShape()[broadcastDim] != 1; + dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0) + : b->getAffineDimExpr(broadcastDim)); } return { AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), @@ -566,7 +563,7 @@ struct BroadcastInDimOpToBroadcastConverter final Location loc = op.getLoc(); SmallVector broadcastDimensions = - llvm::to_vector(op.getBroadcastDimensions().getValues()); + llvm::to_vector(op.getBroadcastDimensions()); Value operand = adaptor.getOperand(); auto operandTy = llvm::cast(operand.getType()); diff --git a/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/dialect/AssemblyFormat.cpp index aba782ce2b..1fcfd1f3ab 100644 --- a/stablehlo/dialect/AssemblyFormat.cpp +++ b/stablehlo/dialect/AssemblyFormat.cpp @@ -279,6 +279,23 @@ ParseResult parseDenseI64Array(OpAsmParser& parser, return success(); } +void printI64DenseArrayOrElements1D(OpAsmPrinter& p, Operation* op, + Attribute attr) { + if (auto elems = dyn_cast(attr)) { + printDenseI64Array(p, op, elems); + return; + } + dyn_cast(attr).print(p); +} + +ParseResult parseI64DenseArrayOrElements1D(OpAsmParser& parser, + Attribute& attr) { + if ((attr = DenseI64ArrayAttr::parse(parser, Type{}))) { + return success(); + } + return failure(); +} + void printSliceRanges(OpAsmPrinter& p, Operation* op, ArrayRef startIndices, ArrayRef limitIndices, diff --git a/stablehlo/dialect/AssemblyFormat.h b/stablehlo/dialect/AssemblyFormat.h index 9336134ae0..7ac0d999c3 100644 --- a/stablehlo/dialect/AssemblyFormat.h +++ b/stablehlo/dialect/AssemblyFormat.h @@ -186,6 +186,19 @@ void printDenseI64Array(OpAsmPrinter& p, Operation* op, ParseResult parseDenseI64Array(OpAsmParser& parser, DenseIntElementsAttr& attr); +// I64DenseArrayOrElements1D - Used to print an attr that can be either +// I64ElementsAttr (DenseIntElementsAttr) or DenseI64ArrayAttr. +// +// Dense elements: +// { dense<[1, 2]> : tensor<2xi64> } +// Array: +// { array } +void printI64DenseArrayOrElements1D(OpAsmPrinter& p, Operation* op, + Attribute attr); + +ParseResult parseI64DenseArrayOrElements1D(OpAsmParser& parser, + Attribute& attr); + // SliceRanges - Used to print multi-dimensional ranges for slice. void printSliceRanges(OpAsmPrinter& p, Operation* op, ArrayRef startIndices, diff --git a/stablehlo/dialect/StablehloAttrs.td b/stablehlo/dialect/StablehloAttrs.td index 53b3d25098..c85619358e 100644 --- a/stablehlo/dialect/StablehloAttrs.td +++ b/stablehlo/dialect/StablehloAttrs.td @@ -209,4 +209,20 @@ def StableHLO_ConvolutionAttributes { ); } +def I64Elements1D : And<[I64ElementsAttr.predicate, CPred<"$_self.cast().getType().getRank() == 1">]>; + +// TODO(#1578) migrate uses to DenseI64ArrayAttr and delete this attr +def I64DenseArrayOrElements1DAttr : Attr, "either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr."> { + let storageType = "Attribute"; + let returnType = "SmallVector"; + let convertFromStorage = [{ + [&]() -> SmallVector { + if (auto elems = $_self.dyn_cast()) { + return llvm::to_vector(elems.getValues()); + } + return llvm::to_vector($_self.cast().asArrayRef()); + }() + }]; +} + #endif // STABLEHLO_DIALECT_STABLEHLO_ATTRS diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 6e37198d90..9b0933a2d4 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -2427,6 +2427,7 @@ using mlir::hlo::parseCustomCallTarget; using mlir::hlo::parseDenseI64Array; using mlir::hlo::parseDotDimensionNumbers; using mlir::hlo::parseExponentMantissa; +using mlir::hlo::parseI64DenseArrayOrElements1D; using mlir::hlo::parsePairwiseOpType; using mlir::hlo::parseSameOperandsAndResultType; using mlir::hlo::parseSelectOpType; @@ -2439,6 +2440,7 @@ using mlir::hlo::printCustomCallTarget; using mlir::hlo::printDenseI64Array; using mlir::hlo::printDotDimensionNumbers; using mlir::hlo::printExponentMantissa; +using mlir::hlo::printI64DenseArrayOrElements1D; using mlir::hlo::printPairwiseOpType; using mlir::hlo::printSameOperandsAndResultType; using mlir::hlo::printSelectOpType; diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index a822228e99..d52fe0712d 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -1873,7 +1873,7 @@ def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim", }]; let arguments = (ins HLO_Tensor:$operand /*broadcast_in_dim_i1*/, - BroadcastDimAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/ + I64DenseArrayOrElements1DAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/ ); let results = (outs HLO_StaticShapeTensor); @@ -1881,7 +1881,7 @@ def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim", let hasVerifier = 1; let assemblyFormat = [{ - $operand `,` `dims` `=` custom($broadcast_dimensions) + $operand `,` `dims` `=` custom($broadcast_dimensions) attr-dict `:` functional-type(operands, results) }]; } diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 03c7dd0d40..7917a6e12e 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3198,7 +3198,7 @@ LogicalResult verifyBitcastConvertOp(std::optional location, LogicalResult verifyBroadcastInDimOp(std::optional location, Value operand, - DenseIntElementsAttr broadcastDimensions, + ArrayRef broadcastDimensions, Value result) { auto operandType = operand.getType().dyn_cast(); if (!operandType) { @@ -3207,21 +3207,15 @@ LogicalResult verifyBroadcastInDimOp(std::optional location, return success(); } - auto dimensionsType = broadcastDimensions.getType(); - auto dimensionsRank = dimensionsType.getRank(); - // broadcast_in_dim_i2 - if (dimensionsRank != 1) - return emitOptionalError(location, "broadcast_dimensions has rank ", - dimensionsRank, " instead of rank 1"); // broadcast_in_dim_c2 - auto dimensionsSize = dimensionsType.getNumElements(); + auto dimensionsSize = broadcastDimensions.size(); auto operandRank = operandType.getRank(); - if (dimensionsSize != operandRank) + if (static_cast(dimensionsSize) != operandRank) return emitOptionalError(location, "broadcast_dimensions size (", dimensionsSize, ") does not match operand rank (", operandRank, ")"); - auto dimensions = llvm::to_vector(broadcastDimensions.getValues()); + auto dimensions = llvm::to_vector(broadcastDimensions); // broadcast_in_dim_c4 if (hasDuplicates(dimensions)) return emitOptionalError(location, @@ -3229,7 +3223,7 @@ LogicalResult verifyBroadcastInDimOp(std::optional location, auto resultType = result.getType().cast(); auto resultRank = resultType.getRank(); - for (int i = 0; i != dimensionsSize; ++i) { + for (size_t i = 0; i != dimensionsSize; ++i) { auto dimIndex = dimensions[i]; // broadcast_in_dim_c3 if (dimIndex < 0 || dimIndex >= resultRank) diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index a090108b29..f16e19e1a0 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -385,7 +385,7 @@ LogicalResult verifyBitcastConvertOp(std::optional location, LogicalResult verifyBroadcastInDimOp(std::optional location, Value operand, - DenseIntElementsAttr broadcastDimensions, + ArrayRef broadcastDimensions, Value result); LogicalResult verifyCollectiveBroadcastOp(std::optional location, diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index bd9bfe15eb..0383a5d406 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -980,7 +980,7 @@ func.func @broadcast_in_dim_c5(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // ----- func.func @broadcast_in_dim_i2(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { - // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} + // expected-error@+1 {{failed to satisfy constraint: either a DenseI64ArrayAttr or a 1-dimensional I64ElementsAttr}} %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } @@ -5639,3 +5639,22 @@ func.func @dynamic_iota_output_shape_mismatching_size() -> tensor<4xf32> { %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<4xf32> func.return %1 : tensor<4xf32> } + +// Tests for I64DenseArrayOrElementsAttr. + +// ----- + +// CHECK-LABEL: func @broadcast_in_dim_elements +func.func @broadcast_in_dim_elements(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + func.return %0 : tensor<1x2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @broadcast_in_dim_dense_array +func.func @broadcast_in_dim_dense_array(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { + %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + func.return %0 : tensor<1x2x2xi32> +} + diff --git a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp index 094164685d..fda0b3678a 100644 --- a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp @@ -521,6 +521,10 @@ SpecialResult convertSpecial(const OpConversionPattern& pattern, vhloName == "strides") return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); } + if constexpr (std::is_same::value) { + if (vhloName == "broadcast_dimensions") + return convertDenseArray(vhloName, vhloAttr, stablehloAttrs); + } return notSpecial(); }