Skip to content

Commit

Permalink
Add FFT Poisson solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Oct 18, 2024
1 parent e40a9ec commit e26cf17
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 91 deletions.
93 changes: 69 additions & 24 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,19 @@ public:
template <typename F>
void forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward)
{
this->forward_doit(inmf);
this->forward(inmf);
this->post_forward_doit(post_forward);
this->backward_doit(outmf);
this->backward(outmf);
}

void forward (MF const& inmf, Scaling scaling = Scaling::none);
void forward (MF const& inmf, cMF& outmf, Scaling scaling = Scaling::none);

void backward (MF& outmf, Scaling scaling = Scaling::none);
void backward (cMF const& inmf, MF& outmf, Scaling scaling = Scaling::none);

std::pair<cMF*,IntVect> getSpectralData ();

struct Swap01
{
[[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 operator() (Dim3 i) const noexcept
Expand Down Expand Up @@ -153,9 +161,6 @@ private:
}
}

void forward_doit (MF const& inmf, Scaling scaling = Scaling::none);
void backward_doit (MF& outmf, Scaling scaling = Scaling::none);

static void exec_r2c (Plan plan, MF& in, cMF& out);
static void exec_c2r (Plan plan, cMF& in, MF& out);
template <Direction direction>
Expand All @@ -175,10 +180,10 @@ private:
// Comm meta-data. In the forward phase, we start with (x,y,z),
// transpose to (y,x,z) and then (z,x,y). In the backward phase, we
// perform inverse transpose.
std::unique_ptr<MultiBlockCommMetaData> m_cmd_x2y;
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2x;
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2z;
std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2y;
std::unique_ptr<MultiBlockCommMetaData> m_cmd_x2y; // (x,y,z) -> (y,x,z)
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2x; // (y,x,z) -> (x,y,z)
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2z; // (y,x,z) -> (z,x,y)
std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2y; // (z,x,y) -> (y,x,z)
Swap01 m_dtos_x2y{};
Swap01 m_dtos_y2x{};
Swap02 m_dtos_y2z{};
Expand Down Expand Up @@ -232,12 +237,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
int nprocs = ParallelDescriptor::NProcs();

auto bax = amrex::decompose(m_real_domain, nprocs, {AMREX_D_DECL(false,true,true)});
DistributionMapping dmx;
{
Vector<int> pm(bax.size());
std::iota(pm.begin(), pm.end(), 0);
dmx.define(std::move(pm));
}
DistributionMapping dmx = detail::make_iota_distromap(bax.size());
m_rx.define(bax, dmx, 1, 0);

{
Expand Down Expand Up @@ -346,9 +346,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
if (cbay.size() == dmx.size()) {
cdmy = dmx;
} else {
Vector<int> pm(cbay.size());
std::iota(pm.begin(), pm.end(), 0);
cdmy.define(std::move(pm));
cdmy = detail::make_iota_distromap(cbay.size());
}
m_cy.define(cbay, cdmy, 1, 0);

Expand All @@ -365,7 +363,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)

#if (AMREX_SPACEDIM == 3)
if (m_real_domain.length(1) > 1 &&
(! m_info.batch_mode || m_real_domain.length(2) > 1))
(! m_info.batch_mode && m_real_domain.length(2) > 1))
{
auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs, {false,true,true});
DistributionMapping cdmz;
Expand All @@ -374,9 +372,7 @@ R2C<T>::R2C (Box const& domain, Info const& info)
} else if (cbaz.size() == cdmy.size()) {
cdmz = cdmy;
} else {
Vector<int> pm(cbaz.size());
std::iota(pm.begin(), pm.end(), 0);
cdmz.define(std::move(pm));
cdmz = detail::make_iota_distromap(cbaz.size());
}
m_cz.define(cbaz, cdmz, 1, 0);

Expand Down Expand Up @@ -563,8 +559,10 @@ void R2C<T>::exec_c2c (Plan2 plan, cMF& inout)
}

template <typename T>
void R2C<T>::forward_doit (MF const& inmf, Scaling /*scaling*/)
void R2C<T>::forward (MF const& inmf, Scaling scaling)
{
AMREX_ALWAYS_ASSERT(scaling == Scaling::none); // xxxxx TODO

m_rx.ParallelCopy(inmf, 0, 0, 1);
exec_r2c(m_fft_fwd_x, m_rx, m_cx);

Expand All @@ -580,8 +578,10 @@ void R2C<T>::forward_doit (MF const& inmf, Scaling /*scaling*/)
}

template <typename T>
void R2C<T>::backward_doit (MF& outmf, Scaling /*scaling*/)
void R2C<T>::backward (MF& outmf, Scaling scaling)
{
AMREX_ALWAYS_ASSERT(scaling == Scaling::none); // xxxxx TODO

exec_c2c<Direction::backward>(m_fft_bwd_z, m_cz);
if ( m_cmd_z2y) {
ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y);
Expand Down Expand Up @@ -716,6 +716,51 @@ void R2C<T>::post_forward_doit (F const& post_forward)
}
}

