Skip to content

Commit

Permalink
Add support for Kokkos::resize in the rvs mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Oct 14, 2024
1 parent f86eede commit 8fd700c
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 20 deletions.
77 changes: 61 additions & 16 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ template <typename View, int Rank> struct iterate_over_all_view_elements {
template <typename View> struct iterate_over_all_view_elements<View, 1> {
template <typename F> static void run(const View& v, F func) {
::Kokkos::parallel_for("iterate_over_all_view_elements", v.extent(0), func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 2> {
Expand All @@ -348,6 +349,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 2> {
::Kokkos::MDRangePolicy<::Kokkos::Rank<2>>(
{0, 0}, {v.extent(0), v.extent(1)}),
func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 3> {
Expand All @@ -357,6 +359,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 3> {
::Kokkos::MDRangePolicy<::Kokkos::Rank<3>>(
{0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2)}),
func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 4> {
Expand All @@ -366,6 +369,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 4> {
::Kokkos::MDRangePolicy<::Kokkos::Rank<4>>(
{0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2), v.extent(3)}),
func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 5> {
Expand All @@ -376,6 +380,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 5> {
{0, 0, 0, 0, 0},
{v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4)}),
func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 6> {
Expand All @@ -386,6 +391,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 6> {
{0, 0, 0, 0, 0, 0}, {v.extent(0), v.extent(1), v.extent(2),
v.extent(3), v.extent(4), v.extent(5)}),
func);
::Kokkos::fence();
}
};
template <typename View> struct iterate_over_all_view_elements<View, 7> {
Expand All @@ -397,6 +403,7 @@ template <typename View> struct iterate_over_all_view_elements<View, 7> {
{v.extent(0), v.extent(1), v.extent(2), v.extent(3), v.extent(4),
v.extent(5), v.extent(6)}),
func);
::Kokkos::fence();
}
};
template <typename... ViewArgs>
Expand Down Expand Up @@ -452,30 +459,68 @@ inline void deep_copy_pullback(

template <typename View, typename Idx0, typename Idx1, typename Idx2,
typename Idx3, typename Idx4, typename Idx5, typename Idx6,
typename Idx7>
inline void
resize_pushforward(View& v, const Idx0 n0, const Idx1 n1, const Idx2 n2,
const Idx3 n3, const Idx4 n4, const Idx5 n5, const Idx6 n6,
const Idx7 n7, View& d_v, const Idx0 /*d_n*/,
const Idx1 /*d_n*/, const Idx2 /*d_n*/, const Idx3 /*d_n*/,
const Idx4 /*d_n*/, const Idx5 /*d_n*/, const Idx6 /*d_n*/,
const Idx7 /*d_n*/) {
typename Idx7, typename dIdx0, typename dIdx1, typename dIdx2,
typename dIdx3, typename dIdx4, typename dIdx5, typename dIdx6,
typename dIdx7>
inline void resize_pushforward(View& v, const Idx0 n0, const Idx1 n1,
const Idx2 n2, const Idx3 n3, const Idx4 n4,
const Idx5 n5, const Idx6 n6, const Idx7 n7,
View& d_v, const dIdx0 /*d_n*/,
const dIdx1 /*d_n*/, const dIdx2 /*d_n*/,
const dIdx3 /*d_n*/, const dIdx4 /*d_n*/,
const dIdx5 /*d_n*/, const dIdx6 /*d_n*/,
const dIdx7 /*d_n*/) {
::Kokkos::resize(v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}
template <class I, class dI, class View, typename Idx0, typename Idx1,
typename Idx2, typename Idx3, typename Idx4, typename Idx5,
typename Idx6, typename Idx7>
inline void
resize_pushforward(const I& arg, View& v, const Idx0 n0, const Idx1 n1,
const Idx2 n2, const Idx3 n3, const Idx4 n4, const Idx5 n5,
const Idx6 n6, const Idx7 n7, const dI& /*d_arg*/, View& d_v,
const Idx0 /*d_n*/, const Idx1 /*d_n*/, const Idx2 /*d_n*/,
const Idx3 /*d_n*/, const Idx4 /*d_n*/, const Idx5 /*d_n*/,
const Idx6 /*d_n*/, const Idx7 /*d_n*/) {
typename Idx6, typename Idx7, typename dIdx0, typename dIdx1,
typename dIdx2, typename dIdx3, typename dIdx4, typename dIdx5,
typename dIdx6, typename dIdx7>
inline void resize_pushforward(const I& arg, View& v, const Idx0 n0,
const Idx1 n1, const Idx2 n2, const Idx3 n3,
const Idx4 n4, const Idx5 n5, const Idx6 n6,
const Idx7 n7, const dI& /*d_arg*/, View& d_v,
const dIdx0 /*d_n*/, const dIdx1 /*d_n*/,
const dIdx2 /*d_n*/, const dIdx3 /*d_n*/,
const dIdx4 /*d_n*/, const dIdx5 /*d_n*/,
const dIdx6 /*d_n*/, const dIdx7 /*d_n*/) {
::Kokkos::resize(arg, v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}
template <typename View, typename Idx0, typename Idx1, typename Idx2,
typename Idx3, typename Idx4, typename Idx5, typename Idx6,
typename Idx7, typename dIdx0, typename dIdx1, typename dIdx2,
typename dIdx3, typename dIdx4, typename dIdx5, typename dIdx6,
typename dIdx7>
void resize_reverse_forw(View& v, const Idx0 n0, const Idx1 n1, const Idx2 n2,
const Idx3 n3, const Idx4 n4, const Idx5 n5,
const Idx6 n6, const Idx7 n7, View& d_v,
const dIdx0 /*d_n*/, const dIdx1 /*d_n*/,
const dIdx2 /*d_n*/, const dIdx3 /*d_n*/,
const dIdx4 /*d_n*/, const dIdx5 /*d_n*/,
const dIdx6 /*d_n*/, const dIdx7 /*d_n*/) {
::Kokkos::resize(v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}
template <class I, class dI, class View, typename Idx0, typename Idx1,
typename Idx2, typename Idx3, typename Idx4, typename Idx5,
typename Idx6, typename Idx7, typename dIdx0, typename dIdx1,
typename dIdx2, typename dIdx3, typename dIdx4, typename dIdx5,
typename dIdx6, typename dIdx7>
void resize_reverse_forw(const I& arg, View& v, const Idx0 n0, const Idx1 n1,
const Idx2 n2, const Idx3 n3, const Idx4 n4,
const Idx5 n5, const Idx6 n6, const Idx7 n7,
const dI& /*d_arg*/, View& d_v, const dIdx0 /*d_n*/,
const dIdx1 /*d_n*/, const dIdx2 /*d_n*/,
const dIdx3 /*d_n*/, const dIdx4 /*d_n*/,
const dIdx5 /*d_n*/, const dIdx6 /*d_n*/,
const dIdx7 /*d_n*/) {
::Kokkos::resize(arg, v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}
template <class... Args> void resize_pullback(Args... /*args*/) {}

/// Fence
template <typename S> void fence_pushforward(const S& s, const S& /*d_s*/) {
Expand Down
6 changes: 3 additions & 3 deletions unittests/Kokkos/ViewAccess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,16 @@ TEST(ViewAccess, Test2) {
EXPECT_NEAR(f_3_y.execute(3, 4), dy_f_3_FD, tolerance * dy_f_3_FD);

auto f_grad_exe = clad::gradient(f);
double dx, dy;
double dx = 0, dy = 0;
f_grad_exe.execute(3., 4., &dx, &dy);
EXPECT_NEAR(f_x.execute(3, 4), dx, tolerance * dx);

double dx_2, dy_2;
double dx_2 = 0, dy_2 = 0;
auto f_2_grad_exe = clad::gradient(f_2);
f_2_grad_exe.execute(3., 4., &dx_2, &dy_2);
EXPECT_NEAR(f_2_x.execute(3, 4), dx_2, tolerance * dx_2);

double dx_3, dy_3;
double dx_3 = 0, dy_3 = 0;
auto f_3_grad_exe = clad::gradient(f_3);
f_3_grad_exe.execute(3., 4., &dx_3, &dy_3);
EXPECT_NEAR(f_3_y.execute(3, 4), dy_3, tolerance * dy_3);
Expand Down
39 changes: 38 additions & 1 deletion unittests/Kokkos/ViewBasics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ double f_basics_resize_3(double x, double y) {
2);
Kokkos::deep_copy(a, 3 * x + y);

Kokkos::resize(Kokkos::WithoutInitializing, a, 5, 5);
Kokkos::resize(Kokkos::WithoutInitializing, a, 5,
5); // FIXME: this signature for the resize function is not yet
// supported in the reverse mode

a(4, 4, 0) = x * y;

Expand Down Expand Up @@ -250,6 +252,41 @@ TEST(ViewBasics, TestResize4) {
EXPECT_NEAR(df.execute(x, y), df_true(x, y), eps);
}

double f_basics_resize_5_both_modes(double x, double y) {
Kokkos::View<double** [3], Kokkos::LayoutLeft, Kokkos::HostSpace> a("a", 3,
2);
Kokkos::View<double** [3], Kokkos::LayoutLeft, Kokkos::HostSpace> b("b", 5,
5);

b(4, 4, 0) = x * y * 2;
b(2, 1, 0) = 0;

Kokkos::deep_copy(a, 3 * x + y);
a(2, 1, 0) = x * y;

Kokkos::resize(a, 5, 5);
Kokkos::deep_copy(a, b);

return a(4, 4, 0);
}

TEST(ViewBasics, TestResize5) {
const double eps = 1e-8;

auto df = clad::differentiate(f_basics_resize_5_both_modes, 0);
auto gradf = clad::gradient(f_basics_resize_5_both_modes);
auto df_true_x = [](double x, double y) { return y * 2; };
for (double x = 3; x <= 5; x += 1)
for (double y = 3; y <= 5; y += 1) {
double dfdx = df.execute(x, y);
EXPECT_NEAR(dfdx, df_true_x(x, y), eps);
double dx = 0, dy = 0;
gradf.execute(x, y, &dx, &dy);
EXPECT_NEAR(dfdx, dx, eps);
EXPECT_NEAR(2 * x, dy, eps);
}
}

template <typename View> struct FooModifier {
double x;

Expand Down

0 comments on commit 8fd700c

Please sign in to comment.