Skip to content

Commit

Permalink
amrex::FFT
Browse files Browse the repository at this point in the history
Add parallel FFT capability to AMReX. It relies on FFTW3, cuFFT, rocFFT and
oneMKL, for CPU, CUDA, HIP and SYCL builds, respectively.
  • Loading branch information
WeiqunZhang committed Oct 15, 2024
1 parent 62c2a81 commit 3611aa9
Show file tree
Hide file tree
Showing 7 changed files with 455 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Src/Base/AMReX_FabArray.H
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <AMReX_FabFactory.H>
#include <AMReX_DistributionMapping.H>
#include <AMReX_Geometry.H>
#include <AMReX_GpuComplex.H>
#include <AMReX_ParallelDescriptor.H>
#include <AMReX_Utility.H>
#include <AMReX_ccse-mpi.H>
Expand Down Expand Up @@ -3679,6 +3680,8 @@ FabArray<FAB>::norminf (FabArray<IFAB> const& mask, int comp, int ncomp,
return nm0;
}

using cMultiFab = FabArray<BaseFab<GpuComplex<Real> > >;

}

#endif /*BL_FABARRAY_H*/
143 changes: 143 additions & 0 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#ifndef AMREX_FFT_H_
#define AMREX_FFT_H_
#include <AMReX_Config.H>

#include <AMReX_MultiFab.H>
#include <fftw3.h>
#include <utility>

