Skip to content

Commit

Permalink
Add support for Kokkos::deep_copy in the rvs mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Oct 12, 2024
1 parent a14a3f6 commit fdc3dac
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 8 deletions.
116 changes: 116 additions & 0 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,122 @@ inline void deep_copy_pushforward(const View1& dst, const View2& src, T param,
deep_copy(dst, src);
deep_copy(d_dst, d_src);
}
template <typename View, int Rank> struct iterate_over_all_view_elements {
template <typename F> static void run(const View& v, F func) {}
};
template <typename View> struct iterate_over_all_view_elements<View, 1> {
template <typename F> static void run(const View& v, F func) {
::Kokkos::parallel_for("iterate_over_all_view_elements", v.extent(0), func);
}
};
template <typename View> struct iterate_over_all_view_elements<View, 2> {
template <typename F> static void run(const View& v, F func) {
::Kokkos::parallel_for("iterate_over_all_view_elements",
::Kokkos::MDRangePolicy<::Kokkos::Rank<2>>(
{0, 0}, {v.extent(0), v.extent(1)}),
func);
}
};
template <typename View> struct iterate_over_all_view_elements<View, 3> {
template <typename F> static void run(const View& v, F func) {
::Kokkos::parallel_for(
"iterate_over_all_view_elements",
::Kokkos::MDRangePolicy<::Kokkos::Rank<3>>(
{0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2)}),
func);
}
};
template <typename View> struct iterate_over_all_view_elements<View, 4> {
template <typename F> static void run(const View& v, F func) {
::Kokkos::parallel_for(
"iterate_over_all_view_elements",
::Kokkos::MDRangePolicy<::Kokkos::Rank<4>>(
{0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2), v.extent(3)}),
func);
}
};
template <typename View> struct iterate_over_all_view_elements<View, 5> {
template <typename F> static void run(const View& v, F func) {
::Kokkos::parallel_for(
"iterate_over_all_view_elements",
::Kokkos::MDRangePolicy<::Kokkos::Rank<5>>(
{0, 0, 0, 0, 0},
{v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4)}),
func);
}
};
template <typename View> struct iterate_over_all_view_elements<View, 6> {
template <typename F> static void run(const View& v, F func) {
::Kokkos::parallel_for(
"iterate_over_all_view_elements",
::Kokkos::MDRangePolicy<::Kokkos::Rank<6>>(
{0, 0, 0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2),
v.extent(3), v.extent(4), v.extent(5)}),
func);
}
};
template <typename View> struct iterate_over_all_view_elements<View, 7> {
template <typename F> static void run(const View& v, F func) {
::Kokkos::parallel_for(
"iterate_over_all_view_elements",
::Kokkos::MDRangePolicy<::Kokkos::Rank<7>>(
{0, 0, 0, 0, 0, 0, 0},
{v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4),
v.extent(5), v.extent(6)}),
func);
}
};
template <typename... ViewArgs>
void deep_copy_pullback(
const ::Kokkos::View<ViewArgs...>& dst,
typename ::Kokkos::ViewTraits<ViewArgs...>::const_value_type& /*value*/,
::std::enable_if_t<::std::is_same<
typename ::Kokkos::ViewTraits<ViewArgs...>::specialize, void>::value>*,
::Kokkos::View<ViewArgs...>* d_dst,
typename ::Kokkos::ViewTraits<ViewArgs...>::value_type* d_value,
::std::enable_if_t<
::std::is_same<typename ::Kokkos::ViewTraits<ViewArgs...>::specialize,
void>::value>*) {
typename ::Kokkos::ViewTraits<ViewArgs...>::value_type res = 0;

iterate_over_all_view_elements<
::Kokkos::View<ViewArgs...>,
::Kokkos::ViewTraits<ViewArgs...>::rank>::run(dst,
[&res,
&d_dst](auto&&... args) {
res += (*d_dst)(args...);
(*d_dst)(args...) = 0;
});

(*d_value) += res;
}
template <typename... ViewArgs1, typename... ViewArgs2>
inline void deep_copy_pullback(
const ::Kokkos::View<ViewArgs1...>& dst,
const ::Kokkos::View<ViewArgs2...>& /*src*/,
::std::enable_if_t<
(::std::is_void<
typename ::Kokkos::ViewTraits<ViewArgs1...>::specialize>::value &&
::std::is_void<
typename ::Kokkos::ViewTraits<ViewArgs2...>::specialize>::value &&
((unsigned int)(::Kokkos::ViewTraits<ViewArgs1...>::rank) != 0 ||
(unsigned int)(::Kokkos::ViewTraits<ViewArgs2...>::rank) != 0))>*,
::Kokkos::View<ViewArgs1...>* d_dst, ::Kokkos::View<ViewArgs2...>* d_src,
::std::enable_if_t<
(::std::is_void<
typename ::Kokkos::ViewTraits<ViewArgs1...>::specialize>::value &&
::std::is_void<
typename ::Kokkos::ViewTraits<ViewArgs2...>::specialize>::value &&
((unsigned int)(::Kokkos::ViewTraits<ViewArgs1...>::rank) != 0 ||
(unsigned int)(::Kokkos::ViewTraits<ViewArgs2...>::rank) != 0))>*) {
iterate_over_all_view_elements<::Kokkos::View<ViewArgs1...>,
::Kokkos::ViewTraits<ViewArgs1...>::rank>::
run(dst, [&d_src, &d_dst](auto&&... args) {
(*d_src)(args...) += (*d_dst)(args...);
(*d_dst)(args...) = 0;
});
}

