forked from AMReX-Codes/amrex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add parallel FFT capability to AMReX. It relies on FFTW3, cuFFT, rocFFT and oneMKL, for CPU, CUDA, HIP and SYCL builds, respectively.
- Loading branch information
1 parent
62c2a81
commit 3611aa9
Showing
7 changed files
with
455 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
#ifndef AMREX_FFT_H_ | ||
#define AMREX_FFT_H_ | ||
#include <AMReX_Config.H> | ||
|
||
#include <AMReX_MultiFab.H> | ||
#include <fftw3.h> | ||
#include <utility> | ||
|
||
namespace amrex::FFT | ||
{ | ||
|
||
enum struct Scaling { full, symmetric, none }; | ||
|
||
class R2C | ||
{ | ||
public: | ||
R2C (Box const& domain); | ||
|
||
~R2C (); | ||
|
||
R2C (R2C const&) = delete; | ||
R2C (R2C &&) = delete; | ||
R2C& operator= (R2C const&) = delete; | ||
R2C& operator= (R2C &&) = delete; | ||
|
||
template <typename F> | ||
void forwardThenBackward (MultiFab const& inmf, MultiFab& outmf, | ||
F const& post_forward) | ||
{ | ||
this->forward_doit(inmf); | ||
this->post_forward_doit(post_forward); | ||
this->backward_doit(outmf); | ||
} | ||
|
||
struct Swap01 | ||
{ | ||
[[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 operator() (Dim3 i) const noexcept | ||
{ | ||
return {i.y, i.x, i.z}; | ||
} | ||
|
||
[[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 Inverse (Dim3 i) const noexcept | ||
{ | ||
return {i.y, i.x, i.z}; | ||
} | ||
|
||
[[nodiscard]] IndexType operator() (IndexType it) const noexcept | ||
{ | ||
return it; | ||
} | ||
|
||
[[nodiscard]] IndexType Inverse (IndexType it) const noexcept | ||
{ | ||
return it; | ||
} | ||
}; | ||
|
||
struct Swap02 | ||
{ | ||
[[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 operator() (Dim3 i) const noexcept | ||
{ | ||
return {i.z, i.y, i.x}; | ||
} | ||
|
||
[[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 Inverse (Dim3 i) const noexcept | ||
{ | ||
return {i.z, i.y, i.x}; | ||
} | ||
|
||
[[nodiscard]] IndexType operator() (IndexType it) const noexcept | ||
{ | ||
return it; | ||
} | ||
|
||
[[nodiscard]] IndexType Inverse (IndexType it) const noexcept | ||
{ | ||
return it; | ||
} | ||
}; | ||
|
||
public: // public for cuda | ||
|
||
template <typename F> | ||
void post_forward_doit (F const& post_forward); | ||
|
||
private: | ||
|
||
void forward_doit (MultiFab const& inmf, Scaling scaling = Scaling::none); | ||
void backward_doit (MultiFab& outmf, Scaling scaling = Scaling::none); | ||
|
||
static std::pair<fftw_plan,fftw_plan> | ||
make_c2c_plans (cMultiFab& inout, Box const& spectral_domain); | ||
|
||
Box m_real_domain; | ||
Box m_spectral_domain_x; | ||
Box m_spectral_domain_y; | ||
Box m_spectral_domain_z; | ||
|
||
// assuming it's double for now | ||
fftw_plan m_fftw_fwd_x = nullptr; | ||
fftw_plan m_fftw_bwd_x = nullptr; | ||
fftw_plan m_fftw_fwd_y = nullptr; | ||
fftw_plan m_fftw_bwd_y = nullptr; | ||
fftw_plan m_fftw_fwd_z = nullptr; | ||
fftw_plan m_fftw_bwd_z = nullptr; | ||
|
||
// Comm meta-data. In the forward phase, we start with (x,y,z), | ||
// transpose to (y,x,z) and then (z,x,y). In the backward phase, we | ||
// perform inverse transpose. | ||
Swap01 m_dtos_x2y{}; | ||
std::unique_ptr<MultiBlockCommMetaData> m_cmd_x2y; | ||
// | ||
Swap01 m_dtos_y2x{}; | ||
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2x; | ||
// | ||
Swap02 m_dtos_y2z{}; | ||
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2z; | ||
// | ||
Swap02 m_dtos_z2y{}; | ||
std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2y; | ||
|
||
// Optionally we need to copy from m_cz to user provided cMultiFab. xxxxx todo | ||
|
||
MultiFab m_rx; | ||
cMultiFab m_cx; | ||
cMultiFab m_cy; | ||
cMultiFab m_cz; | ||
}; | ||
|
||
template <typename F> | ||
void R2C::post_forward_doit (F const& post_forward) | ||
{ | ||
auto& spectral_fab = m_cz[ParallelDescriptor::MyProc()]; | ||
auto const& a = spectral_fab.array(); // m_cz's ordering is z,x,y | ||
ParallelFor(spectral_fab.box(), [=] AMREX_GPU_DEVICE (int iz, int jx, int ky) | ||
{ | ||
post_forward(jx,ky,iz,a(iz,jx,ky)); | ||
}); | ||
} | ||
|
||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
#include <AMReX_FFT.H> | ||
#include <numeric> | ||
#include <tuple> | ||
|
||
namespace amrex::FFT | ||
{ | ||
|
||
R2C::R2C (Box const& domain) | ||
: m_real_domain(domain), | ||
m_spectral_domain_x(IntVect(0), IntVect(AMREX_D_DECL(domain.length(0)/2, | ||
domain.bigEnd(1), | ||
domain.bigEnd(2)))), | ||
m_spectral_domain_y(IntVect(0), IntVect(AMREX_D_DECL(domain.bigEnd(1), | ||
domain.length(0)/2, | ||
domain.bigEnd(2)))), | ||
m_spectral_domain_z(IntVect(0), IntVect(AMREX_D_DECL(domain.bigEnd(2), | ||
domain.length(0)/2, | ||
domain.bigEnd(1)))) | ||
{ | ||
AMREX_ALWAYS_ASSERT(domain.smallEnd() == 0); | ||
AMREX_ALWAYS_ASSERT((std::is_same_v<double,Real>)); // xxxxx todo | ||
|
||
int myproc = ParallelDescriptor::MyProc(); | ||
int nprocs = ParallelDescriptor::NProcs(); | ||
|
||
// xxxxx todo: need to handle cases there are more processes than 2d cells | ||
// xxxxx todo: 1d & 2d | ||
|
||
auto bax = amrex::decompose(m_real_domain, nprocs, {false,true,true}); | ||
Vector<int> pmx(bax.size()); | ||
std::iota(pmx.begin(), pmx.end(), 0); | ||
DistributionMapping dmx(std::move(pmx)); | ||
m_rx.define(bax, dmx, 1, 0); | ||
|
||
auto cbax = amrex::decompose(m_spectral_domain_x, nprocs, {false,true,true}); | ||
m_cx.define(cbax, dmx, 1, 0); | ||
|
||
// plans for x-direction | ||
{ | ||
double* in = m_rx[myproc].dataPtr(); | ||
auto* out = (fftw_complex*)(m_cx[myproc].dataPtr()); | ||
|
||
int n = m_real_domain.length(0); | ||
int howmany = m_real_domain.length(1) * m_real_domain.length(2); | ||
|
||
m_fftw_fwd_x = fftw_plan_many_dft_r2c | ||
(1, &n, howmany, in, nullptr, 1, m_real_domain.length(0), | ||
out, nullptr, 1, m_spectral_domain_x.length(0), | ||
FFTW_ESTIMATE | FFTW_DESTROY_INPUT); | ||
|
||
m_fftw_bwd_x = fftw_plan_many_dft_c2r | ||
(1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0), | ||
in, nullptr, 1, m_real_domain.length(0), | ||
FFTW_ESTIMATE | FFTW_DESTROY_INPUT); | ||
} | ||
|
||
auto cbay = amrex::decompose(m_spectral_domain_y, nprocs, {false,true,true}); | ||
DistributionMapping cdmy = dmx; // xxxxx todo | ||
m_cy.define(cbay, cdmy, 1, 0); | ||
|
||
std::tie(m_fftw_fwd_y, m_fftw_bwd_y) = make_c2c_plans(m_cy, m_spectral_domain_y); | ||
|
||
// comm meta-data between x and y phases | ||
m_cmd_x2y = std::make_unique<MultiBlockCommMetaData> | ||
(m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y, | ||
m_cx.boxArray(), m_cx.DistributionMap(), IntVect(0), m_dtos_x2y); | ||
m_cmd_y2x = std::make_unique<MultiBlockCommMetaData> | ||
(m_cx.boxArray(), m_cx.DistributionMap(), m_spectral_domain_x, | ||
m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2x); | ||
|
||
auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs, {false,true,true}); | ||
DistributionMapping cdmz = dmx; // xxxxx todo | ||
m_cz.define(cbaz, cdmz, 1, 0); | ||
|
||
std::tie(m_fftw_fwd_z, m_fftw_bwd_z) = make_c2c_plans(m_cz, m_spectral_domain_z); | ||
|
||
// comm meta-data between y and z phases | ||
m_cmd_y2z = std::make_unique<MultiBlockCommMetaData> | ||
(m_cz.boxArray(), m_cz.DistributionMap(), m_spectral_domain_z, | ||
m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2z); | ||
m_cmd_z2y = std::make_unique<MultiBlockCommMetaData> | ||
(m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y, | ||
m_cz.boxArray(), m_cz.DistributionMap(), IntVect(0), m_dtos_z2y); | ||
} | ||
|
||
R2C::~R2C () | ||
{ | ||
fftw_destroy_plan(m_fftw_fwd_x); | ||
fftw_destroy_plan(m_fftw_bwd_x); | ||
fftw_destroy_plan(m_fftw_fwd_y); | ||
fftw_destroy_plan(m_fftw_bwd_y); | ||
fftw_destroy_plan(m_fftw_fwd_z); | ||
fftw_destroy_plan(m_fftw_bwd_z); | ||
} | ||
|
||
void R2C::forward_doit (MultiFab const& inmf, Scaling scaling) | ||
{ | ||
m_rx.ParallelCopy(inmf, 0, 0, 1); | ||
fftw_execute(m_fftw_fwd_x); | ||
|
||
ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y); | ||
fftw_execute(m_fftw_fwd_y); | ||
|
||
ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z); | ||
fftw_execute(m_fftw_fwd_z); | ||
} | ||
|
||
void R2C::backward_doit (MultiFab& outmf, Scaling scaling) | ||
{ | ||
// xxxxx todo: scaling | ||
|
||
fftw_execute(m_fftw_bwd_z); | ||
ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y); | ||
|
||
fftw_execute(m_fftw_bwd_y); | ||
ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x); | ||
|
||
fftw_execute(m_fftw_bwd_x); | ||
outmf.ParallelCopy(m_rx, 0, 0, 1); | ||
} | ||
|
||
std::pair<fftw_plan,fftw_plan> | ||
R2C::make_c2c_plans (cMultiFab& inout, Box const& spectral_domain) | ||
{ | ||
auto* pinout = (fftw_complex*)(inout[ParallelDescriptor::MyProc()].dataPtr()); | ||
|
||
int n = spectral_domain.length(0); | ||
int howmany = spectral_domain.length(1) * spectral_domain.length(2); | ||
|
||
fftw_plan fwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, | ||
pinout, nullptr, 1, n, -1, FFTW_ESTIMATE); | ||
fftw_plan bwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, | ||
pinout, nullptr, 1, n, +1, FFTW_ESTIMATE); | ||
return {fwd,bwd}; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
CEXE_headers += AMReX_FFT.H | ||
CEXE_sources += AMReX_FFT.cpp | ||
|
||
VPATH_LOCATIONS += $(AMREX_HOME)/Src/FFT | ||
INCLUDE_LOCATIONS += $(AMREX_HOME)/Src/FFT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
AMREX_HOME := ../../.. | ||
|
||
DEBUG = TRUE | ||
|
||
DIM = 3 | ||
|
||
COMP = gcc | ||
|
||
USE_MPI = FALSE | ||
USE_OMP = FALSE | ||
USE_CUDA = FALSE | ||
USE_HIP = FALSE | ||
USE_SYCL = FALSE | ||
|
||
BL_NO_FORT = TRUE | ||
|
||
TINY_PROFILE = FALSE | ||
|
||
include $(AMREX_HOME)/Tools/GNUMake/Make.defs | ||
|
||
include ./Make.package | ||
include $(AMREX_HOME)/Src/Base/Make.package | ||
include $(AMREX_HOME)/Src/FFT/Make.package | ||
|
||
ifeq ($(USE_HEFFTE),TRUE) | ||
HEFFTE_HOME ?= ../../../../heffte/build | ||
VPATH_LOCATIONS += $(HEFFTE_HOME)/include | ||
INCLUDE_LOCATIONS += $(HEFFTE_HOME)/include | ||
LIBRARY_LOCATIONS += $(HEFFTE_HOME)/lib | ||
libraries += -lheffte | ||
ifneq ($(USE_GPU),TRUE) | ||
libraries += -lfftw3_mpi | ||
endif | ||
endif | ||
|
||
ifeq ($(USE_CUDA),TRUE) | ||
libraries += -lcufft | ||
else ifeq ($(USE_HIP),TRUE) | ||
# Use rocFFT. ROC_PATH is defined in amrex | ||
INCLUDE_LOCATIONS += $(ROC_PATH)/rocfft/include | ||
LIBRARY_LOCATIONS += $(ROC_PATH)/rocfft/lib | ||
LIBRARIES += -L$(ROC_PATH)/rocfft/lib -lrocfft | ||
else | ||
libraries += -lfftw3 | ||
endif | ||
|
||
include $(AMREX_HOME)/Tools/GNUMake/Make.rules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
CEXE_sources += main.cpp |
Oops, something went wrong.