Skip to content

Commit

Permalink
make R2C a class template
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Oct 15, 2024
1 parent aa7d9a9 commit 4eb5f43
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 177 deletions.
265 changes: 247 additions & 18 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,39 @@
#include <AMReX_Config.H>

#include <AMReX_MultiFab.H>
#include <fftw3.h>
#include <numeric>
#include <tuple>
#include <utility>

#if defined(AMREX_USE_CUDA)
# include <cufft.h>
# include <cuComplex.h>
#elif defined(AMREX_USE_HIP)
# if __has_include(<rocfft/rocfft.h>) // ROCm 5.3+
# include <rocfft/rocfft.h>
# else
# include <rocfft.h>
# endif
# include <hip/hip_complex.h>
#elif defined(AMREX_USE_SYCL)
# include <oneapi/mkl/dfti.hpp>
#else
# include <fftw3.h>
#endif

namespace amrex::FFT
{

enum struct Scaling { full, symmetric, none };

template <typename T = Real>
class R2C
{
public:
using MF = std::conditional_t<std::is_same_v<T,Real>,
MultiFab, FabArray<BaseFab<T> > >;
using cMF = FabArray<BaseFab<GpuComplex<T> > >;

R2C (Box const& domain);

~R2C ();
Expand All @@ -24,8 +46,7 @@ public:
R2C& operator= (R2C &&) = delete;

template <typename F>
void forwardThenBackward (MultiFab const& inmf, MultiFab& outmf,
F const& post_forward)
void forwardThenBackward (MF const& inmf, MF& outmf, F const& post_forward)
{
this->forward_doit(inmf);
this->post_forward_doit(post_forward);
Expand Down Expand Up @@ -85,24 +106,45 @@ public: // public for cuda

private:

void forward_doit (MultiFab const& inmf, Scaling scaling = Scaling::none);
void backward_doit (MultiFab& outmf, Scaling scaling = Scaling::none);
void forward_doit (MF const& inmf, Scaling scaling = Scaling::none);
void backward_doit (MF& outmf, Scaling scaling = Scaling::none);

#if defined(AMREX_USE_CUDA)
using FFTPlan = cufftHandle;
using FFTComplex = std::conditional_t<std::is_same_v<float,T>,
cuComplex, cuDoubleComplex>;
#elif defined(AMREX_USE_HIP)
using FFTPlan = rocfft_plan;
using FFTComplex = std::conditional_t<std::is_same_v<float,T>,
float2, double2>;
#elif defined(AMREX_USE_SYCL)
using FFTPlan = oneapi::mkl::dft::descriptor<
std::conditional_t<std::is_same_v<float,T>,
oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::precision::DOUBLE>,
oneapi::mkl::dft::domain::REAL> *;
using FFTComplex = GpuComplex<T>;
#else
using FFTPlan = std::conditional_t<std::is_same_v<float,T>,
fftwf_plan, fftw_plan>;
using FFTComplex = std::conditional_t<std::is_same_v<float,T>,
fftwf_complex, fftw_complex>;
#endif

static std::pair<fftw_plan,fftw_plan>
make_c2c_plans (cMultiFab& inout);
static std::pair<FFTPlan,FFTPlan> make_c2c_plans (cMF& inout);

Box m_real_domain;
Box m_spectral_domain_x;
Box m_spectral_domain_y;
Box m_spectral_domain_z;

// assuming it's double for now
fftw_plan m_fftw_fwd_x = nullptr;
fftw_plan m_fftw_bwd_x = nullptr;
fftw_plan m_fftw_fwd_y = nullptr;
fftw_plan m_fftw_bwd_y = nullptr;
fftw_plan m_fftw_fwd_z = nullptr;
fftw_plan m_fftw_bwd_z = nullptr;
FFTPlan m_fft_fwd_x = nullptr;
FFTPlan m_fft_bwd_x = nullptr;
FFTPlan m_fft_fwd_y = nullptr;
FFTPlan m_fft_bwd_y = nullptr;
FFTPlan m_fft_fwd_z = nullptr;
FFTPlan m_fft_bwd_z = nullptr;

// Comm meta-data. In the forward phase, we start with (x,y,z),
// transpose to (y,x,z) and then (z,x,y). In the backward phase, we
Expand All @@ -121,14 +163,201 @@ private:

// Optionally we need to copy from m_cz to user provided cMultiFab. xxxxx todo

MultiFab m_rx;
cMultiFab m_cx;
cMultiFab m_cy;
cMultiFab m_cz;
MF m_rx;
cMF m_cx;
cMF m_cy;
cMF m_cz;
};

template <typename T>
R2C<T>::R2C (Box const& domain)
: m_real_domain(domain),
m_spectral_domain_x(IntVect(0), IntVect(AMREX_D_DECL(domain.length(0)/2,
domain.bigEnd(1),
domain.bigEnd(2)))),
m_spectral_domain_y(IntVect(0), IntVect(AMREX_D_DECL(domain.bigEnd(1),
domain.length(0)/2,
domain.bigEnd(2)))),
m_spectral_domain_z(IntVect(0), IntVect(AMREX_D_DECL(domain.bigEnd(2),
domain.length(0)/2,
domain.bigEnd(1))))
{
AMREX_ALWAYS_ASSERT(domain.smallEnd() == 0);
AMREX_ALWAYS_ASSERT((std::is_same_v<double,Real>)); // xxxxx todo

int myproc = ParallelDescriptor::MyProc();
int nprocs = ParallelDescriptor::NProcs();

// xxxxx todo: need to handle cases there are more processes than 2d cells
// xxxxx todo: 1d & 2d

auto bax = amrex::decompose(m_real_domain, nprocs, {false,true,true});
Vector<int> pmx(bax.size());
std::iota(pmx.begin(), pmx.end(), 0);
DistributionMapping dmx(std::move(pmx));
m_rx.define(bax, dmx, 1, 0);

{
BoxList bl = bax.boxList();
for (auto & b : bl) {
b.setBig(0, m_spectral_domain_x.bigEnd(0));
}
BoxArray cbax(std::move(bl));
m_cx.define(cbax, dmx, 1, 0);
}

// plans for x-direction
{
double* in = m_rx[myproc].dataPtr();
auto* out = (FFTComplex*)(m_cx[myproc].dataPtr());

Box const local_box = m_rx.boxArray()[myproc];
int n = local_box.length(0);
int howmany = local_box.length(1) * local_box.length(2);

#if defined(AMREX_USE_CUDA)

static_assert(false);

#elif defined(AMREX_USE_HIP)

static_assert(false);

#elif defined(AMREX_USE_SYCL)

static_assert(false);

#else
m_fft_fwd_x = fftw_plan_many_dft_r2c
(1, &n, howmany, in, nullptr, 1, m_real_domain.length(0),
out, nullptr, 1, m_spectral_domain_x.length(0),
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);

m_fft_bwd_x = fftw_plan_many_dft_c2r
(1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0),
in, nullptr, 1, m_real_domain.length(0),
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
#endif
}

auto cbay = amrex::decompose(m_spectral_domain_y, nprocs, {false,true,true});
DistributionMapping cdmy = dmx; // xxxxx todo
m_cy.define(cbay, cdmy, 1, 0);

std::tie(m_fft_fwd_y, m_fft_bwd_y) = make_c2c_plans(m_cy);

// comm meta-data between x and y phases
m_cmd_x2y = std::make_unique<MultiBlockCommMetaData>
(m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y,
m_cx.boxArray(), m_cx.DistributionMap(), IntVect(0), m_dtos_x2y);
m_cmd_y2x = std::make_unique<MultiBlockCommMetaData>
(m_cx.boxArray(), m_cx.DistributionMap(), m_spectral_domain_x,
m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2x);

auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs, {false,true,true});
DistributionMapping cdmz = dmx; // xxxxx todo
m_cz.define(cbaz, cdmz, 1, 0);

std::tie(m_fft_fwd_z, m_fft_bwd_z) = make_c2c_plans(m_cz);

// comm meta-data between y and z phases
m_cmd_y2z = std::make_unique<MultiBlockCommMetaData>
(m_cz.boxArray(), m_cz.DistributionMap(), m_spectral_domain_z,
m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2z);
m_cmd_z2y = std::make_unique<MultiBlockCommMetaData>
(m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y,
m_cz.boxArray(), m_cz.DistributionMap(), IntVect(0), m_dtos_z2y);
}

template <typename T>
R2C<T>::~R2C<T> ()
{
#if defined(AMREX_USE_CUDA)

static_assert(false);

#elif defined(AMREX_USE_HIP)

static_assert(false);

#elif defined(AMREX_USE_SYCL)

static_assert(false);

#else
fftw_destroy_plan(m_fft_fwd_x);
fftw_destroy_plan(m_fft_bwd_x);
fftw_destroy_plan(m_fft_fwd_y);
fftw_destroy_plan(m_fft_bwd_y);
fftw_destroy_plan(m_fft_fwd_z);
fftw_destroy_plan(m_fft_bwd_z);
#endif
}

template <typename T>
void R2C<T>::forward_doit (MF const& inmf, Scaling scaling)
{
m_rx.ParallelCopy(inmf, 0, 0, 1);
fftw_execute(m_fft_fwd_x);

ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y);
fftw_execute(m_fft_fwd_y);

ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z);
fftw_execute(m_fft_fwd_z);
}

