Skip to content

Commit

Permalink
Update GMRES/MLMG interface (#3779)
Browse files Browse the repository at this point in the history
For the curl curl test, if beta = 1e-9 alpha/dx^2, the multigrid solver
is able to reduce the residual by 10 orders of magnitude in 10 v-cycles.
But for beta = 1e-14 alpha/dx^2, the multigrid solver's residual will
stall at about 5e-6 of the original residual. However, it can be solved
using GMRES with multigrid as preconditioner.
  • Loading branch information
WeiqunZhang authored Mar 4, 2024
1 parent c440e4e commit cf712eb
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 51 deletions.
39 changes: 39 additions & 0 deletions Src/Base/AMReX_FabArrayUtility.H
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,13 @@ void setBndry (MF& dst, typename MF::value_type val, int scomp, int ncomp)
dst.setBndry(val, scomp, ncomp);
}

//! dst *= val
template <class MF, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void Scale (MF& dst, typename MF::value_type val, int scomp, int ncomp, int nghost)
{
dst.mult(val, scomp, ncomp, nghost);
}

//! dst = src
template <class DMF, class SMF,
std::enable_if_t<IsMultiFabLike_v<DMF> &&
Expand Down Expand Up @@ -1650,6 +1657,16 @@ void Xpay (MF& dst, typename MF::value_type a, MF const& src, int scomp, int dco
MF::Xpay(dst, a, src, scomp, dcomp, ncomp, nghost);
}

//! dst = a*src_a + b*src_b
template <class MF, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void LinComb (MF& dst,
typename MF::value_type a, MF const& src_a, int acomp,
typename MF::value_type b, MF const& src_b, int bcomp,
int dcomp, int ncomp, IntVect const& nghost)
{
MF::LinComb(dst, a, src_a, acomp, b, src_b, bcomp, dcomp, ncomp, nghost);
}

//! dst = src w/ MPI communication
template <class MF, std::enable_if_t<IsMultiFabLike_v<MF>, int> = 0>
void ParallelCopy (MF& dst, MF const& src, int scomp, int dcomp, int ncomp,
Expand Down Expand Up @@ -1686,6 +1703,16 @@ void setBndry (Array<MF,N>& dst, typename MF::value_type val, int scomp, int nco
}
}

//! dst *= val
template <class MF, std::size_t N, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void Scale (Array<MF,N>& dst, typename MF::value_type val, int scomp, int ncomp,
int nghost)
{
for (auto& mf : dst) {
mf.mult(val, scomp, ncomp, nghost);
}
}

//! dst = src
template <class DMF, class SMF, std::size_t N,
std::enable_if_t<IsMultiFabLike_v<DMF> &&
Expand Down Expand Up @@ -1730,6 +1757,18 @@ void Xpay (Array<MF,N>& dst, typename MF::value_type a,
}
}

//! dst = a*src_a + b*src_b
template <class MF, std::size_t N, std::enable_if_t<IsMultiFabLike_v<MF>,int> = 0>
void LinComb (Array<MF,N>& dst,
typename MF::value_type a, Array<MF,N> const& src_a, int acomp,
typename MF::value_type b, Array<MF,N> const& src_b, int bcomp,
int dcomp, int ncomp, IntVect const& nghost)
{
for (std::size_t i = 0; i < N; ++i) {
MF::LinComb(dst[i], a, src_a[i], acomp, b, src_b[i], bcomp, dcomp, ncomp, nghost);
}
}

