diff --git a/Src/Base/AMReX_GpuLaunch.H b/Src/Base/AMReX_GpuLaunch.H index 435a11f342..c8dbc53950 100644 --- a/Src/Base/AMReX_GpuLaunch.H +++ b/Src/Base/AMReX_GpuLaunch.H @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -176,6 +177,31 @@ namespace Gpu { { return makeExecutionConfig(box.numPts()); } + + template + Vector> makeNExecutionConfigs (Long N) noexcept + { + Long numblocks_max = std::numeric_limits::max(); // Max # of blocks in a kernel launch + Long nmax = Long(MT) * numblocks_max; // Max # of threads in a kernel launch + auto nlaunches = (N+nmax-1)/nmax; // # of launches needed for N elements + Vector> r(nlaunches); + for (int i = 0; i < nlaunches; ++i) { + Long nblocks; + if (N <= nmax) { + nblocks = (N+MT-1) / MT; + } else { + nblocks = numblocks_max; + } + r[i].first = nblocks * MT; // Total # of threads in this launch + r[i].second = int(nblocks); // # of blocks in this launch + } + } + + template + Vector> makeNExecutionConfigs (Box const& box) noexcept + { + return makeNExecutionConfigs(box.numPts()); + } #endif } diff --git a/Src/Base/AMReX_GpuLaunchFunctsG.H b/Src/Base/AMReX_GpuLaunchFunctsG.H index 7955410f8b..3409576a25 100644 --- a/Src/Base/AMReX_GpuLaunchFunctsG.H +++ b/Src/Base/AMReX_GpuLaunchFunctsG.H @@ -766,16 +766,20 @@ std::enable_if_t::value> ParallelFor (Gpu::KernelInfo const&, T n, L const& f) noexcept { if (amrex::isEmpty(n)) { return; } - const auto ec = Gpu::makeExecutionConfig(n); - AMREX_LAUNCH_KERNEL(MT, ec.numBlocks, ec.numThreads, 0, Gpu::gpuStream(), - [=] AMREX_GPU_DEVICE () noexcept { - for (Long i = Long(blockDim.x)*blockIdx.x+threadIdx.x, stride = Long(blockDim.x)*gridDim.x; - i < Long(n); i += stride) { - detail::call_f_scalar_handler(f, T(i), - Gpu::Handler(amrex::min((std::uint64_t(n)-i+(std::uint64_t)threadIdx.x), - (std::uint64_t)blockDim.x))); - } - }); + const auto& nec = Gpu::makeNExecutionConfigs(n); + Long ndone = 0; + for (auto const& ec : nec) { + AMREX_LAUNCH_KERNEL(MT, ec.second, MT, 0, Gpu::gpuStream(), + [=] AMREX_GPU_DEVICE () noexcept { + auto i = Long(blockDim.x)*blockIdx.x+threadIdx.x + ndone; + if (i < Long(n)) { + detail::call_f_scalar_handler(f, T(i), + Gpu::Handler(amrex::min((std::uint64_t(n)-i+(std::uint64_t)threadIdx.x), + (std::uint64_t)blockDim.x))); + } + }); + ndone += ec.first; + } AMREX_GPU_ERROR_CHECK(); }