Skip to content

Commit

Permalink
Fix Kokkos::parallel_for_pushforward signature for non-void work tag
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch authored and vgvassilev committed Aug 27, 2024
1 parent dde2f08 commit e2f4638
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 2> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0);
});
}
};
template <class Policy, class FunctorType>
Expand All @@ -189,11 +189,12 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 3> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0,
0);
});
}
};
template <class Policy, class FunctorType>
Expand All @@ -210,11 +211,12 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 4> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0, 0,
0);
});
}
};
template <class Policy, class FunctorType>
Expand All @@ -231,11 +233,12 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 5> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0, 0,
0, 0);
});
}
};
template <class Policy, class FunctorType>
Expand All @@ -252,11 +255,12 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 6> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0, 0,
0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
Expand Down Expand Up @@ -294,11 +298,11 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_OP_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, {});
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, {});
});
}
};
template <class Policy, class FunctorType>
Expand Down

0 comments on commit e2f4638

Please sign in to comment.