From e2f4638e24ad32e2fb24e75c68eac43d89fd47ff Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Sun, 25 Aug 2024 20:34:07 +0200 Subject: [PATCH] Fix `Kokkos::parallel_for_pushforward` signature for non-void work tag --- include/clad/Differentiator/KokkosBuiltins.h | 64 +++++++++++--------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/include/clad/Differentiator/KokkosBuiltins.h b/include/clad/Differentiator/KokkosBuiltins.h index ef479a16d..14af4dfa0 100644 --- a/include/clad/Differentiator/KokkosBuiltins.h +++ b/include/clad/Differentiator/KokkosBuiltins.h @@ -168,11 +168,11 @@ template struct diff_parallel_for_MDP_call_dispatch { static void run(const ::std::string& str, const Policy& policy, const FunctorType& functor, const FunctorType& d_functor) { - ::Kokkos::parallel_for("_diff_" + str, policy, - [&functor, &d_functor](const T x, auto&&... args) { - functor.operator_call_pushforward( - x, args..., &d_functor, &x, 0, 0); - }); + ::Kokkos::parallel_for( + "_diff_" + str, policy, + [&functor, &d_functor](const auto x, auto&&... args) { + functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0); + }); } }; template @@ -189,11 +189,12 @@ template struct diff_parallel_for_MDP_call_dispatch { static void run(const ::std::string& str, const Policy& policy, const FunctorType& functor, const FunctorType& d_functor) { - ::Kokkos::parallel_for("_diff_" + str, policy, - [&functor, &d_functor](const T x, auto&&... args) { - functor.operator_call_pushforward( - x, args..., &d_functor, &x, 0, 0, 0); - }); + ::Kokkos::parallel_for( + "_diff_" + str, policy, + [&functor, &d_functor](const auto x, auto&&... args) { + functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0, + 0); + }); } }; template @@ -210,11 +211,12 @@ template struct diff_parallel_for_MDP_call_dispatch { static void run(const ::std::string& str, const Policy& policy, const FunctorType& functor, const FunctorType& d_functor) { - ::Kokkos::parallel_for("_diff_" + str, policy, - [&functor, &d_functor](const T x, auto&&... args) { - functor.operator_call_pushforward( - x, args..., &d_functor, &x, 0, 0, 0, 0); - }); + ::Kokkos::parallel_for( + "_diff_" + str, policy, + [&functor, &d_functor](const auto x, auto&&... args) { + functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0, 0, + 0); + }); } }; template @@ -231,11 +233,12 @@ template struct diff_parallel_for_MDP_call_dispatch { static void run(const ::std::string& str, const Policy& policy, const FunctorType& functor, const FunctorType& d_functor) { - ::Kokkos::parallel_for("_diff_" + str, policy, - [&functor, &d_functor](const T x, auto&&... args) { - functor.operator_call_pushforward( - x, args..., &d_functor, &x, 0, 0, 0, 0, 0); - }); + ::Kokkos::parallel_for( + "_diff_" + str, policy, + [&functor, &d_functor](const auto x, auto&&... args) { + functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0, 0, + 0, 0); + }); } }; template @@ -252,11 +255,12 @@ template struct diff_parallel_for_MDP_call_dispatch { static void run(const ::std::string& str, const Policy& policy, const FunctorType& functor, const FunctorType& d_functor) { - ::Kokkos::parallel_for("_diff_" + str, policy, - [&functor, &d_functor](const T x, auto&&... args) { - functor.operator_call_pushforward( - x, args..., &d_functor, &x, 0, 0, 0, 0, 0, 0); - }); + ::Kokkos::parallel_for( + "_diff_" + str, policy, + [&functor, &d_functor](const auto x, auto&&... args) { + functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0, 0, + 0, 0, 0); + }); } }; template @@ -294,11 +298,11 @@ template struct diff_parallel_for_OP_call_dispatch { static void run(const ::std::string& str, const Policy& policy, const FunctorType& functor, const FunctorType& d_functor) { - ::Kokkos::parallel_for("_diff_" + str, policy, - [&functor, &d_functor](const T x, auto&&... args) { - functor.operator_call_pushforward( - x, args..., &d_functor, &x, {}); - }); + ::Kokkos::parallel_for( + "_diff_" + str, policy, + [&functor, &d_functor](const auto x, auto&&... args) { + functor.operator_call_pushforward(x, args..., &d_functor, {}, {}); + }); } }; template