Skip to content

Commit

Permalink
WIP: HIP support
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Oct 16, 2024
1 parent 6628628 commit 39405c1
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 102 deletions.
10 changes: 10 additions & 0 deletions Src/Base/AMReX_GpuError.H
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ namespace Gpu {
std::string errStr(std::string("HIPRAND error in file ") + __FILE__ \
+ " line " + std::to_string(__LINE__)); \
amrex::Abort(errStr); }} while(0)

#define AMREX_ROCFFT_SAFE_CALL(call) { \
auto amrex_i_err = call; \
if (rocfft_status_success != amrex_i_err) { \
std::string errStr(std::string("rocFFT error ")+std::to_string(amrex_i_err) \
+ std::string(" in file ") + __FILE__ \
+ " line " + std::to_string(__LINE__)); \
amrex::Abort(errStr); \
}}

#endif

#define AMREX_GPU_ERROR_CHECK() amrex::Gpu::ErrorCheck(__FILE__, __LINE__)
Expand Down
168 changes: 68 additions & 100 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace amrex::FFT
{

enum struct Scaling { full, symmetric, none };
enum struct Direction { forward, backwark };
enum struct Direction { forward, backward };

template <typename T = Real>
class R2C
Expand Down Expand Up @@ -134,8 +134,7 @@ private:

static void exec_r2c (FFTPlan plan, MF& in, cMF& out);
static void exec_c2r (FFTPlan plan, cMF& in, MF& out);
static void exec_c2c_forward (FFTPlan plan, cMF& inout);
static void exec_c2c_backward (FFTPlan plan, cMF& inout);
static void exec_c2c (FFTPlan plan, cMF& inout, Direction direction);

static void destroy_plan (FFTPlan plan);
static std::pair<FFTPlan,FFTPlan> make_c2c_plans (cMF& inout);
Expand Down Expand Up @@ -189,7 +188,8 @@ R2C<T>::R2C (Box const& domain)
domain.length(0)/2,
domain.bigEnd(1))))
{
AMREX_ALWAYS_ASSERT(domain.smallEnd() == 0 && domain.cellCentered());
static_assert(std::is_same_v<float,T> || std::is_same_v<double,T>);
AMREX_ALWAYS_ASSERT(m_real_domain.smallEnd() == 0 && m_real_domain.cellCentered());

int myproc = ParallelDescriptor::MyProc();
int nprocs = ParallelDescriptor::NProcs();
Expand Down Expand Up @@ -238,7 +238,16 @@ R2C<T>::R2C (Box const& domain)

#elif defined(AMREX_USE_HIP)

static_assert(false);
auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
const std::size_t length = n;
AMREX_ROCFFT_SAFE_CALL
(rocfft_plan_create(&m_fft_fwd_x, rocfft_placement_notinplace,
rocfft_transform_type_real_forward, prec, 1, &length, howmany,
nullptr));
AMREX_ROCFFT_SAFE_CALL
(rocfft_plan_create(&m_fft_bwd_x, rocfft_placement_notinplace,
rocfft_transform_type_real_inverse, prec, 1, &length, howmany,
nullptr));

#elif defined(AMREX_USE_SYCL)

Expand Down Expand Up @@ -304,7 +313,7 @@ void R2C<T>::destroy_plan (FFTPlan plan)
#if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan));
#elif defined(AMREX_USE_HIP)
static_assert(false);
AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan));
#elif defined(AMREX_USE_SYCL)
static_assert(false);
#else
Expand All @@ -327,6 +336,10 @@ R2C<T>::~R2C<T> ()
destroy_plan(m_fft_bwd_z);
}

#ifdef AMREX_USE_HIP
namespace detail { void execute (rocfft_plan plan, void **in, void **out); }
#endif

