Skip to content

Commit

Permalink
Use constexpr if instead of the call_f way
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Jul 26, 2023
1 parent a3b69fa commit 704098e
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 191 deletions.
86 changes: 30 additions & 56 deletions Src/Base/AMReX_GpuLaunchFunctsC.H
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,6 @@

namespace amrex {

namespace detail {
template <typename F, typename N>
AMREX_GPU_DEVICE
auto call_f (F const& f, N i)
noexcept -> decltype(f(0))
{
f(i);
}

template <typename F, typename N>
AMREX_GPU_DEVICE
auto call_f (F const& f, N i)
noexcept -> decltype(f(0,Gpu::Handler{}))
{
f(i,Gpu::Handler{});
}

template <typename F>
AMREX_GPU_DEVICE
auto call_f (F const& f, int i, int j, int k)
noexcept -> decltype(f(0,0,0))
{
f(i,j,k);
}

template <typename F>
AMREX_GPU_DEVICE
auto call_f (F const& f, int i, int j, int k)
noexcept -> decltype(f(0,0,0,Gpu::Handler{}))
{
f(i,j,k,Gpu::Handler{});
}

template <typename F, typename T>
AMREX_GPU_DEVICE
auto call_f (F const& f, int i, int j, int k, T n)
noexcept -> decltype(f(0,0,0,0))
{
f(i,j,k,n);
}

template <typename F, typename T>
AMREX_GPU_DEVICE
auto call_f (F const& f, int i, int j, int k, T n)
noexcept -> decltype(f(0,0,0,0,Gpu::Handler{}))
{
f(i,j,k,n,Gpu::Handler{});
}
}

template<typename T, typename L>
void launch (T const& n, L&& f) noexcept
{
Expand All @@ -71,7 +21,11 @@ template <typename T, typename L, typename M=std::enable_if_t<std::is_integral<T
void For (T n, L&& f) noexcept
{
for (T i = 0; i < n; ++i) {
detail::call_f(f,i);
if constexpr (IsCallable<L,T,Gpu::Handler>::value) {
f(i,Gpu::Handler{});
} else {
f(i);
}
}
}

Expand Down Expand Up @@ -100,7 +54,11 @@ void ParallelFor (T n, L&& f) noexcept
{
AMREX_PRAGMA_SIMD
for (T i = 0; i < n; ++i) {
detail::call_f(f,i);
if constexpr (IsCallable<L,T,Gpu::Handler>::value) {
f(i,Gpu::Handler{});
} else {
f(i);
}
}
}

Expand Down Expand Up @@ -132,7 +90,11 @@ void For (Box const& box, L&& f) noexcept
for (int k = lo.z; k <= hi.z; ++k) {
for (int j = lo.y; j <= hi.y; ++j) {
for (int i = lo.x; i <= hi.x; ++i) {
detail::call_f(f,i,j,k);
if constexpr (IsCallable<L,int,int,int,Gpu::Handler>::value) {
f(i,j,k,Gpu::Handler{});
} else {
f(i,j,k);
}
}}}
}

Expand Down Expand Up @@ -165,7 +127,11 @@ void ParallelFor (Box const& box, L&& f) noexcept
for (int j = lo.y; j <= hi.y; ++j) {
AMREX_PRAGMA_SIMD
for (int i = lo.x; i <= hi.x; ++i) {
detail::call_f(f,i,j,k);
if constexpr (IsCallable<L,int,int,int,Gpu::Handler>::value) {
f(i,j,k,Gpu::Handler{});
} else {
f(i,j,k);
}
}}}
}

Expand Down Expand Up @@ -198,7 +164,11 @@ void For (Box const& box, T ncomp, L&& f) noexcept
for (int k = lo.z; k <= hi.z; ++k) {
for (int j = lo.y; j <= hi.y; ++j) {
for (int i = lo.x; i <= hi.x; ++i) {
detail::call_f(f,i,j,k,n);
if constexpr (IsCallable<L,int,int,int,int,Gpu::Handler>::value) {
f(i,j,k,n,Gpu::Handler{});
} else {
f(i,j,k,n);
}
}}}
}
}
Expand Down Expand Up @@ -233,7 +203,11 @@ void ParallelFor (Box const& box, T ncomp, L&& f) noexcept
for (int j = lo.y; j <= hi.y; ++j) {
AMREX_PRAGMA_SIMD
for (int i = lo.x; i <= hi.x; ++i) {
detail::call_f(f,i,j,k,n);
if constexpr (IsCallable<L,int,int,int,int,Gpu::Handler>::value) {
f(i,j,k,n,Gpu::Handler{});
} else {
f(i,j,k,n);
}
}}}
}
}
Expand Down
Loading

0 comments on commit 704098e

Please sign in to comment.