From 7b3f79ffe8e2003c77ac5a30ad747d9d90b90eb9 Mon Sep 17 00:00:00 2001 From: Nicolas Morales Date: Wed, 24 Jan 2024 09:59:23 -0800 Subject: [PATCH] add precondition to SharedAllocationDisableTrackingGuard and change with_shared_allocation_tracking_disabled to a factory function instead --- core/src/Kokkos_Parallel.hpp | 32 ++++++++++++++--------------- core/src/Kokkos_Parallel_Reduce.hpp | 24 ++++++++-------------- core/src/Kokkos_View.hpp | 8 +++++--- 3 files changed, 29 insertions(+), 35 deletions(-) diff --git a/core/src/Kokkos_Parallel.hpp b/core/src/Kokkos_Parallel.hpp index fa7bfd6e833..122239df790 100644 --- a/core/src/Kokkos_Parallel.hpp +++ b/core/src/Kokkos_Parallel.hpp @@ -137,9 +137,9 @@ inline void parallel_for(const std::string& str, const ExecPolicy& policy, ExecPolicy inner_policy = policy; Kokkos::Tools::Impl::begin_parallel_for(inner_policy, functor, str, kpID); - auto closure = Kokkos::Impl::with_shared_allocation_tracking_disabled([&]() { - return Impl::ParallelFor(functor, inner_policy); - }); + auto closure = + Kokkos::Impl::construct_with_shared_allocation_tracking_disabled< + Impl::ParallelFor>(functor, inner_policy); closure.execute(); @@ -352,10 +352,10 @@ inline void parallel_scan(const std::string& str, const ExecutionPolicy& policy, ExecutionPolicy inner_policy = policy; Kokkos::Tools::Impl::begin_parallel_scan(inner_policy, functor, str, kpID); - auto closure = Kokkos::Impl::with_shared_allocation_tracking_disabled([&]() { - return Impl::ParallelScan(functor, + auto closure = + Kokkos::Impl::construct_with_shared_allocation_tracking_disabled< + Impl::ParallelScan>(functor, inner_policy); - }); closure.execute(); @@ -399,20 +399,18 @@ inline void parallel_scan(const std::string& str, const ExecutionPolicy& policy, if constexpr (Kokkos::is_view::value) { auto closure = - Kokkos::Impl::with_shared_allocation_tracking_disabled([&]() { - return Impl::ParallelScanWithTotal( - functor, inner_policy, return_value); - }); + Kokkos::Impl::construct_with_shared_allocation_tracking_disabled< + Impl::ParallelScanWithTotal>( + functor, inner_policy, return_value); closure.execute(); } else { + Kokkos::View view(&return_value); auto closure = - Kokkos::Impl::with_shared_allocation_tracking_disabled([&]() { - Kokkos::View view(&return_value); - return Impl::ParallelScanWithTotal(functor, inner_policy, - view); - }); + Kokkos::Impl::construct_with_shared_allocation_tracking_disabled< + Impl::ParallelScanWithTotal>(functor, inner_policy, + view); closure.execute(); } diff --git a/core/src/Kokkos_Parallel_Reduce.hpp b/core/src/Kokkos_Parallel_Reduce.hpp index e4075d7377c..df145a8e81d 100644 --- a/core/src/Kokkos_Parallel_Reduce.hpp +++ b/core/src/Kokkos_Parallel_Reduce.hpp @@ -1503,21 +1503,15 @@ struct ParallelReduceAdaptor { PolicyType, typename ReducerSelector::type, typename return_value_adapter::value_type>; - auto closure = - Kokkos::Impl::with_shared_allocation_tracking_disabled([&]() { - CombinedFunctorReducer functor_reducer( - functor, typename Analysis::Reducer( - ReducerSelector::select(functor, return_value))); - - // FIXME Remove "Wrapper" once all backends implement the new - // interface - return Impl::ParallelReduce< - decltype(functor_reducer), PolicyType, - typename Impl::FunctorPolicyExecutionSpace< - FunctorType, PolicyType>::execution_space>( - functor_reducer, inner_policy, - return_value_adapter::return_value(return_value, functor)); - }); + CombinedFunctorReducer functor_reducer( + functor, typename Analysis::Reducer( + ReducerSelector::select(functor, return_value))); + auto closure = construct_with_shared_allocation_tracking_disabled< + Impl::ParallelReduce::execution_space>>( + functor_reducer, inner_policy, + return_value_adapter::return_value(return_value, functor)); closure.execute(); Kokkos::Tools::Impl::end_parallel_reduce( diff --git a/core/src/Kokkos_View.hpp b/core/src/Kokkos_View.hpp index 672b1bf9f23..046ee68f216 100644 --- a/core/src/Kokkos_View.hpp +++ b/core/src/Kokkos_View.hpp @@ -27,6 +27,7 @@ static_assert(false, #include #include #include +#include #include #include @@ -1880,6 +1881,7 @@ namespace Kokkos { namespace Impl { struct SharedAllocationDisableTrackingGuard { SharedAllocationDisableTrackingGuard() { + assert( ( Kokkos::Impl::SharedAllocationRecord< void, void >::tracking_enabled() ) ); Kokkos::Impl::SharedAllocationRecord::tracking_disable(); } @@ -1888,10 +1890,10 @@ struct SharedAllocationDisableTrackingGuard { } }; -template -inline decltype(auto) with_shared_allocation_tracking_disabled(F&& fun) { +template +inline FunctorType construct_with_shared_allocation_tracking_disabled(Args&&... args) { [[maybe_unused]] auto guard = SharedAllocationDisableTrackingGuard{}; - return std::invoke(std::forward(fun)); + return {std::forward(args)...}; } } /* namespace Impl */