From d7c313fc422f7efc97bea377969889074f33aba8 Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Tue, 15 Oct 2024 16:36:56 -0700 Subject: [PATCH] Add single precision support --- Src/FFT/AMReX_FFT.H | 111 +++++++++++++++++++++++++--------- Tests/FFT/Poisson/GNUmakefile | 2 +- 2 files changed, 84 insertions(+), 29 deletions(-) diff --git a/Src/FFT/AMReX_FFT.H b/Src/FFT/AMReX_FFT.H index e60cf5fb62..6da7326a26 100644 --- a/Src/FFT/AMReX_FFT.H +++ b/Src/FFT/AMReX_FFT.H @@ -182,8 +182,7 @@ R2C::R2C (Box const& domain) domain.length(0)/2, domain.bigEnd(1)))) { - AMREX_ALWAYS_ASSERT(domain.smallEnd() == 0); - AMREX_ALWAYS_ASSERT((std::is_same_v)); // xxxxx todo + AMREX_ALWAYS_ASSERT(domain.smallEnd() == 0 && domain.cellCentered()); int myproc = ParallelDescriptor::MyProc(); int nprocs = ParallelDescriptor::NProcs(); @@ -208,7 +207,7 @@ R2C::R2C (Box const& domain) // plans for x-direction { - double* in = m_rx[myproc].dataPtr(); + auto* in = m_rx[myproc].dataPtr(); auto* out = (FFTComplex*)(m_cx[myproc].dataPtr()); Box const local_box = m_rx.boxArray()[myproc]; @@ -228,15 +227,27 @@ R2C::R2C (Box const& domain) static_assert(false); #else - m_fft_fwd_x = fftw_plan_many_dft_r2c - (1, &n, howmany, in, nullptr, 1, m_real_domain.length(0), - out, nullptr, 1, m_spectral_domain_x.length(0), - FFTW_ESTIMATE | FFTW_DESTROY_INPUT); - - m_fft_bwd_x = fftw_plan_many_dft_c2r - (1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0), - in, nullptr, 1, m_real_domain.length(0), - FFTW_ESTIMATE | FFTW_DESTROY_INPUT); + if constexpr (std::is_same_v) { + m_fft_fwd_x = fftwf_plan_many_dft_r2c + (1, &n, howmany, in, nullptr, 1, m_real_domain.length(0), + out, nullptr, 1, m_spectral_domain_x.length(0), + FFTW_ESTIMATE | FFTW_DESTROY_INPUT); + + m_fft_bwd_x = fftwf_plan_many_dft_c2r + (1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0), + in, nullptr, 1, m_real_domain.length(0), + FFTW_ESTIMATE | FFTW_DESTROY_INPUT); + } else { + m_fft_fwd_x = fftw_plan_many_dft_r2c + (1, &n, howmany, in, nullptr, 1, m_real_domain.length(0), + out, nullptr, 1, m_spectral_domain_x.length(0), + FFTW_ESTIMATE | FFTW_DESTROY_INPUT); + + m_fft_bwd_x = fftw_plan_many_dft_c2r + (1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0), + in, nullptr, 1, m_real_domain.length(0), + FFTW_ESTIMATE | FFTW_DESTROY_INPUT); + } #endif } @@ -285,12 +296,22 @@ R2C::~R2C () static_assert(false); #else - fftw_destroy_plan(m_fft_fwd_x); - fftw_destroy_plan(m_fft_bwd_x); - fftw_destroy_plan(m_fft_fwd_y); - fftw_destroy_plan(m_fft_bwd_y); - fftw_destroy_plan(m_fft_fwd_z); - fftw_destroy_plan(m_fft_bwd_z); + + if constexpr (std::is_same_v) { + fftwf_destroy_plan(m_fft_fwd_x); + fftwf_destroy_plan(m_fft_bwd_x); + fftwf_destroy_plan(m_fft_fwd_y); + fftwf_destroy_plan(m_fft_bwd_y); + fftwf_destroy_plan(m_fft_fwd_z); + fftwf_destroy_plan(m_fft_bwd_z); + } else { + fftw_destroy_plan(m_fft_fwd_x); + fftw_destroy_plan(m_fft_bwd_x); + fftw_destroy_plan(m_fft_fwd_y); + fftw_destroy_plan(m_fft_bwd_y); + fftw_destroy_plan(m_fft_fwd_z); + fftw_destroy_plan(m_fft_bwd_z); + } #endif } @@ -298,13 +319,25 @@ template void R2C::forward_doit (MF const& inmf, Scaling scaling) { m_rx.ParallelCopy(inmf, 0, 0, 1); - fftw_execute(m_fft_fwd_x); + if constexpr (std::is_same_v) { + fftwf_execute(m_fft_fwd_x); + } else { + fftw_execute(m_fft_fwd_x); + } ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y); - fftw_execute(m_fft_fwd_y); + if constexpr (std::is_same_v) { + fftwf_execute(m_fft_fwd_y); + } else { + fftw_execute(m_fft_fwd_y); + } ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z); - fftw_execute(m_fft_fwd_z); + if constexpr (std::is_same_v) { + fftwf_execute(m_fft_fwd_z); + } else { + fftw_execute(m_fft_fwd_z); + } } template @@ -312,13 +345,25 @@ void R2C::backward_doit (MF& outmf, Scaling scaling) { // xxxxx todo: scaling - fftw_execute(m_fft_bwd_z); + if constexpr (std::is_same_v) { + fftwf_execute(m_fft_bwd_z); + } else { + fftw_execute(m_fft_bwd_z); + } ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y); - fftw_execute(m_fft_bwd_y); + if constexpr (std::is_same_v) { + fftwf_execute(m_fft_bwd_y); + } else { + fftw_execute(m_fft_bwd_y); + } ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x); - fftw_execute(m_fft_bwd_x); + if constexpr (std::is_same_v) { + fftwf_execute(m_fft_bwd_x); + } else { + fftw_execute(m_fft_bwd_x); + } outmf.ParallelCopy(m_rx, 0, 0, 1); } @@ -333,6 +378,9 @@ R2C::make_c2c_plans (cMF& inout) int n = local_box.length(0); int howmany = local_box.length(1) * local_box.length(2); + FFTPlan fwd; + FFTPlan bwd; + #if defined(AMREX_USE_CUDA) static_assert(false); @@ -346,10 +394,17 @@ R2C::make_c2c_plans (cMF& inout) static_assert(false); #else - auto* fwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, - pinout, nullptr, 1, n, -1, FFTW_ESTIMATE); - auto* bwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, - pinout, nullptr, 1, n, +1, FFTW_ESTIMATE); + if constexpr (std::is_same_v) { + fwd = fftwf_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, + pinout, nullptr, 1, n, -1, FFTW_ESTIMATE); + bwd = fftwf_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, + pinout, nullptr, 1, n, +1, FFTW_ESTIMATE); + } else { + fwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, + pinout, nullptr, 1, n, -1, FFTW_ESTIMATE); + bwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, + pinout, nullptr, 1, n, +1, FFTW_ESTIMATE); + } #endif return {fwd,bwd}; diff --git a/Tests/FFT/Poisson/GNUmakefile b/Tests/FFT/Poisson/GNUmakefile index eac89a3a6c..5be04b0f3f 100644 --- a/Tests/FFT/Poisson/GNUmakefile +++ b/Tests/FFT/Poisson/GNUmakefile @@ -30,7 +30,7 @@ else ifeq ($(USE_HIP),TRUE) LIBRARY_LOCATIONS += $(ROC_PATH)/rocfft/lib LIBRARIES += -L$(ROC_PATH)/rocfft/lib -lrocfft else - libraries += -lfftw3 + libraries += -lfftw3f -lfftw3 endif include $(AMREX_HOME)/Tools/GNUMake/Make.rules