Skip to content

Commit

Permalink
Add support for Kokkos::parallel_for in the fwd mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Aug 4, 2024
1 parent 7d1e26c commit f5dd53a
Show file tree
Hide file tree
Showing 2 changed files with 326 additions and 9 deletions.
220 changes: 211 additions & 9 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define CLAD_DIFFERENTIATOR_KOKKOSBUILTINS_H

#include <Kokkos_Core.hpp>
#include <type_traits>
#include "clad/Differentiator/Differentiator.h"

namespace clad::custom_derivatives {
Expand All @@ -29,7 +30,7 @@ constructor_pushforward(
}
} // namespace class_functions

/// Kokkos functions
/// Kokkos functions (view utils)
namespace Kokkos {
template <typename View1, typename View2, typename T>
inline void deep_copy_pushforward(const View1& dst, const View2& src, T param,
Expand Down Expand Up @@ -66,15 +67,216 @@ inline void resize_pushforward(const I& arg, View& v, const size_t n0,
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}

template <class ExecPolicy, class FunctorType>
inline void
parallel_for_pushforward(const ::std::string& str, const ExecPolicy& policy,
const FunctorType& functor, const ::std::string& d_str,
const ExecPolicy& d_policy,
const FunctorType& d_functor) {
// TODO: implement parallel_for_pushforward
return;
/// Parallel for
template <class... PolicyParams, class FunctorType> // range policy
inline void parallel_for_pushforward(
const ::std::string& str,
const ::Kokkos::RangePolicy<PolicyParams...>& policy,
const FunctorType& functor, const ::std::string& /*d_str*/,
const ::Kokkos::RangePolicy<PolicyParams...>& /*d_policy*/,
const FunctorType& d_functor) {
::Kokkos::parallel_for(str, policy, functor);
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const int i) {
functor.operator_call_pushforward(i, &d_functor, 0);
});
}

// This structure is used to dispatch parallel for pushforward calls based on
// the rank and the work tag of the MDPolicy
template <class Policy, class FunctorType, class T, int Rank>
struct diff_parallel_for_MDP_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
assert(false && "Some parallel_for misuse happened during the compilation "
"(templates have not been matched properly).");
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 2> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 2> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 3> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 3> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 4> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 4> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 5> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 5> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, 0, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, T, 6> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, 0, 0, 0, 0, 0, 0);
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_MDP_call_dispatch<Policy, FunctorType, void, 6> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(
args..., &d_functor, 0, 0, 0, 0, 0, 0);
});
}
};

template <class PolicyP, class... PolicyParams,
class FunctorType> // multi-dimensional policy
inline void parallel_for_pushforward(
const ::std::string& str,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& policy,
const FunctorType& functor, const ::std::string& /*d_str*/,
const ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>& /*d_policy*/,
const FunctorType& d_functor) {
::Kokkos::parallel_for(str, policy, functor);
diff_parallel_for_MDP_call_dispatch<
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>, FunctorType,
typename ::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>::work_tag,
::Kokkos::MDRangePolicy<PolicyP, PolicyParams...>::rank>::run(str, policy,
functor,
d_functor);
}

// This structure is used to dispatch parallel for pushforward calls based on
// the work tag of other types of policies
template <class Policy, class FunctorType, class T>
struct diff_parallel_for_OP_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for("_diff_" + str, policy,
[&functor, &d_functor](const T x, auto&&... args) {
functor.operator_call_pushforward(
x, args..., &d_functor, &x, {});
});
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_OP_call_dispatch<Policy, FunctorType, void> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](auto&&... args) {
functor.operator_call_pushforward(args..., &d_functor, {});
});
}
};

// This structure is used to dispatch parallel for pushforward calls for
// integral policies
template <class Policy, class FunctorType, bool IsInt>
struct diff_parallel_for_int_call_dispatch {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
diff_parallel_for_OP_call_dispatch<
Policy, FunctorType, typename Policy::work_tag>::run(str, policy,
functor,
d_functor);
}
};
template <class Policy, class FunctorType>
struct diff_parallel_for_int_call_dispatch<Policy, FunctorType, true> {
static void run(const ::std::string& str, const Policy& policy,
const FunctorType& functor, const FunctorType& d_functor) {
::Kokkos::parallel_for(
"_diff_" + str, policy, [&functor, &d_functor](const int i) {
functor.operator_call_pushforward(i, &d_functor, 0);
});
}
};

template <class Policy, class FunctorType> // other policy type
inline void parallel_for_pushforward(const ::std::string& str,
const Policy& policy,
const FunctorType& functor,
const ::std::string& /*d_str*/,
const Policy& /*d_policy*/,
const FunctorType& d_functor) {
::Kokkos::parallel_for(str, policy, functor);
diff_parallel_for_int_call_dispatch<
Policy, FunctorType, ::std::is_integral<Policy>::value>::run(str, policy,
functor,
d_functor);
}

} // namespace Kokkos
} // namespace clad::custom_derivatives

