Skip to content

Commit

Permalink
CUDA support
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Oct 16, 2024
1 parent d7c313f commit 5568844
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 60 deletions.
10 changes: 10 additions & 0 deletions Src/Base/AMReX_GpuError.H
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ namespace Gpu {
std::string errStr(std::string("CURAND error in file ") + __FILE__ \
+ " line " + std::to_string(__LINE__)); \
amrex::Abort(errStr); }} while(0)

#define AMREX_CUFFT_SAFE_CALL(call) { \
cufftResult_t amrex_i_err = call; \
if (CUFFT_SUCCESS != amrex_i_err) { \
std::string errStr(std::string("CUFFT error ")+std::to_string(amrex_i_err) \
+ std::string(" in file ") + __FILE__ \
+ " line " + std::to_string(__LINE__)); \
amrex::Abort(errStr); \
}}

#endif

#ifdef AMREX_USE_HIP
Expand Down
255 changes: 195 additions & 60 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace amrex::FFT
{

enum struct Scaling { full, symmetric, none };
enum struct Direction { forward, backwark };

template <typename T = Real>
class R2C
Expand Down Expand Up @@ -106,9 +107,6 @@ public: // public for cuda

private:

void forward_doit (MF const& inmf, Scaling scaling = Scaling::none);
void backward_doit (MF& outmf, Scaling scaling = Scaling::none);

#if defined(AMREX_USE_CUDA)
using FFTPlan = cufftHandle;
using FFTComplex = std::conditional_t<std::is_same_v<float,T>,
Expand All @@ -131,6 +129,15 @@ private:
fftwf_complex, fftw_complex>;
#endif

void forward_doit (MF const& inmf, Scaling scaling = Scaling::none);
void backward_doit (MF& outmf, Scaling scaling = Scaling::none);

static void exec_r2c (FFTPlan plan, MF& in, cMF& out);
static void exec_c2r (FFTPlan plan, cMF& in, MF& out);
static void exec_c2c_forward (FFTPlan plan, cMF& inout);
static void exec_c2c_backward (FFTPlan plan, cMF& inout);

static void destroy_plan (FFTPlan plan);
static std::pair<FFTPlan,FFTPlan> make_c2c_plans (cMF& inout);

Box m_real_domain;
Expand All @@ -139,12 +146,12 @@ private:
Box m_spectral_domain_z;

// assuming it's double for now
FFTPlan m_fft_fwd_x = nullptr;
FFTPlan m_fft_bwd_x = nullptr;
FFTPlan m_fft_fwd_y = nullptr;
FFTPlan m_fft_bwd_y = nullptr;
FFTPlan m_fft_fwd_z = nullptr;
FFTPlan m_fft_bwd_z = nullptr;
FFTPlan m_fft_fwd_x;
FFTPlan m_fft_bwd_x;
FFTPlan m_fft_fwd_y;
FFTPlan m_fft_bwd_y;
FFTPlan m_fft_fwd_z;
FFTPlan m_fft_bwd_z;

// Comm meta-data. In the forward phase, we start with (x,y,z),
// transpose to (y,x,z) and then (z,x,y). In the backward phase, we
Expand Down Expand Up @@ -216,7 +223,18 @@ R2C<T>::R2C (Box const& domain)

#if defined(AMREX_USE_CUDA)

static_assert(false);
cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_R2C : CUFFT_D2Z;
cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2R : CUFFT_Z2D;
AMREX_CUFFT_SAFE_CALL
(cufftPlanMany(&m_fft_fwd_x, 1, &n, nullptr, 1, m_real_domain.length(0),
nullptr, 1, m_spectral_domain_x.length(0),
fwd_type, howmany));
AMREX_CUFFT_SAFE_CALL(cufftSetStream(m_fft_fwd_x, Gpu::gpuStream()));
AMREX_CUFFT_SAFE_CALL
(cufftPlanMany(&m_fft_bwd_x, 1, &n, nullptr, 1, m_spectral_domain_x.length(0),
nullptr, 1, m_real_domain.length(0),
bwd_type, howmany));
AMREX_CUFFT_SAFE_CALL(cufftSetStream(m_fft_bwd_x, Gpu::gpuStream()));

#elif defined(AMREX_USE_HIP)

Expand Down Expand Up @@ -281,89 +299,199 @@ R2C<T>::R2C (Box const& domain)
}

