Skip to content

Commit

Permalink
Clean up hasRank checks (#2119)
Browse files Browse the repository at this point in the history
These checks are no longer necessary in most cases (they are still
necessary in some cases, e.g. when dealing with custom call args/results
or CHLO op args/results).

I plan to make further cleanups in future PRs, e.g. in
ShapeRefinement.cpp (in this PR I did only the obvious ones). Also I
plan to clean up casts, especially casting to `ShapedType` instead of
`RankedTensorType`.

#1991
  • Loading branch information
mlevesquedion authored Mar 25, 2024
1 parent c123a48 commit d89e06f
Show file tree
Hide file tree
Showing 10 changed files with 189 additions and 324 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType,
// new tensor initialization operation. This operation only needs the
// dynamic sizes.
SmallVector<Value> sizes;
if (resultType.hasRank() && !resultType.hasStaticShape()) {
if (!resultType.hasStaticShape()) {
// Ask the op for its output shape.
auto shapeSource = cast<InferShapedTypeOpInterface>(op);
SmallVector<Value, 1> reifiedShapes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -966,11 +966,7 @@ struct RealDynamicSliceConverter final
mlir::stablehlo::RealDynamicSliceOp realDynamicSliceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = realDynamicSliceOp.getLoc();
auto argType = llvm::dyn_cast<ShapedType>(adaptor.getOperand().getType());
if (!argType || !argType.hasRank()) {
return rewriter.notifyMatchFailure(realDynamicSliceOp,
"require known-rank args");
}
auto argType = llvm::cast<RankedTensorType>(adaptor.getOperand().getType());

Type dimElementType = getElementTypeOrSelf(adaptor.getStartIndices());
if (getElementTypeOrSelf(adaptor.getLimitIndices()) != dimElementType ||
Expand Down Expand Up @@ -1405,10 +1401,7 @@ struct SliceConverter final : OpConversionPattern<mlir::stablehlo::SliceOp> {
mlir::stablehlo::SliceOp sliceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto argType =
llvm::dyn_cast<ShapedType>(adaptor.getOperands()[0].getType());
if (!argType || !argType.hasRank()) {
return rewriter.notifyMatchFailure(sliceOp, "expects known-rank args");
}
llvm::cast<RankedTensorType>(adaptor.getOperands()[0].getType());

SmallVector<OpFoldResult, 3> offsets, sizes, strides;
auto startIndices = sliceOp.getStartIndices();
Expand Down Expand Up @@ -1442,11 +1435,7 @@ struct DynamicSliceConverter final
mlir::stablehlo::DynamicSliceOp dynamicSliceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = dynamicSliceOp.getLoc();
auto argType = llvm::dyn_cast<ShapedType>(adaptor.getOperand().getType());
if (!argType || !argType.hasRank()) {
return rewriter.notifyMatchFailure(dynamicSliceOp,
"require known-rank args");
}
auto argType = llvm::cast<RankedTensorType>(adaptor.getOperand().getType());

auto resultType = getTypeConverter()->convertType<RankedTensorType>(
dynamicSliceOp.getType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ FailureOr<PointwiseConversionInfo> checkOperandsAndResults(
}

// Find result type, if on tensors.
auto resultTy = dyn_cast_or_null<ShapedType>(
auto resultTy = cast_or_null<RankedTensorType>(
typeConverter.convertType(op->getResultTypes().front()));

// Check result type compatibility.
if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank ||
if (!resultTy || resultTy.getRank() != maxRank ||
!(resultTy.getElementType().isSignlessIntOrFloat() ||
isa<ComplexType>(resultTy.getElementType()))) {
return rewriter.notifyMatchFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,8 @@ Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{
}];

Rewrite changeElementTypeToI1(type: Type) -> Type [{
auto tensorType = type.cast<mlir::TensorType>();
if (tensorType.hasRank()) {
return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type());
} else {
return UnrankedTensorType::get(rewriter.getI1Type());
}
auto tensorType = type.cast<mlir::RankedTensorType>();
return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type());
}];

// Nullary ops.
Expand Down
8 changes: 3 additions & 5 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,9 @@ LogicalResult unflattenTupleTypes(TypeRange prototype, TypeRange types,

ShapedType createShapedType(ShapedTypeComponents components) {
if (!components.getElementType()) return ShapedType();
if (components.hasRank())
return RankedTensorType::get(components.getDims(),
components.getElementType(),
components.getAttribute());
return UnrankedTensorType::get(components.getElementType());
return RankedTensorType::get(components.getDims(),
components.getElementType(),
components.getAttribute());
}

bool isSplatArray(ArrayRef<int64_t> arr, int64_t val) {
Expand Down
11 changes: 2 additions & 9 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,6 @@ LogicalResult DotGeneralOp::reifyReturnTypeShapes(
SmallVectorImpl<Value>& reifiedReturnShapes) {
auto lhsType = getLhs().getType();
auto rhsType = getRhs().getType();
if (!lhsType.hasRank() || !rhsType.hasRank()) return failure();

Adaptor adaptor(operands);
auto dimNumbers = getDotDimensionNumbers();
Expand Down Expand Up @@ -1520,14 +1519,8 @@ void ReduceOp::build(OpBuilder&, OperationState& odsState, ValueRange inputs,
SmallVector<Type> inferredReturnTypes;
for (auto [inputTy, elementTy] :
llvm::zip(inputArgTensorTypes, elementTypes)) {
if (inputTy.hasRank()) {
inferredReturnTypes.push_back(
RankedTensorType::get(newDimensions, elementTy, encoding));
} else {
if (encoding != nullptr)
llvm::report_fatal_error("attribute not supported.");
inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
}
inferredReturnTypes.push_back(
RankedTensorType::get(newDimensions, elementTy, encoding));
}
odsState.addTypes(inferredReturnTypes);
}
Expand Down
Loading

0 comments on commit d89e06f

Please sign in to comment.