Skip to content

Commit

Permalink
Fix Kokkos::parallel_for_pushforward signature for non-void work tag
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Aug 25, 2024
1 parent cc21b98 commit f59ec7d
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 2> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0);
});
}
};
template <class Policy, class FunctorType>
Expand All @@ -122,11 +122,12 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 3> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0,
0);
});
}
};
template <class Policy, class FunctorType>
Expand All @@ -143,11 +144,12 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 4> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0, 0,
0);
});
}
};
template <class Policy, class FunctorType>
Expand All @@ -164,11 +166,12 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 5> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0, 0,
0, 0);
});
}
};
template <class Policy, class FunctorType>
Expand All @@ -185,11 +188,12 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 6> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0, 0, 0);
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, 0, 0, 0,
0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
Expand Down Expand Up @@ -227,11 +231,11 @@ template <class Policy, class FunctorType, class T>
struct diff_parallel_for_OP_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, {});
});
::Kokkos::parallel_for(
"_diff_" + str, policy,
[&functor, &d_functor](const auto x, auto&&... args) {
functor.operator_call_pushforward(x, args..., &d_functor, {}, {});
});
}
};
template <class Policy, class FunctorType>
Expand Down

0 comments on commit f59ec7d

Please sign in to comment.