Skip to content

Commit

Permalink
Add rewrite pattern into AggresiveSimplification that turns compare+s…
Browse files Browse the repository at this point in the history
…elect into min/max (#2244)

This patch rewrites a comparison and selection combination that utilizes
the same operands into a min/max operation:

```
func.func @test(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>) {
  %0 = stablehlo.compare GE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
  %1 = stablehlo.compare LE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
  %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
  %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
  return %s0, %s1 : tensor<2xi32>, tensor<2xi32>
}
```
transformed into:
```
func.func @test(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>) {
  %0 = stablehlo.maximum %arg0, %arg1 : tensor<2xi32>
  %1 = stablehlo.minimum %arg0, %arg1 : tensor<2xi32>
  return %0, %1 : tensor<2xi32>, tensor<2xi32>
}
```
  • Loading branch information
mvpant authored Apr 22, 2024
1 parent f394f9b commit 1455e52
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 0 deletions.
86 changes: 86 additions & 0 deletions stablehlo/tests/stablehlo_aggressive_simplification.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,92 @@ func.func @select(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %argC: tensor<2xi1

// -----

// CHECK-LABEL: func.func @select_into_minmax1
// CHECK-SAME: [[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<2xi32>, [[ARG2:%.+]]: tensor<2xi32>, [[ARG3:%.+]]: tensor<2xi32>)
func.func @select_into_minmax1(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>,
%arg2: tensor<2xi32>, %arg3: tensor<2xi32>)
-> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) {

%0 = stablehlo.compare EQ, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%1 = stablehlo.compare NE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%2 = stablehlo.compare GE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%3 = stablehlo.compare GT, %arg0, %arg2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%4 = stablehlo.compare LE, %arg1, %arg2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%5 = stablehlo.compare LT, %arg1, %arg3, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>

%s0 = stablehlo.select %0, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%s1 = stablehlo.select %1, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%s2 = stablehlo.select %2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%s3 = stablehlo.select %3, %arg0, %arg2 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%s4 = stablehlo.select %4, %arg1, %arg2 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%s5 = stablehlo.select %5, %arg1, %arg3 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>

// CHECK-DAG: [[C0:%.+]] = stablehlo.compare EQ, [[ARG0]], [[ARG1]], SIGNED
// CHECK-DAG: [[C1:%.+]] = stablehlo.compare NE, [[ARG0]], [[ARG1]], SIGNED

// CHECK-DAG: [[S0:%.+]] = stablehlo.select [[C0]], [[ARG0]], [[ARG1]]
// CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C1]], [[ARG0]], [[ARG1]]
// CHECK-DAG: [[S2:%.+]] = stablehlo.maximum [[ARG0]], [[ARG1]]
// CHECK-DAG: [[S3:%.+]] = stablehlo.maximum [[ARG0]], [[ARG2]]
// CHECK-DAG: [[S4:%.+]] = stablehlo.minimum [[ARG1]], [[ARG2]]
// CHECK-DAG: [[S5:%.+]] = stablehlo.minimum [[ARG1]], [[ARG3]]

// CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]], [[S4]], [[S5]]
return %s0, %s1, %s2, %s3, %s4, %s5 :
tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>
}

// -----

// CHECK-LABEL: func.func @select_into_minmax2
// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>, [[ARG1:%.+]]: tensor<i32>, [[ARG2:%.+]]: tensor<i32>, [[ARG3:%.+]]: tensor<i32>)
func.func @select_into_minmax2(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>)
-> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>,
tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {

%0 = stablehlo.compare GT, %arg1, %arg0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%1 = stablehlo.compare GT, %arg1, %arg2, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%2 = stablehlo.compare GE, %arg1, %arg3, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = stablehlo.compare GE, %arg1, %arg2, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>

%s0 = stablehlo.select %0, %arg0, %arg1 : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%s1 = stablehlo.select %1, %arg0, %arg1 : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%s2 = stablehlo.select %2, %arg3, %arg1 : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%s3 = stablehlo.select %3, %arg0, %arg2 : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>

%4 = stablehlo.compare LT, %arg1, %arg2, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%5 = stablehlo.compare LT, %arg0, %arg2, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%6 = stablehlo.compare LE, %arg2, %arg3, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%7 = stablehlo.compare LE, %arg0, %arg2, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>

%s4 = stablehlo.select %4, %arg2, %arg1 : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%s5 = stablehlo.select %5, %arg1, %arg2 : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%s6 = stablehlo.select %6, %arg3, %arg2 : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
%s7 = stablehlo.select %7, %arg2, %arg3 : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>

// CHECK-DAG: [[C1:%.+]] = stablehlo.compare GT, [[ARG1]], [[ARG2]], SIGNED
// CHECK-DAG: [[C3:%.+]] = stablehlo.compare GE, [[ARG1]], [[ARG2]], SIGNED

// CHECK-DAG: [[S0:%.+]] = stablehlo.minimum [[ARG0]], [[ARG1]]
// CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C1]], [[ARG0]], [[ARG1]]
// CHECK-DAG: [[S2:%.+]] = stablehlo.minimum [[ARG3]], [[ARG1]]
// CHECK-DAG: [[S3:%.+]] = stablehlo.select [[C3]], [[ARG0]], [[ARG2]]

// CHECK-DAG: [[C5:%.+]] = stablehlo.compare LT, [[ARG0]], [[ARG2]], SIGNED
// CHECK-DAG: [[C7:%.+]] = stablehlo.compare LE, [[ARG0]], [[ARG2]], SIGNED

// CHECK-DAG: [[S4:%.+]] = stablehlo.maximum [[ARG2]], [[ARG1]]
// CHECK-DAG: [[S5:%.+]] = stablehlo.select [[C5]], [[ARG1]], [[ARG2]]
// CHECK-DAG: [[S6:%.+]] = stablehlo.maximum [[ARG3]], [[ARG2]]
// CHECK-DAG: [[S7:%.+]] = stablehlo.select [[C7]], [[ARG2]], [[ARG3]]

// CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]], [[S4]], [[S5]], [[S6]], [[S7]]
return %s0, %s1, %s2, %s3, %s4, %s5, %s6, %s7 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>,
tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}

// -----

// CHECK-LABEL: func.func @broadcast_in_dim
// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>)
func.func @broadcast_in_dim(%arg0: tensor<3x3xi32>)
Expand Down
50 changes: 50 additions & 0 deletions stablehlo/transforms/StablehloAggressiveSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,55 @@ struct SelectOpCanon final : OpRewritePattern<mlir::stablehlo::SelectOp> {
}
};