template <typename T>
void R2C<T>::backward_doit (MF& outmf, Scaling scaling)
{
// xxxxx todo: scaling

fftw_execute(m_fft_bwd_z);
ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y);

fftw_execute(m_fft_bwd_y);
ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x);

fftw_execute(m_fft_bwd_x);
outmf.ParallelCopy(m_rx, 0, 0, 1);
}

template <typename T>
std::pair<typename R2C<T>::FFTPlan, typename R2C<T>::FFTPlan>
R2C<T>::make_c2c_plans (cMF& inout)
{
auto& fab = inout[ParallelDescriptor::MyProc()];
Box const& local_box = fab.box();
auto* pinout = (FFTComplex*)fab.dataPtr();

int n = local_box.length(0);
int howmany = local_box.length(1) * local_box.length(2);

#if defined(AMREX_USE_CUDA)

static_assert(false);

#elif defined(AMREX_USE_HIP)

static_assert(false);

#elif defined(AMREX_USE_SYCL)

static_assert(false);

#else
auto* fwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n,
pinout, nullptr, 1, n, -1, FFTW_ESTIMATE);
auto* bwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n,
pinout, nullptr, 1, n, +1, FFTW_ESTIMATE);
#endif

return {fwd,bwd};
}

template <typename T>
template <typename F>
void R2C::post_forward_doit (F const& post_forward)
void R2C<T>::post_forward_doit (F const& post_forward)
{
auto& spectral_fab = m_cz[ParallelDescriptor::MyProc()];
auto const& a = spectral_fab.array(); // m_cz's ordering is z,x,y
Expand Down
Loading

0 comments on commit 4eb5f43

Please sign in to comment.