From 3b71a9efae0e278ff88097d4f9f71f0b0824c431 Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Thu, 17 Oct 2024 18:33:41 -0700 Subject: [PATCH] 1D & 2D support --- Src/FFT/AMReX_FFT.H | 424 +++++++++++++++++++++++++++----------------- 1 file changed, 263 insertions(+), 161 deletions(-) diff --git a/Src/FFT/AMReX_FFT.H b/Src/FFT/AMReX_FFT.H index 97451dd5a8..dfcaf8862d 100644 --- a/Src/FFT/AMReX_FFT.H +++ b/Src/FFT/AMReX_FFT.H @@ -105,74 +105,84 @@ public: private: #if defined(AMREX_USE_CUDA) - using FFTPlan = cufftHandle; - using FFTPlan2 = FFTPlan; + using VendorPlan = cufftHandle; + using VendorPlan2 = VendorPlan; using FFTComplex = std::conditional_t, cuComplex, cuDoubleComplex>; #elif defined(AMREX_USE_HIP) - using FFTPlan = rocfft_plan; - using FFTPlan2 = FFTPlan; + using VendorPlan = rocfft_plan; + using VendorPlan2 = VendorPlan; using FFTComplex = std::conditional_t, float2, double2>; #elif defined(AMREX_USE_SYCL) - using FFTPlan = oneapi::mkl::dft::descriptor< + using VendorPlan = oneapi::mkl::dft::descriptor< std::is_same_v ? oneapi::mkl::dft::precision::SINGLE : oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::REAL> *; - using FFTPlan2 = oneapi::mkl::dft::descriptor< + using VendorPlan2 = oneapi::mkl::dft::descriptor< std::is_same_v ? oneapi::mkl::dft::precision::SINGLE : oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::COMPLEX> *; using FFTComplex = GpuComplex; #else - using FFTPlan = std::conditional_t, - fftwf_plan, fftw_plan>; - using FFTPlan2 = FFTPlan; + using VendorPlan = std::conditional_t, + fftwf_plan, fftw_plan>; + using VendorPlan2 = VendorPlan; using FFTComplex = std::conditional_t, fftwf_complex, fftw_complex>; #endif + struct Plan { + bool defined = false; + VendorPlan plan = 0; // NOLINT + }; + + struct Plan2 { + bool defined = false; + VendorPlan2 plan = 0; // NOLINT + }; + + template + static typename FA::FABType::value_type * + get_fab (FA& fa) { + auto myproc = ParallelDescriptor::MyProc(); + if (myproc < fa.size()) { + return fa.fabPtr(myproc); + } else { + return nullptr; + } + } + void forward_doit (MF const& inmf, Scaling scaling = Scaling::none); void backward_doit (MF& outmf, Scaling scaling = Scaling::none); - static void exec_r2c (FFTPlan plan, MF& in, cMF& out); - static void exec_c2r (FFTPlan plan, cMF& in, MF& out); + static void exec_r2c (Plan plan, MF& in, cMF& out); + static void exec_c2r (Plan plan, cMF& in, MF& out); template - static void exec_c2c (FFTPlan2 plan, cMF& inout); + static void exec_c2c (Plan2 plan, cMF& inout); template static void destroy_plan (P plan); - static std::pair make_c2c_plans (cMF& inout); + static std::pair make_c2c_plans (cMF& inout); - Info m_info; - - Box m_real_domain; - Box m_spectral_domain_x; - Box m_spectral_domain_y; - Box m_spectral_domain_z; - - // assuming it's double for now - FFTPlan m_fft_fwd_x; - FFTPlan m_fft_bwd_x; - FFTPlan2 m_fft_fwd_y; - FFTPlan2 m_fft_bwd_y; - FFTPlan2 m_fft_fwd_z; - FFTPlan2 m_fft_bwd_z; + Plan m_fft_fwd_x{}; + Plan m_fft_bwd_x{}; + Plan2 m_fft_fwd_y{}; + Plan2 m_fft_bwd_y{}; + Plan2 m_fft_fwd_z{}; + Plan2 m_fft_bwd_z{}; // Comm meta-data. In the forward phase, we start with (x,y,z), // transpose to (y,x,z) and then (z,x,y). In the backward phase, we // perform inverse transpose. - Swap01 m_dtos_x2y{}; std::unique_ptr m_cmd_x2y; - // - Swap01 m_dtos_y2x{}; std::unique_ptr m_cmd_y2x; - // - Swap02 m_dtos_y2z{}; std::unique_ptr m_cmd_y2z; - // - Swap02 m_dtos_z2y{}; std::unique_ptr m_cmd_z2y; + Swap01 m_dtos_x2y{}; + Swap01 m_dtos_y2x{}; + Swap02 m_dtos_y2z{}; + Swap02 m_dtos_z2y{}; // Optionally we need to copy from m_cz to user provided cMultiFab. xxxxx todo @@ -180,38 +190,51 @@ private: cMF m_cx; cMF m_cy; cMF m_cz; + + Box m_real_domain; + Box m_spectral_domain_x; + Box m_spectral_domain_y; + Box m_spectral_domain_z; + + Info m_info; }; template R2C::R2C (Box const& domain, Info const& info) - : m_info(info), - m_real_domain(domain), + : m_real_domain(domain), m_spectral_domain_x(IntVect(0), IntVect(AMREX_D_DECL(domain.length(0)/2, domain.bigEnd(1), domain.bigEnd(2)))), +#if (AMREX_SPACEDIM >= 2) m_spectral_domain_y(IntVect(0), IntVect(AMREX_D_DECL(domain.bigEnd(1), domain.length(0)/2, domain.bigEnd(2)))), +#if (AMREX_SPACEDIM == 3) m_spectral_domain_z(IntVect(0), IntVect(AMREX_D_DECL(domain.bigEnd(2), domain.length(0)/2, - domain.bigEnd(1)))) + domain.bigEnd(1)))), +#endif +#endif + m_info(info) { static_assert(std::is_same_v || std::is_same_v); AMREX_ALWAYS_ASSERT(m_real_domain.smallEnd() == 0 && m_real_domain.cellCentered()); -#if (AMREX_SPACEDIM != 3) +#if (AMREX_SPACEDIM == 3) + AMREX_ALWAYS_ASSERT(m_real_domain.length(2) > 1 || false == m_info.batch_mode); +#else AMREX_ALWAYS_ASSERT(false == m_info.batch_mode); #endif int myproc = ParallelDescriptor::MyProc(); int nprocs = ParallelDescriptor::NProcs(); - // xxxxx todo: need to handle cases there are more processes than 2d cells - // xxxxx todo: 1d & 2d - auto bax = amrex::decompose(m_real_domain, nprocs, {AMREX_D_DECL(false,true,true)}); - Vector pmx(bax.size()); - std::iota(pmx.begin(), pmx.end(), 0); - DistributionMapping dmx(std::move(pmx)); + DistributionMapping dmx; + { + Vector pm(bax.size()); + std::iota(pm.begin(), pm.end(), 0); + dmx.define(std::move(pm)); + } m_rx.define(bax, dmx, 1, 0); { @@ -224,60 +247,63 @@ R2C::R2C (Box const& domain, Info const& info) } // plans for x-direction + if (myproc < m_rx.size()) { Box const local_box = m_rx.boxArray()[myproc]; int n = local_box.length(0); - int howmany = local_box.length(1) * local_box.length(2); + int howmany = AMREX_D_TERM(1, *local_box.length(1), *local_box.length(2)); #if defined(AMREX_USE_CUDA) cufftType fwd_type = std::is_same_v ? CUFFT_R2C : CUFFT_D2Z; cufftType bwd_type = std::is_same_v ? CUFFT_C2R : CUFFT_Z2D; AMREX_CUFFT_SAFE_CALL - (cufftPlanMany(&m_fft_fwd_x, 1, &n, nullptr, 1, m_real_domain.length(0), + (cufftPlanMany(&m_fft_fwd_x.plan, 1, &n, + nullptr, 1, m_real_domain.length(0), nullptr, 1, m_spectral_domain_x.length(0), fwd_type, howmany)); - AMREX_CUFFT_SAFE_CALL(cufftSetStream(m_fft_fwd_x, Gpu::gpuStream())); + AMREX_CUFFT_SAFE_CALL(cufftSetStream(m_fft_fwd_x.plan, Gpu::gpuStream())); AMREX_CUFFT_SAFE_CALL - (cufftPlanMany(&m_fft_bwd_x, 1, &n, nullptr, 1, m_spectral_domain_x.length(0), + (cufftPlanMany(&m_fft_bwd_x.plan, 1, &n, + nullptr, 1, m_spectral_domain_x.length(0), nullptr, 1, m_real_domain.length(0), bwd_type, howmany)); - AMREX_CUFFT_SAFE_CALL(cufftSetStream(m_fft_bwd_x, Gpu::gpuStream())); + AMREX_CUFFT_SAFE_CALL(cufftSetStream(m_fft_bwd_x.plan, Gpu::gpuStream())); #elif defined(AMREX_USE_HIP) auto prec = std::is_same_v ? rocfft_precision_single : rocfft_precision_double; const std::size_t length = n; AMREX_ROCFFT_SAFE_CALL - (rocfft_plan_create(&m_fft_fwd_x, rocfft_placement_notinplace, - rocfft_transform_type_real_forward, prec, 1, &length, howmany, - nullptr)); + (rocfft_plan_create(&m_fft_fwd_x.plan, rocfft_placement_notinplace, + rocfft_transform_type_real_forward, prec, 1, + &length, howmany, nullptr)); AMREX_ROCFFT_SAFE_CALL - (rocfft_plan_create(&m_fft_bwd_x, rocfft_placement_notinplace, - rocfft_transform_type_real_inverse, prec, 1, &length, howmany, - nullptr)); + (rocfft_plan_create(&m_fft_bwd_x.plan, rocfft_placement_notinplace, + rocfft_transform_type_real_inverse, prec, 1, + &length, howmany, nullptr)); #elif defined(AMREX_USE_SYCL) - m_fft_fwd_x = new std::remove_pointer_t(n); - m_fft_fwd_x->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_NOT_INPLACE); - m_fft_fwd_x->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, - howmany); - m_fft_fwd_x->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, - m_real_domain.length(0)); - m_fft_fwd_x->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, - m_spectral_domain_x.length(0)); + m_fft_fwd_x.plan = new std::remove_pointer_t(n); + m_fft_fwd_x.plan->set_value(oneapi::mkl::dft::config_param::PLACEMENT, + DFTI_NOT_INPLACE); + m_fft_fwd_x.plan->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, + howmany); + m_fft_fwd_x.plan->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, + m_real_domain.length(0)); + m_fft_fwd_x.plan->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, + m_spectral_domain_x.length(0)); std::array strides{0,1}; - m_fft_fwd_x->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, - strides.data()); - m_fft_fwd_x->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, - strides.data()); - m_fft_fwd_x->set_value(oneapi::mkl::dft::config_param::WORKSPACE, - oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); - m_fft_fwd_x->commit(amrex::Gpu::Device::streamQueue()); + m_fft_fwd_x.plan->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, + strides.data()); + m_fft_fwd_x.plan->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, + strides.data()); + m_fft_fwd_x.plan->set_value(oneapi::mkl::dft::config_param::WORKSPACE, + oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); + m_fft_fwd_x.plan->commit(amrex::Gpu::Device::streamQueue()); - m_fft_bwd_x = m_fft_fwd_x; + m_fft_bwd_x.plan = m_fft_fwd_x.plan; #else /* FFTW */ @@ -285,31 +311,41 @@ R2C::R2C (Box const& domain, Info const& info) auto* out = (FFTComplex*)(m_cx[myproc].dataPtr()); if constexpr (std::is_same_v) { - m_fft_fwd_x = fftwf_plan_many_dft_r2c + m_fft_fwd_x.plan = fftwf_plan_many_dft_r2c (1, &n, howmany, in, nullptr, 1, m_real_domain.length(0), out, nullptr, 1, m_spectral_domain_x.length(0), FFTW_ESTIMATE | FFTW_DESTROY_INPUT); - m_fft_bwd_x = fftwf_plan_many_dft_c2r + m_fft_bwd_x.plan = fftwf_plan_many_dft_c2r (1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0), in, nullptr, 1, m_real_domain.length(0), FFTW_ESTIMATE | FFTW_DESTROY_INPUT); } else { - m_fft_fwd_x = fftw_plan_many_dft_r2c + m_fft_fwd_x.plan = fftw_plan_many_dft_r2c (1, &n, howmany, in, nullptr, 1, m_real_domain.length(0), out, nullptr, 1, m_spectral_domain_x.length(0), FFTW_ESTIMATE | FFTW_DESTROY_INPUT); - m_fft_bwd_x = fftw_plan_many_dft_c2r + m_fft_bwd_x.plan = fftw_plan_many_dft_c2r (1, &n, howmany, out, nullptr, 1, m_spectral_domain_x.length(0), in, nullptr, 1, m_real_domain.length(0), FFTW_ESTIMATE | FFTW_DESTROY_INPUT); } #endif + m_fft_fwd_x.defined = true; + m_fft_bwd_x.defined = true; } +#if (AMREX_SPACEDIM >= 2) auto cbay = amrex::decompose(m_spectral_domain_y, nprocs, {AMREX_D_DECL(false,true,true)}); - DistributionMapping const& cdmy = dmx; // xxxxx todo + DistributionMapping cdmy; + if (cbay.size() == dmx.size()) { + cdmy = dmx; + } else { + Vector pm(cbay.size()); + std::iota(pm.begin(), pm.end(), 0); + cdmy.define(std::move(pm)); + } m_cy.define(cbay, cdmy, 1, 0); std::tie(m_fft_fwd_y, m_fft_bwd_y) = make_c2c_plans(m_cy); @@ -322,38 +358,56 @@ R2C::R2C (Box const& domain, Info const& info) (m_cx.boxArray(), m_cx.DistributionMap(), m_spectral_domain_x, m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2x); - auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs, {AMREX_D_DECL(false,true,true)}); - DistributionMapping const& cdmz = dmx; // xxxxx todo - m_cz.define(cbaz, cdmz, 1, 0); +#if (AMREX_SPACEDIM == 3) + if (false == m_info.batch_mode || m_real_domain.length(2) > 1) { + auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs, {false,true,true}); + DistributionMapping cdmz; + if (cbaz.size() == dmx.size()) { + cdmz = dmx; + } else if (cbaz.size() == cdmy.size()) { + cdmz = cdmy; + } else { + Vector pm(cbaz.size()); + std::iota(pm.begin(), pm.end(), 0); + cdmz.define(std::move(pm)); + } + m_cz.define(cbaz, cdmz, 1, 0); - std::tie(m_fft_fwd_z, m_fft_bwd_z) = make_c2c_plans(m_cz); + std::tie(m_fft_fwd_z, m_fft_bwd_z) = make_c2c_plans(m_cz); - // comm meta-data between y and z phases - m_cmd_y2z = std::make_unique - (m_cz.boxArray(), m_cz.DistributionMap(), m_spectral_domain_z, - m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2z); - m_cmd_z2y = std::make_unique - (m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y, - m_cz.boxArray(), m_cz.DistributionMap(), IntVect(0), m_dtos_z2y); + // comm meta-data between y and z phases + m_cmd_y2z = std::make_unique + (m_cz.boxArray(), m_cz.DistributionMap(), m_spectral_domain_z, + m_cy.boxArray(), m_cy.DistributionMap(), IntVect(0), m_dtos_y2z); + m_cmd_z2y = std::make_unique + (m_cy.boxArray(), m_cy.DistributionMap(), m_spectral_domain_y, + m_cz.boxArray(), m_cz.DistributionMap(), IntVect(0), m_dtos_z2y); + } +#endif +#endif } template template void R2C::destroy_plan (P plan) { + if (! plan.defined) { return; } + #if defined(AMREX_USE_CUDA) - AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan)); + AMREX_CUFFT_SAFE_CALL(cufftDestroy(plan.plan)); #elif defined(AMREX_USE_HIP) - AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan)); + AMREX_ROCFFT_SAFE_CALL(rocfft_plan_destroy(plan.plan)); #elif defined(AMREX_USE_SYCL) - delete plan; + delete plan.plan; #else if constexpr (std::is_same_v) { - fftwf_destroy_plan(plan); + fftwf_destroy_plan(plan.plan); } else { - fftw_destroy_plan(plan); + fftw_destroy_plan(plan.plan); } #endif + + plan.defined = false; } template @@ -407,8 +461,10 @@ void sycl_execute (P plan, TI* in, TO* out) #endif template -void R2C::exec_r2c (FFTPlan plan, MF& in, cMF& out) +void R2C::exec_r2c (Plan plan, MF& in, cMF& out) { + if (! plan.defined) { return; } + #if defined(AMREX_USE_GPU) auto* pin = in[ParallelDescriptor::MyProc()].dataPtr(); auto* pout = out[ParallelDescriptor::MyProc()].dataPtr(); @@ -418,26 +474,28 @@ void R2C::exec_r2c (FFTPlan plan, MF& in, cMF& out) #if defined(AMREX_USE_CUDA) if constexpr (std::is_same_v) { - AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan, pin, (FFTComplex*)pout)); + AMREX_CUFFT_SAFE_CALL(cufftExecR2C(plan.plan, pin, (FFTComplex*)pout)); } else { - AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan, pin, (FFTComplex*)pout)); + AMREX_CUFFT_SAFE_CALL(cufftExecD2Z(plan.plan, pin, (FFTComplex*)pout)); } #elif defined(AMREX_USE_HIP) - detail::hip_execute(plan, (void**)&pin, (void**)&pout); + detail::hip_execute(plan.plan, (void**)&pin, (void**)&pout); #elif defined(AMREX_USE_SYCL) - detail::sycl_execute(plan, pin, (std::complex*)pout); + detail::sycl_execute(plan.plan, pin, (std::complex*)pout); #else if constexpr (std::is_same_v) { - fftwf_execute(plan); + fftwf_execute(plan.plan); } else { - fftw_execute(plan); + fftw_execute(plan.plan); } #endif } template -void R2C::exec_c2r (FFTPlan plan, cMF& in, MF& out) +void R2C::exec_c2r (Plan plan, cMF& in, MF& out) { + if (! plan.defined) { return; } + #if defined(AMREX_USE_GPU) auto* pin = in[ParallelDescriptor::MyProc()].dataPtr(); auto* pout = out[ParallelDescriptor::MyProc()].dataPtr(); @@ -447,27 +505,29 @@ void R2C::exec_c2r (FFTPlan plan, cMF& in, MF& out) #if defined(AMREX_USE_CUDA) if constexpr (std::is_same_v) { - AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan, (FFTComplex*)pin, pout)); + AMREX_CUFFT_SAFE_CALL(cufftExecC2R(plan.plan, (FFTComplex*)pin, pout)); } else { - AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan, (FFTComplex*)pin, pout)); + AMREX_CUFFT_SAFE_CALL(cufftExecZ2D(plan.plan, (FFTComplex*)pin, pout)); } #elif defined(AMREX_USE_HIP) - detail::hip_execute(plan, (void**)&pin, (void**)&pout); + detail::hip_execute(plan.plan, (void**)&pin, (void**)&pout); #elif defined(AMREX_USE_SYCL) - detail::sycl_execute(plan, (std::complex*)pin, pout); + detail::sycl_execute(plan.plan, (std::complex*)pin, pout); #else if constexpr (std::is_same_v) { - fftwf_execute(plan); + fftwf_execute(plan.plan); } else { - fftw_execute(plan); + fftw_execute(plan.plan); } #endif } template template -void R2C::exec_c2c (FFTPlan2 plan, cMF& inout) +void R2C::exec_c2c (Plan2 plan, cMF& inout) { + if (! plan.defined) { return; } + amrex::ignore_unused(inout); #if defined(AMREX_USE_GPU) auto* p = inout[ParallelDescriptor::MyProc()].dataPtr(); @@ -476,21 +536,21 @@ void R2C::exec_c2c (FFTPlan2 plan, cMF& inout) #if defined(AMREX_USE_CUDA) auto cufft_direction = (direction == Direction::forward) ? CUFFT_FORWARD : CUFFT_INVERSE; if constexpr (std::is_same_v) { - AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan, (FFTComplex*)p, (FFTComplex*)p, + AMREX_CUFFT_SAFE_CALL(cufftExecC2C(plan.plan, (FFTComplex*)p, (FFTComplex*)p, cufft_direction)); } else { - AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan, (FFTComplex*)p, (FFTComplex*)p, + AMREX_CUFFT_SAFE_CALL(cufftExecZ2Z(plan.plan, (FFTComplex*)p, (FFTComplex*)p, cufft_direction)); } #elif defined(AMREX_USE_HIP) - detail::hip_execute(plan, (void**)&p, (void**)&p); + detail::hip_execute(plan.plan, (void**)&p, (void**)&p); #elif defined(AMREX_USE_SYCL) - detail::sycl_execute(plan, (std::complex*)p, (std::complex*)p); + detail::sycl_execute(plan.plan, (std::complex*)p, (std::complex*)p); #else if constexpr (std::is_same_v) { - fftwf_execute(plan); + fftwf_execute(plan.plan); } else { - fftw_execute(plan); + fftw_execute(plan.plan); } #endif } @@ -501,10 +561,14 @@ void R2C::forward_doit (MF const& inmf, Scaling /*scaling*/) m_rx.ParallelCopy(inmf, 0, 0, 1); exec_r2c(m_fft_fwd_x, m_rx, m_cx); - ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y); + if ( m_cmd_x2y) { + ParallelCopy(m_cy, m_cx, *m_cmd_x2y, 0, 0, 1, m_dtos_x2y); + } exec_c2c(m_fft_fwd_y, m_cy); - ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z); + if ( m_cmd_y2z) { + ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z); + } exec_c2c(m_fft_fwd_z, m_cz); } @@ -512,86 +576,95 @@ template void R2C::backward_doit (MF& outmf, Scaling /*scaling*/) { exec_c2c(m_fft_bwd_z, m_cz); - ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y); + if ( m_cmd_z2y) { + ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y); + } exec_c2c(m_fft_bwd_y, m_cy); - ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x); + if ( m_cmd_y2x) { + ParallelCopy(m_cx, m_cy, *m_cmd_y2x, 0, 0, 1, m_dtos_y2x); + } exec_c2r(m_fft_bwd_x, m_cx, m_rx); outmf.ParallelCopy(m_rx, 0, 0, 1); } template -std::pair::FFTPlan2, typename R2C::FFTPlan2> +std::pair::Plan2, typename R2C::Plan2> R2C::make_c2c_plans (cMF& inout) { - auto& fab = inout[ParallelDescriptor::MyProc()]; - Box const& local_box = fab.box(); + Plan2 fwd; + Plan2 bwd; - int n = local_box.length(0); - int howmany = local_box.length(1) * local_box.length(2); + auto* fab = get_fab(inout); + if (!fab) { return {fwd, bwd};} - FFTPlan2 fwd; - FFTPlan2 bwd; + Box const& local_box = fab->box(); + + int n = local_box.length(0); + int howmany = AMREX_D_TERM(1, *local_box.length(1), *local_box.length(2)); #if defined(AMREX_USE_CUDA) cufftType fwd_type = std::is_same_v ? CUFFT_C2C : CUFFT_Z2Z; cufftType bwd_type = std::is_same_v ? CUFFT_C2C : CUFFT_Z2Z; AMREX_CUFFT_SAFE_CALL - (cufftPlanMany(&fwd, 1, &n, nullptr, 1, n, nullptr, 1, n, fwd_type, howmany)); - AMREX_CUFFT_SAFE_CALL(cufftSetStream(fwd, Gpu::gpuStream())); + (cufftPlanMany(&fwd.plan, 1, &n, nullptr, 1, n, nullptr, 1, n, fwd_type, howmany)); + AMREX_CUFFT_SAFE_CALL(cufftSetStream(fwd.plan, Gpu::gpuStream())); AMREX_CUFFT_SAFE_CALL - (cufftPlanMany(&bwd, 1, &n, nullptr, 1, n, nullptr, 1, n, bwd_type, howmany)); - AMREX_CUFFT_SAFE_CALL(cufftSetStream(bwd, Gpu::gpuStream())); + (cufftPlanMany(&bwd.plan, 1, &n, nullptr, 1, n, nullptr, 1, n, bwd_type, howmany)); + AMREX_CUFFT_SAFE_CALL(cufftSetStream(bwd.plan, Gpu::gpuStream())); #elif defined(AMREX_USE_HIP) auto prec = std::is_same_v ? rocfft_precision_single : rocfft_precision_double; const std::size_t length = n; AMREX_ROCFFT_SAFE_CALL - (rocfft_plan_create(&fwd, rocfft_placement_inplace, + (rocfft_plan_create(&fwd.plan, rocfft_placement_inplace, rocfft_transform_type_complex_forward, prec, 1, &length, howmany, nullptr)); AMREX_ROCFFT_SAFE_CALL - (rocfft_plan_create(&bwd, rocfft_placement_inplace, + (rocfft_plan_create(&bwd.plan, rocfft_placement_inplace, rocfft_transform_type_complex_inverse, prec, 1, &length, howmany, nullptr)); #elif defined(AMREX_USE_SYCL) - fwd = new std::remove_pointer_t(n); - fwd->set_value(oneapi::mkl::dft::config_param::PLACEMENT, - DFTI_INPLACE); - fwd->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, - howmany); - fwd->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n); - fwd->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n); - std::array strides{0,1}; - fwd->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data()); - fwd->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data()); - fwd->set_value(oneapi::mkl::dft::config_param::WORKSPACE, - oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); - fwd->commit(amrex::Gpu::Device::streamQueue()); - - bwd = fwd; + fwd.plan = new std::remove_pointer_t(n); + fwd.plan->set_value(oneapi::mkl::dft::config_param::PLACEMENT, + DFTI_INPLACE); + fwd.plan->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, + howmany); + fwd.plan->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n); + fwd.plan->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, n); + std::array strides{0,1}; + fwd.plan->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides.data()); + fwd.plan->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides.data()); + fwd.plan->set_value(oneapi::mkl::dft::config_param::WORKSPACE, + oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); + fwd.plan->commit(amrex::Gpu::Device::streamQueue()); + + bwd.plan = fwd.plan; #else - auto* pinout = (FFTComplex*)fab.dataPtr(); + auto* pinout = (FFTComplex*)fab->dataPtr(); if constexpr (std::is_same_v) { - fwd = fftwf_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, - pinout, nullptr, 1, n, -1, FFTW_ESTIMATE); - bwd = fftwf_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, - pinout, nullptr, 1, n, +1, FFTW_ESTIMATE); + fwd.plan = fftwf_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, + pinout, nullptr, 1, n, -1, FFTW_ESTIMATE); + bwd.plan = fftwf_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, + pinout, nullptr, 1, n, +1, FFTW_ESTIMATE); } else { - fwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, - pinout, nullptr, 1, n, -1, FFTW_ESTIMATE); - bwd = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, - pinout, nullptr, 1, n, +1, FFTW_ESTIMATE); + fwd.plan = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, + pinout, nullptr, 1, n, -1, FFTW_ESTIMATE); + bwd.plan = fftw_plan_many_dft(1, &n, howmany, pinout, nullptr, 1, n, + pinout, nullptr, 1, n, +1, FFTW_ESTIMATE); } #endif + fwd.defined = true; + bwd.defined = true; + return {fwd,bwd}; } @@ -599,12 +672,41 @@ template template void R2C::post_forward_doit (F const& post_forward) { - auto& spectral_fab = m_cz[ParallelDescriptor::MyProc()]; - auto const& a = spectral_fab.array(); // m_cz's ordering is z,x,y - ParallelFor(spectral_fab.box(), [=] AMREX_GPU_DEVICE (int iz, int jx, int ky) - { - post_forward(jx,ky,iz,a(iz,jx,ky)); - }); + if (m_info.batch_mode) { + amrex::Abort("xxxxx todo: post_forward"); + } else { + if ( ! m_cz.empty()) { + auto* spectral_fab = get_fab(m_cz); + if (spectral_fab) { + auto const& a = spectral_fab->array(); // m_cz's ordering is z,x,y + ParallelFor(spectral_fab->box(), + [=] AMREX_GPU_DEVICE (int iz, int jx, int ky) + { + post_forward(jx,ky,iz,a(iz,jx,ky)); + }); + } + } else if ( ! m_cy.empty()) { + auto* spectral_fab = get_fab(m_cy); + if (spectral_fab) { + auto const& a = spectral_fab->array(); // m_cy's ordering is y,x,z + ParallelFor(spectral_fab->box(), + [=] AMREX_GPU_DEVICE (int iy, int jx, int k) + { + post_forward(jx,iy,k,a(iy,jx,k)); + }); + } + } else { + auto* spectral_fab = get_fab(m_cx); + if (spectral_fab) { + auto const& a = spectral_fab->array(); + ParallelFor(spectral_fab->box(), + [=] AMREX_GPU_DEVICE (int i, int j, int k) + { + post_forward(i,j,k,a(i,j,k)); + }); + } + } + } } }