From 1455e52d2d3014ef1fa79ab7dff7f1cd390e7a0c Mon Sep 17 00:00:00 2001 From: Michael Date: Mon, 22 Apr 2024 17:59:57 +0300 Subject: [PATCH] Add rewrite pattern into AggresiveSimplification that turns compare+select into min/max (#2244) This patch rewrites a comparison and selection combination that utilizes the same operands into a min/max operation: ``` func.func @test(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>) { %0 = stablehlo.compare GE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> %1 = stablehlo.compare LE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %s0, %s1 : tensor<2xi32>, tensor<2xi32> } ``` transformed into: ``` func.func @test(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>) { %0 = stablehlo.maximum %arg0, %arg1 : tensor<2xi32> %1 = stablehlo.minimum %arg0, %arg1 : tensor<2xi32> return %0, %1 : tensor<2xi32>, tensor<2xi32> } ``` --- .../stablehlo_aggressive_simplification.mlir | 86 +++++++++++++++++++ .../StablehloAggressiveSimplification.cpp | 50 +++++++++++ 2 files changed, 136 insertions(+) diff --git a/stablehlo/tests/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/stablehlo_aggressive_simplification.mlir index acf80e9231..ab229ceefd 100644 --- a/stablehlo/tests/stablehlo_aggressive_simplification.mlir +++ b/stablehlo/tests/stablehlo_aggressive_simplification.mlir @@ -238,6 +238,92 @@ func.func @select(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %argC: tensor<2xi1 // ----- +// CHECK-LABEL: func.func @select_into_minmax1 +// CHECK-SAME: [[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<2xi32>, [[ARG2:%.+]]: tensor<2xi32>, [[ARG3:%.+]]: tensor<2xi32>) +func.func @select_into_minmax1(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, + %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) + -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) { + + %0 = stablehlo.compare EQ, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %1 = stablehlo.compare NE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %2 = stablehlo.compare GE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %3 = stablehlo.compare GT, %arg0, %arg2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %4 = stablehlo.compare LE, %arg1, %arg2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %5 = stablehlo.compare LT, %arg1, %arg3, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + + %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s2 = stablehlo.select %2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s3 = stablehlo.select %3, %arg0, %arg2 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s4 = stablehlo.select %4, %arg1, %arg2 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s5 = stablehlo.select %5, %arg1, %arg3 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + + // CHECK-DAG: [[C0:%.+]] = stablehlo.compare EQ, [[ARG0]], [[ARG1]], SIGNED + // CHECK-DAG: [[C1:%.+]] = stablehlo.compare NE, [[ARG0]], [[ARG1]], SIGNED + + // CHECK-DAG: [[S0:%.+]] = stablehlo.select [[C0]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C1]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S2:%.+]] = stablehlo.maximum [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S3:%.+]] = stablehlo.maximum [[ARG0]], [[ARG2]] + // CHECK-DAG: [[S4:%.+]] = stablehlo.minimum [[ARG1]], [[ARG2]] + // CHECK-DAG: [[S5:%.+]] = stablehlo.minimum [[ARG1]], [[ARG3]] + + // CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]], [[S4]], [[S5]] + return %s0, %s1, %s2, %s3, %s4, %s5 : + tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: func.func @select_into_minmax2 +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor, [[ARG3:%.+]]: tensor) +func.func @select_into_minmax2(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) + -> (tensor, tensor, tensor, tensor, + tensor, tensor, tensor, tensor) { + + %0 = stablehlo.compare GT, %arg1, %arg0, SIGNED : (tensor, tensor) -> tensor + %1 = stablehlo.compare GT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor + %2 = stablehlo.compare GE, %arg1, %arg3, SIGNED : (tensor, tensor) -> tensor + %3 = stablehlo.compare GE, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor + + %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor, tensor, tensor) -> tensor + %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor, tensor, tensor) -> tensor + %s2 = stablehlo.select %2, %arg3, %arg1 : (tensor, tensor, tensor) -> tensor + %s3 = stablehlo.select %3, %arg0, %arg2 : (tensor, tensor, tensor) -> tensor + + %4 = stablehlo.compare LT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor + %5 = stablehlo.compare LT, %arg0, %arg2, SIGNED : (tensor, tensor) -> tensor + %6 = stablehlo.compare LE, %arg2, %arg3, SIGNED : (tensor, tensor) -> tensor + %7 = stablehlo.compare LE, %arg0, %arg2, SIGNED : (tensor, tensor) -> tensor + + %s4 = stablehlo.select %4, %arg2, %arg1 : (tensor, tensor, tensor) -> tensor + %s5 = stablehlo.select %5, %arg1, %arg2 : (tensor, tensor, tensor) -> tensor + %s6 = stablehlo.select %6, %arg3, %arg2 : (tensor, tensor, tensor) -> tensor + %s7 = stablehlo.select %7, %arg2, %arg3 : (tensor, tensor, tensor) -> tensor + + // CHECK-DAG: [[C1:%.+]] = stablehlo.compare GT, [[ARG1]], [[ARG2]], SIGNED + // CHECK-DAG: [[C3:%.+]] = stablehlo.compare GE, [[ARG1]], [[ARG2]], SIGNED + + // CHECK-DAG: [[S0:%.+]] = stablehlo.minimum [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C1]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S2:%.+]] = stablehlo.minimum [[ARG3]], [[ARG1]] + // CHECK-DAG: [[S3:%.+]] = stablehlo.select [[C3]], [[ARG0]], [[ARG2]] + + // CHECK-DAG: [[C5:%.+]] = stablehlo.compare LT, [[ARG0]], [[ARG2]], SIGNED + // CHECK-DAG: [[C7:%.+]] = stablehlo.compare LE, [[ARG0]], [[ARG2]], SIGNED + + // CHECK-DAG: [[S4:%.+]] = stablehlo.maximum [[ARG2]], [[ARG1]] + // CHECK-DAG: [[S5:%.+]] = stablehlo.select [[C5]], [[ARG1]], [[ARG2]] + // CHECK-DAG: [[S6:%.+]] = stablehlo.maximum [[ARG3]], [[ARG2]] + // CHECK-DAG: [[S7:%.+]] = stablehlo.select [[C7]], [[ARG2]], [[ARG3]] + + // CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]], [[S4]], [[S5]], [[S6]], [[S7]] + return %s0, %s1, %s2, %s3, %s4, %s5, %s6, %s7 : tensor, tensor, tensor, tensor, + tensor, tensor, tensor, tensor +} + +// ----- + // CHECK-LABEL: func.func @broadcast_in_dim // CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>) func.func @broadcast_in_dim(%arg0: tensor<3x3xi32>) diff --git a/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/StablehloAggressiveSimplification.cpp index cb3223edbc..8e4fac9d46 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/StablehloAggressiveSimplification.cpp @@ -395,6 +395,55 @@ struct SelectOpCanon final : OpRewritePattern { } }; +struct CompareSelectIntoMinMax final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op, + PatternRewriter &rewriter) const override { + Value pred = op.getPred(); + Value trueVal = op.getOnTrue(); + Value falseVal = op.getOnFalse(); + + auto cmpOp = pred.getDefiningOp(); + if (!cmpOp) return failure(); + + using mlir::stablehlo::ComparisonDirection; + ComparisonDirection direction = cmpOp.getComparisonDirection(); + Value cmpLhs = cmpOp.getLhs(); + Value cmpRhs = cmpOp.getRhs(); + + // Turn into canonical form: + // b <= a ? a : b ---> a >= b ? a : b + // b < a ? a : b ---> a > b ? a : b + // b >= a ? a : b ---> a <= b ? a : b + // b > a ? a : b ---> a < b ? a : b + if (cmpLhs == falseVal && cmpRhs == trueVal) { + direction = invertDirection(direction); + } else if (!(cmpLhs == trueVal && cmpRhs == falseVal)) { + return failure(); + } + + switch (direction) { + case ComparisonDirection::GE: + case ComparisonDirection::GT: { + rewriter.replaceOpWithNewOp(op, trueVal, + falseVal); + return success(); + } + case ComparisonDirection::LE: + case ComparisonDirection::LT: { + rewriter.replaceOpWithNewOp(op, trueVal, + falseVal); + return success(); + } + default: { + return failure(); + } + } + } +}; + struct BroadcastInDimOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1161,6 +1210,7 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context, patterns->add< // Arithmetic ops. AddOpCanon, SubtractOpCanon, MulOpCanon, CompareOpCanon, SelectOpCanon, + CompareSelectIntoMinMax, // Complex ops. RealOpCanon, ImagOpCanon, // Query ops.