Skip to content

Commit

Permalink
FFT: Add raw pointer interfaces
Browse files Browse the repository at this point in the history
This allows the users to use amrex::FFT without using any amrex specific
data container.
  • Loading branch information
WeiqunZhang committed Feb 23, 2025
1 parent bfd1f11 commit bfaa584
Show file tree
Hide file tree
Showing 16 changed files with 912 additions and 43 deletions.
117 changes: 117 additions & 0 deletions Docs/sphinx_documentation/source/FFT.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,120 @@ support non-uniform cell size in the z-direction. For most applications,
Similar to :cpp:`FFT::R2C`, the Poisson solvers should be cached for reuse,
and one might need to use :cpp:`std::unique_ptr<FFT::Poisson<MultiFab>>`
because there is no default constructor.

.. _sec:FFT:rawptr:

Raw Pointer Interface
=====================

If you only want to use AMReX as an FFT library without using other
functionalities and data containers, you could use the raw pointer
interface. Below is an example.

.. highlight:: c++

::

MPI_Init(&argc, &argv);

// We don't need to call the full-blown amrex::Initialize
amrex::Init_FFT(MPI_COMM_WORLD);

int nprocs, myproc;
MPI_Comm_size(MPI_COMM_WORLD, &nprocs);
MPI_Comm_rank(MPI_COMM_WORLD, &myproc);

using RT = double;
using CT = std::complex<RT>; // or cufftDoubleComplex, etc.

std::array<int,3> domain_size{128,128,128};

// FFT between real and complex.
// Domain decomposition is flexible. The only constraint for the raw
// pointer interface is that there can be only zero or one local box
// per process, whereas the MultiFab interface can take any number of
// boxes. In this case, we choose to do manual domain decomposition for
// the real (i.e., forward) domain, and use the domain decomposition
// provided by amrex for the complex (i.e., backward) domain.
{
amrex::FFT::R2C<RT,amrex::FFT::Direction::both> r2c(domain_size);

int nx = (domain_size[0] + nprocs - 1) / nprocs;
int xlo = nx * myproc;
nx = std::max(std::min(nx,domain_size[0]-xlo), 0);
std::array<int,3> local_start{xlo,0,0};
std::array<int,3> local_size{nx,domain_size[1],domain_size[2]};

// Let amrex know the domain decomposition in the forward domain.
r2c.setLocalDomain(local_start,local_size);

// Use amrex's domain decomposition in the backward domain.
auto const& [local_start_sp, local_size_sp] = r2c.getLocalSpectralDomain();

auto nr = std::size_t(local_size[0])
* std::size_t(local_size[1])
* std::size_t(local_size[2]);
auto* pr = (RT*)std::malloc(sizeof(RT)*nr); // or use cudaMalloc
// Initialize data ...

auto nc = std::size_t(local_size_sp[0])
* std::size_t(local_size_sp[1])
* std::size_t(local_size_sp[2]);
auto* pc = (CT*)std::malloc(sizeof(CT)*nc); // or use cudaMalloc

r2c.forward(pr, pc); // forward transform from real to complex

// work on the complex data pointed by pc ...

r2c.backward(pc, pr); // backward transform from complex to real

std::free(pr);
std::free(pc);
}

// Batched FFT between complex and complex.
// In this case, we choose to use the domain decomposition provided
// by amrex for the forward domain, and do manual domain decomposition
// for the backward domain.
int nbatch = 3; // batch size
{
amrex::FFT::Info info{};
info.setBatchSize(nbatch);
amrex::FFT::C2C<RT,amrex::FFT::Direction::both> c2c(domain_size,info);

// Use amrex's domain decomposition in the forward domain.
auto const& [local_start, local_size] = c2c.getLocalDomain();

int nx = (domain_size[0] + nprocs - 1) / nprocs;
int xlo = nx * myproc;
nx = std::max(std::min(nx,domain_size[0]-xlo), 0);
std::array<int,3> local_start_sp{xlo,0,0};
std::array<int,3> local_size_sp{nx,domain_size[1],domain_size[2]};

// Let amrex know the domain decomposition in the backward domain.
c2c.setLocalSpectralDomain(local_start_sp, local_size_sp);

auto nf = std::size_t(local_size[0])
* std::size_t(local_size[1])
* std::size_t(local_size[2]);
auto* pf = (CT*)std::malloc(sizeof(CT)*nf*nbatch); // or use cudaMalloc
// Initialize data ...

auto nb = std::size_t(local_size_sp[0])
* std::size_t(local_size_sp[1])
* std::size_t(local_size_sp[2]);
auto* pb = (CT*)std::malloc(sizeof(CT)*nb*nbatch);

c2c.forward(pf, pb); // forward transform

// work on the data pointed by pb

c2c.backward(pb, pf); // backward transform

std::free(pf);
std::free(pb);
}

