Skip to content

Commit

Permalink
Remove unnecessary {RankedTensor,Shaped}Type casts (#2217)
Browse files Browse the repository at this point in the history
I revisited all the casts to ShapedType and RankedTensorType in light of
the fact that most op input/output types are ranked and removed
unnecessary casts.

This PR is sort of an extension of
#2183

Also simplified a few `op.getResult().getType()` to `op.getType()`
because I was modifying that code anyway.

Fixes #2065
  • Loading branch information
mlevesquedion authored Apr 15, 2024
1 parent d68db56 commit db73020
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,7 @@ struct BroadcastConverter final

static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
Builder *b) {
ShapedType inputType =
llvm::cast<ShapedType>(broadcastOp.getOperand().getType());
ShapedType inputType = broadcastOp.getOperand().getType();
unsigned inputRank = inputType.getRank();
unsigned nloops = getHloOpResultType(broadcastOp).getRank();

Expand Down Expand Up @@ -458,7 +457,7 @@ struct HloBroadcastInDimConverter final
static SmallVector<AffineMap, 2> getIndexingMaps(
mlir::stablehlo::BroadcastInDimOp broadcastOp, Builder *b) {
ShapedType resultType = getHloOpResultType(broadcastOp);
auto operandType = cast<ShapedType>(broadcastOp.getOperand().getType());
auto operandType = broadcastOp.getOperand().getType();
unsigned nloops = resultType.getRank();

// The input is a scalar, i.e. this is a scalar broadcast op.
Expand Down Expand Up @@ -1047,7 +1046,7 @@ struct ReshapeOpConverter final
Value operand = adaptor.getOperand();
auto operandType = llvm::cast<ShapedType>(operand.getType());
Type elemType = operandType.getElementType();
auto resultType = llvm::cast<ShapedType>(reshapeOp.getType());
ShapedType resultType = reshapeOp.getType();

if (!resultType.hasStaticShape()) return failure();

Expand Down Expand Up @@ -1901,7 +1900,7 @@ struct SelectAndScatterNoOverlapConverter final
auto sourceTy = llvm::cast<RankedTensorType>(source.getType());
auto operandTy = llvm::cast<RankedTensorType>(operand.getType());
auto initTy = llvm::cast<RankedTensorType>(init.getType());
auto resultTy = llvm::cast<RankedTensorType>(op.getResult().getType());
auto resultTy = op.getType();

auto indexETy = b.getI32Type();
auto srcETy = operandTy.getElementType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ struct DepthwiseConvolutionOpConversion final
// Make sure that this is depthwise convolution.
int64_t inputFeatureDim = dimensionNumbers.getInputFeatureDimension();
int64_t inputFeatureCount =
cast<ShapedType>(op.getLhs().getType()).getDimSize(inputFeatureDim);
op.getLhs().getType().getDimSize(inputFeatureDim);
if (static_cast<int64_t>(op.getFeatureGroupCount()) != inputFeatureCount) {
return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
}
Expand Down Expand Up @@ -646,8 +646,7 @@ struct DepthwiseConvolutionOpConversion final
op.getLhsDilationAttr(), spatialDimMapping,
rewriter);

auto filterDims =
llvm::to_vector(cast<ShapedType>(op.getRhs().getType()).getShape());
auto filterDims = llvm::to_vector(op.getRhs().getType().getShape());

auto getReassociationIndicesToCollapseLastTwoDims = [](Value v) {
SmallVector<ReassociationIndices> reassociations;
Expand Down Expand Up @@ -680,8 +679,7 @@ struct DepthwiseConvolutionOpConversion final
reshapedFilterDims[kernelOutputFeatureDimension] /=
op.getFeatureGroupCount();
auto reshapedFilterType = RankedTensorType::get(
reshapedFilterDims,
cast<ShapedType>(op.getRhs().getType()).getElementType());
reshapedFilterDims, op.getRhs().getType().getElementType());

reshapedFilter = rewriter.create<mlir::stablehlo::ReshapeOp>(
loc, reshapedFilterType, filter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ enum class DotOperationType {
};

DotOperationType getDotOperationType(mlir::stablehlo::DotOp dotOp) {
ArrayRef<int64_t> lhsShape =
cast<ShapedType>(dotOp.getLhs().getType()).getShape();
ArrayRef<int64_t> rhsShape =
cast<ShapedType>(dotOp.getRhs().getType()).getShape();
ArrayRef<int64_t> lhsShape = dotOp.getLhs().getType().getShape();
ArrayRef<int64_t> rhsShape = dotOp.getRhs().getType().getShape();
auto shapeMatches = [](int64_t a, int64_t b) {
return a == ShapedType::kDynamic || b == ShapedType::kDynamic || a == b;
};
Expand Down Expand Up @@ -131,7 +129,7 @@ struct DotGeneralBatchMatMulOpConversion final
if (failed(verifyHloOpBufferOrTensorSemantics(op))) {
return failure();
}
if (llvm::cast<RankedTensorType>(op.getType()).getRank() != 3) {
if (op.getType().getRank() != 3) {
return rewriter.notifyMatchFailure(op, "expected a batch matmul");
}

Expand Down
21 changes: 10 additions & 11 deletions stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ struct ConvertStablehloDotOp : public OpRewritePattern<stablehlo::DotOp> {

LogicalResult matchAndRewrite(stablehlo::DotOp op,
PatternRewriter& rewriter) const override {
auto lhsType = cast<RankedTensorType>(op.getLhs().getType());
auto rhsType = cast<RankedTensorType>(op.getRhs().getType());
auto lhsType = op.getLhs().getType();
auto rhsType = op.getRhs().getType();

auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
auto resultType = op.getType();
if (!resultType) {
return rewriter.notifyMatchFailure(op,
"result tensor does not have shape");
Expand Down Expand Up @@ -184,15 +184,14 @@ struct ConvertStablehloIotaOp : public OpRewritePattern<stablehlo::IotaOp> {

LogicalResult matchAndRewrite(stablehlo::IotaOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getResult().getType();
auto elementType = cast<ShapedType>(resultType).getElementType();
auto resultRankedType = cast<RankedTensorType>(resultType);
auto resultType = op.getType();
auto elementType = resultType.getElementType();

if (!resultRankedType.hasStaticShape()) {
if (!resultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "result tensor must be static");
}

auto resultShape = resultRankedType.getShape();
auto resultShape = resultType.getShape();
auto iotaDimension = op.getIotaDimension();
int64_t iotaArrayLength = resultShape[iotaDimension];

Expand Down Expand Up @@ -243,21 +242,21 @@ struct ConvertStablehloGatherOp : public OpRewritePattern<stablehlo::GatherOp> {
PatternRewriter& rewriter) const override {
// The input operand must be 3D, with shape [N, K, C].
auto operand = op.getOperand();
auto operandType = cast<RankedTensorType>(operand.getType());
auto operandType = operand.getType();
if (operandType.getRank() != 3) {
return rewriter.notifyMatchFailure(op, "operand must have rank of 3");
}

// The indices tensor must be 2D, with shape [N, W].
auto startIndices = op.getStartIndices();
auto startIndicesType = cast<RankedTensorType>(startIndices.getType());
auto startIndicesType = startIndices.getType();
if (startIndicesType.getRank() != 2) {
return rewriter.notifyMatchFailure(op,
"start_indices must have rank of 2");
}

// The result tensor must be 3D, with shape [N, W, C].
auto resultType = cast<RankedTensorType>(op.getResult().getType());
auto resultType = op.getType();
if (resultType.getRank() != 3) {
return rewriter.notifyMatchFailure(op, "result must have rank of 3");
}
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,7 @@ LogicalResult BitcastConvertOp::reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
auto operandType = cast<RankedTensorType>(operands[0].getType());
auto resultType = cast<RankedTensorType>(getType());
auto resultType = getType();

// Shape-changing bitcast convert is not implemented.
// TODO(kramerb): This could be done by adjusting the last dimension.
Expand Down

0 comments on commit db73020

Please sign in to comment.