template <typename T>
R2C<T>::~R2C<T> ()
void R2C<T>::destroy_plan (FFTPlan plan)
{
#if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan));
#elif defined(AMREX_USE_HIP)
static_assert(false);
#elif defined(AMREX_USE_SYCL)
static_assert(false);
#else
if constexpr (std::is_same_v<float,T>) {
fftwf_destroy_plan(plan);
} else {
fftw_destroy_plan(plan);
}
#endif
}

static_assert(false);
template <typename T>
R2C<T>::~R2C<T> ()
{
destroy_plan(m_fft_fwd_x);
destroy_plan(m_fft_bwd_x);
destroy_plan(m_fft_fwd_y);
destroy_plan(m_fft_bwd_y);
destroy_plan(m_fft_fwd_z);
destroy_plan(m_fft_bwd_z);
}

#elif defined(AMREX_USE_HIP)
template <typename T>
void R2C<T>::exec_r2c (FFTPlan plan, MF& in, cMF& out)
{
#if defined(AMREX_USE_GPU)
auto* pin = in[ParallelDescriptor::MyProc()].dataPtr();
auto* pout = (FFTComplex*)(out[ParallelDescriptor::MyProc()].dataPtr());
#else
amrex::ignore_unused(in,out);
#endif

if constexpr (std::is_same_v<float,T>)
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pin, pout));
# elif defined(AMREX_USE_HIP)
static_assert(false);
# elif defined(AMREX_USE_SYCL)
static_assert(false);
# else
fftwf_execute(plan);
# endif
}
else
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, pin, pout));
#elif defined(AMREX_USE_HIP)
static_assert(false);
#elif defined(AMREX_USE_SYCL)
static_assert(false);
#else
fftw_execute(plan);
#endif
}
}

template <typename T>
void R2C<T>::exec_c2r (FFTPlan plan, cMF& in, MF& out)
{
#if defined(AMREX_USE_GPU)
auto* pin = (FFTComplex*)(in[ParallelDescriptor::MyProc()].dataPtr());
auto* pout = out[ParallelDescriptor::MyProc()].dataPtr();
#else
amrex::ignore_unused(in,out);
#endif

if constexpr (std::is_same_v<float,T>)
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan, pin, pout));
# elif defined(AMREX_USE_HIP)
static_assert(false);
# elif defined(AMREX_USE_SYCL)
static_assert(false);
# else
fftwf_execute(plan);
# endif
}
else
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan, pin, pout));
#elif defined(AMREX_USE_HIP)
static_assert(false);
#elif defined(AMREX_USE_SYCL)
static_assert(false);
#else
fftw_execute(plan);
#endif
}
}

template <typename T>
void R2C<T>::exec_c2c_forward (FFTPlan plan, cMF& inout)
{
#if defined(AMREX_USE_GPU)
auto* p = (FFTComplex*)(inout[ParallelDescriptor::MyProc()].dataPtr());
#else
amrex::ignore_unused(inout);
#endif

if constexpr (std::is_same_v<float,T>)
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, CUFFT_FORWARD));
# elif defined(AMREX_USE_HIP)
static_assert(false);
# elif defined(AMREX_USE_SYCL)
static_assert(false);
# else
fftwf_execute(plan);
# endif
}
else
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, CUFFT_FORWARD));
#elif defined(AMREX_USE_HIP)
static_assert(false);
#elif defined(AMREX_USE_SYCL)
static_assert(false);
#else
fftw_execute(plan);
#endif
}
}