namespace amrex::FFT
{

enum struct Scaling { full, symmetric, none };

class R2C
{
public:
R2C (Box const& domain);

~R2C ();

R2C (R2C const&) = delete;
R2C (R2C &&) = delete;
R2C& operator= (R2C const&) = delete;
R2C& operator= (R2C &&) = delete;

template <typename F>
void forwardThenBackward (MultiFab const& inmf, MultiFab& outmf,
F const& post_forward)
{
this->forward_doit(inmf);
this->post_forward_doit(post_forward);
this->backward_doit(outmf);
}

struct Swap01
{
[[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 operator() (Dim3 i) const noexcept
{
return {i.y, i.x, i.z};
}

[[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 Inverse (Dim3 i) const noexcept
{
return {i.y, i.x, i.z};
}

[[nodiscard]] IndexType operator() (IndexType it) const noexcept
{
return it;
}

[[nodiscard]] IndexType Inverse (IndexType it) const noexcept
{
return it;
}
};

struct Swap02
{
[[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 operator() (Dim3 i) const noexcept
{
return {i.z, i.y, i.x};
}

[[nodiscard]] AMREX_GPU_HOST_DEVICE Dim3 Inverse (Dim3 i) const noexcept
{
return {i.z, i.y, i.x};
}

[[nodiscard]] IndexType operator() (IndexType it) const noexcept
{
return it;
}

[[nodiscard]] IndexType Inverse (IndexType it) const noexcept
{
return it;
}
};

public: // public for cuda

template <typename F>
void post_forward_doit (F const& post_forward);

private:

void forward_doit (MultiFab const& inmf, Scaling scaling = Scaling::none);
void backward_doit (MultiFab& outmf, Scaling scaling = Scaling::none);

static std::pair<fftw_plan,fftw_plan>
make_c2c_plans (cMultiFab& inout, Box const& spectral_domain);

Box m_real_domain;
Box m_spectral_domain_x;
Box m_spectral_domain_y;
Box m_spectral_domain_z;

// assuming it's double for now
fftw_plan m_fftw_fwd_x = nullptr;
fftw_plan m_fftw_bwd_x = nullptr;
fftw_plan m_fftw_fwd_y = nullptr;
fftw_plan m_fftw_bwd_y = nullptr;
fftw_plan m_fftw_fwd_z = nullptr;
fftw_plan m_fftw_bwd_z = nullptr;

// Comm meta-data. In the forward phase, we start with (x,y,z),
// transpose to (y,x,z) and then (z,x,y). In the backward phase, we
// perform inverse transpose.
Swap01 m_dtos_x2y{};
std::unique_ptr<MultiBlockCommMetaData> m_cmd_x2y;
//
Swap01 m_dtos_y2x{};
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2x;
//
Swap02 m_dtos_y2z{};
std::unique_ptr<MultiBlockCommMetaData> m_cmd_y2z;
//
Swap02 m_dtos_z2y{};
std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2y;

// Optionally we need to copy from m_cz to user provided cMultiFab. xxxxx todo

MultiFab m_rx;
cMultiFab m_cx;
cMultiFab m_cy;
cMultiFab m_cz;
};

template <typename F>
void R2C::post_forward_doit (F const& post_forward)
{
auto& spectral_fab = m_cz[ParallelDescriptor::MyProc()];
auto const& a = spectral_fab.array(); // m_cz's ordering is z,x,y
ParallelFor(spectral_fab.box(), [=] AMREX_GPU_DEVICE (int iz, int jx, int ky)
{
post_forward(jx,ky,iz,a(iz,jx,ky));
});
}

}

#endif
137 changes: 137 additions & 0 deletions Src/FFT/AMReX_FFT.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#include <AMReX_FFT.H>
#include <numeric>
#include <tuple>

namespace amrex::FFT
{

R2C::R2C (Box const& domain)
: m_real_domain(domain),
m_spectral_domain_x(IntVect(0), IntVect(AMREX_D_DECL(domain.length(0)/2,
domain.bigEnd(1),
domain.bigEnd(2)))),
m_spectral_domain_y(IntVect(0), IntVect(AMREX_D_DECL(domain.bigEnd(1),
domain.length(0)/2,
domain.bigEnd(2)))),
m_spectral_domain_z(IntVect(0), IntVect(AMREX_D_DECL(domain.bigEnd(2),
domain.length(0)/2,
domain.bigEnd(1))))
{
AMREX_ALWAYS_ASSERT(domain.smallEnd() == 0);
AMREX_ALWAYS_ASSERT((std::is_same_v<double,Real>)); // xxxxx todo

int myproc = ParallelDescriptor::MyProc();
int nprocs = ParallelDescriptor::NProcs();

// xxxxx todo: need to handle cases there are more processes than 2d cells
// xxxxx todo: 1d & 2d

auto bax = amrex::decompose(m_real_domain, nprocs, {false,true,true});
Vector<int> pmx(bax.size());
std::iota(pmx.begin(), pmx.end(), 0);
DistributionMapping dmx(std::move(pmx));
m_rx.define(bax, dmx, 1, 0);

auto cbax = amrex::decompose(m_spectral_domain_x, nprocs, {false,true,true});
m_cx.define(cbax, dmx, 1, 0);

// plans for x-direction
{
double* in = m_rx[myproc].dataPtr();
auto* out = (fftw_complex*)(m_cx[myproc].dataPtr());

int n = m_real_domain.length(0);
int howmany = m_real_domain.length(1) * m_real_domain.length(2);

m_fftw_fwd_x = fftw_plan_many_dft_r2c
(1, &n, howmany, in, nullptr, 1, m_real_domain.length(0),
out, nullptr, 1, m_spectral_domain_x.length(0),
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);

m_fftw_bwd_x = fftw_plan_many_dft_c2r
(1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0),
in, nullptr, 1, m_real_domain.length(0),
FFTW_ESTIMATE | FFTW_DESTROY_INPUT);
}

auto cbay = amrex::decompose(m_spectral_domain_y, nprocs, {false,true,true});
DistributionMapping cdmy = dmx; // xxxxx todo
m_cy.define(cbay, cdmy, 1, 0);

std::tie(m_fftw_fwd_y, m_fftw_bwd_y) = make_c2c_plans(m_cy, m_spectral_domain_y);

// comm meta-data between x and y phases
m_cmd_x2y = std::make_unique<MultiBlockCommMetaData>
(m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y,
m_cx.boxArray(), m_cx.DistributionMap(), IntVect(0), m_dtos_x2y);
m_cmd_y2x = std::make_unique<MultiBlockCommMetaData>
(m_cx.boxArray(), m_cx.DistributionMap(), m_spectral_domain_x,
m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2x);

auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs, {false,true,true});
DistributionMapping cdmz = dmx; // xxxxx todo
m_cz.define(cbaz, cdmz, 1, 0);

std::tie(m_fftw_fwd_z, m_fftw_bwd_z) = make_c2c_plans(m_cz, m_spectral_domain_z);

// comm meta-data between y and z phases
m_cmd_y2z = std::make_unique<MultiBlockCommMetaData>
(m_cz.boxArray(), m_cz.DistributionMap(), m_spectral_domain_z,
m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2z);
m_cmd_z2y = std::make_unique<MultiBlockCommMetaData>
(m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y,
m_cz.boxArray(), m_cz.DistributionMap(), IntVect(0), m_dtos_z2y);
}

R2C::~R2C ()
{
fftw_destroy_plan(m_fftw_fwd_x);
fftw_destroy_plan(m_fftw_bwd_x);
fftw_destroy_plan(m_fftw_fwd_y);
fftw_destroy_plan(m_fftw_bwd_y);
fftw_destroy_plan(m_fftw_fwd_z);
fftw_destroy_plan(m_fftw_bwd_z);
}

void R2C::forward_doit (MultiFab const& inmf, Scaling scaling)
{
m_rx.ParallelCopy(inmf, 0, 0, 1);
fftw_execute(m_fftw_fwd_x);

ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y);
fftw_execute(m_fftw_fwd_y);

ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z);
fftw_execute(m_fftw_fwd_z);
}

void R2C::backward_doit (MultiFab& outmf, Scaling scaling)
{
// xxxxx todo: scaling

fftw_execute(m_fftw_bwd_z);
ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y);

fftw_execute(m_fftw_bwd_y);
ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x);

fftw_execute(m_fftw_bwd_x);
outmf.ParallelCopy(m_rx, 0, 0, 1);
}

std::pair<fftw_plan,fftw_plan>
R2C::make_c2c_plans (cMultiFab& inout, Box const& spectral_domain)
{
auto* pinout = (fftw_complex*)(inout[ParallelDescriptor::MyProc()].dataPtr());

int n = spectral_domain.length(0);
int howmany = spectral_domain.length(1) * spectral_domain.length(2);

fftw_plan fwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n,
pinout, nullptr, 1, n, -1, FFTW_ESTIMATE);
fftw_plan bwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n,
pinout, nullptr, 1, n, +1, FFTW_ESTIMATE);
return {fwd,bwd};
}

}
6 changes: 6 additions & 0 deletions Src/FFT/Make.package
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