struct CompareSelectIntoMinMax final
: OpRewritePattern<mlir::stablehlo::SelectOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op,
PatternRewriter &rewriter) const override {
Value pred = op.getPred();
Value trueVal = op.getOnTrue();
Value falseVal = op.getOnFalse();

auto cmpOp = pred.getDefiningOp<mlir::stablehlo::CompareOp>();
if (!cmpOp) return failure();

using mlir::stablehlo::ComparisonDirection;
ComparisonDirection direction = cmpOp.getComparisonDirection();
Value cmpLhs = cmpOp.getLhs();
Value cmpRhs = cmpOp.getRhs();

// Turn into canonical form:
// b <= a ? a : b ---> a >= b ? a : b
// b < a ? a : b ---> a > b ? a : b
// b >= a ? a : b ---> a <= b ? a : b
// b > a ? a : b ---> a < b ? a : b
if (cmpLhs == falseVal && cmpRhs == trueVal) {
direction = invertDirection(direction);
} else if (!(cmpLhs == trueVal && cmpRhs == falseVal)) {
return failure();
}

switch (direction) {
case ComparisonDirection::GE:
case ComparisonDirection::GT: {
rewriter.replaceOpWithNewOp<mlir::stablehlo::MaxOp>(op, trueVal,
falseVal);
return success();
}
case ComparisonDirection::LE:
case ComparisonDirection::LT: {
rewriter.replaceOpWithNewOp<mlir::stablehlo::MinOp>(op, trueVal,
falseVal);
return success();
}
default: {
return failure();
}
}
}
};

struct BroadcastInDimOpCanon final
: OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -1161,6 +1210,7 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context,
patterns->add<
// Arithmetic ops.
AddOpCanon, SubtractOpCanon, MulOpCanon, CompareOpCanon, SelectOpCanon,
CompareSelectIntoMinMax,
// Complex ops.
RealOpCanon, ImagOpCanon,
// Query ops.
Expand Down

0 comments on commit 1455e52

Please sign in to comment.