Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for Fft (#2225)
Browse files Browse the repository at this point in the history
The logic is the same as for unary ops, except that if the type of FFT
is RFFT or IRFFT then the last `fft_length.size()` dimensions in the
operand need to be static (for speculatability), because there are extra
constraints on them that could turn out to be false at runtime if they
are dynamic.

Indeed, there is a general constraint that `shape(result) =
shape(operand)`. For RFFT, the operand type element type is float, so
the last `fft_length.size()` dims in the operand have to be static,
because `fft_length` is static (it is an attribute). For IRFFT, the
result element type is float, so the last `fft_length.size()` dims in
the result are inferred from the `fft_length`, and they need to match
the operand, so those dims have to be static in the operand. There are
also constraints on the last dimension, but `size(fft_length) >= 1` so
that is already covered by the previous check.

The relevant constraints from the spec are:
```
(C4) If among operand and result, there is a tensor real of a floating-point type, then shape(real)[-size(fft_length):] = fft_length.
(C5) shape(result) = shape(operand) except for:
* If fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
* If fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.
```
  • Loading branch information
mlevesquedion authored Apr 17, 2024
1 parent 6515513 commit 9fb78c1
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
18 changes: 18 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,24 @@ LogicalResult FftOp::inferReturnTypeComponents(
adaptor.getFftLength(), inferredReturnShapes);
}

mlir::Speculation::Speculatability FftOp::getSpeculatability() {
// This is the same logic as SpeculatableIfStaticDimInOutputIsStaticInInput,
// except that for RFFT and IRFFT the last `fft_length.size()` dimensions in
// the operand need to be static.
auto inputType = getOperand().getType();
auto resultType = getType();
size_t minStaticDim = inputType.getRank();
if (getFftType() == FftType::RFFT || getFftType() == FftType::IRFFT)
minStaticDim = minStaticDim - getFftLength().size();
for (size_t i : llvm::seq(inputType.getRank())) {
if (i >= minStaticDim && inputType.isDynamicDim(i))
return mlir::Speculation::NotSpeculatable;
if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(i))
return mlir::Speculation::NotSpeculatable;
}
return mlir::Speculation::Speculatable;
}

//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 8 additions & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2455,7 +2455,9 @@ def StableHLO_UnaryEinsumOp: StableHLO_Op<"unary_einsum", [Pure]> {
}];
}

