diff --git a/Src/Base/AMReX_GpuError.H b/Src/Base/AMReX_GpuError.H index 11186be11b..65457c8f4e 100644 --- a/Src/Base/AMReX_GpuError.H +++ b/Src/Base/AMReX_GpuError.H @@ -110,6 +110,16 @@ namespace Gpu { std::string errStr(std::string("HIPRAND error in file ") + __FILE__ \ + " line " + std::to_string(__LINE__)); \ amrex::Abort(errStr); }} while(0) + +#define AMREX_ROCFFT_SAFE_CALL(call) { \ + auto amrex_i_err = call; \ + if (rocfft_status_success != amrex_i_err) { \ + std::string errStr(std::string("rocFFT error ")+std::to_string(amrex_i_err) \ + + std::string(" in file ") + __FILE__ \ + + " line " + std::to_string(__LINE__)); \ + amrex::Abort(errStr); \ + }} + #endif #define AMREX_GPU_ERROR_CHECK() amrex::Gpu::ErrorCheck(__FILE__, __LINE__) diff --git a/Src/FFT/AMReX_FFT.H b/Src/FFT/AMReX_FFT.H index 25c67114da..dcdc0d1427 100644 --- a/Src/FFT/AMReX_FFT.H +++ b/Src/FFT/AMReX_FFT.H @@ -27,7 +27,7 @@ namespace amrex::FFT { enum struct Scaling { full, symmetric, none }; -enum struct Direction { forward, backwark }; +enum struct Direction { forward, backward }; template class R2C @@ -134,8 +134,7 @@ private: static void exec_r2c (FFTPlan plan, MF& in, cMF& out); static void exec_c2r (FFTPlan plan, cMF& in, MF& out); - static void exec_c2c_forward (FFTPlan plan, cMF& inout); - static void exec_c2c_backward (FFTPlan plan, cMF& inout); + static void exec_c2c (FFTPlan plan, cMF& inout, Direction direction); static void destroy_plan (FFTPlan plan); static std::pair make_c2c_plans (cMF& inout); @@ -189,7 +188,8 @@ R2C::R2C (Box const& domain) domain.length(0)/2, domain.bigEnd(1)))) { - AMREX_ALWAYS_ASSERT(domain.smallEnd() == 0 && domain.cellCentered()); + static_assert(std::is_same_v || std::is_same_v); + AMREX_ALWAYS_ASSERT(m_real_domain.smallEnd() == 0 && m_real_domain.cellCentered()); int myproc = ParallelDescriptor::MyProc(); int nprocs = ParallelDescriptor::NProcs(); @@ -238,7 +238,16 @@ R2C::R2C (Box const& domain) #elif defined(AMREX_USE_HIP) - static_assert(false); + auto prec = std::is_same_v ? rocfft_precision_single : rocfft_precision_double; + const std::size_t length = n; + AMREX_ROCFFT_SAFE_CALL + (rocfft_plan_create(&m_fft_fwd_x, rocfft_placement_notinplace, + rocfft_transform_type_real_forward, prec, 1, &length, howmany, + nullptr)); + AMREX_ROCFFT_SAFE_CALL + (rocfft_plan_create(&m_fft_bwd_x, rocfft_placement_notinplace, + rocfft_transform_type_real_inverse, prec, 1, &length, howmany, + nullptr)); #elif defined(AMREX_USE_SYCL) @@ -304,7 +313,7 @@ void R2C::destroy_plan (FFTPlan plan) #if defined(AMREX_USE_CUDA) AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan)); #elif defined(AMREX_USE_HIP) - static_assert(false); + AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan)); #elif defined(AMREX_USE_SYCL) static_assert(false); #else @@ -327,6 +336,10 @@ R2C::~R2C () destroy_plan(m_fft_bwd_z); } +#ifdef AMREX_USE_HIP +namespace detail { void execute (rocfft_plan plan, void **in, void **out); } +#endif + template void R2C::exec_r2c (FFTPlan plan, MF& in, cMF& out) { @@ -337,30 +350,23 @@ void R2C::exec_r2c (FFTPlan plan, MF& in, cMF& out) amrex::ignore_unused(in,out); #endif - if constexpr (std::is_same_v) - { -# if defined(AMREX_USE_CUDA) +#if defined(AMREX_USE_CUDA) + if constexpr (std::is_same_v) { AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pin, pout)); -# elif defined(AMREX_USE_HIP) - static_assert(false); -# elif defined(AMREX_USE_SYCL) - static_assert(false); -# else - fftwf_execute(plan); -# endif - } - else - { -# if defined(AMREX_USE_CUDA) + } else { AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, pin, pout)); + } #elif defined(AMREX_USE_HIP) - static_assert(false); + detail::execute(plan, (void**)&pin, (void**)&pout); #elif defined(AMREX_USE_SYCL) - static_assert(false); + static_assert(false); #else + if constexpr (std::is_same_v) { + fftwf_execute(plan); + } else { fftw_execute(plan); -#endif } +#endif } template @@ -373,100 +379,51 @@ void R2C::exec_c2r (FFTPlan plan, cMF& in, MF& out) amrex::ignore_unused(in,out); #endif - if constexpr (std::is_same_v) - { -# if defined(AMREX_USE_CUDA) +#if defined(AMREX_USE_CUDA) + if constexpr (std::is_same_v) { AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan, pin, pout)); -# elif defined(AMREX_USE_HIP) - static_assert(false); -# elif defined(AMREX_USE_SYCL) - static_assert(false); -# else - fftwf_execute(plan); -# endif - } - else - { -# if defined(AMREX_USE_CUDA) + } else { AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan, pin, pout)); -#elif defined(AMREX_USE_HIP) - static_assert(false); -#elif defined(AMREX_USE_SYCL) - static_assert(false); -#else - fftw_execute(plan); -#endif - } -} - -template -void R2C::exec_c2c_forward (FFTPlan plan, cMF& inout) -{ -#if defined(AMREX_USE_GPU) - auto* p = (FFTComplex*)(inout[ParallelDescriptor::MyProc()].dataPtr()); -#else - amrex::ignore_unused(inout); -#endif - - if constexpr (std::is_same_v) - { -# if defined(AMREX_USE_CUDA) - AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, CUFFT_FORWARD)); -# elif defined(AMREX_USE_HIP) - static_assert(false); -# elif defined(AMREX_USE_SYCL) - static_assert(false); -# else - fftwf_execute(plan); -# endif } - else - { -# if defined(AMREX_USE_CUDA) - AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, CUFFT_FORWARD)); #elif defined(AMREX_USE_HIP) - static_assert(false); + detail::execute(plan, (void**)&pin, (void**)&pout); #elif defined(AMREX_USE_SYCL) - static_assert(false); + static_assert(false); #else + if constexpr (std::is_same_v) { + fftwf_execute(plan); + } else { fftw_execute(plan); -#endif } +#endif } template -void R2C::exec_c2c_backward (FFTPlan plan, cMF& inout) +void R2C::exec_c2c (FFTPlan plan, cMF& inout, Direction direction) { + amrex::ignore_unused(inout, direction); #if defined(AMREX_USE_GPU) auto* p = (FFTComplex*)(inout[ParallelDescriptor::MyProc()].dataPtr()); -#else - amrex::ignore_unused(inout); #endif - if constexpr (std::is_same_v) - { -# if defined(AMREX_USE_CUDA) - AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, CUFFT_INVERSE)); -# elif defined(AMREX_USE_HIP) - static_assert(false); -# elif defined(AMREX_USE_SYCL) - static_assert(false); -# else - fftwf_execute(plan); -# endif +#if defined(AMREX_USE_CUDA) + auto cufft_direction = (direction == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE; + if constexpr (std::is_same_v) { + AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, p, p, cufft_direction)); + } else { + AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, cufft_direction)); } - else - { -# if defined(AMREX_USE_CUDA) - AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, p, p, CUFFT_INVERSE)); #elif defined(AMREX_USE_HIP) - static_assert(false); + detail::execute(plan, (void**)&p, (void**)&p); #elif defined(AMREX_USE_SYCL) - static_assert(false); + static_assert(false); #else + if constexpr (std::is_same_v) { + fftwf_execute(plan); + } else { fftw_execute(plan); -#endif } +#endif } template @@ -476,19 +433,19 @@ void R2C::forward_doit (MF const& inmf, Scaling scaling) exec_r2c(m_fft_fwd_x, m_rx, m_cx); ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y); - exec_c2c_forward(m_fft_fwd_y, m_cy); + exec_c2c(m_fft_fwd_y, m_cy, Direction::forward); ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z); - exec_c2c_forward(m_fft_fwd_z, m_cz); + exec_c2c(m_fft_fwd_z, m_cz, Direction::forward); } template void R2C::backward_doit (MF& outmf, Scaling scaling) { - exec_c2c_backward(m_fft_bwd_z, m_cz); + exec_c2c(m_fft_bwd_z, m_cz, Direction::backward); ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y); - exec_c2c_backward(m_fft_bwd_y, m_cy); + exec_c2c(m_fft_bwd_y, m_cy, Direction::backward); ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x); exec_c2r(m_fft_bwd_x, m_cx, m_rx); @@ -522,7 +479,16 @@ R2C::make_c2c_plans (cMF& inout) #elif defined(AMREX_USE_HIP) - static_assert(false); + auto prec = std::is_same_v ? rocfft_precision_single : rocfft_precision_double; + const std::size_t length = n; + AMREX_ROCFFT_SAFE_CALL + (rocfft_plan_create(&fwd, rocfft_placement_inplace, + rocfft_transform_type_complex_forward, prec, 1, &length, howmany, + nullptr)); + AMREX_ROCFFT_SAFE_CALL + (rocfft_plan_create(&bwd, rocfft_placement_inplace, + rocfft_transform_type_complex_inverse, prec, 1, &length, howmany, + nullptr)); #elif defined(AMREX_USE_SYCL) @@ -557,6 +523,8 @@ void R2C::post_forward_doit (F const& post_forward) }); } +extern template class R2C; + } #endif diff --git a/Src/FFT/AMReX_FFT.cpp b/Src/FFT/AMReX_FFT.cpp new file mode 100644 index 0000000000..d3e1e277d2 --- /dev/null +++ b/Src/FFT/AMReX_FFT.cpp @@ -0,0 +1,34 @@ +#include + +namespace amrex::FFT +{ + +template class R2C; + +#ifdef AMREX_USE_HIP +namespace detail +{ +void execute (rocfft_plan plan, void **in, void **out) +{ + rocfft_execution_info execinfo = nullptr; + AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_create(&execinfo)); + + std::size_t buffersize = 0; + AMREX_ROCFFT_SAFE_CALL(rocfft_plan_get_work_buffer_size(plan, &buffersize)); + + auto* buffer = (void*)amrex::The_Arena()->alloc(buffersize); + AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_set_work_buffer(execinfo, buffer, buffersize)); + + AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_set_stream(execinfo, amrex::Gpu::gpuStream())); + + AMREX_ROCFFT_SAFE_CALL(rocfft_execute(plan, in, out, execinfo)); + + amrex::Gpu::streamSynchronize(); + amrex::The_Arena()->free(buffer); + + AMREX_ROCFFT_SAFE_CALL(rocfft_execution_info_destroy(execinfo)); +} +} +#endif + +} diff --git a/Src/FFT/Make.package b/Src/FFT/Make.package index 3a7c7335c7..54322566b7 100644 --- a/Src/FFT/Make.package +++ b/Src/FFT/Make.package @@ -2,6 +2,7 @@ ifndef AMREX_FFT_MAKE AMREX_FFT_MAKE := 1 CEXE_headers += AMReX_FFT.H +CEXE_sources += AMReX_FFT.cpp VPATH_LOCATIONS += $(AMREX_HOME)/Src/FFT INCLUDE_LOCATIONS += $(AMREX_HOME)/Src/FFT diff --git a/Tools/GNUMake/comps/hip.mak b/Tools/GNUMake/comps/hip.mak index 87bb3e93f5..26dff7f94f 100644 --- a/Tools/GNUMake/comps/hip.mak +++ b/Tools/GNUMake/comps/hip.mak @@ -119,8 +119,8 @@ ifeq ($(HIP_COMPILER),clang) endif # Generic HIP info - ROC_PATH=$(realpath $(dir $(HIP_PATH))) - SYSTEM_INCLUDE_LOCATIONS += $(ROC_PATH)/include $(HIP_PATH)/include + ROC_PATH=$(realpath $(HIP_PATH)) + SYSTEM_INCLUDE_LOCATIONS += $(ROC_PATH)/include # rocRand SYSTEM_INCLUDE_LOCATIONS += $(ROC_PATH)/include/hiprand $(ROC_PATH)/include/rocrand