Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Kokkos::parallel_for in the fwd mode #1022

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 248 additions & 9 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define CLAD_DIFFERENTIATOR_KOKKOSBUILTINS_H

#include <Kokkos_Core.hpp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'Kokkos_Core.hpp' file not found [clang-diagnostic-error]

#include <Kokkos_Core.hpp>
         ^

#include <type_traits>
#include "clad/Differentiator/Differentiator.h"

namespace clad::custom_derivatives {
Expand All @@ -29,7 +30,7 @@ constructor_pushforward(
}
} // namespace class_functions

/// Kokkos functions
/// Kokkos functions (view utils)
namespace Kokkos {
template <typename View1, typename View2, typename T>
inline void deep_copy_pushforward(const View1& dst, const View2& src, T param,
Expand Down Expand Up @@ -66,15 +67,253 @@ inline void resize_pushforward(const I& arg, View& v, const size_t n0,
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}

template <class ExecPolicy, class FunctorType>
inline void
parallel_for_pushforward(const ::std::string& str, const ExecPolicy& policy,
const FunctorType& functor, const ::std::string& d_str,
const ExecPolicy& d_policy,
const FunctorType& d_functor) {
// TODO: implement parallel_for_pushforward
return;
/// Parallel for
template <class... PolicyParams, class FunctorType> // range policy
void parallel_for_pushforward(
const ::std::string& str,
const ::Kokkos::RangePolicy<PolicyParams...>& policy,
const FunctorType& functor, const ::std::string& /*d_str*/,
const ::Kokkos::RangePolicy<PolicyParams...>& /*d_policy*/,
const FunctorType& d_functor) {
::Kokkos::parallel_for(str, policy, functor);
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const int i) {
functor.operator_call_pushforward(i, &d_functor, 0);
});
}

// This structure is used to dispatch parallel for pushforward calls based on
// the rank and the work tag of the MDPolicy
template <class Policy, class FunctorType, class T, int Rank>
struct diff_parallel_for_MDP_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
assert(false && "Some parallel_for misuse happened during the compilation "
"(templates have not been matched properly).");
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 2> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 2> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 3> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 3> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 4> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 4> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 5> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 5> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 6> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 6> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(
args..., &d_functor, 0, 0, 0, 0, 0, 0);
});
}
};

template <class PolicyP, class... PolicyParams,
class FunctorType> // multi-dimensional policy
void parallel_for_pushforward(
const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, const ::std::string& /*d_str*/,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& /*d_policy*/,
const FunctorType& d_functor) {
::Kokkos::parallel_for(str, policy, functor);
diff_parallel_for_MDP_call_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType,
typename ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>::work_tag,
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>::rank>::run(str, policy,
functor,
d_functor);
}

// This structure is used to dispatch parallel for pushforward calls based on
// the work tag of other types of policies
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_OP_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, {});
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_OP_call_dispatch<Policy, FunctorType, void> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, {});
});
}
};

// This structure is used to dispatch parallel for pushforward calls for
// integral policies
template <class Policy, class FunctorType, bool IsInt>
struct diff_parallel_for_int_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
diff_parallel_for_OP_call_dispatch<
Policy, FunctorType, typename Policy::work_tag>::run(str, policy,
functor,
d_functor);
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_int_call_dispatch<Policy, FunctorType, true> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](const int i) {
functor.operator_call_pushforward(i, &d_functor, 0);
});
}
};

template <class Policy, class FunctorType> // other policy type
void parallel_for_pushforward(const ::std::string& str, const Policy& policy,
const FunctorType& functor,
const ::std::string& /*d_str*/,
const Policy& /*d_policy*/,
const FunctorType& d_functor) {
::Kokkos::parallel_for(str, policy, functor);
diff_parallel_for_int_call_dispatch<
Policy, FunctorType, ::std::is_integral<Policy>::value>::run(str, policy,
functor,
d_functor);
}

template <class Policy, class FunctorType> // anonymous loop
void parallel_for_pushforward(const Policy& policy, const FunctorType& functor,
const Policy& d_policy,
const FunctorType& d_functor) {
parallel_for_pushforward(::std::string("anonymous_parallel_for"), policy,
functor, ::std::string(""), d_policy, d_functor);
}

template <class Policy, class FunctorType> // anonymous loop
void parallel_for_pushforward(
const Policy& policy, const FunctorType& functor,
::std::enable_if_t<::Kokkos::is_execution_policy<Policy>::value>* /*param*/,
const Policy& d_policy, const FunctorType& d_functor,
::std::enable_if_t<
::Kokkos::is_execution_policy<Policy>::value>* /*d_param*/) {
parallel_for_pushforward(::std::string("anonymous_parallel_for"), policy,
functor, ::std::string(""), d_policy, d_functor);
}

template <typename Policy, class FunctorType> // alternative signature
void parallel_for_pushforward(Policy policy, const FunctorType& functor,
const ::std::string& str, Policy d_policy,
const FunctorType& d_functor,
const ::std::string& d_str) {
parallel_for_pushforward(str, policy, functor, d_str, d_policy, d_functor);
}

template <typename Policy, class FunctorType> // alternative signature
void parallel_for_pushforward(
const Policy& policy, const FunctorType& functor, const ::std::string& str,
::std::enable_if_t<::Kokkos::is_execution_policy<Policy>::value>* /*param*/,
const Policy& d_policy, const FunctorType& d_functor,
const ::std::string& d_str,
::std::enable_if_t<
::Kokkos::is_execution_policy<Policy>::value>* /*d_param*/) {
parallel_for_pushforward(str, policy, functor, d_str, d_policy, d_functor);
}

} // namespace Kokkos
} // namespace clad::custom_derivatives

Expand Down
Loading
Loading