Skip to content

Commit

Permalink
Add support for Kokkos::parallel_for in the fwd mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Aug 10, 2024
1 parent 1b81084 commit fd3c361
Show file tree
Hide file tree
Showing 2 changed files with 360 additions and 9 deletions.
254 changes: 245 additions & 9 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define CLAD_DIFFERENTIATOR_KOKKOSBUILTINS_H

#include <Kokkos_Core.hpp>
#include <type_traits>
#include "clad/Differentiator/Differentiator.h"

namespace clad::custom_derivatives {
Expand All @@ -29,7 +30,7 @@ constructor_pushforward(
}
} // namespace class_functions

/// Kokkos functions
/// Kokkos functions (view utils)
namespace Kokkos {
template <typename View1, typename View2, typename T>
inline void deep_copy_pushforward(const View1& dst, const View2& src, T param,
Expand Down Expand Up @@ -66,15 +67,250 @@ inline void resize_pushforward(const I& arg, View& v, const size_t n0,
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}

template <class ExecPolicy, class FunctorType>
inline void
parallel_for_pushforward(const ::std::string& str, const ExecPolicy& policy,
const FunctorType& functor, const ::std::string& d_str,
const ExecPolicy& d_policy,
const FunctorType& d_functor) {
// TODO: implement parallel_for_pushforward
return;
/// Parallel for
template <class... PolicyParams, class FunctorType> // range policy
void parallel_for_pushforward(
const ::std::string& str,
const ::Kokkos::RangePolicy<PolicyParams...>& policy,
const FunctorType& functor, const ::std::string& /*d_str*/,
const ::Kokkos::RangePolicy<PolicyParams...>& /*d_policy*/,
const FunctorType& d_functor) {
::Kokkos::parallel_for(str, policy, functor);
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const int i) {
functor.operator_call_pushforward(i, &d_functor, 0);
});
}

// This structure is used to dispatch parallel for pushforward calls based on
// the rank and the work tag of the MDPolicy
template <class Policy, class FunctorType, class T, int Rank>
struct diff_parallel_for_MDP_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
assert(false && "Some parallel_for misuse happened during the compilation "
"(templates have not been matched properly).");
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 2> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 2> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 3> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 3> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 4> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 4> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 5> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 5> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 6> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 6> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(
args..., &d_functor, 0, 0, 0, 0, 0, 0);
});
}
};

template <class PolicyP, class... PolicyParams,
class FunctorType> // multi-dimensional policy
void parallel_for_pushforward(
const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, const ::std::string& /*d_str*/,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& /*d_policy*/,
const FunctorType& d_functor) {
::Kokkos::parallel_for(str, policy, functor);
diff_parallel_for_MDP_call_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType,
typename ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>::work_tag,
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>::rank>::run(str, policy,
functor,
d_functor);
}

// This structure is used to dispatch parallel for pushforward calls based on
// the work tag of other types of policies
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_OP_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, {});
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_OP_call_dispatch<Policy, FunctorType, void> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, {});
});
}
};

// This structure is used to dispatch parallel for pushforward calls for
// integral policies
template <class Policy, class FunctorType, bool IsInt>
struct diff_parallel_for_int_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
diff_parallel_for_OP_call_dispatch<
Policy, FunctorType, typename Policy::work_tag>::run(str, policy,
functor,
d_functor);
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_int_call_dispatch<Policy, FunctorType, true> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](const int i) {
functor.operator_call_pushforward(i, &d_functor, 0);
});
}
};

template <class Policy, class FunctorType> // other policy type
void parallel_for_pushforward(const ::std::string& str, const Policy& policy,
const FunctorType& functor,
const ::std::string& /*d_str*/,
const Policy& /*d_policy*/,
const FunctorType& d_functor) {
::Kokkos::parallel_for(str, policy, functor);
diff_parallel_for_int_call_dispatch<
Policy, FunctorType, ::std::is_integral<Policy>::value>::run(str, policy,
functor,
d_functor);
}

template <class Policy, class FunctorType> // anonymous loop
void parallel_for_pushforward(const Policy& policy, const FunctorType& functor,
const Policy& d_policy,
const FunctorType& d_functor) {
parallel_for_pushforward(::std::string("anonymous_parallel_for"), policy,
functor, ::std::string(""), d_policy, d_functor);
}

template <class Policy, class FunctorType> // anonymous loop
void parallel_for_pushforward(
const Policy& policy, const FunctorType& functor,
::std::enable_if_t<::Kokkos::is_execution_policy<Policy>::value>* /*param*/,
const Policy& d_policy, const FunctorType& d_functor,
::std::enable_if_t<
::Kokkos::is_execution_policy<Policy>::value>* /*d_param*/) {
parallel_for_pushforward(::std::string("anonymous_parallel_for"), policy,
functor, ::std::string(""), d_policy, d_functor);
}

template <class FunctorType> // alternative signature
void parallel_for_pushforward(const size_t work_count,
const FunctorType& functor,
const ::std::string& str,
const size_t /*_d_work_count*/,
const FunctorType& d_functor,
const ::std::string& /*_d_str*/) {
::Kokkos::parallel_for(work_count, functor, str);
::Kokkos::parallel_for(
work_count,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0);
},
"_diff_" + str);
}

} // namespace Kokkos
} // namespace clad::custom_derivatives

Expand Down
Loading

0 comments on commit fd3c361

Please sign in to comment.