Skip to content

Commit

Permalink
MLMG: Use free functions instead of MF member functions (AMReX-Codes#…
Browse files Browse the repository at this point in the history
…3681)

Note that the use of unqualified functions (e.g., setVal instead of
amrex::setVal) is intentional. With ADL, these calls in MLMG could work
with user defined data.
  • Loading branch information
WeiqunZhang authored Dec 21, 2023
1 parent 3407e87 commit 75571e2
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 160 deletions.
72 changes: 36 additions & 36 deletions Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class MLCGSolverT
{
public:

using FAB = typename MF::fab_type;
using RT = typename MF::value_type;
using FAB = typename MLLinOpT<MF>::FAB;
using RT = typename MLLinOpT<MF>::RT;

enum struct Type { BiCGStab, CG };

Expand Down Expand Up @@ -99,12 +99,12 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
{
BL_PROFILE("MLCGSolver::bicgstab");

const int ncomp = sol.nComp();
const int ncomp = nComp(sol);

MF p = Lp.make(amrlev, mglev, sol.nGrowVect());
MF r = Lp.make(amrlev, mglev, sol.nGrowVect());
p.setVal(RT(0.0)); // Make sure all entries are initialized to avoid errors
r.setVal(RT(0.0));
MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
MF r = Lp.make(amrlev, mglev, nGrowVect(sol));
setVal(p, RT(0.0)); // Make sure all entries are initialized to avoid errors
setVal(r, RT(0.0));

MF rh = Lp.make(amrlev, mglev, nghost);
MF v = Lp.make(amrlev, mglev, nghost);
Expand All @@ -114,19 +114,19 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
MF sorig;

if ( initial_vec_zeroed ) {
r.LocalCopy(rhs,0,0,ncomp,nghost);
LocalCopy(r,rhs,0,0,ncomp,nghost);
} else {
sorig = Lp.make(amrlev, mglev, nghost);

Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);

sorig.LocalCopy(sol,0,0,ncomp,nghost);
sol.setVal(RT(0.0));
LocalCopy(sorig,sol,0,0,ncomp,nghost);
setVal(sol, RT(0.0));
}

// Then normalize
Lp.normalize(amrlev, mglev, r);
rh.LocalCopy (r ,0,0,ncomp,nghost);
LocalCopy(rh, r, 0,0,ncomp,nghost);

RT rnorm = norm_inf(r);
const RT rnorm0 = rnorm;
Expand Down Expand Up @@ -159,13 +159,13 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
}
if ( iter == 1 )
{
p.LocalCopy(r,0,0,ncomp,nghost);
LocalCopy(p,r,0,0,ncomp,nghost);
}
else
{
const RT beta = (rho/rho_1)*(alpha/omega);
MF::Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
}
Lp.apply(amrlev, mglev, v, p, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
Lp.normalize(amrlev, mglev, v);
Expand All @@ -179,8 +179,8 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
{
ret = 2; break;
}
MF::Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
MF::Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v
Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v

rnorm = norm_inf(r);
rnorm = norm_inf(r);
Expand Down Expand Up @@ -216,8 +216,8 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
{
ret = 3; break;
}
MF::Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r
MF::Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t
Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r
Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t

rnorm = norm_inf(r);

Expand Down Expand Up @@ -257,14 +257,14 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
{
if ( !initial_vec_zeroed ) {
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
}
}
else
{
sol.setVal(RT(0.0));
setVal(sol, RT(0.0));
if ( !initial_vec_zeroed ) {
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
}
}

Expand All @@ -277,25 +277,25 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
{
BL_PROFILE("MLCGSolver::cg");

const int ncomp = sol.nComp();
const int ncomp = nComp(sol);

MF p = Lp.make(amrlev, mglev, sol.nGrowVect());
p.setVal(RT(0.0));
MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
setVal(p, RT(0.0));

MF r = Lp.make(amrlev, mglev, nghost);
MF q = Lp.make(amrlev, mglev, nghost);

MF sorig;

if ( initial_vec_zeroed ) {
r.LocalCopy(rhs,0,0,ncomp,nghost);
LocalCopy(r,rhs,0,0,ncomp,nghost);
} else {
sorig = Lp.make(amrlev, mglev, nghost);

Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);

sorig.LocalCopy(sol,0,0,ncomp,nghost);
sol.setVal(RT(0.0));
LocalCopy(sorig,sol,0,0,ncomp,nghost);
setVal(sol, RT(0.0));
}

RT rnorm = norm_inf(r);
Expand Down Expand Up @@ -330,12 +330,12 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
}
if (iter == 1)
{
p.LocalCopy(r,0,0,ncomp,nghost);
LocalCopy(p,r,0,0,ncomp,nghost);
}
else
{
RT beta = rho/rho_1;
MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p
Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p
}
Lp.apply(amrlev, mglev, q, p, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);

Expand All @@ -357,8 +357,8 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
<< " rho " << rho
<< " alpha " << alpha << '\n';
}
MF::Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
MF::Saxpy(r, -alpha, q, 0, 0, ncomp, nghost); // r += -alpha * q
Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
Saxpy(r, -alpha, q, 0, 0, ncomp, nghost); // r += -alpha * q
rnorm = norm_inf(r);

if ( verbose > 2 )
Expand Down Expand Up @@ -393,14 +393,14 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
{
if ( !initial_vec_zeroed ) {
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
}
}
else
{
sol.setVal(RT(0.0));
setVal(sol, RT(0.0));
if ( !initial_vec_zeroed ) {
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
}
}