template <typename T>
std::pair<typename R2C<T>::cMF *, IntVect>
R2C<T>::getSpectralData ()
{
if (!m_cz.empty()) {
return std::make_pair(&m_cz, IntVect{AMREX_D_DECL(2,0,1)});
} else if (!m_cy.empty()) {
return std::make_pair(&m_cy, IntVect{AMREX_D_DECL(1,0,2)});
} else {
return std::make_pair(&m_cx, IntVect{AMREX_D_DECL(0,1,2)});
}
}

template <typename T>
void R2C<T>::forward (MF const& inmf, cMF& outmf, Scaling scaling)
{
forward(inmf);
if (!m_cz.empty()) { // m_cz's ordering is z,x,y
amrex::Abort("xxxxx todo, forward m_cz");
} else if (!m_cy.empty()) { // m_cy's order (y,x,z) -> (x,y,z)
MultiBlockCommMetaData cmd
(outmf.boxArray(), outmf.DistributionMap(), m_spectral_domain_x,
m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2x);
ParallelCopy(outmf, m_cy, cmd, 0, 0, 1, m_dtos_y2x);
} else {
outmf.ParallelCopy(m_cx, 0, 0, 1);
}
}

template <typename T>
void R2C<T>::backward (cMF const& inmf, MF& outmf, Scaling scaling)
{
if (!m_cz.empty()) { // m_cz's ordering is z,x,y
amrex::Abort("xxxxx todo, backward m_cz");
} else if (!m_cy.empty()) { // (x,y,z) -> m_cy's ordering (y,x,z)
MultiBlockCommMetaData cmd
(m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y,
inmf.boxArray(), inmf.DistributionMap(), IntVect(0), m_dtos_x2y);
ParallelCopy(m_cy, inmf, cmd, 0, 0, 1, m_dtos_x2y);
} else {
m_cx.ParallelCopy(inmf, 0, 0, 1);
}
backward(outmf);
}

}

#endif
14 changes: 10 additions & 4 deletions Src/FFT/AMReX_FFT.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
#include <AMReX_FFT.H>
#include <algorithm>

namespace amrex::FFT
namespace amrex::FFT::detail
{

#ifdef AMREX_USE_HIP
namespace detail
DistributionMapping make_iota_distromap (Long n)
{
AMREX_ASSERT(n <= ParallelDescriptor::NProcs());
Vector<int> pm(n);
std::iota(pm.begin(), pm.end(), 0);
return DistributionMapping(std::move(pm));
}

#ifdef AMREX_USE_HIP
void hip_execute (rocfft_plan plan, void **in, void **out)
{
rocfft_execution_info execinfo = nullptr;
Expand All @@ -26,7 +33,6 @@ void hip_execute (rocfft_plan plan, void **in, void **out)

AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_destroy(execinfo));
}
}
#endif

}
51 changes: 4 additions & 47 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
#define AMREX_FFT_HELPER_H_
#include <AMReX_Config.H>

#include <AMReX.H>
#include <AMReX_Geometry.H>
#include <AMReX_Gpu.H>
#include <AMReX_GpuComplex.H>
#include <AMReX_Math.H>
#include <AMReX_DistributionMapping.H>

namespace amrex::FFT
{
Expand All @@ -24,49 +20,10 @@ struct Info
Info& setBatchMode (bool x) { batch_mode = x; return *this; }
};

template <typename T>
struct PoissonSpectral
namespace detail
{
PoissonSpectral (Geometry const& geom)
: fac({AMREX_D_DECL(T(2)*Math::pi<T>()/T(geom.ProbLength(0)),
T(2)*Math::pi<T>()/T(geom.ProbLength(1)),
T(2)*Math::pi<T>()/T(geom.ProbLength(2)))}),
dx({AMREX_D_DECL(T(geom.CellSize(0)),
T(geom.CellSize(1)),
T(geom.CellSize(2)))}),
scale(T(1.0/geom.Domain().d_numPts())),
len(geom.Domain().length())
{
static_assert(std::is_floating_point_v<T>);
}

AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void operator() (int i, int j, int k, GpuComplex<T>& spectral_data) const
{
amrex::ignore_unused(i,j,k);
// the values in the upper-half of the spectral array in y and z
// are here interpreted as negative wavenumbers
AMREX_D_TERM(T a = fac[0]*i;,
T b = (j < len[1]/2) ? fac[1]*j : fac[1]*(len[1]-j);,
T c = (k < len[2]/2) ? fac[2]*k : fac[2]*(len[2]-k));
T k2 = AMREX_D_TERM(T(2)*(std::cos(a*dx[0])-T(1))/(dx[0]*dx[0]),
+T(2)*(std::cos(b*dx[1])-T(1))/(dx[1]*dx[1]),
+T(2)*(std::cos(c*dx[2])-T(1))/(dx[2]*dx[2]));
if (k2 != T(0)) {
spectral_data /= k2;
} else {
// interpretation here is that the average value of the
// solution is zero
spectral_data = 0;
}
spectral_data *= scale;
}

GpuArray<T,AMREX_SPACEDIM> fac;
GpuArray<T,AMREX_SPACEDIM> dx;
T scale;
IntVect len;
};
DistributionMapping make_iota_distromap (Long n);
}

}

Expand Down
Loading

0 comments on commit e26cf17

Please sign in to comment.