Skip to content

Commit

Permalink
Port DynamicBroadcastInDim to I64DenseArrayOrElements1DAttr (#1893)
Browse files Browse the repository at this point in the history
This is an intermediate step in the migration to `DenseI64ArrayAttr`.

This also includes changing `dense<...>` to `array<...>` in the tests
for `BroadcastInDim` and `DynamicBroadcastInDim` (both are supported,
but we want to move towards `array<...>`).

#1578
  • Loading branch information
mlevesquedion authored Dec 16, 2023
1 parent 69eeee7 commit 76901e3
Show file tree
Hide file tree
Showing 22 changed files with 120 additions and 129 deletions.
2 changes: 1 addition & 1 deletion docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -1417,7 +1417,7 @@ in the `operand` tensor and produces a `result` tensor. More formally,
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = dense<[2, 1]>: tensor<2xi64>
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
Expand Down
10 changes: 5 additions & 5 deletions stablehlo/conversions/linalg/tests/miscellaneous.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ func.func @einsum_dynamic_size_broadcast_dot(%arg0: tensor<?x?x4xf32>, %arg1: te
// CHECK: func @broadcast_in_dim
func.func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
%0 = "stablehlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
{broadcast_dimensions = array<i64: 4, 0, 2>}
: (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
func.return %0 : tensor<7x10x6x4x5xf32>
}
Expand All @@ -385,7 +385,7 @@ func.func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf3
// CHECK: func @broadcast_in_dim_ui32
func.func @broadcast_in_dim_ui32(%operand: tensor<5x7x1xui32>) -> tensor<7x10x6x4x5xui32> {
%0 = "stablehlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
{broadcast_dimensions = array<i64: 4, 0, 2>}
: (tensor<5x7x1xui32>) -> tensor<7x10x6x4x5xui32>
func.return %0 : tensor<7x10x6x4x5xui32>
}
Expand Down Expand Up @@ -413,7 +413,7 @@ func.func @broadcast_in_dim_ui32(%operand: tensor<5x7x1xui32>) -> tensor<7x10x6x
func.func @broadcast_in_dim_with_one_to_one(
%operand: tensor<1xf32>) -> tensor<1x5xf32> {
%0 = "stablehlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[0]> : tensor<1xi64>}
{broadcast_dimensions = array<i64: 0>}
: (tensor<1xf32>) -> tensor<1x5xf32>
func.return %0 : tensor<1x5xf32>
}
Expand All @@ -436,7 +436,7 @@ func.func @broadcast_in_dim_with_one_to_one(
func.func @broadcast_in_dim_with_transpose(
%operand: tensor<2x3x4xf32>) -> tensor<3x4x2x5xf32> {
%0 = "stablehlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[2, 0, 1]> : tensor<3xi64>}
{broadcast_dimensions = array<i64: 2, 0, 1>}
: (tensor<2x3x4xf32>) -> tensor<3x4x2x5xf32>
func.return %0 : tensor<3x4x2x5xf32>
}
Expand All @@ -460,7 +460,7 @@ func.func @broadcast_in_dim_with_transpose(
// CHECK: func @broadcast_in_dim_scalar
func.func @broadcast_in_dim_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
%0 = "stablehlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[]> : tensor<0xi64>}
{broadcast_dimensions = array<i64>}
: (tensor<f32>) -> tensor<7x10x6xf32>
func.return %0 : tensor<7x10x6xf32>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,10 +627,7 @@ struct HloDynamicBroadcastInDimConverter final
SmallVector<AffineExpr> dimExprs(operandType.getRank(), nullptr);

// Use static type info.
auto bcastDims =
llvm::map_to_vector(op.getBroadcastDimensions(), [](const APInt &d) {
return static_cast<int64_t>(d.getLimitedValue());
});
auto bcastDims = op.getBroadcastDimensions();
for (auto [idx, dim] : llvm::enumerate(operandType.getShape())) {
if (ShapedType::isDynamic(dim)) continue;

Expand All @@ -640,17 +637,13 @@ struct HloDynamicBroadcastInDimConverter final
}

// Use annotated expansion behavior, if available.
if (op.getKnownExpandingDimensions()) {
for (const auto &it :
op.getKnownExpandingDimensions()->getValues<APInt>()) {
auto i = it.getLimitedValue();
if (auto dims = op.getKnownExpandingDimensions()) {
for (const auto &i : *dims) {
dimExprs[i] = rewriter.getAffineConstantExpr(0);
}
}
if (op.getKnownNonexpandingDimensions()) {
for (const auto &it :
op.getKnownNonexpandingDimensions()->getValues<APInt>()) {
auto i = it.getLimitedValue();
if (auto dims = op.getKnownNonexpandingDimensions()) {
for (const auto &i : *dims) {
dimExprs[i] = rewriter.getAffineDimExpr(bcastDims[i]);
}
}
Expand Down Expand Up @@ -697,7 +690,7 @@ struct DynamicBroadcastInDimOpToBroadcastConverter final
if (!resultTy) return failure();

SmallVector<int64_t> broadcastDimensions =
llvm::to_vector(op.getBroadcastDimensions().getValues<int64_t>());
llvm::to_vector(op.getBroadcastDimensions());

SmallVector<std::optional<bool>> expansionBehavior(
broadcastDimensions.size());
Expand All @@ -709,16 +702,14 @@ struct DynamicBroadcastInDimOpToBroadcastConverter final
}

// Use annotated expansion behavior, if available.
if (op.getKnownExpandingDimensions()) {
for (const auto &it :
op.getKnownExpandingDimensions()->getValues<int64_t>()) {
expansionBehavior[it] = true;
if (auto dims = op.getKnownExpandingDimensions()) {
for (const auto &i : *dims) {
expansionBehavior[i] = true;
}
}
if (op.getKnownNonexpandingDimensions()) {
for (const auto &it :
op.getKnownNonexpandingDimensions()->getValues<int64_t>()) {
expansionBehavior[it] = false;
if (auto dims = op.getKnownNonexpandingDimensions()) {
for (const auto &i : *dims) {
expansionBehavior[i] = false;
}
}

Expand Down
12 changes: 6 additions & 6 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1906,17 +1906,17 @@ def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp<
let arguments = (ins
HLO_Tensor:$operand,
HLO_DimensionTensor:$output_dimensions,
BroadcastDimAttr:$broadcast_dimensions,
OptionalAttr<BroadcastDimAttr>:$known_expanding_dimensions,
OptionalAttr<BroadcastDimAttr>:$known_nonexpanding_dimensions
I64DenseArrayOrElements1DAttr:$broadcast_dimensions,
OptionalAttr<I64DenseArrayOrElements1DAttr>:$known_expanding_dimensions,
OptionalAttr<I64DenseArrayOrElements1DAttr>:$known_nonexpanding_dimensions
);

let results = (outs HLO_Tensor);

let builders = [
OpBuilder<(ins
OpBuilder<(ins
"Type":$result_type, "Value":$operand, "Value":$output_dimensions,
"DenseIntElementsAttr":$broadcast_dimensions), [{
"Attribute":$broadcast_dimensions), [{
build($_builder, $_state, result_type, operand, output_dimensions,
broadcast_dimensions, /*known_expanding_dimensions=*/{},
/*known_nonexpanding_dimensions=*/{});
Expand All @@ -1926,7 +1926,7 @@ def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp<
let hasVerifier = 1;

let assemblyFormat = [{
$operand `,` $output_dimensions `,` `dims` `=` custom<DenseI64Array>($broadcast_dimensions)
$operand `,` $output_dimensions `,` `dims` `=` custom<I64DenseArrayOrElements1D>($broadcast_dimensions)
attr-dict `:` functional-type(operands, results)
}];
}
Expand Down
23 changes: 8 additions & 15 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3402,9 +3402,9 @@ LogicalResult verifyDotGeneralOp(std::optional<Location> location, Value lhs,

LogicalResult verifyDynamicBroadcastInDimOp(
std::optional<Location> location, Value operand, Value outputDimensions,
DenseIntElementsAttr broadcastDimensions,
std::optional<DenseIntElementsAttr> knownExpandingDimensions,
std::optional<DenseIntElementsAttr> knownNonexpandingDimensions,
ArrayRef<int64_t> broadcastDimensions,
std::optional<ArrayRef<int64_t>> knownExpandingDimensions,
std::optional<ArrayRef<int64_t>> knownNonexpandingDimensions,
Value result) {
auto operandType = operand.getType().dyn_cast<RankedTensorType>();
auto resultType = result.getType().dyn_cast<RankedTensorType>();
Expand All @@ -3421,14 +3421,7 @@ LogicalResult verifyDynamicBroadcastInDimOp(

// Verify broadcast_dimensions.
auto bcastDimensions = broadcastDimensions;
auto bcastDimensionsType = broadcastDimensions.getType();
auto bcastDimensionsRank = bcastDimensionsType.getRank();
// TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1.
if (bcastDimensionsRank != 1)
return emitOptionalError(location, "broadcast_dimensions has rank ",
bcastDimensionsRank, " instead of rank 1");

auto bcastDimensionsSize = bcastDimensionsType.getNumElements();
int64_t bcastDimensionsSize = bcastDimensions.size();
if (bcastDimensionsSize != operandRank)
return emitOptionalError(
location, "broadcast_dimensions size (", bcastDimensionsSize,
Expand All @@ -3439,7 +3432,7 @@ LogicalResult verifyDynamicBroadcastInDimOp(
") is less than operand rank (", operandRank, ")");

for (int i = 0; i != bcastDimensionsSize; ++i) {
auto dimIndex = bcastDimensions.getValues<int64_t>()[i];
auto dimIndex = bcastDimensions[i];
if (dimIndex < 0 || dimIndex >= resultRank)
return emitOptionalError(location,
"broadcast_dimensions contains invalid value ",
Expand Down Expand Up @@ -3467,11 +3460,11 @@ LogicalResult verifyDynamicBroadcastInDimOp(
int64_t numKnownExpansionBehavior = 0;
DenseSet<int64_t> knownExpansionBehavior;
auto collectExpansionBehaviorDims =
[&](const std::optional<DenseIntElementsAttr>& attr) {
[&](const std::optional<ArrayRef<int64_t>>& attr) {
if (!attr) return;
for (const APInt& it : *attr) {
for (const auto& i : attr.value()) {
numKnownExpansionBehavior++;
knownExpansionBehavior.insert(it.getLimitedValue());
knownExpansionBehavior.insert(i);
}
};
collectExpansionBehaviorDims(knownExpandingDimensions);
Expand Down
7 changes: 3 additions & 4 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,9 @@ LogicalResult verifyDotGeneralOp(std::optional<Location> location, Value lhs,

LogicalResult verifyDynamicBroadcastInDimOp(
std::optional<Location> location, Value operand, Value outputDimensions,
DenseIntElementsAttr broadcastDimensions,
std::optional<DenseIntElementsAttr> knownExpandingDimensions,
std::optional<DenseIntElementsAttr> knownNonexpandingDimensions,
Value result);
ArrayRef<int64_t> broadcastDimensions,
std::optional<ArrayRef<int64_t>> knownExpandingDimensions,
std::optional<ArrayRef<int64_t>> knownNonexpandingDimensions, Value result);

LogicalResult verifyDynamicIotaOp(std::optional<Location> location,
Value outputShape, int64_t iotaDimension,
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/tests/interpret_broadcast_in_dim.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
func.func @broadcast_in_dim() {
%operand = stablehlo.constant dense<[[1], [2], [3]]> : tensor<3x1xi64>
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = dense<[0, 2]>: tensor<2xi64>
broadcast_dimensions = array<i64: 0, 2>
} : (tensor<3x1xi64>) -> tensor<3x2x2xi64>
check.expect_eq_const %result, dense<[[[1, 1], [1, 1]], [[2, 2], [2, 2]], [[3, 3], [3, 3]]]> : tensor<3x2x2xi64>
func.return
Expand Down
34 changes: 17 additions & 17 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -859,39 +859,39 @@ func.func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tenso

// CHECK-LABEL: func @dynamic_broadcast_in_dim
func.func @dynamic_broadcast_in_dim(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<?x?x?xi32> {
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<?x?x?xi32>
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<?x?x?xi32>
func.return %0 : tensor<?x?x?xi32>
}

// -----

// CHECK-LABEL: func @dynamic_broadcast_in_dim_unranked
func.func @dynamic_broadcast_in_dim_unranked(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<*xi32> {
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<*xi32>
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<*xi32>
func.return %0 : tensor<*xi32>
}

// -----

// CHECK-LABEL: func @dynamic_broadcast_in_dim_unknown_dim
func.func @dynamic_broadcast_in_dim_unknown_dim(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor<?x?x?xf32> {
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 2>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
func.return %0 : tensor<?x?x?xf32>
}

// -----

// CHECK-LABEL: func @dynamic_broadcast_in_dim_ok_dim
func.func @dynamic_broadcast_in_dim_ok_dim(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> {
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 2>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
func.return %0 : tensor<7x8x9xf32>
}

// -----

func.func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> {
// expected-error@+1 {{size of operand dimension 0 (32) is not compatible with size of result dimension 2 (9)}}
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 2>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
func.return %0 : tensor<7x8x9xf32>
}

Expand All @@ -917,63 +917,63 @@ func.func @dynamic_broadcast_in_dim_output_dimensions_mismatching_size(%arg0: te

func.func @dynamic_broadcast_in_dim_negative_size(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> {
// expected-error@+1 {{broadcast_dimensions contains invalid value -1 for result with rank 3}}
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[-1]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: -1>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
func.return %0 : tensor<7x8x9xf32>
}

// -----

func.func @dynamic_broadcast_in_dim_too_large(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> {
// expected-error@+1 {{broadcast_dimensions contains invalid value 3 for result with rank 3}}
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[3]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 3>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32>
func.return %0 : tensor<7x8x9xf32>
}

// -----

// CHECK-LABEL: func @broadcast_in_dim
func.func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> {
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32>
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32>
func.return %0 : tensor<1x2x2xi32>
}

// -----

func.func @broadcast_in_dim_c2(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{broadcast_dimensions size (1) does not match operand rank (2)}}
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
func.return %0 : tensor<1x2x3xi32>
}

// -----

func.func @broadcast_in_dim_c3(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> {
// expected-error@+1 {{broadcast_dimensions contains invalid value -1 for result with rank 3}}
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[-1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32>
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: -1, 2>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32>
func.return %0 : tensor<1x2x2xi32>
}

// -----

func.func @broadcast_in_dim_c3(%arg0: tensor<1x2x3xi32>) -> tensor<3xi32> {
// expected-error@+1 {{broadcast_dimensions contains invalid value 1 for result with rank 1}}
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32>
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 0,1,2>} : (tensor<1x2x3xi32>) -> tensor<3xi32>
func.return %0 : tensor<3xi32>
}

// -----

func.func @broadcast_in_dim_c4(%arg0: tensor<1x1x3xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{broadcast_dimensions should not have duplicates}}
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,0,2]> : tensor<3xi64>} : (tensor<1x1x3xi32>) -> tensor<1x2x3xi32>
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 0,0,2>} : (tensor<1x1x3xi32>) -> tensor<1x2x3xi32>
func.return %0 : tensor<1x2x3xi32>
}

// -----

func.func @broadcast_in_dim_c5(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{size of operand dimension 0 (3) is not equal to 1 or size of result dimension 1 (2)}}
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1>} : (tensor<3xi32>) -> tensor<1x2x3xi32>
func.return %0 : tensor<1x2x3xi32>
}

Expand All @@ -992,7 +992,7 @@ func.func @broadcast_in_dim_i2(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
// CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand
func.func @broadcast_in_dim_dynamic_shaped_operand(%arg0 : tensor<?xf32>) -> tensor<2xf32> {
%0 = "stablehlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<0> : tensor<1xi64>
broadcast_dimensions = array<i64: 0>
} : (tensor<?xf32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
Expand All @@ -1004,7 +1004,7 @@ func.func @broadcast_in_dim_dynamic_shaped_operand(%arg0 : tensor<?xf32>) -> ten
// CHECK-LABEL: func @broadcast_in_dim_unranked_operand
func.func @broadcast_in_dim_unranked_operand(%arg0 : tensor<*xf32>) -> tensor<2xf32> {
%0 = "stablehlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<0> : tensor<1xi64>
broadcast_dimensions = array<i64: 0>
} : (tensor<*xf32>) -> tensor<2xf32>
func.return %0 : tensor<2xf32>
}
Expand Down Expand Up @@ -5342,8 +5342,8 @@ func.func @quantization_supported_ops(%arg0: tensor<1x2x2x!quant.uniform<i8:f32,
}

func.func @per_axis_quantized_ops(%arg0: tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>, %arg1: tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:0, {0.1:-30}>>) {
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,3]> : tensor<3xi64>} : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform<i8<-128:127>:f32:3, {0.1:-30, 0.5:-20}>>
%1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:0, {0.1:-30}>>) -> tensor<2x2x2x!quant.uniform<i8<-128:127>:f32:0, {0.1:-30, 0.1:-30}>>
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 0, 1, 3>} : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform<i8<-128:127>:f32:3, {0.1:-30, 0.5:-20}>>
%1 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = array<i64: 0, 1, 2>} : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:0, {0.1:-30}>>) -> tensor<2x2x2x!quant.uniform<i8<-128:127>:f32:0, {0.1:-30, 0.1:-30}>>
%2 = stablehlo.reshape %arg0 : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<2x2x!quant.uniform<i8<-128:127>:f32:1, {0.1:-30, 0.5:-20}>>
%3 = "stablehlo.transpose"(%arg0) {permutation = array<i64: 0, 2, 1>}: (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:1, {0.1:-30, 0.5:-20}>>
func.return
Expand Down
Loading

0 comments on commit 76901e3

Please sign in to comment.