From d89e06f1ee77660aee8b65efbeb7b3d54f73738f Mon Sep 17 00:00:00 2001 From: mlevesquedion Date: Mon, 25 Mar 2024 14:56:32 -0700 Subject: [PATCH] Clean up hasRank checks (#2119) These checks are no longer necessary in most cases (they are still necessary in some cases, e.g. when dealing with custom call args/results or CHLO op args/results). I plan to make further cleanups in future PRs, e.g. in ShapeRefinement.cpp (in this PR I did only the obvious ones). Also I plan to clean up casts, especially casting to `ShapedType` instead of `RankedTensorType`. #1991 --- .../transforms/LegalizeToLinalgUtils.cpp | 2 +- .../transforms/StablehloLegalizeToLinalg.cpp | 17 +- .../transforms/StablehloToLinalgPointwise.cpp | 4 +- .../transforms/StablehloLegalizeToTosa.pdll | 8 +- stablehlo/dialect/Base.cpp | 8 +- stablehlo/dialect/StablehloOps.cpp | 11 +- stablehlo/dialect/TypeInference.cpp | 442 +++++++----------- stablehlo/tests/ops_stablehlo.mlir | 2 +- stablehlo/tests/verify_reduce.mlir | 4 +- .../transforms/StablehloRefineShapes.cpp | 15 +- 10 files changed, 189 insertions(+), 324 deletions(-) diff --git a/stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.cpp b/stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.cpp index 239ba3a251..a0fc2b3d7d 100644 --- a/stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.cpp +++ b/stablehlo/conversions/linalg/transforms/LegalizeToLinalgUtils.cpp @@ -73,7 +73,7 @@ Value getEmptyTensorFor(OpBuilder &b, Location loc, ShapedType resultType, // new tensor initialization operation. This operation only needs the // dynamic sizes. SmallVector sizes; - if (resultType.hasRank() && !resultType.hasStaticShape()) { + if (!resultType.hasStaticShape()) { // Ask the op for its output shape. auto shapeSource = cast(op); SmallVector reifiedShapes; diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 4f45029046..412ad9ab53 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -966,11 +966,7 @@ struct RealDynamicSliceConverter final mlir::stablehlo::RealDynamicSliceOp realDynamicSliceOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = realDynamicSliceOp.getLoc(); - auto argType = llvm::dyn_cast(adaptor.getOperand().getType()); - if (!argType || !argType.hasRank()) { - return rewriter.notifyMatchFailure(realDynamicSliceOp, - "require known-rank args"); - } + auto argType = llvm::cast(adaptor.getOperand().getType()); Type dimElementType = getElementTypeOrSelf(adaptor.getStartIndices()); if (getElementTypeOrSelf(adaptor.getLimitIndices()) != dimElementType || @@ -1405,10 +1401,7 @@ struct SliceConverter final : OpConversionPattern { mlir::stablehlo::SliceOp sliceOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto argType = - llvm::dyn_cast(adaptor.getOperands()[0].getType()); - if (!argType || !argType.hasRank()) { - return rewriter.notifyMatchFailure(sliceOp, "expects known-rank args"); - } + llvm::cast(adaptor.getOperands()[0].getType()); SmallVector offsets, sizes, strides; auto startIndices = sliceOp.getStartIndices(); @@ -1442,11 +1435,7 @@ struct DynamicSliceConverter final mlir::stablehlo::DynamicSliceOp dynamicSliceOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = dynamicSliceOp.getLoc(); - auto argType = llvm::dyn_cast(adaptor.getOperand().getType()); - if (!argType || !argType.hasRank()) { - return rewriter.notifyMatchFailure(dynamicSliceOp, - "require known-rank args"); - } + auto argType = llvm::cast(adaptor.getOperand().getType()); auto resultType = getTypeConverter()->convertType( dynamicSliceOp.getType()); diff --git a/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp b/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp index 7830fa60d2..49bec1b863 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp @@ -80,11 +80,11 @@ FailureOr checkOperandsAndResults( } // Find result type, if on tensors. - auto resultTy = dyn_cast_or_null( + auto resultTy = cast_or_null( typeConverter.convertType(op->getResultTypes().front())); // Check result type compatibility. - if (!resultTy || !resultTy.hasRank() || resultTy.getRank() != maxRank || + if (!resultTy || resultTy.getRank() != maxRank || !(resultTy.getElementType().isSignlessIntOrFloat() || isa(resultTy.getElementType()))) { return rewriter.notifyMatchFailure( diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll index b232898dbf..f1004ffc38 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll @@ -48,12 +48,8 @@ Rewrite positiveFloatInfinityLike(op: Op, type: Type) -> Op [{ }]; Rewrite changeElementTypeToI1(type: Type) -> Type [{ - auto tensorType = type.cast(); - if (tensorType.hasRank()) { - return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type()); - } else { - return UnrankedTensorType::get(rewriter.getI1Type()); - } + auto tensorType = type.cast(); + return RankedTensorType::get(tensorType.getShape(), rewriter.getI1Type()); }]; // Nullary ops. diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index 1e79069e17..56e395690b 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -601,11 +601,9 @@ LogicalResult unflattenTupleTypes(TypeRange prototype, TypeRange types, ShapedType createShapedType(ShapedTypeComponents components) { if (!components.getElementType()) return ShapedType(); - if (components.hasRank()) - return RankedTensorType::get(components.getDims(), - components.getElementType(), - components.getAttribute()); - return UnrankedTensorType::get(components.getElementType()); + return RankedTensorType::get(components.getDims(), + components.getElementType(), + components.getAttribute()); } bool isSplatArray(ArrayRef arr, int64_t val) { diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index a63d15b603..3cc68c5ff0 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -550,7 +550,6 @@ LogicalResult DotGeneralOp::reifyReturnTypeShapes( SmallVectorImpl& reifiedReturnShapes) { auto lhsType = getLhs().getType(); auto rhsType = getRhs().getType(); - if (!lhsType.hasRank() || !rhsType.hasRank()) return failure(); Adaptor adaptor(operands); auto dimNumbers = getDotDimensionNumbers(); @@ -1520,14 +1519,8 @@ void ReduceOp::build(OpBuilder&, OperationState& odsState, ValueRange inputs, SmallVector inferredReturnTypes; for (auto [inputTy, elementTy] : llvm::zip(inputArgTensorTypes, elementTypes)) { - if (inputTy.hasRank()) { - inferredReturnTypes.push_back( - RankedTensorType::get(newDimensions, elementTy, encoding)); - } else { - if (encoding != nullptr) - llvm::report_fatal_error("attribute not supported."); - inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy)); - } + inferredReturnTypes.push_back( + RankedTensorType::get(newDimensions, elementTy, encoding)); } odsState.addTypes(inferredReturnTypes); } diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 48cfb4d80b..0e3b448794 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -510,39 +510,22 @@ LogicalResult verifyReduceOpInputsAndInferShape( std::optional location, SmallVector inputTypes, ArrayRef dimensions, SmallVector& newDimensions, Attribute& encoding) { - // Check for unranked tensors in input operands. - uint64_t numInputs = inputTypes.size(); - int64_t rankedInputIdx = -1; - for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { - if (inputTypes[inputIdx].hasRank()) { - rankedInputIdx = inputIdx; - break; - } - } - bool allInputsUnranked = (rankedInputIdx == -1); // reduce_c1 - if (!allInputsUnranked) { - for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) - if (failed(mlir::verifyCompatibleShape(inputTypes[rankedInputIdx], - inputTypes[inputIdx]))) - return emitOptionalError( - location, "expects all inputs to have compatible shapes. Shape at", - " input-index ", inputIdx, - " is not compatible with shape at input-index ", rankedInputIdx); - } + auto witnessType = inputTypes[0].cast(); + for (size_t i = 1; i < inputTypes.size(); i++) + if (failed(mlir::verifyCompatibleShape(witnessType, inputTypes[i]))) + return emitOptionalError( + location, + "expects all inputs to have compatible shapes. Shape at input-index ", + i, " is not compatible with shape at input-index 0"); DenseSet dimensionsToReduceSet; for (int64_t dimension : dimensions) { // reduce_c4 - if ((!allInputsUnranked && - dimension >= inputTypes[rankedInputIdx].getRank()) || - dimension < 0) - return emitOptionalError( - location, "Out-of-bounds dimension ", dimension, ", expected to be ", - allInputsUnranked - ? "> 0" - : "less than the input-tensor rank " + - std::to_string(inputTypes[rankedInputIdx].getRank())); + if (dimension < 0 || dimension >= witnessType.getRank()) + return emitOptionalError(location, "Out-of-bounds dimension ", dimension, + ", expected to be in range [0, ", + witnessType.getRank(), ')'); // reduce_c5 if (!dimensionsToReduceSet.insert(dimension).second) @@ -550,22 +533,19 @@ LogicalResult verifyReduceOpInputsAndInferShape( "Duplicate reduction dimension: ", dimension); } - if (!allInputsUnranked) { - auto rankedInput = inputTypes[rankedInputIdx].cast(); - ArrayRef inputBounds = encodingToBounds(rankedInput.getEncoding()); - SmallVector newBounds; - for (int inputIdx = 0; inputIdx < rankedInput.getRank(); ++inputIdx) { - if (!dimensionsToReduceSet.count(inputIdx)) { - newDimensions.push_back(rankedInput.getDimSize(inputIdx)); - if (!inputBounds.empty()) newBounds.push_back(inputBounds[inputIdx]); - } + ArrayRef inputBounds = encodingToBounds(witnessType.getEncoding()); + SmallVector newBounds; + for (int inputIdx = 0; inputIdx < witnessType.getRank(); ++inputIdx) { + if (!dimensionsToReduceSet.count(inputIdx)) { + newDimensions.push_back(witnessType.getDimSize(inputIdx)); + if (!inputBounds.empty()) newBounds.push_back(inputBounds[inputIdx]); } - - // Set encoding based on the bounds only if the bounds is not empty. - encoding = nullptr; - if (!newBounds.empty()) - encoding = boundsToEncoding(rankedInput.getEncoding(), newBounds); } + + // Set encoding based on the bounds only if the bounds is not empty. + encoding = nullptr; + if (!newBounds.empty()) + encoding = boundsToEncoding(witnessType.getEncoding(), newBounds); return success(); } @@ -683,11 +663,7 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, block.getArgument(numInputs + inputIdx).getType())); Type blockArgType = block.getArgument(numInputs + inputIdx).getType(); - auto blockArgTensorTy = blockArgType.cast(); - - auto allInputsUnranked = llvm::none_of( - inputTypes, [&](ShapedType type) { return type.hasRank(); }); - if (allInputsUnranked || !blockArgTensorTy.hasRank()) return success(); + auto blockArgTensorTy = blockArgType.cast(); auto argShape = blockArgTensorTy.getShape(); // reduce_c6, reduce_window_c13, select_and_scatter_c10 @@ -733,27 +709,14 @@ LogicalResult verifyReduceWindowOpInputsAndInferWindow( if (inputTypes.empty()) return emitOptionalError(location, "requires at least 1 input value"); - // Check for unranked tensors in input operands. - uint64_t numInputs = inputTypes.size(); - int64_t rankedInputIdx = -1; - for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { - if (inputTypes[inputIdx].hasRank()) { - rankedInputIdx = inputIdx; - break; - } - } - bool allInputsUnranked = (rankedInputIdx == -1); - + auto witnessType = inputTypes[0].cast(); // reduce_window_c2 - if (!allInputsUnranked) { - for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) - if (failed(mlir::verifyCompatibleShape(inputTypes[rankedInputIdx], - inputTypes[inputIdx]))) - return emitOptionalError( - location, "expects all inputs to have compatible shapes. Shape at", - " input-index ", inputIdx, - " is not compatible with shape at input-index ", rankedInputIdx); - } + for (size_t i = 1; i < inputTypes.size(); i++) + if (failed(mlir::verifyCompatibleShape(witnessType, inputTypes[i]))) + return emitOptionalError( + location, + "expects all inputs to have compatible shapes. Shape at input-index ", + i, " is not compatible with shape at input-index 0"); // reduce_window_c12, reduce_window_i7 auto paddingOrErr = convertPaddingAttribute(padding, location); @@ -761,7 +724,6 @@ LogicalResult verifyReduceWindowOpInputsAndInferWindow( // reduce_window_c4 for (const auto inputType : inputTypes) { - if (!inputType.hasRank()) continue; if (inputType.getRank() != static_cast(windowDimensions.size())) return emitOptionalError( location, "expects window-dimensions size == input rank, but got ", @@ -1124,35 +1086,32 @@ static LogicalResult verifyGather( // gather_c10 for (int64_t i = 0; i < static_cast(startIndexMap.size()); ++i) - if (startIndexMap[i] < 0 || - (operandShape.hasRank() && startIndexMap[i] >= operandShape.getRank())) + if (startIndexMap[i] < 0 || startIndexMap[i] >= operandShape.getRank()) return emitOptionalError( location, "start_index_map[", i, "]: ", startIndexMap[i], " is out of bounds for ", "operand rank ", operandShape.getRank()); - if (startIndicesShape.hasRank()) { - // gather_c2 - // index_vector_dim == start_indices.rank implies a trailing 1 on the shape - // of start_indices. - if (indexVectorDim > startIndicesShape.getRank() || indexVectorDim < 0) - return emitOptionalError(location, "index_vector_dim ", indexVectorDim, - " is out of bounds for start indices with rank ", - startIndicesShape.getRank()); - - // gather_c3 - bool impliedTrailingDim = indexVectorDim == startIndicesShape.getRank(); - if (impliedTrailingDim || !startIndicesShape.isDynamicDim(indexVectorDim)) { - int64_t effectiveDimSize; - if (impliedTrailingDim) - effectiveDimSize = 1; - else - effectiveDimSize = startIndicesShape.getDimSize(indexVectorDim); - if (effectiveDimSize != static_cast(startIndexMap.size())) - return emitOptionalError( - location, "start_index_map size (", startIndexMap.size(), - ") is not equal to size of index dimension (", indexVectorDim, - ") of start_indices (", effectiveDimSize, ")"); - } + // gather_c2 + // index_vector_dim == start_indices.rank implies a trailing 1 on the + // shape of start_indices. + if (indexVectorDim > startIndicesShape.getRank() || indexVectorDim < 0) + return emitOptionalError(location, "index_vector_dim ", indexVectorDim, + " is out of bounds for start indices with rank ", + startIndicesShape.getRank()); + + // gather_c3 + bool impliedTrailingDim = indexVectorDim == startIndicesShape.getRank(); + if (impliedTrailingDim || !startIndicesShape.isDynamicDim(indexVectorDim)) { + int64_t effectiveDimSize; + if (impliedTrailingDim) + effectiveDimSize = 1; + else + effectiveDimSize = startIndicesShape.getDimSize(indexVectorDim); + if (effectiveDimSize != static_cast(startIndexMap.size())) + return emitOptionalError( + location, "start_index_map size (", startIndexMap.size(), + ") is not equal to size of index dimension (", indexVectorDim, + ") of start_indices (", effectiveDimSize, ")"); } // gather_c4 @@ -1175,14 +1134,14 @@ static LogicalResult verifyGather( // gather_c1 int64_t impliedOperandRank = offsetDims.size() + collapsedSliceDims.size(); - if (operandShape.hasRank() && operandShape.getRank() != impliedOperandRank) + if (operandShape.getRank() != impliedOperandRank) return emitOptionalError( location, "offset_dims size (", offsetDims.size(), ") plus collapse_slice_dims size (", collapsedSliceDims.size(), ") is not equal to operand rank (", operandShape.getRank(), ")"); // gather_i7 - if (sliceSizesShape.hasRank() && sliceSizesShape.getRank() != 1) + if (sliceSizesShape.getRank() != 1) return emitOptionalError(location, "slice_sizes.rank != 1 (got ", sliceSizesShape.getRank(), ')'); if (sliceSizesShape.hasStaticShape()) { @@ -1282,13 +1241,6 @@ static LogicalResult inferGatherReturnTypeComponents( Type elementType = operandShape.getElementType(); ShapeAdaptor startIndicesShape(startIndices.getType()); - // We need this to determine the result rank. We could still place bounds on - // the result rank if that was something ShapedTypeComponents could express. - if (!startIndicesShape.hasRank()) { - inferredReturnShapes.push_back(elementType); - return success(); - } - int64_t startIndicesRank = startIndicesShape.getRank(); // If index_vector_dim == start_indices.rank, then an implicit trailing 1 is // appended to start_indices shape. @@ -1391,7 +1343,7 @@ LogicalResult verifyDimInBounds(std::optional loc, ShapedType type, if (dim < 0) return emitOptionalError( loc, "requires non-negative dimension attribute; found (", dim, ")"); - if (type.hasRank() && dim >= type.getRank()) + if (dim >= type.getRank()) return emitOptionalError(loc, "requires dimension attribute in range [0, ", type.getRank(), "); found (", dim, ")"); return success(); @@ -1635,9 +1587,8 @@ LogicalResult inferCompareOp( inferredReturnShapes.emplace_back(IntegerType::get(context, /*width=*/1)); auto argTy = lhs.getType().cast(); // compare_c2 - if (argTy.hasRank()) - components = - ShapedTypeComponents(argTy.getShape(), components.getElementType()); + components = + ShapedTypeComponents(argTy.getShape(), components.getElementType()); return success(); } @@ -1657,54 +1608,37 @@ LogicalResult inferConcatenateOp(std::optional location, // concatenate_c4 if (dimension < 0) return emitOptionalError(location, "dimension ", dimension, " is negative"); - RankedTensorType firstRankedType; - int firstRankedIndex = -1; - for (uint64_t i = 0; i < inputTypes.size(); i++) { - auto secondType = inputTypes[i].cast(); - if (!secondType.hasRank()) continue; - if (!firstRankedType) { - firstRankedType = secondType.cast(); - firstRankedIndex = i; - // concatenate_c4 - if (firstRankedType.getRank() == 0) - return emitOptionalError(location, - "rank-0 values cannot be concatenated"); - // concatenate_c4 - if (dimension >= firstRankedType.getRank()) - return emitOptionalError(location, "dimension ", dimension, - " is out-of-bounds for input rank ", - firstRankedType.getRank()); - continue; - } - // concatenate_c2 - if (firstRankedType.getRank() != secondType.getRank()) - return emitOptionalError(location, "operands (", firstRankedIndex, - ") and (", i, ") do not match rank"); - - auto firstShape = firstRankedType.getShape(); - auto secondShape = secondType.getShape(); - for (int d = 0; d < firstRankedType.getRank(); ++d) { - // concatenate_c2 - if (d != dimension && - !verifyCompatibleDims(firstShape[d], secondShape[d])) + + auto witnessType = inputTypes[0].cast(); + int64_t rank = witnessType.getRank(); + + // concatenate_c4 + if (rank == 0) + return emitOptionalError(location, "rank-0 values cannot be concatenated"); + if (dimension >= rank) + return emitOptionalError(location, "dimension ", dimension, + " is out-of-bounds for input rank ", rank); + + // concatenate_c2 + for (size_t i = 0; i < inputTypes.size(); i++) { + auto type = inputTypes[i].cast(); + if (type.getRank() != rank) + return emitOptionalError(location, "operands (0) and (", i, + ") do not match rank rank"); + + auto witnessShape = witnessType.getShape(); + auto shape = type.getShape(); + for (int d = 0; d < rank; ++d) { + if (d != dimension && !verifyCompatibleDims(witnessShape[d], shape[d])) return emitOptionalError( - location, "shapes of operand (", firstRankedIndex, ") and (", i, - ") do not match at non-concat " - "index: (", - llvm::make_range(firstShape.begin(), firstShape.end()), ") != (", - llvm::make_range(secondShape.begin(), secondShape.end()), - ") at non-concat index ", d); + location, "shapes of operand (", 0, ") and (", i, + ") are not compatible at non-concat index ", d, ": (", + llvm::make_range(witnessShape.begin(), witnessShape.end()), + ") != (", llvm::make_range(shape.begin(), shape.end()), ")"); } } - // concatenate_c5 - auto elementType = inputTypes[0].cast().getElementType(); - if (!firstRankedType) { - inferredReturnTypes.push_back(UnrankedTensorType::get(elementType)); - return success(); - } // Infer the most specific (size, bound) of all dimensions of the return type - auto rank = firstRankedType.getRank(); SmallVector inferredSizes(rank, ShapedType::kDynamic); SmallVector inferredBounds(rank, ShapedType::kDynamic); // Note: for the concatenate dimension, 0 should be the identity element: @@ -1746,9 +1680,9 @@ LogicalResult inferConcatenateOp(std::optional location, } // concatenate_c5, concatenate_c6 inferredReturnTypes.push_back(RankedTensorType::get( - inferredSizes, elementType, + inferredSizes, witnessType.getElementType(), boundsToEncoding( - firstRankedType.getEncoding(), + witnessType.getEncoding(), // Empty array as argument is an indicator to boundsToEncoding() that // there are no bounds at all in inputs, thus sparsity attributes will // be included in the return type @@ -1767,8 +1701,7 @@ LogicalResult inferConvertOp( SmallVectorImpl& inferredReturnShapes) { auto operandType = operand.getType().cast(); // convert_c1 - inferredReturnShapes.emplace_back( - operandType.hasRank() ? operandType.getShape() : ArrayRef{}); + inferredReturnShapes.emplace_back(operandType.getShape()); return success(); } @@ -2138,16 +2071,14 @@ LogicalResult inferDynamicUpdateSliceOp( auto updateType = update.getType().cast(); // dynamic_update_slice_c3 - if (updateType.hasRank() && operandType.hasRank() && - updateType.getRank() != operandType.getRank()) + if (updateType.getRank() != operandType.getRank()) return emitOptionalError( location, "update rank does not match operand rank: ", updateType.getRank(), " vs ", operandType.getRank(), "."); // dynamic_update_slice_c4 - if (operandType.hasRank() && - (int64_t)startIndices.size() != operandType.getRank()) + if ((int64_t)startIndices.size() != operandType.getRank()) return emitOptionalError( location, "expects number of start_indices to match operand rank: ", startIndices.size(), " vs ", operandType.getRank(), "."); @@ -2158,31 +2089,27 @@ LogicalResult inferDynamicUpdateSliceOp( "start indices must have same element type"); // dynamic_update_slice_c6 - if (operandType.hasRank() && updateType.hasRank()) - for (auto [index, dims] : llvm::enumerate( - llvm::zip(operandType.getShape(), updateType.getShape()))) { - auto [operandDim, updateDim] = dims; - if (isDynamicDimSize(updateDim)) continue; - if (isStaticDimSize(operandDim)) { - if (updateDim < 0 || updateDim > operandDim) - return emitOptionalError(location, "expects size at dimension ", - index, " of update to be in range [0, ", - operandDim, "]. Got: ", updateDim, "."); - } else { - if (updateDim < 0) - return emitOptionalError( - location, "expects size at dimension ", index, - " of update to be non-negative. Got: ", updateDim, "."); - } + for (auto [index, dims] : llvm::enumerate( + llvm::zip(operandType.getShape(), updateType.getShape()))) { + auto [operandDim, updateDim] = dims; + if (isDynamicDimSize(updateDim)) continue; + if (isStaticDimSize(operandDim)) { + if (updateDim < 0 || updateDim > operandDim) + return emitOptionalError(location, "expects size at dimension ", index, + " of update to be in range [0, ", operandDim, + "]. Got: ", updateDim, "."); + } else { + if (updateDim < 0) + return emitOptionalError( + location, "expects size at dimension ", index, + " of update to be non-negative. Got: ", updateDim, "."); } + } // dynamic_update_slice_c1 - if (operandType.hasRank()) - inferredReturnShapes.emplace_back( - operandType.getShape(), operandType.getElementType(), - operandType.cast().getEncoding()); - else - inferredReturnShapes.emplace_back(operandType.getElementType()); + inferredReturnShapes.emplace_back( + operandType.getShape(), operandType.getElementType(), + operandType.cast().getEncoding()); return success(); } @@ -2312,16 +2239,14 @@ LogicalResult inferGatherOp( } // gather_c12 - if (operandShape.hasRank()) { - for (const auto& it : llvm::enumerate(sliceSizes)) { - if (operandShape.isDynamicDim(it.index())) continue; - auto operandDimSize = operandShape.getDimSize(it.index()); - auto sliceDimSize = it.value(); - if (sliceDimSize < 0 || sliceDimSize > operandDimSize) - return emitOptionalError(location, "slice size (", sliceDimSize, - ") is out of bounds for operand dimension (", - operandDimSize, ") at index ", it.index()); - } + for (const auto& it : llvm::enumerate(sliceSizes)) { + if (operandShape.isDynamicDim(it.index())) continue; + auto operandDimSize = operandShape.getDimSize(it.index()); + auto sliceDimSize = it.value(); + if (sliceDimSize < 0 || sliceDimSize > operandDimSize) + return emitOptionalError(location, "slice size (", sliceDimSize, + ") is out of bounds for operand dimension (", + operandDimSize, ") at index ", it.index()); } auto getSliceDim = [&sliceSizes](int64_t index) -> int64_t { @@ -2445,28 +2370,21 @@ LogicalResult inferMapOp( // map_c3 ArrayRef resultShape; - bool allInputsUnranked = true; for (auto operand : inputs) { auto operandType = operand.getType().cast(); - if (operandType.hasRank()) { - if (dimensions.size() != operandType.getShape().size()) - return emitOptionalError( - location, - "applied to a subset of dimensions currently not supported: " - "operand dimensions = ", - operandType.getShape().size(), - ", requested map dimensions size = ", dimensions.size()); - resultShape = operandType.getShape(); - allInputsUnranked = false; - } + if (dimensions.size() != operandType.getShape().size()) + return emitOptionalError( + location, + "applied to a subset of dimensions currently not supported: operand " + "dimensions = ", + operandType.getShape().size(), + ", requested map dimensions size = ", dimensions.size()); + resultShape = operandType.getShape(); } // map_c4 - if (allInputsUnranked) - inferredReturnShapes.emplace_back(computationOutputType.getElementType()); - else - inferredReturnShapes.emplace_back(resultShape, - computationOutputType.getElementType()); + inferredReturnShapes.emplace_back(resultShape, + computationOutputType.getElementType()); return success(); } @@ -2583,12 +2501,8 @@ LogicalResult inferReduceOp( auto accumulatorTypesOrErr = getAccumulatorTypes(location, body); if (failed(accumulatorTypesOrErr)) return failure(); for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) { - ShapedType inputType = inputArgTensorTypes[inputIdx]; Type elementType = (*accumulatorTypesOrErr)[inputIdx].getElementType(); - if (inputType.hasRank()) - inferredReturnShapes.emplace_back(newDimensions, elementType, encoding); - else - inferredReturnShapes.emplace_back(elementType); + inferredReturnShapes.emplace_back(newDimensions, elementType, encoding); } return success(); @@ -2720,7 +2634,7 @@ LogicalResult inferSelectOp( location, "requires compatible types for non-predicate operands"); // select_c1 - bool predCannotBeScalar = predType.hasRank() && predType.getRank() != 0; + bool predCannotBeScalar = predType.getRank() != 0; if (predCannotBeScalar) if (failed(verifyCompatibleShape(predType, trueType))) return emitOptionalError(location, @@ -3042,7 +2956,7 @@ LogicalResult inferUniformDequantizeOp( auto operandType = operand.getType().cast(); // Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType; auto quantType = operandType.getElementType().cast(); - auto shape = operandType.cast().getShape(); + auto shape = operandType.getShape(); // uniform_dequantize_c1, uniform_dequantize_c2 inferredReturnShapes.emplace_back(shape, quantType.getExpressedType()); return success(); @@ -3053,8 +2967,7 @@ LogicalResult inferUniformQuantizeOp( SmallVectorImpl& inferredReturnShapes) { auto operandType = operand.getType().cast(); // uniform_quantize_c1 - inferredReturnShapes.emplace_back( - operandType.hasRank() ? operandType.getShape() : ArrayRef{}); + inferredReturnShapes.emplace_back(operandType.getShape()); return success(); } @@ -3445,8 +3358,7 @@ LogicalResult verifyConvolutionOp( auto inferredShape = inferredReturnShapes[0]; auto shapedResultType = resultType.cast(); - if (inferredShape.hasRank() && shapedResultType.hasRank() && - failed(verifyCompatibleShape(inferredShape.getDims(), + if (failed(verifyCompatibleShape(inferredShape.getDims(), shapedResultType.getShape()))) return emitOptionalError(location, "inferred shape '", dimSizesToString(inferredShape.getDims()), "' ", @@ -3467,8 +3379,7 @@ LogicalResult verifyDotOp(std::optional location, auto inferredShape = inferredReturnShapes[0]; auto resultType = result.getType().cast(); - if (inferredShape.hasRank() && resultType.hasRank() && - failed(verifyCompatibleShape(inferredShape.getDims(), + if (failed(verifyCompatibleShape(inferredShape.getDims(), resultType.getShape()))) return emitOptionalError( location, "inferred shape '", dimSizesToString(inferredShape.getDims()), @@ -3493,8 +3404,7 @@ LogicalResult verifyDotGeneralOp(std::optional location, Value lhs, auto inferredShape = inferredReturnShapes[0]; auto resultType = result.getType().cast(); - if (inferredShape.hasRank() && resultType.hasRank() && - failed(verifyCompatibleShape(inferredShape.getDims(), + if (failed(verifyCompatibleShape(inferredShape.getDims(), resultType.getShape()))) return emitOptionalError( location, "inferred shape '", dimSizesToString(inferredShape.getDims()), @@ -3687,8 +3597,8 @@ LogicalResult verifyInfeedOp(HloDialectInterface* dialect, // infeed_c3 if (!dialect->isTokenType(results.back().getType())) return emitOptionalError(location, - "last element of result types is expected to " - "be of token type, but got ", + "last element of result types is expected to be " + "of token type, but got ", results.back().getType()); if (!layout.has_value()) return success(); @@ -3697,19 +3607,18 @@ LogicalResult verifyInfeedOp(HloDialectInterface* dialect, "layout-attribute expected to be of array-type."); if (layout.value().size() != resultTypes.size() - 1) - return emitOptionalError(location, "layout-attribute size must be ", - resultTypes.size() - 1, - " (which is the number of " - "op-results - 1 (for token result)), but got ", - layout.value().size()); + return emitOptionalError( + location, "layout-attribute size must be ", resultTypes.size() - 1, + " (which is the number of op-results - 1 (for token result)), but got ", + layout.value().size()); for (auto childLayout : layout.value()) { mlir::ArrayAttr childLayoutArr = childLayout.dyn_cast(); if (!childLayoutArr) - return emitOptionalError(location, - "layout-attribute expected to have " - "elements of type array, but got ", - childLayout); + return emitOptionalError( + location, + "layout-attribute expected to have elements of type array, but got ", + childLayout); for (auto i : childLayoutArr) { mlir::IntegerAttr attr = i.dyn_cast(); @@ -3727,7 +3636,6 @@ LogicalResult verifyInfeedOp(HloDialectInterface* dialect, LogicalResult verifyIotaOp(std::optional location, int64_t iotaDimension, Value result) { auto shape = result.getType().cast(); - if (!shape.hasRank()) return success(); if (shape.getRank() == 0) return emitOptionalError(location, "does not support scalars."); @@ -3862,7 +3770,6 @@ LogicalResult verifyReduceScatterOp(std::optional location, return failure(); auto resultType = result.getType().cast(); - if (!operandType.hasRank() || !resultType.hasRank()) return success(); // reduce_scatter_c8 if (operandType.getRank() != resultType.getRank()) return emitOptionalError(location, @@ -3879,10 +3786,10 @@ LogicalResult verifyReduceScatterOp(std::optional location, // reduce_scatter_c6 if (useGlobalDeviceIds && channelId <= 0) - return emitOptionalError( - location, - "channel_id must be positive when useGlobalDeviceIds is set but got: ", - channelId); + return emitOptionalError(location, + "channel_id must be positive when " + "useGlobalDeviceIds is set but got: ", + channelId); if (operandType.isDynamicDim(scatterDimension) || resultType.isDynamicDim(scatterDimension)) @@ -4233,16 +4140,15 @@ LogicalResult verifySelectAndScatterOp( location, "expects select-region to return single value, but got: ", selectResult.size()); - auto selectResultType = selectResult[0].getType().dyn_cast(); + auto selectResultType = + selectResult[0].getType().dyn_cast(); // select_and_scatter_c9 if (!selectResultType || !selectResultType.getElementType().isInteger(1) || - (selectResultType.hasRank() && - selectResultType.cast().getRank() != 0)) + selectResultType.getRank() != 0) return emitOptionalError( location, "expects the return-type of select-region to be tensor, but got: ", selectResult[0].getType()); - // select_and_scatter_c10 if (failed(verifyReducerShape( location, scatter.front(), @@ -4252,16 +4158,14 @@ LogicalResult verifySelectAndScatterOp( return failure(); auto windowDims = windowDimensionsOpt.value_or(SmallVector{}); - if (operandType.hasRank()) { - // select_and_scatter_c4 - if (operandType.getRank() != static_cast(windowDims.size())) - return emitOptionalError( - location, - "expects window-dimensions size == operand rank, but got " - "window-dimensions size: ", - windowDims.size(), " and operand-type: ", operandType, - " with rank = ", operandType.getRank(), "."); - } + // select_and_scatter_c4 + if (operandType.getRank() != static_cast(windowDims.size())) + return emitOptionalError(location, + "expects window-dimensions size == operand rank, " + "but got window-dimensions size: ", + windowDims.size(), + " and operand-type: ", operandType, + " with rank = ", operandType.getRank(), "."); auto windowStrides = windowStridesOpt.value_or(SmallVector{}); @@ -4276,12 +4180,9 @@ LogicalResult verifySelectAndScatterOp( if (failed(windowOrErr)) return failure(); ShapedType windowResultType; - if (!operandType.hasRank()) - windowResultType = UnrankedTensorType::get(operandType.getElementType()); - else - windowResultType = RankedTensorType::get( - inferWindowOutputShape(operandType.getShape(), *windowOrErr), - operandType.getElementType()); + windowResultType = RankedTensorType::get( + inferWindowOutputShape(operandType.getShape(), *windowOrErr), + operandType.getElementType()); // select_and_scatter_c1, select_and_scatter_c2 if (!compatibleShapeAndElementType(windowResultType, sourceType, @@ -4297,17 +4198,15 @@ LogicalResult verifySortOp(std::optional location, ValueRange inputs, auto operandTypes = inputs.getTypes(); for (auto operandType : operandTypes) { auto operandShapedType = operandType.cast(); - if (operandShapedType.hasRank()) { - int64_t cmpDim = dimension; - int64_t rank = operandShapedType.getRank(); - // sort_c4 - if (cmpDim < -rank || cmpDim >= rank) - return emitOptionalError( - location, "dimension attribute value must be in range [-", rank, - ", ", rank, "), but found ", cmpDim); - else - break; // ODS SameOperandsAndResultShape asserts inputs have same shape - } + int64_t cmpDim = dimension; + int64_t rank = operandShapedType.getRank(); + // sort_c4 + if (cmpDim < -rank || cmpDim >= rank) + return emitOptionalError(location, + "dimension attribute value must be in range [-", + rank, ", ", rank, "), but found ", cmpDim); + // ODS SameOperandsAndResultShape asserts inputs have same shape + break; } Block& block = comparator.front(); @@ -4339,7 +4238,7 @@ LogicalResult verifySortOp(std::optional location, ValueRange inputs, comparatorResult.size()); // sort_c5 auto comparatorResultType = comparatorResult[0].getType().cast(); - if ((comparatorResultType.hasRank() && comparatorResultType.getRank() != 0) || + if (comparatorResultType.getRank() != 0 || !comparatorResultType.getElementType().isInteger(1)) return emitOptionalError(location, "comparator must return tensor but got ", @@ -4379,8 +4278,7 @@ LogicalResult verifyWhileOp(std::optional location, condReturnTypes.size()); // while_c1 auto operandType = condReturnTypes[0].cast(); - if ((operandType.hasRank() && operandType.getRank() != 0) || - !operandType.getElementType().isInteger(1)) + if (operandType.getRank() != 0 || !operandType.getElementType().isInteger(1)) return emitOptionalError( location, "expect condition block return a zero-ranked tensor of i1 but got ", diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 0c80dfc59b..9732c8bf0c 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -1501,7 +1501,7 @@ func.func @concatenate_c4(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor func.func @concatenate_c6(%arg0: tensor<1x3xi32>, %arg1: tensor<2x2xi32>) -> tensor<3x3xi32> { // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{shapes of operand (0) and (1) do not match at non-concat index: (1, 3) != (2, 2) at non-concat index 1}} + // expected-error@+1 {{shapes of operand (0) and (1) are not compatible at non-concat index 1: (1, 3) != (2, 2)}} %0 = "stablehlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x3xi32>, tensor<2x2xi32>) -> tensor<3x3xi32> func.return %0 : tensor<3x3xi32> } diff --git a/stablehlo/tests/verify_reduce.mlir b/stablehlo/tests/verify_reduce.mlir index 49b8d9fa08..2e2a7d2023 100644 --- a/stablehlo/tests/verify_reduce.mlir +++ b/stablehlo/tests/verify_reduce.mlir @@ -166,7 +166,7 @@ func.func @reduce_c2(%arg0: tensor, %arg1: tensor, func.func @reduce_c4(%arg0: tensor, %arg1 : tensor) -> (tensor) { - // expected-error@+1 {{Out-of-bounds dimension -1, expected to be less than the input-tensor rank 2}} + // expected-error@+1 {{Out-of-bounds dimension -1, expected to be in range [0, 2)}} %0 = "stablehlo.reduce"(%arg0, %arg1) ({ ^bb0(%arg2: tensor, %arg3: tensor ): @@ -182,7 +182,7 @@ func.func @reduce_c4(%arg0: tensor, %arg1 : tensor) func.func @reduce_c4(%arg0: tensor, %arg1 : tensor) -> (tensor) { - // expected-error@+1 {{Out-of-bounds dimension 2, expected to be less than the input-tensor rank 2}} + // expected-error@+1 {{Out-of-bounds dimension 2, expected to be in range [0, 2)}} %0 = "stablehlo.reduce"(%arg0, %arg1) ({ ^bb0(%arg2: tensor, %arg3: tensor ): diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index 7d50781439..5800bc5277 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -276,8 +276,7 @@ template LogicalResult evalElementwise(PatternRewriter& rewriter, OpType op, FuncType fn) { auto resultType = op.getType(); - if (!resultType.hasRank() || - !resultType.getElementType().template isa()) + if (!resultType.getElementType().template isa()) return rewriter.notifyMatchFailure(op, "expected integer result tensor type"); @@ -345,7 +344,7 @@ struct EvalBroadcastInDimOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(BroadcastInDimOp op, PatternRewriter& rewriter) const override { auto operandType = op.getOperand().getType(); - if (!operandType.hasRank() || operandType.getRank() != 0) + if (operandType.getRank() != 0) return rewriter.notifyMatchFailure(op, "expected 0-dimensional type"); SmallVector operand; @@ -463,7 +462,7 @@ struct EvalConcatenateOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(ConcatenateOp op, PatternRewriter& rewriter) const override { auto resultType = op.getType(); - if (!resultType.hasRank() || op.getDimension() != 0) + if (op.getDimension() != 0) return rewriter.notifyMatchFailure(op, "expected dimension = 0"); SmallVector result; @@ -508,8 +507,6 @@ struct EvalGetDimensionSizeOpPattern LogicalResult matchAndRewrite(GetDimensionSizeOp op, PatternRewriter& rewriter) const override { auto operandType = op.getOperand().getType(); - if (!operandType.hasRank()) - return rewriter.notifyMatchFailure(op, "expected ranked operand"); if (operandType.isDynamicDim(op.getDimension())) return rewriter.notifyMatchFailure(op, "expected static dimension"); @@ -694,8 +691,6 @@ struct RefineAllGatherOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(AllGatherOp op, PatternRewriter& rewriter) const override { auto operandType = op.getOperand().getType(); - if (!operandType.hasRank()) - return rewriter.notifyMatchFailure(op, "expected ranked operand type"); // This represents the cross_replica_and_partition process grouping strategy // that requires num_partitions to compute shardCount. Since we don't know @@ -718,8 +713,6 @@ struct RefineBitcastConvertOpPattern LogicalResult matchAndRewrite(BitcastConvertOp op, PatternRewriter& rewriter) const override { auto operandType = op.getOperand().getType(); - if (!operandType.hasRank()) - return rewriter.notifyMatchFailure(op, "expected ranked operand type"); // If bit widths of the operand and the result are different, then // operand and result shapes have different ranks. @@ -1016,8 +1009,6 @@ struct RefineReduceScatterOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(ReduceScatterOp op, PatternRewriter& rewriter) const override { auto operandType = op.getOperand().getType(); - if (!operandType.hasRank()) - return rewriter.notifyMatchFailure(op, "expected ranked operand type"); // This represents the cross_replica_and_partition process grouping strategy // that requires num_partitions to compute shardCount. Since we don't know