template <typename T>
void R2C<T>::exec_r2c (FFTPlan plan, MF& in, cMF& out)
{
Expand All @@ -337,30 +350,23 @@ void R2C<T>::exec_r2c (FFTPlan plan, MF& in, cMF& out)
amrex::ignore_unused(in,out);
#endif

if constexpr (std::is_same_v<float,T>)
{
# if defined(AMREX_USE_CUDA)
#if defined(AMREX_USE_CUDA)
if constexpr (std::is_same_v<float,T>) {
AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pin, pout));
# elif defined(AMREX_USE_HIP)
static_assert(false);
# elif defined(AMREX_USE_SYCL)
static_assert(false);
# else
fftwf_execute(plan);
# endif
}
else
{
# if defined(AMREX_USE_CUDA)
} else {
AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, pin, pout));
}
#elif defined(AMREX_USE_HIP)
static_assert(false);
detail::execute(plan, (void**)&pin, (void**)&pout);
#elif defined(AMREX_USE_SYCL)
static_assert(false);
static_assert(false);
#else
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(plan);
} else {
fftw_execute(plan);
#endif
}
#endif
}

template <typename T>
Expand All @@ -373,100 +379,51 @@ void R2C<T>::exec_c2r (FFTPlan plan, cMF& in, MF& out)
amrex::ignore_unused(in,out);
#endif

if constexpr (std::is_same_v<float,T>)
{
# if defined(AMREX_USE_CUDA)
#if defined(AMREX_USE_CUDA)
if constexpr (std::is_same_v<float,T>) {
AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan, pin, pout));
# elif defined(AMREX_USE_HIP)
static_assert(false);
# elif defined(AMREX_USE_SYCL)
static_assert(false);
# else
fftwf_execute(plan);
# endif
}
else
{
# if defined(AMREX_USE_CUDA)
} else {
AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan, pin, pout));
#elif defined(AMREX_USE_HIP)
static_assert(false);
#elif defined(AMREX_USE_SYCL)
static_assert(false);
#else
fftw_execute(plan);
#endif
}
}

template <typename T>
void R2C<T>::exec_c2c_forward (FFTPlan plan, cMF& inout)
{
#if defined(AMREX_USE_GPU)
auto* p = (FFTComplex*)(inout[ParallelDescriptor::MyProc()].dataPtr());
#else
amrex::ignore_unused(inout);
#endif

if constexpr (std::is_same_v<float,T>)
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, CUFFT_FORWARD));
# elif defined(AMREX_USE_HIP)
static_assert(false);
# elif defined(AMREX_USE_SYCL)
static_assert(false);
# else
fftwf_execute(plan);
# endif
}
else
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, CUFFT_FORWARD));
#elif defined(AMREX_USE_HIP)
static_assert(false);
detail::execute(plan, (void**)&pin, (void**)&pout);
#elif defined(AMREX_USE_SYCL)
static_assert(false);
static_assert(false);
#else
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(plan);
} else {
fftw_execute(plan);
#endif
}
#endif
}

template <typename T>
void R2C<T>::exec_c2c_backward (FFTPlan plan, cMF& inout)
void R2C<T>::exec_c2c (FFTPlan plan, cMF& inout, Direction direction)
{
amrex::ignore_unused(inout, direction);
#if defined(AMREX_USE_GPU)
auto* p = (FFTComplex*)(inout[ParallelDescriptor::MyProc()].dataPtr());
#else
amrex::ignore_unused(inout);
#endif

if constexpr (std::is_same_v<float,T>)
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, CUFFT_INVERSE));
# elif defined(AMREX_USE_HIP)
static_assert(false);
# elif defined(AMREX_USE_SYCL)
static_assert(false);
# else
fftwf_execute(plan);
# endif
#if defined(AMREX_USE_CUDA)
auto cufft_direction = (direction == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE;
if constexpr (std::is_same_v<float,T>) {
AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, cufft_direction));
} else {
AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, cufft_direction));
}
else
{
# if defined(AMREX_USE_CUDA)
AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, CUFFT_INVERSE));
#elif defined(AMREX_USE_HIP)
static_assert(false);
detail::execute(plan, (void**)&p, (void**)&p);
#elif defined(AMREX_USE_SYCL)
static_assert(false);
static_assert(false);
#else
if constexpr (std::is_same_v<float,T>) {
fftwf_execute(plan);
} else {
fftw_execute(plan);
#endif
}
#endif
}

