Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFT: Support complex to complex transforms #4329

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Docs/sphinx_documentation/source/FFT.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ in an :cpp:`FFT::Info` object passed to the constructor of

r2c.backward(cmf, mf);

.. _sec:FFT:c2c:

FFT::C2C Class
==============

:cpp:`FFT::C2C` is a class template that supports complex to complex Fourier
transforms. It has a similar interface as :cpp:`FFT::R2C`.

.. _sec:FFT:localr2c:

FFT::LocalR2C Class
Expand Down
78 changes: 63 additions & 15 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct Info

//! For automatic strategy, this is the size per process below which we
//! switch from slab to pencil.
int pencil_threshold = 8;
int pencil_threshold = 4;

//! Supported only in 3D. When twod_mode is true, FFT is performed on
//! the first two dimensions only and the third dimension size is the
Expand Down Expand Up @@ -310,7 +310,7 @@ struct Plan
void init_r2c (IntVectND<M> const& fft_size, void*, void*, bool cache, int ncomp = 1);

template <Direction D>
void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1)
void init_c2c (Box const& box, VendorComplex* p, int ncomp = 1, int ndims = 1)
{
static_assert(D == Direction::forward || D == Direction::backward);

Expand All @@ -319,9 +319,35 @@ struct Plan
pf = (void*)p;
pb = (void*)p;

n = box.length(0);
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
howmany *= ncomp;
int len[3] = {};

if (ndims == 1) {
n = box.length(0);
howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2));
howmany *= ncomp;
len[0] = box.length(0);
}
#if (AMREX_SPACEDIM >= 2)
else if (ndims == 2) {
n = box.length(0) * box.length(1);
#if (AMREX_SPACEDIM == 2)
howmany = ncomp;
#else
howmany = box.length(2) * ncomp;
#endif
len[0] = box.length(1);
len[1] = box.length(0);
}
#if (AMREX_SPACEDIM == 3)
else if (ndims == 3) {
n = box.length(0) * box.length(1) * box.length(2);
howmany = ncomp;
len[0] = box.length(2);
len[1] = box.length(1);
len[2] = box.length(0);
}
#endif
#endif

#if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftCreate(&plan));
Expand All @@ -330,22 +356,39 @@ struct Plan
cufftType t = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
std::size_t work_size;
AMREX_CUFFT_SAFE_CALL
(cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));
(cufftMakePlanMany(plan, ndims, len, nullptr, 1, n, nullptr, 1, n, t, howmany, &work_size));

#elif defined(AMREX_USE_HIP)

auto prec = std::is_same_v<float,T> ? rocfft_precision_single
: rocfft_precision_double;
auto dir= (D == Direction::forward) ? rocfft_transform_type_complex_forward
: rocfft_transform_type_complex_inverse;
const std::size_t length = n;
std::size_t length[3];
if (ndims == 1) {
length[0] = len[0];
} else if (ndims == 2) {
length[0] = len[1];
length[1] = len[0];
} else {
length[0] = len[2];
length[1] = len[1];
length[2] = len[0];
}
AMREX_ROCFFT_SAFE_CALL
(rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, 1,
&length, howmany, nullptr));
(rocfft_plan_create(&plan, rocfft_placement_inplace, dir, prec, ndims,
length, howmany, nullptr));

#elif defined(AMREX_USE_SYCL)

auto* pp = new mkl_desc_c(n);
mkl_desc_c* pp;
if (ndims == 1) {
pp = new mkl_desc_c(n);
} else if (ndims == 2) {
pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1])});
} else {
pp = new mkl_desc_c({std::int64_t(len[0]), std::int64_t(len[1]), std::int64_t(len[2])});
}
#ifndef AMREX_USE_MKL_DFTI_2024
pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::INPLACE);
Expand All @@ -355,7 +398,12 @@ struct Plan
pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany);
pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n);
pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n);
std::vector<std::int64_t> strides = {0,1};
std::vector<std::int64_t> strides(ndims+1);
strides[0] = 0;
strides[ndims] = 1;
for (int i = ndims-1; i >= 1; --i) {
strides[i] = strides[i+1] * len[ndims-1-i];
}
#ifndef AMREX_USE_MKL_DFTI_2024
pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides);
pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides);
Expand All @@ -373,21 +421,21 @@ struct Plan
if constexpr (std::is_same_v<float,T>) {
if constexpr (D == Direction::forward) {
plan = fftwf_plan_many_dft
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
FFTW_ESTIMATE);
} else {
plan = fftwf_plan_many_dft
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
FFTW_ESTIMATE);
}
} else {
if constexpr (D == Direction::forward) {
plan = fftw_plan_many_dft
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, -1,
FFTW_ESTIMATE);
} else {
plan = fftw_plan_many_dft
(1, &n, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
(ndims, len, howmany, p, nullptr, 1, n, p, nullptr, 1, n, +1,
FFTW_ESTIMATE);
}
}
Expand Down
Loading
Loading