amrex::Finalize_FFT();

MPI_Finalize();
16 changes: 16 additions & 0 deletions Src/Base/AMReX.H
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ namespace amrex
std::ostream& a_oserr = std::cerr,
ErrorHandler a_errhandler = nullptr);

// \brief Minimal version of initialization.
//
// This version is intended for users who only need AMReX for some
// specific functionalities such as FFT. It's the user's responsibility
// to initialize MPI. For multiple-GPU systems, it's the user's
// responsibility to properly set the GPU devices to be used. We will not
// try to pre-allocate memory arenas. We will not install a signal
// handler. Functionalities like random number generator and async I/O
// will be work. However, functionalities like FFT and linear solvers do
// work.
void Init_minimal (MPI_Comm mpi_comm = MPI_COMM_WORLD);

/**
\brief Returns true if there are any currently-active and initialized
AMReX instances (i.e. one for which amrex::Initialize has been called,
Expand All @@ -96,6 +108,10 @@ namespace amrex

void Finalize (AMReX* pamrex);
void Finalize (); // Finalize the current top
// For initialization with Init_minimal, Finalize_minimal should be used
// for finalization.
void Finalize_minimal();

/**
* \brief We maintain a stack of functions that need to be called in Finalize().
* The functions are called in LIFO order. The idea here is to allow
Expand Down
81 changes: 64 additions & 17 deletions Src/Base/AMReX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ namespace system
}
}

namespace {
long long init_minimal_called = 0;
bool initialization_by_init_minimal = false;
}

namespace {
std::string command_line;
std::vector<std::string> command_arguments;
Expand Down Expand Up @@ -339,20 +344,37 @@ amrex::Initialize (int& argc, char**& argv, bool build_parm_parse,
ErrorHandler a_errhandler)
{
system::exename.clear();
if (initialization_by_init_minimal) {
system::verbose = 0;
system::regtest_reduction = false;
system::signal_handling = false;
system::handle_sigsegv = false;
system::handle_sigterm = false;
system::handle_sigint = false;
system::handle_sigabrt = false;
system::handle_sigfpe = false;
system::handle_sigill = false;
system::call_addr2line = false;
system::throw_exception = false;
system::osout = &std::cout;
system::oserr = &std::cerr;
system::error_handler = nullptr;
} else {
// system::verbose = 0;
system::regtest_reduction = false;
system::signal_handling = true;
system::handle_sigsegv = true;
system::handle_sigterm = false;
system::handle_sigint = true;
system::handle_sigabrt = true;
system::handle_sigfpe = true;
system::handle_sigill = true;
system::call_addr2line = true;
system::throw_exception = false;
system::osout = &a_osout;
system::oserr = &a_oserr;
system::error_handler = a_errhandler;
system::regtest_reduction = false;
system::signal_handling = true;
system::handle_sigsegv = true;
system::handle_sigterm = false;
system::handle_sigint = true;
system::handle_sigabrt = true;
system::handle_sigfpe = true;
system::handle_sigill = true;
system::call_addr2line = true;
system::throw_exception = false;
system::osout = &a_osout;
system::oserr = &a_oserr;
system::error_handler = a_errhandler;
}

ParallelDescriptor::StartParallel(&argc, &argv, mpi_comm);

Expand Down Expand Up @@ -510,7 +532,7 @@ amrex::Initialize (int& argc, char**& argv, bool build_parm_parse,

#ifdef AMREX_USE_GPU
// Initialize after ParmParse so that we can read inputs.
Gpu::Device::Initialize();
Gpu::Device::Initialize(initialization_by_init_minimal);
#ifdef AMREX_USE_CUPTI
CuptiInitialize();
#endif
Expand Down Expand Up @@ -640,13 +662,15 @@ amrex::Initialize (int& argc, char**& argv, bool build_parm_parse,
ParallelDescriptor::Initialize();

BL_TINY_PROFILE_MEMORYINITIALIZE();
Arena::Initialize();
Arena::Initialize(initialization_by_init_minimal);
amrex_mempool_init();

//
// Initialize random seed after we're running in parallel.
//
amrex::InitRandom(ParallelDescriptor::MyProc()+1, ParallelDescriptor::NProcs());
if (!initialization_by_init_minimal) {
amrex::InitRandom(ParallelDescriptor::MyProc()+1, ParallelDescriptor::NProcs());
}

// For thread safety, we should do these initializations here.
BaseFab_Initialize();
Expand All @@ -658,7 +682,9 @@ amrex::Initialize (int& argc, char**& argv, bool build_parm_parse,
MultiFab::Initialize();
iMultiFab::Initialize();
VisMF::Initialize();
AsyncOut::Initialize();
if (!initialization_by_init_minimal) {
AsyncOut::Initialize();
}
VectorGrowthStrategy::Initialize();

#ifdef AMREX_USE_FFT
Expand Down Expand Up @@ -998,4 +1024,25 @@ FPExcept enableFPExcept (FPExcept excepts)
return prev;
}

void Init_minimal (MPI_Comm mpi_comm)
{
++init_minimal_called;

if (Initialized()) { return; }

initialization_by_init_minimal = true;
Initialize(mpi_comm);
}

void Finalize_minimal ()
{
if (init_minimal_called > 0) {
--init_minimal_called;
}
if (init_minimal_called == 0 && initialization_by_init_minimal) {
Finalize();
initialization_by_init_minimal = false;
}
}

}
2 changes: 1 addition & 1 deletion Src/Base/AMReX_Arena.H
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public:
*/
static std::size_t align (std::size_t sz);