template <typename T>
Expand All @@ -476,19 +433,19 @@ void R2C<T>::forward_doit (MF const& inmf, Scaling scaling)
exec_r2c(m_fft_fwd_x, m_rx, m_cx);

ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y);
exec_c2c_forward(m_fft_fwd_y, m_cy);
exec_c2c(m_fft_fwd_y, m_cy, Direction::forward);

ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z);
exec_c2c_forward(m_fft_fwd_z, m_cz);
exec_c2c(m_fft_fwd_z, m_cz, Direction::forward);
}

template <typename T>
void R2C<T>::backward_doit (MF& outmf, Scaling scaling)
{
exec_c2c_backward(m_fft_bwd_z, m_cz);
exec_c2c(m_fft_bwd_z, m_cz, Direction::backward);
ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y);

exec_c2c_backward(m_fft_bwd_y, m_cy);
exec_c2c(m_fft_bwd_y, m_cy, Direction::backward);
ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x);

exec_c2r(m_fft_bwd_x, m_cx, m_rx);
Expand Down Expand Up @@ -522,7 +479,16 @@ R2C<T>::make_c2c_plans (cMF& inout)

#elif defined(AMREX_USE_HIP)

static_assert(false);
auto prec = std::is_same_v<float,T> ? rocfft_precision_single : rocfft_precision_double;
const std::size_t length = n;
AMREX_ROCFFT_SAFE_CALL
(rocfft_plan_create(&fwd, rocfft_placement_inplace,
rocfft_transform_type_complex_forward, prec, 1, &length, howmany,
nullptr));
AMREX_ROCFFT_SAFE_CALL
(rocfft_plan_create(&bwd, rocfft_placement_inplace,
rocfft_transform_type_complex_inverse, prec, 1, &length, howmany,
nullptr));

#elif defined(AMREX_USE_SYCL)

Expand Down Expand Up @@ -557,6 +523,8 @@ void R2C<T>::post_forward_doit (F const& post_forward)
});
}

extern template class R2C<Real>;

}

#endif
34 changes: 34 additions & 0 deletions Src/FFT/AMReX_FFT.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <AMReX_FFT.H>

namespace amrex::FFT
{

template class R2C<Real>;

#ifdef AMREX_USE_HIP
namespace detail
{
void execute (rocfft_plan plan, void **in, void **out)
{
rocfft_execution_info execinfo = nullptr;
AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_create(&execinfo));

std::size_t buffersize = 0;
AMREX_ROCFFT_SAFE_CALL(rocfft_plan_get_work_buffer_size(plan, &buffersize));

auto* buffer = (void*)amrex::The_Arena()->alloc(buffersize);
AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize));

AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream()));

AMREX_ROCFFT_SAFE_CALL(rocfft_execute(plan, in, out, execinfo));

amrex::Gpu::streamSynchronize();
amrex::The_Arena()->free(buffer);

AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_destroy(execinfo));
}
}
#endif

}
1 change: 1 addition & 0 deletions Src/FFT/Make.package
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ ifndef AMREX_FFT_MAKE
AMREX_FFT_MAKE := 1

CEXE_headers += AMReX_FFT.H
CEXE_sources += AMReX_FFT.cpp

VPATH_LOCATIONS += $(AMREX_HOME)/Src/FFT
INCLUDE_LOCATIONS += $(AMREX_HOME)/Src/FFT
Expand Down
4 changes: 2 additions & 2 deletions Tools/GNUMake/comps/hip.mak
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ ifeq ($(HIP_COMPILER),clang)
endif

# Generic HIP info
ROC_PATH=$(realpath $(dir $(HIP_PATH)))
SYSTEM_INCLUDE_LOCATIONS += $(ROC_PATH)/include $(HIP_PATH)/include
ROC_PATH=$(realpath $(HIP_PATH))
SYSTEM_INCLUDE_LOCATIONS += $(ROC_PATH)/include

# rocRand
SYSTEM_INCLUDE_LOCATIONS += $(ROC_PATH)/include/hiprand $(ROC_PATH)/include/rocrand
Expand Down

0 comments on commit 39405c1

Please sign in to comment.