Skip to content

Commit

Permalink
Simplify getResult().getType() to just getType() (#2218)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlevesquedion authored Apr 16, 2024
1 parent db73020 commit 71ddfe8
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2301,8 +2301,7 @@ struct PadOpConversion final : OpConversionPattern<mlir::stablehlo::PadOp> {
mlir::stablehlo::PadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resultType =
getTypeConverter()->convertType<ShapedType>(op.getResult().getType());
auto resultType = getTypeConverter()->convertType<ShapedType>(op.getType());
if (!resultType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ struct NormalConvolutionOpConversion final
Value filter = adaptor.getRhs();
filter = applyConvolutionReversal(loc, rewriter, op, filter);
auto resultType = dyn_cast_or_null<ShapedType>(
getTypeConverter()->convertType(op.getResult().getType()));
getTypeConverter()->convertType(op.getType()));
if (!resultType) {
return rewriter.notifyMatchFailure(op, "type conversion failed");
}
Expand Down Expand Up @@ -304,7 +304,7 @@ struct ConvolutionOpGeneralConversion final
MLIRContext *ctx = op.getContext();

auto resultType = dyn_cast_or_null<ShapedType>(
getTypeConverter()->convertType(op.getResult().getType()));
getTypeConverter()->convertType(op.getType()));
if (!resultType) {
return rewriter.notifyMatchFailure(op, "type conversion failed");
}
Expand Down Expand Up @@ -623,7 +623,7 @@ struct DepthwiseConvolutionOpConversion final
Value input = adaptor.getLhs();
Value filter = adaptor.getRhs();
auto resultType = dyn_cast_or_null<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));
getTypeConverter()->convertType(op.getType()));
if (!resultType) {
return rewriter.notifyMatchFailure(op, "type conversion failed");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ struct RngUniformConversion final
op, "expected min/max for rng op to be FloatType");
}
auto targetTy = dyn_cast_or_null<ShapedType>(
getTypeConverter()->convertType(op.getResult().getType()));
getTypeConverter()->convertType(op.getType()));
if (!targetTy) {
return rewriter.notifyMatchFailure(
op, "expected target shape of rng op to be ShapedType");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ struct ConvertStablehloConcatenateOp
LogicalResult matchAndRewrite(stablehlo::ConcatenateOp op,
PatternRewriter& rewriter) const override {
rewriter.replaceOpWithNewOp<tosa::ConcatOp>(
op, op.getResult().getType(), op.getInputs(), op.getDimension());
op, op.getType(), op.getInputs(), op.getDimension());
return success();
}
};
Expand Down Expand Up @@ -398,7 +398,7 @@ struct ConvertStablehloSliceOp : public OpRewritePattern<stablehlo::SliceOp> {
}

rewriter.replaceOpWithNewOp<tosa::SliceOp>(
op, op.getResult().getType(), op.getOperand(),
op, op.getType(), op.getOperand(),
rewriter.getDenseI64ArrayAttr(startIndicesI64),
rewriter.getDenseI64ArrayAttr(size));
return success();
Expand All @@ -422,7 +422,7 @@ struct ConvertStablehloTransposeOp
rewriter.getI64Type());
auto constOp = rewriter.create<tosa::ConstOp>(
op->getLoc(), type, DenseIntElementsAttr::get(type, perms));
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(op, op.getResult().getType(),
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(op, op.getType(),
op.getOperand(), constOp);
return success();
}
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/transforms/ChloLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ struct ConvertTrivialNonBroadcastBinaryOp final
}