static void Initialize ();
static void Initialize (bool minimal);
static void PrintUsage ();
static void PrintUsageToFiles (std::string const& filename, std::string const& message);
static void Finalize ();
Expand Down
15 changes: 11 additions & 4 deletions Src/Base/AMReX_Arena.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace {
Arena* the_cpu_arena = nullptr;
Arena* the_comms_arena = nullptr;

Long the_arena_init_size = 0L;
Long the_arena_init_size = 1024*1024*8;
Long the_device_arena_init_size = 1024*1024*8;
Long the_managed_arena_init_size = 1024*1024*8;
Long the_pinned_arena_init_size = 1024*1024*8;
Expand Down Expand Up @@ -280,7 +280,7 @@ namespace {
}

void
Arena::Initialize ()
Arena::Initialize (bool minimal)
{
if (initialized) { return; }
initialized = true;
Expand All @@ -294,11 +294,18 @@ Arena::Initialize ()
BL_ASSERT(the_cpu_arena == nullptr || the_cpu_arena == The_BArena());
BL_ASSERT(the_comms_arena == nullptr || the_comms_arena == The_BArena());

if (minimal) {
the_pinned_arena_init_size = 0;
} else {
#ifdef AMREX_USE_GPU
the_arena_init_size = Gpu::Device::totalGlobalMem() / Gpu::Device::numDevicePartners() / 4L * 3L;
the_arena_init_size = Gpu::Device::totalGlobalMem() / Gpu::Device::numDevicePartners() / 4L * 3L;
#ifdef AMREX_USE_SYCL
the_arena_init_size = std::min(the_arena_init_size, Gpu::Device::maxMemAllocSize());
the_arena_init_size = std::min(the_arena_init_size, Gpu::Device::maxMemAllocSize());
#endif
#endif
}

#ifdef AMREX_USE_GPU
the_pinned_arena_release_threshold = Gpu::Device::totalGlobalMem() / Gpu::Device::numDevicePartners() / 2L;
#endif

Expand Down
4 changes: 2 additions & 2 deletions Src/Base/AMReX_GpuDevice.H
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Device

public:

static void Initialize ();
static void Initialize (bool minimal);
static void Finalize ();

#if defined(AMREX_USE_GPU)
Expand Down Expand Up @@ -184,7 +184,7 @@ public:

private:

static void initialize_gpu ();
static void initialize_gpu (bool minimal);

static AMREX_EXPORT int device_id;
static AMREX_EXPORT int num_devices_used;
Expand Down
Loading

0 comments on commit bfaa584

Please sign in to comment.