//! dst = src w/ MPI communication
template <class MF, std::size_t N, std::enable_if_t<IsMultiFabLike_v<MF>, int> = 0>
void ParallelCopy (Array<MF,N>& dst, Array<MF,N> const& src,
Expand Down
23 changes: 15 additions & 8 deletions Src/LinearSolvers/AMReX_GMRES.H
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ namespace amrex {
* - void precond(V& lhs, V const& rhs)\n
* applies preconditioner to rhs. If there is no preconditioner,
* this function should do lhs = rhs.
* - void setVal(V& v, RT value)\n
* v = value.
* - void setToZero(V& v)\n
* v = 0.
*/
template <typename V, typename M>
class GMRES
{
public:

using RT = typename V::value_type; // double or float
using RT = typename M::RT; // double or float

GMRES ();

Expand All @@ -87,6 +87,9 @@ public:
//! Sets restart length. The default is 30.
void setRestartLength (int rl);

//! Sets the number of iterations
void setNumIters (int niters) { m_maxiter = niters; }

//! Gets the number of iterations.
[[nodiscard]] int getNumIters () const { return m_its; }

Expand Down Expand Up @@ -202,9 +205,9 @@ void GMRES<V,M>::solve (V& a_sol, V const& a_rhs, RT a_tol_rel, RT a_tol_abs, in
m_v_tmp_lhs = std::make_unique<V>(m_linop->makeVecLHS());
}
if (m_vv.empty()) {
m_vv.resize(m_restrtlen+1);
for (auto& v : m_vv) {
v = m_linop->makeVecRHS();
m_vv.reserve(m_restrtlen+1);
for (int i = 0; i < 2; ++i) { // to save space, start with just 2
m_vv.emplace_back(m_linop->makeVecRHS());
}
}

Expand All @@ -216,7 +219,7 @@ void GMRES<V,M>::solve (V& a_sol, V const& a_rhs, RT a_tol_rel, RT a_tol_abs, in
auto rnorm0 = RT(0);

m_linop->assign(m_vv[0], a_rhs);
m_linop->setVal(a_sol, RT(0.0));
m_linop->setToZero(a_sol);

m_its = 0;
m_status = -1;
Expand Down Expand Up @@ -269,6 +272,10 @@ void GMRES<V,M>::cycle (V& a_xx, int& a_status, int& a_itcount, RT& a_rnorm0)

if (a_status == 0) { break; }

while (m_vv.size() < it+2) {
m_vv.emplace_back(m_linop->makeVecRHS());
}

auto const& vv_it = m_vv[it ];
auto & vv_it1 = m_vv[it+1];

Expand Down Expand Up @@ -384,7 +391,7 @@ void GMRES<V,M>::build_solution (V& a_xx, int const it)
m_grs[k] = tt / m_hh(k,k);
}

m_linop->setVal(*m_v_tmp_rhs, RT(0.0));
m_linop->setToZero(*m_v_tmp_rhs);
for (int ii = 0; ii < it+1; ++ii) {
m_linop->increment(*m_v_tmp_rhs, m_vv[ii], m_grs[ii]);
}
Expand Down
39 changes: 23 additions & 16 deletions Src/LinearSolvers/AMReX_GMRES_MLMG.H
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class GMRESMLMGT
{
public:
using MF = typename M::MFType; // typically MultiFab
using RT = typename MF::value_type; // double or float
using RT = typename M::RT; // double or float

explicit GMRESMLMGT (M& mlmg);

Expand All @@ -29,8 +29,8 @@ public:

RT dotProduct (MF const& mf1, MF const& mf2) const;

//! lhs = value
static void setVal (MF& lhs, RT value);
//! lhs = 0
static void setToZero (MF& lhs);

//! lhs = rhs
static void assign (MF& lhs, MF const& rhs);
Expand Down Expand Up @@ -58,6 +58,8 @@ template <typename M>
GMRESMLMGT<M>::GMRESMLMGT (M& mlmg)
: m_mlmg(mlmg), m_linop(mlmg.getLinOp())
{
m_mlmg.setVerbose(0);
m_mlmg.setBottomVerbose(0);
m_mlmg.prepareLinOp();
}

Expand All @@ -71,7 +73,7 @@ template <typename M>
auto GMRESMLMGT<M>::makeVecLHS () const -> MF
{
auto mf = m_linop.make(0, 0, IntVect(1));
mf.setBndry(0);
setBndry(mf, RT(0), 0, nComp(mf));
return mf;
}

Expand All @@ -85,7 +87,7 @@ auto GMRESMLMGT<M>::norm2 (MF const& mf) const -> RT
template <typename M>
void GMRESMLMGT<M>::scale (MF& mf, RT scale_factor)
{
mf.mult(scale_factor, 0, mf.nComp());
Scale(mf, scale_factor, 0, nComp(mf), 0);
}

template <typename M>
Expand All @@ -95,27 +97,27 @@ auto GMRESMLMGT<M>::dotProduct (MF const& mf1, MF const& mf2) const -> RT
}

template <typename M>
void GMRESMLMGT<M>::setVal (MF& lhs, RT value)
void GMRESMLMGT<M>::setToZero (MF& lhs)
{
lhs.setVal(value);
setVal(lhs, RT(0.0));
}

template <typename M>
void GMRESMLMGT<M>::assign (MF& lhs, MF const& rhs)
{
MF::Copy(lhs, rhs, 0, 0, lhs.nComp(), IntVect(0));
LocalCopy(lhs, rhs, 0, 0, nComp(lhs), IntVect(0));
}

template <typename M>
void GMRESMLMGT<M>::increment (MF& lhs, MF const& rhs, RT a)
{
MF::Saxpy(lhs, a, rhs, 0, 0, lhs.nComp(), IntVect(0));
Saxpy(lhs, a, rhs, 0, 0, nComp(lhs), IntVect(0));
}

template <typename M>
void GMRESMLMGT<M>::linComb (MF& lhs, RT a, MF const& rhs_a, RT b, MF const& rhs_b)
{
MF::LinComb(lhs, a, rhs_a, 0, b, rhs_b, 0, 0, lhs.nComp(), IntVect(0));
LinComb(lhs, a, rhs_a, 0, b, rhs_b, 0, 0, nComp(lhs), IntVect(0));
}

template <typename M>
Expand All @@ -130,13 +132,18 @@ template <typename M>
void GMRESMLMGT<M>::precond (MF& lhs, MF const& rhs) const
{
if (m_use_precond) {
// for now, let's just do some smoothing
lhs.setVal(RT(0.0));
for (int m = 0; m < 4; ++m) {
m_linop.smooth(0, 0, lhs, rhs, (m==0) ? true : false);
}
AMREX_ALWAYS_ASSERT(m_linop.NAMRLevels() == 1);

m_mlmg.prepareMGcycle();

LocalCopy(m_mlmg.res[0][0], rhs, 0, 0, nComp(rhs), IntVect(0));

m_mlmg.mgVcycle(0,0);

LocalCopy(lhs, m_mlmg.cor[0][0], 0, 0, nComp(rhs), IntVect(0));

} else {
amrex::Copy(lhs, rhs, 0, 0, lhs.nComp(), IntVect(0));
LocalCopy(lhs, rhs, 0, 0, nComp(lhs), IntVect(0));
}
}

Expand Down
11 changes: 7 additions & 4 deletions Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.H
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@ namespace amrex {
* scalar, and beta is a non-negative scalar.
*
* It's the caller's responsibility to make sure rhs has consistent nodal
* data. If needed, one could use FabArray::OverrideSync to synchronize
* nodal data.
* data. If needed, one could call prepareRHS for this.
*
* The smoother is based on the 4-color Gauss-Seidel smoother of Li
* et. al. 2020. "An Efficient Preconditioner for 3-D Finite Difference
* Modeling of the Electromagnetic Diffusion Process in the Frequency
* Domain", IEEE Transactions on Geoscience and Remote Sensing, 58, 500-509.
*
* TODO: If beta is zero, the system could be singular.
*/
class MLCurlCurl
: public MLLinOpT<Array<MultiFab,3> >
Expand All @@ -48,6 +45,12 @@ public:

void setScalars (RT a_alpha, RT a_beta) noexcept;

//! Synchronize RHS on nodal points and set to zero on Dirichlet
//! boundaries. If the user can guarantee these requirements on RHS,
//! this function does not need to be called. If this is called, it
//! should only be called after setDomainBC is called.
void prepareRHS (Vector<MF*> const& rhs) const;

[[nodiscard]] std::string name () const override {
return std::string("curl of curl");
}
Expand Down
56 changes: 56 additions & 0 deletions Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,62 @@ void MLCurlCurl::setScalars (RT a_alpha, RT a_beta) noexcept
{
m_alpha = a_alpha;
m_beta = a_beta;
AMREX_ASSERT(m_beta > RT(0));
}

void MLCurlCurl::prepareRHS (Vector<MF*> const& rhs) const
{
MFItInfo mfi_info{};
#ifdef AMREX_USE_GPU
Vector<Array4BoxTag<RT>> tags;
mfi_info.DisableDeviceSync();
#endif

for (int amrlev = 0; amrlev < m_num_amr_levels; ++amrlev) {
for (auto& mf : *rhs[amrlev]) {
mf.OverrideSync(m_geom[amrlev][0].periodicity());

auto const idxtype = mf.ixType();
Box const domain = amrex::convert(m_geom[amrlev][0].Domain(), idxtype);

#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion())
#endif
for (MFIter mfi(mf,mfi_info); mfi.isValid(); ++mfi) {
auto const& vbx = mfi.validbox();
auto const& a = mf.array(mfi);
for (OrientationIter oit; oit; ++oit) {
Orientation const face = oit();
int const idim = face.coordDir();
bool is_dirichlet = face.isLow()
? m_lobc[0][idim] == LinOpBCType::Dirichlet
: m_hibc[0][idim] == LinOpBCType::Dirichlet;
if (is_dirichlet && domain[face] == vbx[face] &&
idxtype.nodeCentered(idim))
{
Box b = vbx;
b.setRange(idim, vbx[face], 1);
#ifdef AMREX_USE_GPU
tags.emplace_back(Array4BoxTag<RT>{a,b});
#else
amrex::LoopOnCpu(b, [&] (int i, int j, int k)
{
a(i,j,k) = RT(0.0);
});
#endif
}
}
}
}
}

#ifdef AMREX_USE_GPU
ParallelFor(tags,
[=] AMREX_GPU_DEVICE (int i, int j, int k, Array4BoxTag<RT> const& tag) noexcept
{
tag.dfab(i,j,k) = RT(0.0);
});
#endif
}

void MLCurlCurl::setLevelBC (int amrlev, const MF* levelbcdata, // TODO
Expand Down
Loading

0 comments on commit cf712eb

Please sign in to comment.