From b9118430c760e1283d2e48791e80065b7dcc0db5 Mon Sep 17 00:00:00 2001 From: mlevesquedion Date: Thu, 18 Apr 2024 11:19:22 -0700 Subject: [PATCH] Implement ConditionallySpeculatable for DynamicReshape (#2231) --- stablehlo/dialect/StablehloOps.cpp | 23 +++++++++++++++++- stablehlo/dialect/StablehloOps.td | 8 +++++- stablehlo/dialect/TypeInference.cpp | 19 ++++++++++++++- stablehlo/dialect/TypeInference.h | 3 ++- stablehlo/tests/ops_speculatability.mlir | 31 ++++++++++++++++++++++++ stablehlo/tests/ops_stablehlo.mlir | 20 +++++++-------- 6 files changed, 90 insertions(+), 14 deletions(-) diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 0dfde89765..dc94f410f5 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1496,7 +1496,8 @@ mlir::Speculation::Speculatability ConcatenateOp::getSpeculatability() { //===----------------------------------------------------------------------===// LogicalResult DynamicReshapeOp::verify() { - return hlo::verifyDynamicReshapeOp(getLoc(), getOutputShape(), getResult()); + return hlo::verifyDynamicReshapeOp(getLoc(), getOperand(), getOutputShape(), + getResult()); } LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( @@ -1508,6 +1509,26 @@ LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( return success(); } +mlir::Speculation::Speculatability DynamicReshapeOp::getSpeculatability() { + // If the output type's shape is fully dynamic, there is no expectation + // for the shape so the op is speculatable. + if (llvm::all_of(llvm::seq(getType().getRank()), + [this](int64_t i) { return getType().isDynamicDim(i); })) + return mlir::Speculation::Speculatable; + + // If the input is static and the shape operand is constant, the output + // shape can be inferred and any mismatch will be caught statically. + // If any dimension in the input is dynamic, the number of elements may + // disagree with either the output. + // If the shape operand is not constant, it could disagree with the output, + // which has at least 1 static dimension at this point in the function. + if (getOperand().getType().hasStaticShape() && + matchPattern(getOutputShape(), m_Constant())) + return mlir::Speculation::Speculatable; + + return mlir::Speculation::NotSpeculatable; +} + //===----------------------------------------------------------------------===// // DynamicSliceOp //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 701a35c0a4..2f5384b60d 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2637,7 +2637,8 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape", }]; } -def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [Pure]> { +def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", + [ConditionallySpeculatable, NoMemoryEffect]> { let summary = "DynamicReshape operation"; let description = [{ This operation is a work in progress, so it is not yet included in @@ -2659,6 +2660,11 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [ let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = commonClassDeclaration # [{ + /// Interface method for ConditionallySpeculatable. + mlir::Speculation::Speculatability getSpeculatability(); + }]; } def StableHLO_ScatterOp: StableHLO_Op<"scatter", diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 1198a1d37f..42c181e441 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3780,13 +3780,30 @@ LogicalResult verifyDynamicPadOp(std::optional location, } LogicalResult verifyDynamicReshapeOp(std::optional location, - Value outputShape, Value result) { + Value operand, Value outputShape, + Value result) { auto resultType = cast(result.getType()); auto outputShapeType = cast(outputShape.getType()); if (outputShapeType.getDimSize(0) != resultType.getRank()) return emitOptionalError(location, "output should have a rank equal to the number of " "elements in output_shape"); + + auto operandType = cast(operand.getType()); + if (SmallVector shape; operandType.hasStaticShape() && + matchInts(outputShape, shape).succeeded()) { + int64_t operandCount = operandType.getNumElements(); + int64_t shapeCount = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + if (operandCount != shapeCount) { + return emitOptionalError(location, + "output_shape is incompatible with input type " + "of operation: input has ", + operandCount, " elements, but output_shape has ", + shapeCount); + } + } + if (!isCompatibleForHloTypeInference(outputShape, resultType)) return emitOptionalError( location, "output_shape is incompatible with return type of operation ", diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index 92ffbe07e2..b6e4bc18b3 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -457,7 +457,8 @@ LogicalResult verifyDynamicPadOp(std::optional location, Value interiorPadding, Value result); LogicalResult verifyDynamicReshapeOp(std::optional location, - Value outputShape, Value result); + Value operand, Value outputShape, + Value result); LogicalResult verifyInfeedOp(HloDialectInterface* dialect, std::optional location, diff --git a/stablehlo/tests/ops_speculatability.mlir b/stablehlo/tests/ops_speculatability.mlir index f142c77922..7457903fb0 100644 --- a/stablehlo/tests/ops_speculatability.mlir +++ b/stablehlo/tests/ops_speculatability.mlir @@ -1701,6 +1701,37 @@ func.func @dynamic_iota(%unknown_shape: tensor<2xi32>) { // ----- +// CHECK-LABEL: func @dynamic_reshape +// CHECK-NEXT: return +func.func @dynamic_reshape( + %static_arg: tensor<4x5xf64>, %dynamic_arg: tensor, + %unknown_shape: tensor<2xi32> +) { + %constant_shape = stablehlo.constant dense<[5, 4]> : tensor<2xi32> + + // Static input, constant shape + %0 = stablehlo.dynamic_reshape %static_arg, %constant_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64> + "hlo_test_speculatability.is_speculatable"(%0) : (tensor<5x4xf64>) -> () + %1 = stablehlo.dynamic_reshape %static_arg, %constant_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Dynamic input + %2 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor, tensor<2xi32>) -> tensor<5x4xf64> + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<5x4xf64>) -> () + %3 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor, tensor<2xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%3) : (tensor) -> () + + // Unknown shape + %4 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64> + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<5x4xf64>) -> () + %5 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%5) : (tensor) -> () + + return +} + +// ----- + // Recursively speculatable ops // ----- diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 3d187a3b0e..97b431f951 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -3164,18 +3164,9 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor, %shape: ten // ----- -func.func @dynamic_reshape_output_shape_negative_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}} - %0 = stablehlo.constant dense<[-1, 1]> : tensor<2xi64> - %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> - return %1 : tensor<1x4xf32> -} - -// ----- - func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { // expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}} - %0 = stablehlo.constant dense<[1, 1]> : tensor<2xi64> + %0 = stablehlo.constant dense<[2, 2]> : tensor<2xi64> %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> return %1 : tensor<1x4xf32> } @@ -3190,6 +3181,15 @@ func.func @dynamic_reshape_dynamic_output_shape(%arg0: tensor, %shape: te // ----- +func.func @dynamic_reshape_input_count_mismatch_shape_count(%arg0: tensor<2x5xf32>) -> tensor { + %0 = stablehlo.constant dense<[2, 3, 4]> : tensor<3xi32> + // expected-error@+1 {{output_shape is incompatible with input type of operation: input has 10 elements, but output_shape has 24}} + %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<2x5xf32>, tensor<3xi32>) -> tensor + return %1 : tensor +} + +// ----- + func.func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { %0 = "stablehlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32> func.return %0 : tensor<2x4xf32>