From 348e9a0069a4203d28b54fc971e47f67415d7178 Mon Sep 17 00:00:00 2001 From: "m.pantilimonov" Date: Fri, 19 Apr 2024 22:12:04 +0300 Subject: [PATCH 1/2] Add rewrite pattern into AggresiveSimplification that turns compare+select into min/max --- .../stablehlo_aggressive_simplification.mlir | 66 +++++++++++++++++++ .../StablehloAggressiveSimplification.cpp | 49 ++++++++++++++ 2 files changed, 115 insertions(+) diff --git a/stablehlo/tests/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/stablehlo_aggressive_simplification.mlir index e332adbb6c..fa989f1289 100644 --- a/stablehlo/tests/stablehlo_aggressive_simplification.mlir +++ b/stablehlo/tests/stablehlo_aggressive_simplification.mlir @@ -238,6 +238,72 @@ func.func @select(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %argC: tensor<2xi1 // ----- +// CHECK-LABEL: func.func @select_into_minmax1 +// CHECK-SAME: [[ARG0:%.+]]: tensor<2xi32>, [[ARG1:%.+]]: tensor<2xi32>, [[ARG2:%.+]]: tensor<2xi32>, [[ARG3:%.+]]: tensor<2xi32>) +func.func @select_into_minmax1(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, + %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) + -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) { + + %0 = stablehlo.compare EQ, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %1 = stablehlo.compare NE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %2 = stablehlo.compare GE, %arg0, %arg1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %3 = stablehlo.compare GT, %arg0, %arg2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %4 = stablehlo.compare LE, %arg1, %arg2, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %5 = stablehlo.compare LT, %arg1, %arg3, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + + %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s2 = stablehlo.select %2, %arg0, %arg1 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s3 = stablehlo.select %3, %arg0, %arg2 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s4 = stablehlo.select %4, %arg1, %arg2 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %s5 = stablehlo.select %5, %arg1, %arg3 : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + + // CHECK-DAG: [[C0:%.+]] = stablehlo.compare EQ, [[ARG0]], [[ARG1]], SIGNED + // CHECK-DAG: [[C1:%.+]] = stablehlo.compare NE, [[ARG0]], [[ARG1]], SIGNED + + // CHECK-DAG: [[S0:%.+]] = stablehlo.select [[C0]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C1]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S2:%.+]] = stablehlo.maximum [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S3:%.+]] = stablehlo.maximum [[ARG0]], [[ARG2]] + // CHECK-DAG: [[S4:%.+]] = stablehlo.minimum [[ARG1]], [[ARG2]] + // CHECK-DAG: [[S5:%.+]] = stablehlo.minimum [[ARG1]], [[ARG3]] + + // CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]], [[S4]], [[S5]] + return %s0, %s1, %s2, %s3, %s4, %s5 : + tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: func.func @select_into_minmax2 +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor) +func.func @select_into_minmax2(%arg0: tensor, %arg1: tensor, %arg2: tensor) + -> (tensor, tensor, tensor, tensor) { + + %0 = stablehlo.compare GT, %arg1, %arg0, SIGNED : (tensor, tensor) -> tensor + %1 = stablehlo.compare GT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor + %2 = stablehlo.compare LT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor + %3 = stablehlo.compare LT, %arg0, %arg2, SIGNED : (tensor, tensor) -> tensor + + %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor, tensor, tensor) -> tensor + %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor, tensor, tensor) -> tensor + %s2 = stablehlo.select %2, %arg2, %arg1 : (tensor, tensor, tensor) -> tensor + %s3 = stablehlo.select %3, %arg1, %arg2 : (tensor, tensor, tensor) -> tensor + + // CHECK-DAG: [[C0:%.+]] = stablehlo.compare GT, [[ARG1]], [[ARG2]], SIGNED + // CHECK-DAG: [[C1:%.+]] = stablehlo.compare LT, [[ARG0]], [[ARG2]], SIGNED + + // CHECK-DAG: [[S0:%.+]] = stablehlo.minimum [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C0]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S2:%.+]] = stablehlo.maximum [[ARG2]], [[ARG1]] + // CHECK-DAG: [[S3:%.+]] = stablehlo.select [[C1]], [[ARG1]], [[ARG2]] + + // CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]] + return %s0, %s1, %s2, %s3 : tensor, tensor, tensor, tensor +} + +// ----- + // CHECK-LABEL: func.func @broadcast_in_dim // CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>) func.func @broadcast_in_dim(%arg0: tensor<3x3xi32>) diff --git a/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/StablehloAggressiveSimplification.cpp index 9e91efd00d..2a21312432 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/StablehloAggressiveSimplification.cpp @@ -392,6 +392,54 @@ struct SelectOpCanon final : OpRewritePattern { } }; +struct SelectIntoMinMax final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op, + PatternRewriter &rewriter) const override { + Value pred = op.getPred(); + Value trueVal = op.getOnTrue(); + Value falseVal = op.getOnFalse(); + + auto cmpOp = pred.getDefiningOp(); + if (!cmpOp) return failure(); + + using mlir::stablehlo::ComparisonDirection; + ComparisonDirection direction = cmpOp.getComparisonDirection(); + Value lhs = cmpOp.getLhs(); + Value cmpRhs = cmpOp.getRhs(); + + // Turn into canonical form: + // b <= a ? a : b ---> a >= b ? a : b + // b < a ? a : b ---> a > b ? a : b + // b >= a ? a : b ---> a <= b ? a : b + // b > a ? a : b ---> a < b ? a : b + if (lhs == falseVal && cmpRhs == trueVal) { + direction = invertDirection(direction); + } else if (!(lhs == trueVal && cmpRhs == falseVal)) { + return failure(); + } + + switch (direction) { + case ComparisonDirection::GE: + case ComparisonDirection::GT: { + rewriter.replaceOpWithNewOp(op, trueVal, + falseVal); + return success(); + } + case ComparisonDirection::LE: + case ComparisonDirection::LT: { + rewriter.replaceOpWithNewOp(op, trueVal, + falseVal); + return success(); + } + default: { + return failure(); + } + } + } +}; + struct BroadcastInDimOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1047,6 +1095,7 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context, patterns->add< // Arithmetic ops. AddOpCanon, SubtractOpCanon, MulOpCanon, CompareOpCanon, SelectOpCanon, + SelectIntoMinMax, // Complex ops. RealOpCanon, ImagOpCanon, // Query ops. From 4c9ded0a4102fdfed8ccf4ebd5b774ad06f2b1c9 Mon Sep 17 00:00:00 2001 From: "m.pantilimonov" Date: Mon, 22 Apr 2024 12:06:17 +0300 Subject: [PATCH 2/2] apply comments --- .../stablehlo_aggressive_simplification.mlir | 48 +++++++++++++------ .../StablehloAggressiveSimplification.cpp | 11 +++-- 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/stablehlo/tests/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/stablehlo_aggressive_simplification.mlir index fa989f1289..bdcab59625 100644 --- a/stablehlo/tests/stablehlo_aggressive_simplification.mlir +++ b/stablehlo/tests/stablehlo_aggressive_simplification.mlir @@ -276,30 +276,50 @@ func.func @select_into_minmax1(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, // ----- // CHECK-LABEL: func.func @select_into_minmax2 -// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor) -func.func @select_into_minmax2(%arg0: tensor, %arg1: tensor, %arg2: tensor) - -> (tensor, tensor, tensor, tensor) { +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor, [[ARG3:%.+]]: tensor) +func.func @select_into_minmax2(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) + -> (tensor, tensor, tensor, tensor, + tensor, tensor, tensor, tensor) { %0 = stablehlo.compare GT, %arg1, %arg0, SIGNED : (tensor, tensor) -> tensor %1 = stablehlo.compare GT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor - %2 = stablehlo.compare LT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor - %3 = stablehlo.compare LT, %arg0, %arg2, SIGNED : (tensor, tensor) -> tensor + %2 = stablehlo.compare GE, %arg1, %arg3, SIGNED : (tensor, tensor) -> tensor + %3 = stablehlo.compare GE, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor %s0 = stablehlo.select %0, %arg0, %arg1 : (tensor, tensor, tensor) -> tensor %s1 = stablehlo.select %1, %arg0, %arg1 : (tensor, tensor, tensor) -> tensor - %s2 = stablehlo.select %2, %arg2, %arg1 : (tensor, tensor, tensor) -> tensor - %s3 = stablehlo.select %3, %arg1, %arg2 : (tensor, tensor, tensor) -> tensor + %s2 = stablehlo.select %2, %arg3, %arg1 : (tensor, tensor, tensor) -> tensor + %s3 = stablehlo.select %3, %arg0, %arg2 : (tensor, tensor, tensor) -> tensor - // CHECK-DAG: [[C0:%.+]] = stablehlo.compare GT, [[ARG1]], [[ARG2]], SIGNED - // CHECK-DAG: [[C1:%.+]] = stablehlo.compare LT, [[ARG0]], [[ARG2]], SIGNED + %4 = stablehlo.compare LT, %arg1, %arg2, SIGNED : (tensor, tensor) -> tensor + %5 = stablehlo.compare LT, %arg0, %arg2, SIGNED : (tensor, tensor) -> tensor + %6 = stablehlo.compare LE, %arg2, %arg3, SIGNED : (tensor, tensor) -> tensor + %7 = stablehlo.compare LE, %arg0, %arg2, SIGNED : (tensor, tensor) -> tensor + + %s4 = stablehlo.select %4, %arg2, %arg1 : (tensor, tensor, tensor) -> tensor + %s5 = stablehlo.select %5, %arg1, %arg2 : (tensor, tensor, tensor) -> tensor + %s6 = stablehlo.select %6, %arg3, %arg2 : (tensor, tensor, tensor) -> tensor + %s7 = stablehlo.select %7, %arg2, %arg3 : (tensor, tensor, tensor) -> tensor + + // CHECK-DAG: [[C1:%.+]] = stablehlo.compare GT, [[ARG1]], [[ARG2]], SIGNED + // CHECK-DAG: [[C3:%.+]] = stablehlo.compare GE, [[ARG1]], [[ARG2]], SIGNED // CHECK-DAG: [[S0:%.+]] = stablehlo.minimum [[ARG0]], [[ARG1]] - // CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C0]], [[ARG0]], [[ARG1]] - // CHECK-DAG: [[S2:%.+]] = stablehlo.maximum [[ARG2]], [[ARG1]] - // CHECK-DAG: [[S3:%.+]] = stablehlo.select [[C1]], [[ARG1]], [[ARG2]] + // CHECK-DAG: [[S1:%.+]] = stablehlo.select [[C1]], [[ARG0]], [[ARG1]] + // CHECK-DAG: [[S2:%.+]] = stablehlo.minimum [[ARG3]], [[ARG1]] + // CHECK-DAG: [[S3:%.+]] = stablehlo.select [[C3]], [[ARG0]], [[ARG2]] + + // CHECK-DAG: [[C5:%.+]] = stablehlo.compare LT, [[ARG0]], [[ARG2]], SIGNED + // CHECK-DAG: [[C7:%.+]] = stablehlo.compare LE, [[ARG0]], [[ARG2]], SIGNED + + // CHECK-DAG: [[S4:%.+]] = stablehlo.maximum [[ARG2]], [[ARG1]] + // CHECK-DAG: [[S5:%.+]] = stablehlo.select [[C5]], [[ARG1]], [[ARG2]] + // CHECK-DAG: [[S6:%.+]] = stablehlo.maximum [[ARG3]], [[ARG2]] + // CHECK-DAG: [[S7:%.+]] = stablehlo.select [[C7]], [[ARG2]], [[ARG3]] - // CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]] - return %s0, %s1, %s2, %s3 : tensor, tensor, tensor, tensor + // CHECK-NEXT: return [[S0]], [[S1]], [[S2]], [[S3]], [[S4]], [[S5]], [[S6]], [[S7]] + return %s0, %s1, %s2, %s3, %s4, %s5, %s6, %s7 : tensor, tensor, tensor, tensor, + tensor, tensor, tensor, tensor } // ----- diff --git a/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/StablehloAggressiveSimplification.cpp index 2a21312432..a0bfd6f687 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/StablehloAggressiveSimplification.cpp @@ -392,7 +392,8 @@ struct SelectOpCanon final : OpRewritePattern { } }; -struct SelectIntoMinMax final : OpRewritePattern { +struct CompareSelectIntoMinMax final + : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::stablehlo::SelectOp op, @@ -406,7 +407,7 @@ struct SelectIntoMinMax final : OpRewritePattern { using mlir::stablehlo::ComparisonDirection; ComparisonDirection direction = cmpOp.getComparisonDirection(); - Value lhs = cmpOp.getLhs(); + Value cmpLhs = cmpOp.getLhs(); Value cmpRhs = cmpOp.getRhs(); // Turn into canonical form: @@ -414,9 +415,9 @@ struct SelectIntoMinMax final : OpRewritePattern { // b < a ? a : b ---> a > b ? a : b // b >= a ? a : b ---> a <= b ? a : b // b > a ? a : b ---> a < b ? a : b - if (lhs == falseVal && cmpRhs == trueVal) { + if (cmpLhs == falseVal && cmpRhs == trueVal) { direction = invertDirection(direction); - } else if (!(lhs == trueVal && cmpRhs == falseVal)) { + } else if (!(cmpLhs == trueVal && cmpRhs == falseVal)) { return failure(); } @@ -1095,7 +1096,7 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context, patterns->add< // Arithmetic ops. AddOpCanon, SubtractOpCanon, MulOpCanon, CompareOpCanon, SelectOpCanon, - SelectIntoMinMax, + CompareSelectIntoMinMax, // Complex ops. RealOpCanon, ImagOpCanon, // Query ops.