diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 6bff5a3cb7..2ae76a00c1 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -397,8 +397,7 @@ struct BroadcastConverter final static SmallVector getIndexingMaps(OpTy broadcastOp, Builder *b) { - ShapedType inputType = - llvm::cast(broadcastOp.getOperand().getType()); + ShapedType inputType = broadcastOp.getOperand().getType(); unsigned inputRank = inputType.getRank(); unsigned nloops = getHloOpResultType(broadcastOp).getRank(); @@ -458,7 +457,7 @@ struct HloBroadcastInDimConverter final static SmallVector getIndexingMaps( mlir::stablehlo::BroadcastInDimOp broadcastOp, Builder *b) { ShapedType resultType = getHloOpResultType(broadcastOp); - auto operandType = cast(broadcastOp.getOperand().getType()); + auto operandType = broadcastOp.getOperand().getType(); unsigned nloops = resultType.getRank(); // The input is a scalar, i.e. this is a scalar broadcast op. @@ -1047,7 +1046,7 @@ struct ReshapeOpConverter final Value operand = adaptor.getOperand(); auto operandType = llvm::cast(operand.getType()); Type elemType = operandType.getElementType(); - auto resultType = llvm::cast(reshapeOp.getType()); + ShapedType resultType = reshapeOp.getType(); if (!resultType.hasStaticShape()) return failure(); @@ -1901,7 +1900,7 @@ struct SelectAndScatterNoOverlapConverter final auto sourceTy = llvm::cast(source.getType()); auto operandTy = llvm::cast(operand.getType()); auto initTy = llvm::cast(init.getType()); - auto resultTy = llvm::cast(op.getResult().getType()); + auto resultTy = op.getType(); auto indexETy = b.getI32Type(); auto srcETy = operandTy.getElementType(); diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp index 6d19a8d072..39c0634321 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgConvolution.cpp @@ -591,7 +591,7 @@ struct DepthwiseConvolutionOpConversion final // Make sure that this is depthwise convolution. int64_t inputFeatureDim = dimensionNumbers.getInputFeatureDimension(); int64_t inputFeatureCount = - cast(op.getLhs().getType()).getDimSize(inputFeatureDim); + op.getLhs().getType().getDimSize(inputFeatureDim); if (static_cast(op.getFeatureGroupCount()) != inputFeatureCount) { return rewriter.notifyMatchFailure(op, "not depth-wise convolution"); } @@ -646,8 +646,7 @@ struct DepthwiseConvolutionOpConversion final op.getLhsDilationAttr(), spatialDimMapping, rewriter); - auto filterDims = - llvm::to_vector(cast(op.getRhs().getType()).getShape()); + auto filterDims = llvm::to_vector(op.getRhs().getType().getShape()); auto getReassociationIndicesToCollapseLastTwoDims = [](Value v) { SmallVector reassociations; @@ -680,8 +679,7 @@ struct DepthwiseConvolutionOpConversion final reshapedFilterDims[kernelOutputFeatureDimension] /= op.getFeatureGroupCount(); auto reshapedFilterType = RankedTensorType::get( - reshapedFilterDims, - cast(op.getRhs().getType()).getElementType()); + reshapedFilterDims, op.getRhs().getType().getElementType()); reshapedFilter = rewriter.create( loc, reshapedFilterType, filter); diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgDotProduct.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgDotProduct.cpp index 1addafbd94..2735e82b4c 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgDotProduct.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgDotProduct.cpp @@ -33,10 +33,8 @@ enum class DotOperationType { }; DotOperationType getDotOperationType(mlir::stablehlo::DotOp dotOp) { - ArrayRef lhsShape = - cast(dotOp.getLhs().getType()).getShape(); - ArrayRef rhsShape = - cast(dotOp.getRhs().getType()).getShape(); + ArrayRef lhsShape = dotOp.getLhs().getType().getShape(); + ArrayRef rhsShape = dotOp.getRhs().getType().getShape(); auto shapeMatches = [](int64_t a, int64_t b) { return a == ShapedType::kDynamic || b == ShapedType::kDynamic || a == b; }; @@ -131,7 +129,7 @@ struct DotGeneralBatchMatMulOpConversion final if (failed(verifyHloOpBufferOrTensorSemantics(op))) { return failure(); } - if (llvm::cast(op.getType()).getRank() != 3) { + if (op.getType().getRank() != 3) { return rewriter.notifyMatchFailure(op, "expected a batch matmul"); } diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp index 75954c4ad4..9f6a01ea43 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp @@ -95,10 +95,10 @@ struct ConvertStablehloDotOp : public OpRewritePattern { LogicalResult matchAndRewrite(stablehlo::DotOp op, PatternRewriter& rewriter) const override { - auto lhsType = cast(op.getLhs().getType()); - auto rhsType = cast(op.getRhs().getType()); + auto lhsType = op.getLhs().getType(); + auto rhsType = op.getRhs().getType(); - auto resultType = dyn_cast(op.getResult().getType()); + auto resultType = op.getType(); if (!resultType) { return rewriter.notifyMatchFailure(op, "result tensor does not have shape"); @@ -184,15 +184,14 @@ struct ConvertStablehloIotaOp : public OpRewritePattern { LogicalResult matchAndRewrite(stablehlo::IotaOp op, PatternRewriter& rewriter) const override { - auto resultType = op.getResult().getType(); - auto elementType = cast(resultType).getElementType(); - auto resultRankedType = cast(resultType); + auto resultType = op.getType(); + auto elementType = resultType.getElementType(); - if (!resultRankedType.hasStaticShape()) { + if (!resultType.hasStaticShape()) { return rewriter.notifyMatchFailure(op, "result tensor must be static"); } - auto resultShape = resultRankedType.getShape(); + auto resultShape = resultType.getShape(); auto iotaDimension = op.getIotaDimension(); int64_t iotaArrayLength = resultShape[iotaDimension]; @@ -243,21 +242,21 @@ struct ConvertStablehloGatherOp : public OpRewritePattern { PatternRewriter& rewriter) const override { // The input operand must be 3D, with shape [N, K, C]. auto operand = op.getOperand(); - auto operandType = cast(operand.getType()); + auto operandType = operand.getType(); if (operandType.getRank() != 3) { return rewriter.notifyMatchFailure(op, "operand must have rank of 3"); } // The indices tensor must be 2D, with shape [N, W]. auto startIndices = op.getStartIndices(); - auto startIndicesType = cast(startIndices.getType()); + auto startIndicesType = startIndices.getType(); if (startIndicesType.getRank() != 2) { return rewriter.notifyMatchFailure(op, "start_indices must have rank of 2"); } // The result tensor must be 3D, with shape [N, W, C]. - auto resultType = cast(op.getResult().getType()); + auto resultType = op.getType(); if (resultType.getRank() != 3) { return rewriter.notifyMatchFailure(op, "result must have rank of 3"); } diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index c96f942863..3f73cfd4f9 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1157,7 +1157,7 @@ LogicalResult BitcastConvertOp::reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { auto operandType = cast(operands[0].getType()); - auto resultType = cast(getType()); + auto resultType = getType(); // Shape-changing bitcast convert is not implemented. // TODO(kramerb): This could be done by adjusting the last dimension.