rewriter.replaceOp(
op, ValueRange{Adaptor::createOp(op, op.getResult().getType(),
op, ValueRange{Adaptor::createOp(op, op.getType(),
adaptor.getOperands(), rewriter)});
return success();
}
Expand Down Expand Up @@ -245,7 +245,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp final
Value rhs = adaptor.getRhs();
auto lhsType = dyn_cast<RankedTensorType>(lhs.getType());
auto rhsType = dyn_cast<RankedTensorType>(rhs.getType());
auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
auto resultType = dyn_cast<RankedTensorType>(op.getType());
if (!lhsType || !rhsType || !resultType) return failure();

// Check for "numpy"-style rank broadcast.
Expand Down Expand Up @@ -363,7 +363,7 @@ struct ConvertSelectOp final
auto predType = dyn_cast<RankedTensorType>(pred.getType());
auto onTrueType = dyn_cast<RankedTensorType>(onTrue.getType());
auto onFalseType = dyn_cast<RankedTensorType>(onFalse.getType());
auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
auto resultType = dyn_cast<RankedTensorType>(op.getType());
if (!predType || !onTrueType || !onFalseType || !resultType) {
return failure();
}
Expand Down Expand Up @@ -1242,7 +1242,7 @@ struct ConvertErfInvOp final : OpConversionPattern<mlir::chlo::ErfInvOp> {
mlir::chlo::ErfInvOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
if (op.getResult().getType().getElementType().isF64()) {
if (op.getType().getElementType().isF64()) {
rewriter.replaceOp(op, erfInv64(rewriter, loc, adaptor.getOperands()));
return success();
}
Expand Down
30 changes: 14 additions & 16 deletions stablehlo/transforms/ShapeLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ struct ConvertComputeReshapeShapeOpPattern
// results).
// This cannot error out given how the operation is currently defined.
auto resultIndex = maybeCastToIndex(op.getResult(), resultI32, rewriter);
if (!resultIndex || resultIndex.getType() != op.getResult().getType())
if (!resultIndex || resultIndex.getType() != op.getType())
return rewriter.notifyMatchFailure(op, "cast to index failed");
rewriter.replaceOp(op, resultIndex);
return success();
Expand Down Expand Up @@ -238,7 +238,7 @@ struct ConvertNumElementsOpPattern
// Cast result from tensor<i32> to index.
// This will error out if the result is !shape.size.
auto resultIndex = castToIndex(rewriter, op.getLoc(), resultI32);
if (!resultIndex || resultIndex.getType() != op.getResult().getType())
if (!resultIndex || resultIndex.getType() != op.getType())
return rewriter.notifyMatchFailure(op, "cast to index failed");
rewriter.replaceOp(op, resultIndex);
return success();
Expand Down Expand Up @@ -279,7 +279,7 @@ struct ConvertShapeOfOpPattern : public OpRewritePattern<shape::ShapeOfOp> {
// Cast result from tensor<Nxi32> to tensor<Nxindex>.
// This will error out if the result is !shape.shape.
auto shapeIndex = castToIndex(rewriter, op.getLoc(), shapeI32);
if (!shapeIndex || shapeIndex.getType() != op.getResult().getType())
if (!shapeIndex || shapeIndex.getType() != op.getType())
return rewriter.notifyMatchFailure(op, "cast to index failed");
rewriter.replaceOp(op, shapeIndex);
return success();
Expand All @@ -291,7 +291,7 @@ struct ConvertConstShapeOpPattern
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ConstShapeOp op,
PatternRewriter& rewriter) const override {
auto operandType = dyn_cast<RankedTensorType>(op.getResult().getType());
auto operandType = dyn_cast<RankedTensorType>(op.getType());
if (!operandType)
return rewriter.notifyMatchFailure(op, "expected ranked operand");

Expand Down Expand Up @@ -436,8 +436,7 @@ struct ConvertShapeBroadcastOpPattern
auto broadcasted = rewriter.create<MaxOp>(op->getLoc(), shape1, shape2);

auto broadcastedIndex = castToIndex(rewriter, op.getLoc(), broadcasted);
if (!broadcastedIndex ||
broadcastedIndex.getType() != op.getResult().getType())
if (!broadcastedIndex || broadcastedIndex.getType() != op.getType())
return rewriter.notifyMatchFailure(op, "cast to index failed");
rewriter.replaceOp(op, broadcastedIndex);
return success();
Expand Down Expand Up @@ -498,17 +497,17 @@ struct ConvertTensorExtractPattern
Value extractedScalarTensor = rewriter.create<ReshapeOp>(
op.getLoc(), RankedTensorType::get({}, rewriter.getI32Type()),
extractedTensor);
if (getElementTypeOrSelf(op.getResult().getType()).isIndex()) {
if (getElementTypeOrSelf(op.getType()).isIndex()) {
auto extractedIndex =
castToIndex(rewriter, op.getLoc(), extractedScalarTensor);
rewriter.replaceOp(op, extractedIndex);
} else {
// For the special case when the input is a i32 tensor and output is i32,
// convert the result back to i32 to be consistent:
// unrealized_conversion_cast tensor<i32> -> i32
rewriter.replaceOp(op, rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), op.getResult().getType(),
extractedScalarTensor));
rewriter.replaceOp(op,
rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), op.getType(), extractedScalarTensor));
}
return success();
}
Expand All @@ -519,8 +518,7 @@ struct ConvertTensorFromElementsPattern
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::FromElementsOp op,
PatternRewriter& rewriter) const override {
auto tensorType =
dyn_cast_or_null<RankedTensorType>(op.getResult().getType());
auto tensorType = dyn_cast_or_null<RankedTensorType>(op.getType());
if (!tensorType)
return rewriter.notifyMatchFailure(op, "expected constant index op");

Expand All @@ -529,9 +527,9 @@ struct ConvertTensorFromElementsPattern
// tensor.from_elements i64 -> tensor<i64>
// This is converted to unrealized_conversion_cast i64 -> tensor<i64>,
// which is later cancelled with previous unrealized_conversion_cast op.
rewriter.replaceOp(
op, rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), op.getResult().getType(), op.getElements()[0]));
rewriter.replaceOp(op,
rewriter.create<UnrealizedConversionCastOp>(
op.getLoc(), op.getType(), op.getElements()[0]));
return success();
}

Expand All @@ -558,7 +556,7 @@ struct ConvertTensorFromElementsPattern
/*dimension=*/0);

tensorI32 = maybeCastToIndex(op.getResult(), tensorI32, rewriter);
if (!tensorI32 || tensorI32.getType() != op.getResult().getType())
if (!tensorI32 || tensorI32.getType() != op.getType())
return rewriter.notifyMatchFailure(op, "cast to index failed");
rewriter.replaceOp(op, tensorI32);
return success();
Expand Down

0 comments on commit 71ddfe8

Please sign in to comment.