Skip to content

Commit

Permalink
add precondition to SharedAllocationDisableTrackingGuard and change w…
Browse files Browse the repository at this point in the history
…ith_shared_allocation_tracking_disabled to a factory function instead
  • Loading branch information
nmm0 committed Jan 24, 2024
1 parent d778994 commit 7b3f79f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 35 deletions.
32 changes: 15 additions & 17 deletions core/src/Kokkos_Parallel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ inline void parallel_for(const std::string& str, const ExecPolicy& policy,
ExecPolicy inner_policy = policy;
Kokkos::Tools::Impl::begin_parallel_for(inner_policy, functor, str, kpID);

auto closure = Kokkos::Impl::with_shared_allocation_tracking_disabled([&]() {
return Impl::ParallelFor<FunctorType, ExecPolicy>(functor, inner_policy);
});
auto closure =
Kokkos::Impl::construct_with_shared_allocation_tracking_disabled<
Impl::ParallelFor<FunctorType, ExecPolicy>>(functor, inner_policy);

closure.execute();

Expand Down Expand Up @@ -352,10 +352,10 @@ inline void parallel_scan(const std::string& str, const ExecutionPolicy& policy,
ExecutionPolicy inner_policy = policy;
Kokkos::Tools::Impl::begin_parallel_scan(inner_policy, functor, str, kpID);

auto closure = Kokkos::Impl::with_shared_allocation_tracking_disabled([&]() {
return Impl::ParallelScan<FunctorType, ExecutionPolicy>(functor,
auto closure =
Kokkos::Impl::construct_with_shared_allocation_tracking_disabled<
Impl::ParallelScan<FunctorType, ExecutionPolicy>>(functor,
inner_policy);
});

closure.execute();

Expand Down Expand Up @@ -399,20 +399,18 @@ inline void parallel_scan(const std::string& str, const ExecutionPolicy& policy,

if constexpr (Kokkos::is_view<ReturnType>::value) {
auto closure =
Kokkos::Impl::with_shared_allocation_tracking_disabled([&]() {
return Impl::ParallelScanWithTotal<FunctorType, ExecutionPolicy,
typename ReturnType::value_type>(
functor, inner_policy, return_value);
});
Kokkos::Impl::construct_with_shared_allocation_tracking_disabled<
Impl::ParallelScanWithTotal<FunctorType, ExecutionPolicy,
typename ReturnType::value_type>>(
functor, inner_policy, return_value);
closure.execute();
} else {
Kokkos::View<ReturnType, Kokkos::HostSpace> view(&return_value);
auto closure =
Kokkos::Impl::with_shared_allocation_tracking_disabled([&]() {
Kokkos::View<ReturnType, Kokkos::HostSpace> view(&return_value);
return Impl::ParallelScanWithTotal<FunctorType, ExecutionPolicy,
ReturnType>(functor, inner_policy,
view);
});
Kokkos::Impl::construct_with_shared_allocation_tracking_disabled<
Impl::ParallelScanWithTotal<FunctorType, ExecutionPolicy,
ReturnType>>(functor, inner_policy,
view);
closure.execute();
}

Expand Down
24 changes: 9 additions & 15 deletions core/src/Kokkos_Parallel_Reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1503,21 +1503,15 @@ struct ParallelReduceAdaptor {
PolicyType, typename ReducerSelector::type,
typename return_value_adapter::value_type>;

auto closure =
Kokkos::Impl::with_shared_allocation_tracking_disabled([&]() {
CombinedFunctorReducer functor_reducer(
functor, typename Analysis::Reducer(
ReducerSelector::select(functor, return_value)));

// FIXME Remove "Wrapper" once all backends implement the new
// interface
return Impl::ParallelReduce<
decltype(functor_reducer), PolicyType,
typename Impl::FunctorPolicyExecutionSpace<
FunctorType, PolicyType>::execution_space>(
functor_reducer, inner_policy,
return_value_adapter::return_value(return_value, functor));
});
CombinedFunctorReducer functor_reducer(
functor, typename Analysis::Reducer(
ReducerSelector::select(functor, return_value)));
auto closure = construct_with_shared_allocation_tracking_disabled<
Impl::ParallelReduce<decltype(functor_reducer), PolicyType,
typename Impl::FunctorPolicyExecutionSpace<
FunctorType, PolicyType>::execution_space>>(
functor_reducer, inner_policy,
return_value_adapter::return_value(return_value, functor));
closure.execute();

Kokkos::Tools::Impl::end_parallel_reduce<PassedReducerType>(
Expand Down
8 changes: 5 additions & 3 deletions core/src/Kokkos_View.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ static_assert(false,
#include <algorithm>
#include <initializer_list>
#include <functional>
#include <cassert>

#include <Kokkos_Core_fwd.hpp>
#include <Kokkos_HostSpace.hpp>
Expand Down Expand Up @@ -1880,6 +1881,7 @@ namespace Kokkos {
namespace Impl {
struct SharedAllocationDisableTrackingGuard {
SharedAllocationDisableTrackingGuard() {
assert( ( Kokkos::Impl::SharedAllocationRecord< void, void >::tracking_enabled() ) );
Kokkos::Impl::SharedAllocationRecord<void, void>::tracking_disable();
}

Expand All @@ -1888,10 +1890,10 @@ struct SharedAllocationDisableTrackingGuard {
}
};

template <class F>
inline decltype(auto) with_shared_allocation_tracking_disabled(F&& fun) {
template <class FunctorType, class... Args>
inline FunctorType construct_with_shared_allocation_tracking_disabled(Args&&... args) {
[[maybe_unused]] auto guard = SharedAllocationDisableTrackingGuard{};
return std::invoke(std::forward<F>(fun));
return {std::forward<Args>(args)...};
}

} /* namespace Impl */
Expand Down

0 comments on commit 7b3f79f

Please sign in to comment.