CEXE_headers += AMReX_FFT.H
CEXE_sources += AMReX_FFT.cpp

VPATH_LOCATIONS += $(AMREX_HOME)/Src/FFT
INCLUDE_LOCATIONS += $(AMREX_HOME)/Src/FFT
47 changes: 47 additions & 0 deletions Tests/FFT/Poisson/GNUmakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
AMREX_HOME := ../../..

DEBUG = TRUE

DIM = 3

COMP = gcc

USE_MPI = FALSE
USE_OMP = FALSE
USE_CUDA = FALSE
USE_HIP = FALSE
USE_SYCL = FALSE

BL_NO_FORT = TRUE

TINY_PROFILE = FALSE

include $(AMREX_HOME)/Tools/GNUMake/Make.defs

include ./Make.package
include $(AMREX_HOME)/Src/Base/Make.package
include $(AMREX_HOME)/Src/FFT/Make.package

ifeq ($(USE_HEFFTE),TRUE)
HEFFTE_HOME ?= ../../../../heffte/build
VPATH_LOCATIONS += $(HEFFTE_HOME)/include
INCLUDE_LOCATIONS += $(HEFFTE_HOME)/include
LIBRARY_LOCATIONS += $(HEFFTE_HOME)/lib
libraries += -lheffte
ifneq ($(USE_GPU),TRUE)
libraries += -lfftw3_mpi
endif
endif

ifeq ($(USE_CUDA),TRUE)
libraries += -lcufft
else ifeq ($(USE_HIP),TRUE)
# Use rocFFT. ROC_PATH is defined in amrex
INCLUDE_LOCATIONS += $(ROC_PATH)/rocfft/include
LIBRARY_LOCATIONS += $(ROC_PATH)/rocfft/lib
LIBRARIES += -L$(ROC_PATH)/rocfft/lib -lrocfft
else
libraries += -lfftw3
endif

include $(AMREX_HOME)/Tools/GNUMake/Make.rules
1 change: 1 addition & 0 deletions Tests/FFT/Poisson/Make.package
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CEXE_sources += main.cpp
Loading

0 comments on commit 3611aa9

Please sign in to comment.