Skip to content

Commit

Permalink
call_f refactor (#3452)
Browse files Browse the repository at this point in the history
Using `if constexpr` in these functions allows us to substantially
reduce the number of overloads needed.

The proposed changes:
- [ ] fix a bug or incorrect behavior in AMReX
- [ ] add new capabilities to AMReX
- [ ] changes answers in the test suite to more than roundoff level
- [ ] are likely to significantly affect the results of downstream AMReX
users
- [ ] include documentation in the code and/or rst files, if appropriate
  • Loading branch information
atmyers authored Jul 27, 2023
1 parent 49dd703 commit 34c0ae3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 191 deletions.
18 changes: 6 additions & 12 deletions Src/Particle/AMReX_DenseBins.H
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,12 @@ private:

template <typename F, typename I>
AMREX_GPU_HOST_DEVICE
static auto call_f(F const& f, const_pointer_input_type v, I& index)
noexcept -> decltype(f(v,index))
{
return f(v,index);
}

template <typename F, typename I>
AMREX_GPU_HOST_DEVICE
static auto call_f(F const& f, const_pointer_input_type v, I& index)
noexcept -> decltype(f(v[index]))
{
return f(v[index]);
static auto call_f (F const& f, const_pointer_input_type v, I& index) {
if constexpr (IsCallable<F, decltype(v), I>::value) {
return f(v, index);
} else {
return f(v[index]);
}
}

public:
Expand Down
27 changes: 26 additions & 1 deletion Src/Particle/AMReX_ParticleMesh.H
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,33 @@
#include <AMReX_MultiFab.H>
#include <AMReX_ParticleUtil.H>

namespace amrex
namespace amrex {

namespace particle_detail {

template <typename F, typename T, typename T_ParticleType, template<class, int, int> class PTDType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const PTDType<T_ParticleType, NAR, NAI>& p,
const int i, Array4<T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const& plo,
GpuArray<Real,AMREX_SPACEDIM> const& dxi) noexcept
{
if constexpr (IsCallable<F, decltype(p.m_aos[i]), decltype(fabarr), decltype(plo), decltype(dxi)>::value) {
return f(p.m_aos[i], fabarr, plo, dxi);
} else if constexpr (IsCallable<F, decltype(p.m_aos[i]), decltype(fabarr)>::value) {
return f(p.m_aos[i], fabarr);
} else if constexpr (IsCallable<F, decltype(p.getSuperParticle(i)), decltype(fabarr), decltype(plo), decltype(dxi)>::value) {
return f(p.getSuperParticle(i), fabarr, plo, dxi);
} else if constexpr (IsCallable<F, decltype(p.getSuperParticle(i)), decltype(fabarr)>::value) {
return f(p.getSuperParticle(i), fabarr);
} else if constexpr (IsCallable<F, decltype(p), decltype(fabarr), decltype(plo), decltype(dxi)>::value) {
return f(p, i, fabarr, plo, dxi);
} else {
return f(p, i, fabarr);
}
}
}

template <class PC, class MF, class F, std::enable_if_t<IsParticleContainer<PC>::value, int> foo = 0>
void
Expand Down
19 changes: 18 additions & 1 deletion Src/Particle/AMReX_ParticleReduce.H
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,25 @@

#include <limits>

namespace amrex
namespace amrex {

namespace particle_detail {

template <typename F, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ConstParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i) noexcept
{
if constexpr (IsCallable<F, decltype(p.m_aos[i])>::value) {
return f(p.m_aos[i]);
} else if constexpr (IsCallable<F, decltype(p.getSuperParticle(i))>::value) {
return f(p.getSuperParticle(i));
} else {
return f(p, i);
}
}
}

/**
* \brief A general reduction method for the particles in a ParticleContainer that can run on either CPUs or GPUs.
Expand Down
178 changes: 1 addition & 177 deletions Src/Particle/AMReX_ParticleUtil.H
Original file line number Diff line number Diff line change
Expand Up @@ -17,183 +17,7 @@

#include <limits>

namespace amrex
{

namespace particle_detail {

// The next several functions are used by ParticleReduce

// Lambda takes a Particle
template <typename F, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ConstParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i) noexcept
-> decltype(f(p.m_aos[i]))
{
return f(p.m_aos[i]);
}

// Lambda takes a SuperParticle
template <typename F, typename T_ParticleType, int NAR, int NAI,
typename std::enable_if<NAR != 0 || NAI != 0, int>::type = 0>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ConstParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i) noexcept
-> decltype(f(p.getSuperParticle(i)))
{
return f(p.getSuperParticle(i));
}

// Lambda takes a ConstParticleTileData
template <typename F, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ConstParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i) noexcept
-> decltype(f(p, i))
{
return f(p, i);
}

// These next several functions are used by ParticleToMesh and MeshToParticle

// Lambda takes a Particle
template <typename F, typename T, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ConstParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i, Array4<T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const& plo,
GpuArray<Real,AMREX_SPACEDIM> const& dxi) noexcept
-> decltype(f(p.m_aos[i], fabarr, plo, dxi))
{
return f(p.m_aos[i], fabarr, plo, dxi);
}

// Lambda takes a Particle
template <typename F, typename T, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ConstParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i, Array4<T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const&,
GpuArray<Real,AMREX_SPACEDIM> const&) noexcept
-> decltype(f(p.m_aos[i], fabarr))
{
return f(p.m_aos[i], fabarr);
}

// Lambda takes a Particle
template <typename F, typename T, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i, Array4<const T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const& plo,
GpuArray<Real,AMREX_SPACEDIM> const& dxi) noexcept
-> decltype(f(p.m_aos[i], fabarr, plo, dxi))
{
return f(p.m_aos[i], fabarr, plo, dxi);
}

// Lambda takes a Particle
template <typename F, typename T, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i, Array4<const T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const&,
GpuArray<Real,AMREX_SPACEDIM> const&) noexcept
-> decltype(f(p.m_aos[i], fabarr))
{
return f(p.m_aos[i], fabarr);
}

// Lambda takes a SuperParticle
template <typename F, typename T, typename T_ParticleType, int NAR, int NAI,
typename std::enable_if<(NAR != 0) || (NAI != 0), int>::type = 0>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ConstParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i, Array4<T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const& plo,
GpuArray<Real,AMREX_SPACEDIM> const& dxi) noexcept
-> decltype(f(p.getSuperParticle(i), fabarr, plo, dxi))
{
return f(p.getSuperParticle(i), fabarr, plo, dxi);
}

// Lambda takes a SuperParticle
template <typename F, typename T, typename T_ParticleType, int NAR, int NAI,
typename std::enable_if<(NAR != 0) || (NAI != 0), int>::type = 0>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ConstParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i, Array4<T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const&,
GpuArray<Real,AMREX_SPACEDIM> const&) noexcept
-> decltype(f(p.getSuperParticle(i), fabarr))
{
return f(p.getSuperParticle(i), fabarr);
}

// Lambda takes a ConstParticleTileData
template <typename F, typename T, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ConstParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i, Array4<T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const&,
GpuArray<Real,AMREX_SPACEDIM> const&) noexcept
-> decltype(f(p, i, fabarr))
{
return f(p, i, fabarr);
}

// Lambda takes a ConstParticleTileData
template <typename F, typename T, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ConstParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i, Array4<T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const& plo,
GpuArray<Real,AMREX_SPACEDIM> const& dxi) noexcept
-> decltype(f(p, i, fabarr, plo, dxi))
{
return f(p, i, fabarr, plo, dxi);
}

// Lambda takes a ParticleTileData
template <typename F, typename T, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i, Array4<const T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const& plo,
GpuArray<Real,AMREX_SPACEDIM> const& dxi) noexcept
-> decltype(f(p, i, fabarr, plo, dxi))
{
return f(p, i, fabarr, plo, dxi);
}

// Lambda takes a ParticleTileData
template <typename F, typename T, typename T_ParticleType, int NAR, int NAI>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f,
const ParticleTileData<T_ParticleType, NAR, NAI>& p,
const int i, Array4<const T> const& fabarr,
GpuArray<Real,AMREX_SPACEDIM> const&,
GpuArray<Real,AMREX_SPACEDIM> const&) noexcept
-> decltype(f(p, i, fabarr))
{
return f(p, i, fabarr);
}


}
namespace amrex {

/**
* \brief Returns the number of particles that are more than nGrow cells
Expand Down

0 comments on commit 34c0ae3

Please sign in to comment.