def StableHLO_FftOp: StableHLO_Op<"fft", [InferTensorType, Pure]> {
def StableHLO_FftOp: StableHLO_Op<"fft",
[ConditionallySpeculatable, NoMemoryEffect,
InferTensorType]> {
let summary = "Fft operation";
let description = [{
Performs the forward and inverse Fourier transforms for real and complex
Expand All @@ -2481,6 +2483,11 @@ def StableHLO_FftOp: StableHLO_Op<"fft", [InferTensorType, Pure]> {
$operand `,` `type` `=` $fft_type `,` `length` `=` custom<DenseI64Array>($fft_length)
attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_GatherOp: StableHLO_Op<"gather",
Expand Down
47 changes: 47 additions & 0 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,53 @@ func.func @all_gather(%static_arg: tensor<2x2xf64>, %dynamic_arg: tensor<?x?xf64

// -----

// CHECK-LABEL: func @fft
// CHECK-NEXT: return
func.func @fft(
%static_arg: tensor<3x9xcomplex<f64>>, %dynamic_arg: tensor<?x?xcomplex<f64>>,
%static_rfft_arg: tensor<3x9xf64>, %dynamic_rfft_arg: tensor<?x?xf64>,
%static_irfft_arg: tensor<3x5xcomplex<f64>>, %dynamic_irfft_arg: tensor<?x?xcomplex<f64>>
) {
%fft_0 = stablehlo.fft %static_arg, type = FFT, length = [9] : (tensor<3x9xcomplex<f64>>) -> tensor<3x9xcomplex<f64>>
%fft_1 = stablehlo.fft %static_arg, type = FFT, length = [9] : (tensor<3x9xcomplex<f64>>) -> tensor<?x?xcomplex<f64>>
%fft_2 = stablehlo.fft %dynamic_arg, type = FFT, length = [9] : (tensor<?x?xcomplex<f64>>) -> tensor<3x9xcomplex<f64>>
%fft_3 = stablehlo.fft %dynamic_arg, type = FFT, length = [9] : (tensor<?x?xcomplex<f64>>) -> tensor<?x?xcomplex<f64>>
"hlo_test_speculatability.is_speculatable"(%fft_0) : (tensor<3x9xcomplex<f64>>) -> ()
"hlo_test_speculatability.is_speculatable"(%fft_1) : (tensor<?x?xcomplex<f64>>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%fft_2) : (tensor<3x9xcomplex<f64>>) -> ()
"hlo_test_speculatability.is_speculatable"(%fft_3) : (tensor<?x?xcomplex<f64>>) -> ()

%ifft_0 = stablehlo.fft %static_arg, type = IFFT, length = [9] : (tensor<3x9xcomplex<f64>>) -> tensor<3x9xcomplex<f64>>
%ifft_1 = stablehlo.fft %static_arg, type = IFFT, length = [9] : (tensor<3x9xcomplex<f64>>) -> tensor<?x?xcomplex<f64>>
%ifft_2 = stablehlo.fft %dynamic_arg, type = IFFT, length = [9] : (tensor<?x?xcomplex<f64>>) -> tensor<3x9xcomplex<f64>>
%ifft_3 = stablehlo.fft %dynamic_arg, type = IFFT, length = [9] : (tensor<?x?xcomplex<f64>>) -> tensor<?x?xcomplex<f64>>
"hlo_test_speculatability.is_speculatable"(%ifft_0) : (tensor<3x9xcomplex<f64>>) -> ()
"hlo_test_speculatability.is_speculatable"(%ifft_1) : (tensor<?x?xcomplex<f64>>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%ifft_2) : (tensor<3x9xcomplex<f64>>) -> ()
"hlo_test_speculatability.is_speculatable"(%ifft_3) : (tensor<?x?xcomplex<f64>>) -> ()

%rfft_0 = stablehlo.fft %static_rfft_arg, type = RFFT, length = [9] : (tensor<3x9xf64>) -> tensor<3x5xcomplex<f64>>
%rfft_1 = stablehlo.fft %static_rfft_arg, type = RFFT, length = [9] : (tensor<3x9xf64>) -> tensor<?x?xcomplex<f64>>
%rfft_2 = stablehlo.fft %dynamic_rfft_arg, type = RFFT, length = [9] : (tensor<?x?xf64>) -> tensor<3x5xcomplex<f64>>
%rfft_3 = stablehlo.fft %dynamic_rfft_arg, type = RFFT, length = [9] : (tensor<?x?xf64>) -> tensor<?x?xcomplex<f64>>
"hlo_test_speculatability.is_speculatable"(%rfft_0) : (tensor<3x5xcomplex<f64>>) -> ()
"hlo_test_speculatability.is_speculatable"(%rfft_1) : (tensor<?x?xcomplex<f64>>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%rfft_2) : (tensor<3x5xcomplex<f64>>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%rfft_3) : (tensor<?x?xcomplex<f64>>) -> ()

%irfft_0 = stablehlo.fft %static_irfft_arg, type = IRFFT, length = [9] : (tensor<3x5xcomplex<f64>>) -> tensor<3x9xf64>
%irfft_1 = stablehlo.fft %static_irfft_arg, type = IRFFT, length = [9] : (tensor<3x5xcomplex<f64>>) -> tensor<?x?xf64>
%irfft_2 = stablehlo.fft %dynamic_irfft_arg, type = IRFFT, length = [9] : (tensor<?x?xcomplex<f64>>) -> tensor<3x9xf64>
%irfft_3 = stablehlo.fft %dynamic_irfft_arg, type = IRFFT, length = [9] : (tensor<?x?xcomplex<f64>>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%irfft_0) : (tensor<3x9xf64>) -> ()
"hlo_test_speculatability.is_speculatable"(%irfft_1) : (tensor<?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%irfft_2) : (tensor<3x9xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%irfft_3) : (tensor<?x?xf64>) -> ()
return
}

// -----

// CHECK-LABEL: func @reduce_scatter
// CHECK-NEXT: return
func.func @reduce_scatter(%static_arg: tensor<2x4xf64>, %dynamic_arg: tensor<?x?xf64>) {
Expand Down

0 comments on commit 9fb78c1

Please sign in to comment.