From e26cf175503f2bb089ae0556931d50f92226d272 Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Fri, 18 Oct 2024 14:39:17 -0700 Subject: [PATCH] Add FFT Poisson solvers --- Src/FFT/AMReX_FFT.H | 93 ++++++++++---- Src/FFT/AMReX_FFT.cpp | 14 +- Src/FFT/AMReX_FFT_Helper.H | 51 +------- Src/FFT/AMReX_FFT_Poisson.H | 248 ++++++++++++++++++++++++++++++++++++ Src/FFT/CMakeLists.txt | 1 + Src/FFT/Make.package | 2 +- Tests/FFT/Poisson/main.cpp | 49 ++++--- 7 files changed, 367 insertions(+), 91 deletions(-) create mode 100644 Src/FFT/AMReX_FFT_Poisson.H diff --git a/Src/FFT/AMReX_FFT.H b/Src/FFT/AMReX_FFT.H index dda2d151b0..57819472cb 100644 --- a/Src/FFT/AMReX_FFT.H +++ b/Src/FFT/AMReX_FFT.H @@ -47,11 +47,19 @@ public: template void forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward) { - this->forward_doit(inmf); + this->forward(inmf); this->post_forward_doit(post_forward); - this->backward_doit(outmf); + this->backward(outmf); } + void forward (MF const& inmf, Scaling scaling = Scaling::none); + void forward (MF const& inmf, cMF& outmf, Scaling scaling = Scaling::none); + + void backward (MF& outmf, Scaling scaling = Scaling::none); + void backward (cMF const& inmf, MF& outmf, Scaling scaling = Scaling::none); + + std::pair getSpectralData (); + struct Swap01 { [[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 operator() (Dim3 i) const noexcept @@ -153,9 +161,6 @@ private: } } - void forward_doit (MF const& inmf, Scaling scaling = Scaling::none); - void backward_doit (MF& outmf, Scaling scaling = Scaling::none); - static void exec_r2c (Plan plan, MF& in, cMF& out); static void exec_c2r (Plan plan, cMF& in, MF& out); template @@ -175,10 +180,10 @@ private: // Comm meta-data. In the forward phase, we start with (x,y,z), // transpose to (y,x,z) and then (z,x,y). In the backward phase, we // perform inverse transpose. - std::unique_ptr m_cmd_x2y; - std::unique_ptr m_cmd_y2x; - std::unique_ptr m_cmd_y2z; - std::unique_ptr m_cmd_z2y; + std::unique_ptr m_cmd_x2y; // (x,y,z) -> (y,x,z) + std::unique_ptr m_cmd_y2x; // (y,x,z) -> (x,y,z) + std::unique_ptr m_cmd_y2z; // (y,x,z) -> (z,x,y) + std::unique_ptr m_cmd_z2y; // (z,x,y) -> (y,x,z) Swap01 m_dtos_x2y{}; Swap01 m_dtos_y2x{}; Swap02 m_dtos_y2z{}; @@ -232,12 +237,7 @@ R2C::R2C (Box const& domain, Info const& info) int nprocs = ParallelDescriptor::NProcs(); auto bax = amrex::decompose(m_real_domain, nprocs, {AMREX_D_DECL(false,true,true)}); - DistributionMapping dmx; - { - Vector pm(bax.size()); - std::iota(pm.begin(), pm.end(), 0); - dmx.define(std::move(pm)); - } + DistributionMapping dmx = detail::make_iota_distromap(bax.size()); m_rx.define(bax, dmx, 1, 0); { @@ -346,9 +346,7 @@ R2C::R2C (Box const& domain, Info const& info) if (cbay.size() == dmx.size()) { cdmy = dmx; } else { - Vector pm(cbay.size()); - std::iota(pm.begin(), pm.end(), 0); - cdmy.define(std::move(pm)); + cdmy = detail::make_iota_distromap(cbay.size()); } m_cy.define(cbay, cdmy, 1, 0); @@ -365,7 +363,7 @@ R2C::R2C (Box const& domain, Info const& info) #if (AMREX_SPACEDIM == 3) if (m_real_domain.length(1) > 1 && - (! m_info.batch_mode || m_real_domain.length(2) > 1)) + (! m_info.batch_mode && m_real_domain.length(2) > 1)) { auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs, {false,true,true}); DistributionMapping cdmz; @@ -374,9 +372,7 @@ R2C::R2C (Box const& domain, Info const& info) } else if (cbaz.size() == cdmy.size()) { cdmz = cdmy; } else { - Vector pm(cbaz.size()); - std::iota(pm.begin(), pm.end(), 0); - cdmz.define(std::move(pm)); + cdmz = detail::make_iota_distromap(cbaz.size()); } m_cz.define(cbaz, cdmz, 1, 0); @@ -563,8 +559,10 @@ void R2C::exec_c2c (Plan2 plan, cMF& inout) } template -void R2C::forward_doit (MF const& inmf, Scaling /*scaling*/) +void R2C::forward (MF const& inmf, Scaling scaling) { + AMREX_ALWAYS_ASSERT(scaling == Scaling::none); // xxxxx TODO + m_rx.ParallelCopy(inmf, 0, 0, 1); exec_r2c(m_fft_fwd_x, m_rx, m_cx); @@ -580,8 +578,10 @@ void R2C::forward_doit (MF const& inmf, Scaling /*scaling*/) } template -void R2C::backward_doit (MF& outmf, Scaling /*scaling*/) +void R2C::backward (MF& outmf, Scaling scaling) { + AMREX_ALWAYS_ASSERT(scaling == Scaling::none); // xxxxx TODO + exec_c2c(m_fft_bwd_z, m_cz); if ( m_cmd_z2y) { ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y); @@ -716,6 +716,51 @@ void R2C::post_forward_doit (F const& post_forward) } } +template +std::pair::cMF *, IntVect> +R2C::getSpectralData () +{ + if (!m_cz.empty()) { + return std::make_pair(&m_cz, IntVect{AMREX_D_DECL(2,0,1)}); + } else if (!m_cy.empty()) { + return std::make_pair(&m_cy, IntVect{AMREX_D_DECL(1,0,2)}); + } else { + return std::make_pair(&m_cx, IntVect{AMREX_D_DECL(0,1,2)}); + } +} + +template +void R2C::forward (MF const& inmf, cMF& outmf, Scaling scaling) +{ + forward(inmf); + if (!m_cz.empty()) { // m_cz's ordering is z,x,y + amrex::Abort("xxxxx todo, forward m_cz"); + } else if (!m_cy.empty()) { // m_cy's order (y,x,z) -> (x,y,z) + MultiBlockCommMetaData cmd + (outmf.boxArray(), outmf.DistributionMap(), m_spectral_domain_x, + m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2x); + ParallelCopy(outmf, m_cy, cmd, 0, 0, 1, m_dtos_y2x); + } else { + outmf.ParallelCopy(m_cx, 0, 0, 1); + } +} + +template +void R2C::backward (cMF const& inmf, MF& outmf, Scaling scaling) +{ + if (!m_cz.empty()) { // m_cz's ordering is z,x,y + amrex::Abort("xxxxx todo, backward m_cz"); + } else if (!m_cy.empty()) { // (x,y,z) -> m_cy's ordering (y,x,z) + MultiBlockCommMetaData cmd + (m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y, + inmf.boxArray(), inmf.DistributionMap(), IntVect(0), m_dtos_x2y); + ParallelCopy(m_cy, inmf, cmd, 0, 0, 1, m_dtos_x2y); + } else { + m_cx.ParallelCopy(inmf, 0, 0, 1); + } + backward(outmf); +} + } #endif diff --git a/Src/FFT/AMReX_FFT.cpp b/Src/FFT/AMReX_FFT.cpp index e6204454f7..c89de51a0b 100644 --- a/Src/FFT/AMReX_FFT.cpp +++ b/Src/FFT/AMReX_FFT.cpp @@ -1,11 +1,18 @@ #include +#include -namespace amrex::FFT +namespace amrex::FFT::detail { -#ifdef AMREX_USE_HIP -namespace detail +DistributionMapping make_iota_distromap (Long n) { + AMREX_ASSERT(n <= ParallelDescriptor::NProcs()); + Vector pm(n); + std::iota(pm.begin(), pm.end(), 0); + return DistributionMapping(std::move(pm)); +} + +#ifdef AMREX_USE_HIP void hip_execute (rocfft_plan plan, void **in, void **out) { rocfft_execution_info execinfo = nullptr; @@ -26,7 +33,6 @@ void hip_execute (rocfft_plan plan, void **in, void **out) AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_destroy(execinfo)); } -} #endif } diff --git a/Src/FFT/AMReX_FFT_Helper.H b/Src/FFT/AMReX_FFT_Helper.H index 5152f0eef9..a880af3b06 100644 --- a/Src/FFT/AMReX_FFT_Helper.H +++ b/Src/FFT/AMReX_FFT_Helper.H @@ -2,11 +2,7 @@ #define AMREX_FFT_HELPER_H_ #include -#include -#include -#include -#include -#include +#include namespace amrex::FFT { @@ -24,49 +20,10 @@ struct Info Info& setBatchMode (bool x) { batch_mode = x; return *this; } }; -template -struct PoissonSpectral +namespace detail { - PoissonSpectral (Geometry const& geom) - : fac({AMREX_D_DECL(T(2)*Math::pi()/T(geom.ProbLength(0)), - T(2)*Math::pi()/T(geom.ProbLength(1)), - T(2)*Math::pi()/T(geom.ProbLength(2)))}), - dx({AMREX_D_DECL(T(geom.CellSize(0)), - T(geom.CellSize(1)), - T(geom.CellSize(2)))}), - scale(T(1.0/geom.Domain().d_numPts())), - len(geom.Domain().length()) - { - static_assert(std::is_floating_point_v); - } - - AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE - void operator() (int i, int j, int k, GpuComplex& spectral_data) const - { - amrex::ignore_unused(i,j,k); - // the values in the upper-half of the spectral array in y and z - // are here interpreted as negative wavenumbers - AMREX_D_TERM(T a = fac[0]*i;, - T b = (j < len[1]/2) ? fac[1]*j : fac[1]*(len[1]-j);, - T c = (k < len[2]/2) ? fac[2]*k : fac[2]*(len[2]-k)); - T k2 = AMREX_D_TERM(T(2)*(std::cos(a*dx[0])-T(1))/(dx[0]*dx[0]), - +T(2)*(std::cos(b*dx[1])-T(1))/(dx[1]*dx[1]), - +T(2)*(std::cos(c*dx[2])-T(1))/(dx[2]*dx[2])); - if (k2 != T(0)) { - spectral_data /= k2; - } else { - // interpretation here is that the average value of the - // solution is zero - spectral_data = 0; - } - spectral_data *= scale; - } - - GpuArray fac; - GpuArray dx; - T scale; - IntVect len; -}; + DistributionMapping make_iota_distromap (Long n); +} } diff --git a/Src/FFT/AMReX_FFT_Poisson.H b/Src/FFT/AMReX_FFT_Poisson.H new file mode 100644 index 0000000000..e1df441038 --- /dev/null +++ b/Src/FFT/AMReX_FFT_Poisson.H @@ -0,0 +1,248 @@ +#ifndef AMREX_FFT_POISSON_H_ +#define AMREX_FFT_POISSON_H_ + +#include +#include + +namespace amrex::FFT +{ + +template +class Poisson +{ +public: + + template ,int> = 0> + explicit Poisson (Geometry const& geom) + : m_geom(geom), m_r2c(geom.Domain()) + { + AMREX_ALWAYS_ASSERT(geom.isAllPeriodic()); + } + + void solve (MF& soln, MF const& rhs); + +private: + Geometry m_geom; + R2C m_r2c; +}; + +template +class PoissonHybrid +{ +public: + + template ,int> = 0> + explicit PoissonHybrid (Geometry const& geom) + : m_geom(geom), m_r2c(geom.Domain(), Info().setBatchMode(true)) + { +#if (AMREX_SPACEDIM == 3) + AMREX_ALWAYS_ASSERT(geom.isPeriodic(0) && geom.isPeriodic(1)); +#else + amrex::Abort("FFT::PoissonHybrid: 1D & 2D todo"); +#endif + } + + void solve (MF& soln, MF const& rhs); + +private: + Geometry m_geom; + R2C m_r2c; +}; + +template +void Poisson::solve (MF& soln, MF const& rhs) +{ + using T = typename MF::value_type; + + GpuArray fac + {AMREX_D_DECL(T(2)*Math::pi()/T(m_geom.ProbLength(0)), + T(2)*Math::pi()/T(m_geom.ProbLength(1)), + T(2)*Math::pi()/T(m_geom.ProbLength(2)))}; + GpuArray dx + {AMREX_D_DECL(T(m_geom.CellSize(0)), + T(m_geom.CellSize(1)), + T(m_geom.CellSize(2)))}; + auto scale = T(1.0/m_geom.Domain().d_numPts()); + auto const& len = m_geom.Domain().length(); + + m_r2c.forwardThenBackward(rhs, soln, + [=] AMREX_GPU_DEVICE (int i, int j, int k, + GpuComplex& spectral_data) + { + amrex::ignore_unused(i,j,k); + // the values in the upper-half of the spectral array in y and z + // are here interpreted as negative wavenumbers + AMREX_D_TERM(T a = fac[0]*i;, + T b = (j < len[1]/2) ? fac[1]*j : fac[1]*(len[1]-j);, + T c = (k < len[2]/2) ? fac[2]*k : fac[2]*(len[2]-k)); + T k2 = AMREX_D_TERM(T(2)*(std::cos(a*dx[0])-T(1))/(dx[0]*dx[0]), + +T(2)*(std::cos(b*dx[1])-T(1))/(dx[1]*dx[1]), + +T(2)*(std::cos(c*dx[2])-T(1))/(dx[2]*dx[2])); + if (k2 != T(0)) { + spectral_data /= k2; + } else { + // interpretation here is that the average value of the + // solution is zero + spectral_data = 0; + } + spectral_data *= scale; + }); +} + +template +void PoissonHybrid::solve (MF& soln, MF const& rhs) +{ +#if (AMREX_SPACEDIM < 3) + amrex::ignore_unused(soln, rhs); +#else + using T = typename MF::value_type; + + auto facx = T(2)*Math::pi()/T(m_geom.ProbLength(0)); + auto facy = T(2)*Math::pi()/T(m_geom.ProbLength(1)); + auto dx = T(m_geom.CellSize(0)); + auto dy = T(m_geom.CellSize(1)); + auto scale = T(1.0)/(T(m_geom.Domain().length(0)) * + T(m_geom.Domain().length(1))); + auto ny = m_geom.Domain().length(1); + auto nz = m_geom.Domain().length(0); + + Gpu::DeviceVector delzv(nz, T(m_geom.CellSize(2))); + auto const* delz = delzv.data(); + + Box cdomain = m_geom.Domain(); + cdomain.setBig(0,cdomain.length(0)/2); + auto cba = amrex::decompose(cdomain, ParallelDescriptor::NProcs(), + {AMREX_D_DECL(true,true,false)}); + DistributionMapping dm = detail::make_iota_distromap(cba.size()); + FabArray > > spmf(cba, dm, 1, 0); + + m_r2c.forward(rhs, spmf); + + for (MFIter mfi(spmf); mfi.isValid(); ++mfi) + { + auto const& spectral = spmf.array(mfi); + auto const& box = mfi.validbox(); + auto const& xybox = amrex::makeSlab(box, 2, 0); + +#ifdef AMREX_USE_GPU + // xxxxx TODO: We need to explore how to optimize this + // function. Maybe we can use cusparse. Maybe we should make + // z-direction to be the unit stride direction. + + FArrayBox tridiag_workspace(box,4); + auto const& ald = tridiag_workspace.array(0); + auto const& bd = tridiag_workspace.array(1); + auto const& cud = tridiag_workspace.array(2); + auto const& scratch = tridiag_workspace.array(3); + + amrex::ParallelFor(xybox, [=] AMREX_GPU_DEVICE (int i, int j, int) + { + T a = facx*i; + T b = (j < ny/2) ? facy*j : facy*(ny-j); + + T k2 = T(2)*(std::cos(a*dx)-T(1))/(dx*dx) + + T(2)*(std::cos(b*dy)-T(1))/(dy*dy); + + // Tridiagonal solve with homogeneous Neumann + for(int k=0; k < nz; k++) { + if(k==0) { + ald(i,j,k) = 0.; + cud(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k+1])); + bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k); + } else if (k == nz-1) { + ald(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k-1])); + cud(i,j,k) = 0.; + bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k); + if (i == 0 && j == 0) { + bd(i,j,k) *= 2.0; + } + } else { + ald(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k-1])); + cud(i,j,k) = 2.0 /(delz[k]*(delz[k]+delz[k+1])); + bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k); + } + } + + scratch(i,j,0) = cud(i,j,0)/bd(i,j,0); + spectral(i,j,0) = spectral(i,j,0)/bd(i,j,0); + + for (int k = 1; k < nz; k++) { + if (k < nz-1){ + scratch(i,j,k) = cud(i,j,k) / (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1)); + } + spectral(i,j,k) = (spectral(i,j,k) - ald(i,j,k) * spectral(i,j,k - 1)) / (bd(i,j,k) - ald(i,j,k) * scratch(i,j,k-1)); + } + + for (int k = nz - 2; k >= 0; k--) { + spectral(i,j,k) -= scratch(i,j,k) * spectral(i,j,k + 1); + } + + for (int k = 0; k < nz; ++k) { + spectral(i,j,k) *= scale; + } + }); + Gpu::streamSynchronize(); + +#else + + Gpu::DeviceVector> ald(nz); + Gpu::DeviceVector> bd(nz); + Gpu::DeviceVector> cud(nz); + Gpu::DeviceVector> scratch(nz); + + amrex::LoopOnCpu(xybox, [&] (int i, int j, int) + { + T a = facx*i; + T b = (j < ny/2) ? facy*j : facy*(ny-j); + + T k2 = T(2)*(std::cos(a*dx)-T(1))/(dx*dx) + + T(2)*(std::cos(b*dy)-T(1))/(dy*dy); + + // Tridiagonal solve with homogeneous Neumann + for(int k=0; k < nz; k++) { + if(k==0) { + ald[k] = 0.; + cud[k] = 2.0 /(delz[k]*(delz[k]+delz[k+1])); + bd[k] = k2 -ald[k]-cud[k]; + } else if (k == nz-1) { + ald[k] = 2.0 /(delz[k]*(delz[k]+delz[k-1])); + cud[k] = 0.; + bd[k] = k2 -ald[k]-cud[k]; + if (i == 0 && j == 0) { + bd[k] *= 2.0; + } + } else { + ald[k] = 2.0 /(delz[k]*(delz[k]+delz[k-1])); + cud[k] = 2.0 /(delz[k]*(delz[k]+delz[k+1])); + bd[k] = k2 -ald[k]-cud[k]; + } + } + + scratch[0] = cud[0]/bd[0]; + spectral(i,j,0) = spectral(i,j,0)/bd[0]; + + for (int k = 1; k < nz; k++) { + if (k < nz-1){ + scratch[k] = cud[k] / (bd[k] - ald[k] * scratch[k-1]); + } + spectral(i,j,k) = (spectral(i,j,k) - ald[k] * spectral(i,j,k - 1)) / (bd[k] - ald[k] * scratch[k-1]); + } + + for (int k = nz - 2; k >= 0; k--) { + spectral(i,j,k) -= scratch[k] * spectral(i,j,k + 1); + } + + for (int k = 0; k < nz; ++k) { + spectral(i,j,k) *= scale; + } + }); +#endif + } + + m_r2c.backward(spmf, soln); +#endif +} + +} + +#endif diff --git a/Src/FFT/CMakeLists.txt b/Src/FFT/CMakeLists.txt index 89450d33b6..2c695a9aec 100644 --- a/Src/FFT/CMakeLists.txt +++ b/Src/FFT/CMakeLists.txt @@ -8,6 +8,7 @@ foreach(D IN LISTS AMReX_SPACEDIM) AMReX_FFT.H AMReX_FFT.cpp AMReX_FFT_Helper.H + AMReX_FFT_Poisson.H ) endforeach() diff --git a/Src/FFT/Make.package b/Src/FFT/Make.package index 1702840790..1dcd714f64 100644 --- a/Src/FFT/Make.package +++ b/Src/FFT/Make.package @@ -1,7 +1,7 @@ ifndef AMREX_FFT_MAKE AMREX_FFT_MAKE := 1 -CEXE_headers += AMReX_FFT.H AMReX_FFT_Helper.H +CEXE_headers += AMReX_FFT.H AMReX_FFT_Helper.H AMReX_FFT_Poisson.H CEXE_sources += AMReX_FFT.cpp VPATH_LOCATIONS += $(AMREX_HOME)/Src/FFT diff --git a/Tests/FFT/Poisson/main.cpp b/Tests/FFT/Poisson/main.cpp index 2800d8c4e9..c71bbc5f22 100644 --- a/Tests/FFT/Poisson/main.cpp +++ b/Tests/FFT/Poisson/main.cpp @@ -1,4 +1,4 @@ -#include // Put this at the top for testing +#include // Put this at the top for testing #include #include @@ -24,11 +24,14 @@ int main (int argc, char* argv[]) Real prob_hi_y = 1.;, Real prob_hi_z = 1.); + int solver_type = 0; + { ParmParse pp; AMREX_D_TERM(pp.query("n_cell_x", n_cell_x);, pp.query("n_cell_y", n_cell_y);, pp.query("n_cell_z", n_cell_z)); + pp.query("solver_type", solver_type); } Box domain(IntVect(0),IntVect(AMREX_D_DECL(n_cell_x-1,n_cell_y-1,n_cell_z-1))); @@ -60,22 +63,38 @@ int main (int argc, char* argv[]) auto rhosum = rhs.sum(0); rhs.plus(-rhosum/geom.Domain().d_numPts(), 0, 1); - auto t0 = amrex::second(); - - FFT::R2C fft(geom.Domain()); - FFT::PoissonSpectral post_forward(geom); - - auto t1 = amrex::second(); - - double tsolve; - for (int n = 0; n < 2; ++n) { - auto ta = amrex::second(); - fft.forwardThenBackward(rhs, soln, post_forward); - auto tb = amrex::second(); - tsolve = tb-ta; + double tsetup, tsolve; + +#if (AMREX_SPACEDIM == 3) + if (solver_type == 1) { + auto t0 = amrex::second(); + FFT::PoissonHybrid fft_poisson(geom); + auto t1 = amrex::second(); + tsetup = t1-t0; + + for (int n = 0; n < 2; ++n) { + auto ta = amrex::second(); + fft_poisson.solve(soln, rhs); + auto tb = amrex::second(); + tsolve = tb-ta; + } + } else +#endif + { + auto t0 = amrex::second(); + FFT::Poisson fft_poisson(geom); + auto t1 = amrex::second(); + tsetup = t1-t0; + + for (int n = 0; n < 2; ++n) { + auto ta = amrex::second(); + fft_poisson.solve(soln, rhs); + auto tb = amrex::second(); + tsolve = tb-ta; + } } - amrex::Print() << " AMReX FFT setup time: " << t1-t0 << ", solve time " + amrex::Print() << " AMReX FFT setup time: " << tsetup << ", solve time " << tsolve << "\n"; {