Skip to content

Commit

Permalink
Eliminating Matrix operations in MLMG CG bottom solver if initial vec…
Browse files Browse the repository at this point in the history
…tor is zero (#3668)

A matrix multiplication and a few copy operations can be avoided if the
input vector is zero. MLMG calls all the the bottom solvers with zeroed
`x` vector, and thus the initial residual calculation `b - Ax` is `b`.
Furthermore, it also eliminates the memory requirement of storing the
initial vector.
  • Loading branch information
ankithadas authored Dec 20, 2023
1 parent ef38229 commit 85462ce
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
60 changes: 46 additions & 14 deletions Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ public:
void setMaxIter (int _maxiter) { maxiter = _maxiter; }
[[nodiscard]] int getMaxIter () const { return maxiter; }


/**
* Is the initial guess provided to the solver zero ?
* If so, set this to true.
* The solver will avoid a few operations if this is true.
* Default is false.
*/
void setInitSolnZeroed (bool _sol_zeroed) { initial_vec_zeroed = _sol_zeroed; }
[[nodiscard]] bool getInitSolnZeroed () const { return initial_vec_zeroed; }

void setNGhost(int _nghost) {nghost = IntVect(_nghost);}
[[nodiscard]] int getNGhost() {return nghost[0];}

Expand All @@ -62,6 +72,7 @@ private:
int maxiter = 100;
IntVect nghost = IntVect(0);
int iter = -1;
bool initial_vec_zeroed = false;
};

template <typename MF>
Expand Down Expand Up @@ -95,21 +106,28 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
p.setVal(RT(0.0)); // Make sure all entries are initialized to avoid errors
r.setVal(RT(0.0));

MF sorig = Lp.make(amrlev, mglev, nghost);
MF rh = Lp.make(amrlev, mglev, nghost);
MF v = Lp.make(amrlev, mglev, nghost);
MF t = Lp.make(amrlev, mglev, nghost);

Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);

MF sorig;

if ( initial_vec_zeroed ) {
r.LocalCopy(rhs,0,0,ncomp,nghost);
} else {
sorig = Lp.make(amrlev, mglev, nghost);

Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);

sorig.LocalCopy(sol,0,0,ncomp,nghost);
sol.setVal(RT(0.0));
}

// Then normalize
Lp.normalize(amrlev, mglev, r);

sorig.LocalCopy(sol,0,0,ncomp,nghost);
rh.LocalCopy (r ,0,0,ncomp,nghost);

sol.setVal(RT(0.0));

RT rnorm = norm_inf(r);
const RT rnorm0 = rnorm;

Expand Down Expand Up @@ -238,12 +256,16 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)

if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
{
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
if ( !initial_vec_zeroed ) {
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
}
}
else
{
sol.setVal(RT(0.0));
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
if ( !initial_vec_zeroed ) {
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
}
}

return ret;
Expand All @@ -260,15 +282,21 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
MF p = Lp.make(amrlev, mglev, sol.nGrowVect());
p.setVal(RT(0.0));

MF sorig = Lp.make(amrlev, mglev, nghost);
MF r = Lp.make(amrlev, mglev, nghost);
MF q = Lp.make(amrlev, mglev, nghost);

sorig.LocalCopy(sol,0,0,ncomp,nghost);
MF sorig;

if ( initial_vec_zeroed ) {
r.LocalCopy(rhs,0,0,ncomp,nghost);
} else {
sorig = Lp.make(amrlev, mglev, nghost);

Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);

sol.setVal(RT(0.0));
sorig.LocalCopy(sol,0,0,ncomp,nghost);
sol.setVal(RT(0.0));
}

RT rnorm = norm_inf(r);
const RT rnorm0 = rnorm;
Expand Down Expand Up @@ -364,12 +392,16 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)

if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
{
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
if ( !initial_vec_zeroed ) {
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
}
}
else
{
sol.setVal(RT(0.0));
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
if ( !initial_vec_zeroed ) {
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
}
}

return ret;
Expand Down
1 change: 1 addition & 0 deletions Src/LinearSolvers/MLMG/AMReX_MLMG.H
Original file line number Diff line number Diff line change
Expand Up @@ -1526,6 +1526,7 @@ MLMGT<MF>::bottomSolveWithCG (MF& x, const MF& b, typename MLCGSolverT<MF>::Type
cg_solver.setSolver(type);
cg_solver.setVerbose(bottom_verbose);
cg_solver.setMaxIter(bottom_maxiter);
cg_solver.setInitSolnZeroed(true);
if (cf_strategy == CFStrategy::ghostnodes) { cg_solver.setNGhost(linop.getNGrow()); }

int ret = cg_solver.solve(x, b, bottom_reltol, bottom_abstol);
Expand Down

0 comments on commit 85462ce

Please sign in to comment.