template <typename T>
void R2C<T>::exec_c2c_backward (FFTPlan plan, cMF& inout)
{
#if defined(AMREX_USE_GPU)
auto* p = (FFTComplex*)(inout[ParallelDescriptor::MyProc()].dataPtr());
#else
amrex::ignore_unused(inout);
#endif

if constexpr (std::is_same_v<float,T>) {
fftwf_destroy_plan(m_fft_fwd_x);
fftwf_destroy_plan(m_fft_bwd_x);
fftwf_destroy_plan(m_fft_fwd_y);
fftwf_destroy_plan(m_fft_bwd_y);
fftwf_destroy_plan(m_fft_fwd_z);
fftwf_destroy_plan(m_fft_bwd_z);
} else {
fftw_destroy_plan(m_fft_fwd_x);
fftw_destroy_plan(m_fft_bwd_x);
fftw_destroy_plan(m_fft_fwd_y);
fftw_destroy_plan(m_fft_bwd_y);
fftw_destroy_plan(m_fft_fwd_z);
fftw_destroy_plan(m_fft_bwd_z);
if constexpr (std::is_same_v<float,T>)
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, CUFFT_INVERSE));
# elif defined(AMREX_USE_HIP)
static_assert(false);
# elif defined(AMREX_USE_SYCL)
static_assert(false);
# else
fftwf_execute(plan);
# endif
}
else
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, CUFFT_INVERSE));
#elif defined(AMREX_USE_HIP)
static_assert(false);
#elif defined(AMREX_USE_SYCL)
static_assert(false);
#else
fftw_execute(plan);
#endif
}
}

template <typename T>
void R2C<T>::forward_doit (MF const& inmf, Scaling scaling)
{
m_rx.ParallelCopy(inmf, 0, 0, 1);
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_fwd_x);
} else {
fftw_execute(m_fft_fwd_x);
}
exec_r2c(m_fft_fwd_x, m_rx, m_cx);

ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y);
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_fwd_y);
} else {
fftw_execute(m_fft_fwd_y);
}
exec_c2c_forward(m_fft_fwd_y, m_cy);

ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z);
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_fwd_z);
} else {
fftw_execute(m_fft_fwd_z);
}
exec_c2c_forward(m_fft_fwd_z, m_cz);
}

template <typename T>
void R2C<T>::backward_doit (MF& outmf, Scaling scaling)
{
// xxxxx todo: scaling

if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_bwd_z);
} else {
fftw_execute(m_fft_bwd_z);
}
exec_c2c_backward(m_fft_bwd_z, m_cz);
ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y);

if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_bwd_y);
} else {
fftw_execute(m_fft_bwd_y);
}
exec_c2c_backward(m_fft_bwd_y, m_cy);
ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x);

if constexpr (std::is_same_v<float,T>) {
fftwf_execute(m_fft_bwd_x);
} else {
fftw_execute(m_fft_bwd_x);
}
exec_c2r(m_fft_bwd_x, m_cx, m_rx);
outmf.ParallelCopy(m_rx, 0, 0, 1);
}

Expand All @@ -383,7 +511,14 @@ R2C<T>::make_c2c_plans (cMF& inout)

#if defined(AMREX_USE_CUDA)

static_assert(false);
cufftType fwd_type = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
cufftType bwd_type = std::is_same_v<float,T> ? CUFFT_C2C : CUFFT_Z2Z;
AMREX_CUFFT_SAFE_CALL
(cufftPlanMany(&fwd, 1, &n, nullptr, 1, n, nullptr, 1, n, fwd_type, howmany));
AMREX_CUFFT_SAFE_CALL(cufftSetStream(fwd, Gpu::gpuStream()));
AMREX_CUFFT_SAFE_CALL
(cufftPlanMany(&bwd, 1, &n, nullptr, 1, n, nullptr, 1, n, bwd_type, howmany));
AMREX_CUFFT_SAFE_CALL(cufftSetStream(bwd, Gpu::gpuStream()));

#elif defined(AMREX_USE_HIP)

Expand Down

0 comments on commit 5568844

Please sign in to comment.