From 71ddfe8e0851dc282363a3454aae70ea02b2170a Mon Sep 17 00:00:00 2001 From: mlevesquedion Date: Mon, 15 Apr 2024 17:13:48 -0700 Subject: [PATCH] Simplify getResult().getType() to just getType() (#2218) --- .../transforms/StablehloLegalizeToLinalg.cpp | 3 +- .../StablehloToLinalgConvolution.cpp | 6 ++-- .../transforms/StablehloToLinalgRandom.cpp | 2 +- .../transforms/StablehloLegalizeToTosa.cpp | 6 ++-- .../transforms/ChloLegalizeToStablehlo.cpp | 8 ++--- .../transforms/ShapeLegalizeToStablehlo.cpp | 30 +++++++++---------- 6 files changed, 26 insertions(+), 29 deletions(-) diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 2ae76a00c1..4d6c15b20c 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -2301,8 +2301,7 @@ struct PadOpConversion final : OpConversionPattern { mlir::stablehlo::PadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto resultType = - getTypeConverter()->convertType(op.getResult().getType()); + auto resultType = getTypeConverter()->convertType(op.getType()); if (!resultType) return rewriter.notifyMatchFailure(op, "type conversion failed"); diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp index 39c0634321..e5531ec9e5 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp @@ -189,7 +189,7 @@ struct NormalConvolutionOpConversion final Value filter = adaptor.getRhs(); filter = applyConvolutionReversal(loc, rewriter, op, filter); auto resultType = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); + getTypeConverter()->convertType(op.getType())); if (!resultType) { return rewriter.notifyMatchFailure(op, "type conversion failed"); } @@ -304,7 +304,7 @@ struct ConvolutionOpGeneralConversion final MLIRContext *ctx = op.getContext(); auto resultType = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); + getTypeConverter()->convertType(op.getType())); if (!resultType) { return rewriter.notifyMatchFailure(op, "type conversion failed"); } @@ -623,7 +623,7 @@ struct DepthwiseConvolutionOpConversion final Value input = adaptor.getLhs(); Value filter = adaptor.getRhs(); auto resultType = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); + getTypeConverter()->convertType(op.getType())); if (!resultType) { return rewriter.notifyMatchFailure(op, "type conversion failed"); } diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgRandom.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgRandom.cpp index 998a13dc57..c55dfd7ab4 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgRandom.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgRandom.cpp @@ -828,7 +828,7 @@ struct RngUniformConversion final op, "expected min/max for rng op to be FloatType"); } auto targetTy = dyn_cast_or_null( - getTypeConverter()->convertType(op.getResult().getType())); + getTypeConverter()->convertType(op.getType())); if (!targetTy) { return rewriter.notifyMatchFailure( op, "expected target shape of rng op to be ShapedType"); diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp index 9f6a01ea43..39f1810763 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp @@ -85,7 +85,7 @@ struct ConvertStablehloConcatenateOp LogicalResult matchAndRewrite(stablehlo::ConcatenateOp op, PatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.getInputs(), op.getDimension()); + op, op.getType(), op.getInputs(), op.getDimension()); return success(); } }; @@ -398,7 +398,7 @@ struct ConvertStablehloSliceOp : public OpRewritePattern { } rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.getOperand(), + op, op.getType(), op.getOperand(), rewriter.getDenseI64ArrayAttr(startIndicesI64), rewriter.getDenseI64ArrayAttr(size)); return success(); @@ -422,7 +422,7 @@ struct ConvertStablehloTransposeOp rewriter.getI64Type()); auto constOp = rewriter.create( op->getLoc(), type, DenseIntElementsAttr::get(type, perms)); - rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand(), constOp); return success(); } diff --git a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp index 3ad41861bb..cc81f3b6a4 100644 --- a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp @@ -214,7 +214,7 @@ struct ConvertTrivialNonBroadcastBinaryOp final } rewriter.replaceOp( - op, ValueRange{Adaptor::createOp(op, op.getResult().getType(), + op, ValueRange{Adaptor::createOp(op, op.getType(), adaptor.getOperands(), rewriter)}); return success(); } @@ -245,7 +245,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp final Value rhs = adaptor.getRhs(); auto lhsType = dyn_cast(lhs.getType()); auto rhsType = dyn_cast(rhs.getType()); - auto resultType = dyn_cast(op.getResult().getType()); + auto resultType = dyn_cast(op.getType()); if (!lhsType || !rhsType || !resultType) return failure(); // Check for "numpy"-style rank broadcast. @@ -363,7 +363,7 @@ struct ConvertSelectOp final auto predType = dyn_cast(pred.getType()); auto onTrueType = dyn_cast(onTrue.getType()); auto onFalseType = dyn_cast(onFalse.getType()); - auto resultType = dyn_cast(op.getResult().getType()); + auto resultType = dyn_cast(op.getType()); if (!predType || !onTrueType || !onFalseType || !resultType) { return failure(); } @@ -1242,7 +1242,7 @@ struct ConvertErfInvOp final : OpConversionPattern { mlir::chlo::ErfInvOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - if (op.getResult().getType().getElementType().isF64()) { + if (op.getType().getElementType().isF64()) { rewriter.replaceOp(op, erfInv64(rewriter, loc, adaptor.getOperands())); return success(); } diff --git a/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp b/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp index ca57e42537..24995a094c 100644 --- a/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/ShapeLegalizeToStablehlo.cpp @@ -199,7 +199,7 @@ struct ConvertComputeReshapeShapeOpPattern // results). // This cannot error out given how the operation is currently defined. auto resultIndex = maybeCastToIndex(op.getResult(), resultI32, rewriter); - if (!resultIndex || resultIndex.getType() != op.getResult().getType()) + if (!resultIndex || resultIndex.getType() != op.getType()) return rewriter.notifyMatchFailure(op, "cast to index failed"); rewriter.replaceOp(op, resultIndex); return success(); @@ -238,7 +238,7 @@ struct ConvertNumElementsOpPattern // Cast result from tensor to index. // This will error out if the result is !shape.size. auto resultIndex = castToIndex(rewriter, op.getLoc(), resultI32); - if (!resultIndex || resultIndex.getType() != op.getResult().getType()) + if (!resultIndex || resultIndex.getType() != op.getType()) return rewriter.notifyMatchFailure(op, "cast to index failed"); rewriter.replaceOp(op, resultIndex); return success(); @@ -279,7 +279,7 @@ struct ConvertShapeOfOpPattern : public OpRewritePattern { // Cast result from tensor to tensor. // This will error out if the result is !shape.shape. auto shapeIndex = castToIndex(rewriter, op.getLoc(), shapeI32); - if (!shapeIndex || shapeIndex.getType() != op.getResult().getType()) + if (!shapeIndex || shapeIndex.getType() != op.getType()) return rewriter.notifyMatchFailure(op, "cast to index failed"); rewriter.replaceOp(op, shapeIndex); return success(); @@ -291,7 +291,7 @@ struct ConvertConstShapeOpPattern using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::ConstShapeOp op, PatternRewriter& rewriter) const override { - auto operandType = dyn_cast(op.getResult().getType()); + auto operandType = dyn_cast(op.getType()); if (!operandType) return rewriter.notifyMatchFailure(op, "expected ranked operand"); @@ -436,8 +436,7 @@ struct ConvertShapeBroadcastOpPattern auto broadcasted = rewriter.create(op->getLoc(), shape1, shape2); auto broadcastedIndex = castToIndex(rewriter, op.getLoc(), broadcasted); - if (!broadcastedIndex || - broadcastedIndex.getType() != op.getResult().getType()) + if (!broadcastedIndex || broadcastedIndex.getType() != op.getType()) return rewriter.notifyMatchFailure(op, "cast to index failed"); rewriter.replaceOp(op, broadcastedIndex); return success(); @@ -498,7 +497,7 @@ struct ConvertTensorExtractPattern Value extractedScalarTensor = rewriter.create( op.getLoc(), RankedTensorType::get({}, rewriter.getI32Type()), extractedTensor); - if (getElementTypeOrSelf(op.getResult().getType()).isIndex()) { + if (getElementTypeOrSelf(op.getType()).isIndex()) { auto extractedIndex = castToIndex(rewriter, op.getLoc(), extractedScalarTensor); rewriter.replaceOp(op, extractedIndex); @@ -506,9 +505,9 @@ struct ConvertTensorExtractPattern // For the special case when the input is a i32 tensor and output is i32, // convert the result back to i32 to be consistent: // unrealized_conversion_cast tensor -> i32 - rewriter.replaceOp(op, rewriter.create( - op.getLoc(), op.getResult().getType(), - extractedScalarTensor)); + rewriter.replaceOp(op, + rewriter.create( + op.getLoc(), op.getType(), extractedScalarTensor)); } return success(); } @@ -519,8 +518,7 @@ struct ConvertTensorFromElementsPattern using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::FromElementsOp op, PatternRewriter& rewriter) const override { - auto tensorType = - dyn_cast_or_null(op.getResult().getType()); + auto tensorType = dyn_cast_or_null(op.getType()); if (!tensorType) return rewriter.notifyMatchFailure(op, "expected constant index op"); @@ -529,9 +527,9 @@ struct ConvertTensorFromElementsPattern // tensor.from_elements i64 -> tensor // This is converted to unrealized_conversion_cast i64 -> tensor, // which is later cancelled with previous unrealized_conversion_cast op. - rewriter.replaceOp( - op, rewriter.create( - op.getLoc(), op.getResult().getType(), op.getElements()[0])); + rewriter.replaceOp(op, + rewriter.create( + op.getLoc(), op.getType(), op.getElements()[0])); return success(); } @@ -558,7 +556,7 @@ struct ConvertTensorFromElementsPattern /*dimension=*/0); tensorI32 = maybeCastToIndex(op.getResult(), tensorI32, rewriter); - if (!tensorI32 || tensorI32.getType() != op.getResult().getType()) + if (!tensorI32 || tensorI32.getType() != op.getType()) return rewriter.notifyMatchFailure(op, "cast to index failed"); rewriter.replaceOp(op, tensorI32); return success();