Expand Down
115 changes: 115 additions & 0 deletions unittests/Kokkos/ParallelFor.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <Kokkos_Core.hpp>
#include "clad/Differentiator/Differentiator.h"
#include "clad/Differentiator/KokkosBuiltins.h"
#include "gtest/gtest.h"
// #include "TestUtils.h"
#include "ParallelAdd.h"
Expand Down Expand Up @@ -89,4 +90,118 @@ TEST(ParallelFor, ParallelPolynomialReverse) {
// f_grad.execute(x, &dx);
// EXPECT_NEAR(dx_f_true, dx, abs(tau*dx));
// }
}

template <typename View> struct Foo {
View& res;
double& x;

Foo(View& _res, double& _x) : res(_res), x(_x) {}

KOKKOS_INLINE_FUNCTION
void operator()(const int i) const { res(i) = x * i; }
};

double parallel_for_functor_simplest_case_intpol(double x) {
Kokkos::View<double[5], Kokkos::HostSpace> res("res");

Foo<Kokkos::View<double[5], Kokkos::HostSpace>> f(res, x);

f(0);

Kokkos::parallel_for("polynomial", 5, f);
Kokkos::parallel_for(5, f);

return res(3);
}

double parallel_for_functor_simplest_case_rangepol(double x) {
Kokkos::View<double[5], Kokkos::HostSpace> res("res");

Foo<Kokkos::View<double[5], Kokkos::HostSpace>> f(res, x);

f(0);

Kokkos::parallel_for(
"polynomial",
Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace>(1, 5), f);
// Overwrite with another parallel_for (not named)
Kokkos::parallel_for(
Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace>(1, 5), f);

return res(3);
}

template <typename View> struct Foo2 {
View& res;
double& x;

Foo2(View& _res, double& _x) : res(_res), x(_x) {}

KOKKOS_INLINE_FUNCTION
void operator()(const int i, const int j) const { res(i, j) = x * i * j; }
};

double parallel_for_functor_simplest_case_mdpol(double x) {
Kokkos::View<double[5][5], Kokkos::HostSpace> res("res");

Foo2<Kokkos::View<double[5][5], Kokkos::HostSpace>> f(res, x);

f(0, 0);

Kokkos::parallel_for(
"polynomial",
Kokkos::MDRangePolicy<
Kokkos::Rank<2, Kokkos::Iterate::Right, Kokkos::Iterate::Left>>(
{1, 1}, {5, 5}, {1, 1}),
f);

return res(3, 4);
}

double parallel_for_functor_simplest_case_mdpol_space_and_anon(double x) {
Kokkos::View<double[5][5], Kokkos::HostSpace> res("res");

Foo2<Kokkos::View<double[5][5], Kokkos::HostSpace>> f(res, x);

f(0, 0);

Kokkos::parallel_for(
"polynomial",
Kokkos::MDRangePolicy<
Kokkos::DefaultHostExecutionSpace,
Kokkos::Rank<2, Kokkos::Iterate::Right, Kokkos::Iterate::Left>>(
{1, 1}, {5, 5}, {1, 1}),
f);
// Overwrite with another parallel_for (not named)
Kokkos::parallel_for(
Kokkos::MDRangePolicy<
Kokkos::DefaultHostExecutionSpace,
Kokkos::Rank<2, Kokkos::Iterate::Right, Kokkos::Iterate::Left>>(
{1, 1}, {5, 5}, {1, 1}),
f);

return res(3, 4);
}

TEST(ParallelFor, FunctorSimplestCases) {
const double eps = 1e-8;

auto df1 = clad::differentiate(parallel_for_functor_simplest_case_intpol, 0);
for (double x = 3; x <= 5; x += 1)
EXPECT_NEAR(df1.execute(x), 3, eps);

auto df2 =
clad::differentiate(parallel_for_functor_simplest_case_rangepol, 0);
for (double x = 3; x <= 5; x += 1)
EXPECT_NEAR(df2.execute(x), 3, eps);

auto df3 = clad::differentiate(parallel_for_functor_simplest_case_mdpol, 0);
for (double x = 3; x <= 5; x += 1)
EXPECT_NEAR(df3.execute(x), 12, eps);

auto df4 = clad::differentiate(
parallel_for_functor_simplest_case_mdpol_space_and_anon, 0);
for (double x = 3; x <= 5; x += 1)
EXPECT_NEAR(df4.execute(x), 12, eps);
}

0 comments on commit f5dd53a

Please sign in to comment.