Skip to content

Commit

Permalink
Add EvalOrPattern to StablehloRefineShapes pass (#1867)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghpvnist authored Dec 1, 2023
1 parent b3d018f commit 57e5a4a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
13 changes: 13 additions & 0 deletions stablehlo/tests/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,19 @@ func.func @eval_multiply() -> tensor<i64> {

// -----

// CHECK-LABEL: func @eval_or
func.func @eval_or() -> tensor<i1> {
// CHECK-NOT: stablehlo.or
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<true> : tensor<i1>
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<true> : tensor<i1>
%1 = stablehlo.constant dense<false> : tensor<i1>
%2 = stablehlo.or %0, %1 : tensor<i1>
func.return %2 : tensor<i1>
}

// -----

// CHECK-LABEL: func @eval_remainder
func.func @eval_remainder() -> tensor<i64> {
// CHECK-NOT: stablehlo.remainder
Expand Down
15 changes: 15 additions & 0 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,20 @@ struct EvalMulOpPattern : public OpRewritePattern<MulOp> {
}
};

struct EvalOrOpPattern : public OpRewritePattern<OrOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OrOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (!resultType.getElementType().isInteger(1))
return rewriter.notifyMatchFailure(op, "expected boolean element type");

return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) {
return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0);
});
}
};

struct EvalRemOpPattern : public OpRewritePattern<RemOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(RemOp op,
Expand Down Expand Up @@ -1165,6 +1179,7 @@ struct StablehloRefineShapesPass
patterns.add<EvalMaxOpPattern>(&getContext());
patterns.add<EvalMinOpPattern>(&getContext());
patterns.add<EvalMulOpPattern>(&getContext());
patterns.add<EvalOrOpPattern>(&getContext());
patterns.add<EvalRemOpPattern>(&getContext());
patterns.add<EvalReshapeOpPattern>(&getContext());
patterns.add<EvalSelectOpPattern>(&getContext());
Expand Down

0 comments on commit 57e5a4a

Please sign in to comment.