Expand All @@ -422,8 +422,8 @@ template <typename MF>
auto
MLCGSolverT<MF>::norm_inf (const MF& res, bool local) -> RT
{
int ncomp = res.nComp();
RT result = res.norminf(0,ncomp,IntVect(0),true);
int ncomp = nComp(res);
RT result = norminf(res,0,ncomp,IntVect(0),true);
if (!local) {
BL_PROFILE("MLCGSolver::ParallelAllReduce");
ParallelAllReduce::Max(result, Lp.BottomCommunicator());
Expand Down
83 changes: 60 additions & 23 deletions Src/LinearSolvers/MLMG/AMReX_MLLinOp.H
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ struct LinOpEnumType
enum struct Location { FaceCenter, FaceCentroid, CellCenter, CellCentroid };
};

template <typename T, class Enable = void> struct LinOpData {};
//
template <typename T>
struct LinOpData <T, std::enable_if_t<IsMultiFabLike_v<T> > >
{
using fab_type = typename T::fab_type;
using value_type = typename T::value_type;
};

template <typename T> class MLMGT;
template <typename T> class MLCGSolverT;
template <typename T> class MLPoissonT;
Expand All @@ -100,8 +109,8 @@ public:
template <typename T> friend class MLPoissonT;
template <typename T> friend class MLABecLaplacianT;

using FAB = typename MF::fab_type;
using RT = typename MF::value_type;
using FAB = typename LinOpData<MF>::fab_type;
using RT = typename LinOpData<MF>::value_type;

using BCType = LinOpBCType;
using BCMode = LinOpEnumType::BCMode;
Expand Down Expand Up @@ -1375,53 +1384,81 @@ template <typename MF>
void
MLLinOpT<MF>::make (Vector<Vector<MF> >& mf, IntVect const& ng) const
{
mf.clear();
mf.resize(m_num_amr_levels);
for (int alev = 0; alev < m_num_amr_levels; ++alev) {
mf[alev].resize(m_num_mg_levels[alev]);
for (int mlev = 0; mlev < m_num_mg_levels[alev]; ++mlev) {
mf[alev][mlev] = make(alev, mlev, ng);
if constexpr (IsMultiFabLike_v<MF>) {
mf.clear();
mf.resize(m_num_amr_levels);
for (int alev = 0; alev < m_num_amr_levels; ++alev) {
mf[alev].resize(m_num_mg_levels[alev]);
for (int mlev = 0; mlev < m_num_mg_levels[alev]; ++mlev) {
mf[alev][mlev] = make(alev, mlev, ng);
}
}
} else {
amrex::ignore_unused(mf, ng);
amrex::Abort("MLLinOpT::make: how did we get here?");
}
}

template <typename MF>
MF
MLLinOpT<MF>::make (int amrlev, int mglev, IntVect const& ng) const
{
return MF(amrex::convert(m_grids[amrlev][mglev], m_ixtype),
m_dmap[amrlev][mglev], getNComp(), ng, MFInfo(),
*m_factory[amrlev][mglev]);
if constexpr (IsMultiFabLike_v<MF>) {
return MF(amrex::convert(m_grids[amrlev][mglev], m_ixtype),
m_dmap[amrlev][mglev], getNComp(), ng, MFInfo(),
*m_factory[amrlev][mglev]);
} else {
amrex::ignore_unused(amrlev, mglev, ng);
amrex::Abort("MLLinOpT::make: how did we get here?");
return {};
}
}

template <typename MF>
MF
MLLinOpT<MF>::makeAlias (MF const& mf) const
{
return MF(mf, amrex::make_alias, 0, mf.nComp());
if constexpr (IsMultiFabLike_v<MF>) {
return MF(mf, amrex::make_alias, 0, mf.nComp());
} else {
amrex::ignore_unused(mf);
amrex::Abort("MLLinOpT::makeAlias: how did we get here?");
return {};
}
}

template <typename MF>
MF
MLLinOpT<MF>::makeCoarseMG (int amrlev, int mglev, IntVect const& ng) const
{
BoxArray cba = m_grids[amrlev][mglev];
IntVect ratio = (amrlev > 0) ? IntVect(2) : mg_coarsen_ratio_vec[mglev];
cba.coarsen(ratio);
cba.convert(m_ixtype);
return MF(cba, m_dmap[amrlev][mglev], getNComp(), ng);

if constexpr (IsMultiFabLike_v<MF>) {
BoxArray cba = m_grids[amrlev][mglev];
IntVect ratio = (amrlev > 0) ? IntVect(2) : mg_coarsen_ratio_vec[mglev];
cba.coarsen(ratio);
cba.convert(m_ixtype);
return MF(cba, m_dmap[amrlev][mglev], getNComp(), ng);
} else {
amrex::ignore_unused(amrlev, mglev, ng);
amrex::Abort("MLLinOpT::makeCoarseMG: how did we get here?");
return {};
}
}

template <typename MF>
MF
MLLinOpT<MF>::makeCoarseAmr (int famrlev, IntVect const& ng) const
{
BoxArray cba = m_grids[famrlev][0];
IntVect ratio(AMRRefRatio(famrlev-1));
cba.coarsen(ratio);
cba.convert(m_ixtype);
return MF(cba, m_dmap[famrlev][0], getNComp(), ng);
if constexpr (IsMultiFabLike_v<MF>) {
BoxArray cba = m_grids[famrlev][0];
IntVect ratio(AMRRefRatio(famrlev-1));
cba.coarsen(ratio);
cba.convert(m_ixtype);
return MF(cba, m_dmap[famrlev][0], getNComp(), ng);
} else {
amrex::ignore_unused(famrlev, ng);
amrex::Abort("MLLinOpT::makeCoarseAmr: how did we get here?");
return {};
}
}

template <typename MF>
Expand Down
Loading

0 comments on commit 75571e2

Please sign in to comment.