template <typename View, typename Idx0, typename Idx1, typename Idx2,
typename Idx3, typename Idx4, typename Idx5, typename Idx6,
typename Idx7>
Expand Down
42 changes: 34 additions & 8 deletions unittests/Kokkos/ViewAccess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ double f(double x, double y) {
Kokkos::View<double* [4], Kokkos::LayoutLeft, Kokkos::HostSpace> b("b", N1);

a(0, 0) = x;
b(0, 0) = y;
b(1, 1) = y;

b(0, 0) += a(0, 0) * b(0, 0);
b(1, 1) += a(0, 0) * b(1, 1);

return a(0, 0) * a(0, 0) * b(0, 0) + b(0, 0);
return a(0, 0) * a(0, 0) * b(1, 1) + b(1, 1);
}

double f_2(double x, double y) {
Expand All @@ -37,6 +37,22 @@ double f_2(double x, double y) {
return a(0, 0);
}

double f_3(double x, double y) {

const int N1 = 4;

Kokkos::View<double* [4], Kokkos::LayoutLeft, Kokkos::HostSpace> a("a", N1);
Kokkos::View<double* [4], Kokkos::LayoutLeft, Kokkos::HostSpace> b("b", N1);

Kokkos::deep_copy(a, 3 * x + y);
b(0, 0) = y;
Kokkos::deep_copy(b, a);

b(0, 0) += a(0, 0) * b(0, 0);

return a(0, 0) + b(0, 0);
}

TEST(ViewAccess, Test1) {
EXPECT_NEAR(f(0, 1), 1, 1e-8);
EXPECT_NEAR(f(0, 2), 2, 1e-8);
Expand All @@ -51,7 +67,6 @@ TEST(ViewAccess, Test2) {

std::function<double(double)> f_tmp = [](double x) { return f(x, 4.); };
double dx_f_FD = finite_difference_tangent(f_tmp, 3., epsilon);

EXPECT_NEAR(f_x.execute(3, 4), dx_f_FD, tolerance * dx_f_FD);

auto f_2_x = clad::differentiate(f_2, "x");
Expand All @@ -60,13 +75,24 @@ TEST(ViewAccess, Test2) {
double dx_f_2_FD = finite_difference_tangent(f_2_tmp, 3., epsilon);
EXPECT_NEAR(f_2_x.execute(3, 4), dx_f_2_FD, tolerance * dx_f_2_FD);

auto f_3_y = clad::differentiate(f_3, "y");

std::function<double(double)> f_3_tmp = [](double y) { return f_3(3., y); };
double dy_f_3_FD = finite_difference_tangent(f_3_tmp, 4., epsilon);
EXPECT_NEAR(f_3_y.execute(3, 4), dy_f_3_FD, tolerance * dy_f_3_FD);

auto f_grad_exe = clad::gradient(f);
double dx, dy;
f_grad_exe.execute(3., 4., &dx, &dy);
EXPECT_NEAR(f_x.execute(3, 4), dx, tolerance * dx);

// double dx_2, dy_2;
// auto f_2_grad_exe = clad::gradient(f_2);
// f_2_grad_exe.execute(3., 4., &dx_2, &dy_2);
// EXPECT_NEAR(f_2_x.execute(3, 4),dx_2,tolerance*dx_2);
double dx_2, dy_2;
auto f_2_grad_exe = clad::gradient(f_2);
f_2_grad_exe.execute(3., 4., &dx_2, &dy_2);
EXPECT_NEAR(f_2_x.execute(3, 4), dx_2, tolerance * dx_2);

double dx_3, dy_3;
auto f_3_grad_exe = clad::gradient(f_3);
f_3_grad_exe.execute(3., 4., &dx_3, &dy_3);
EXPECT_NEAR(f_3_y.execute(3, 4), dy_3, tolerance * dy_3);
}

0 comments on commit fdc3dac

Please sign in to comment.