Skip to content

Commit

Permalink
Provide pushforward methods for Kokkos::View indexing
Browse files Browse the repository at this point in the history
Previously, we relied on automatically generated pushforwards
for these operator calls, but this solution is way safer and
should work for more machines and Kokkos versions.
  • Loading branch information
gojakuch committed Aug 26, 2024
1 parent 6f4b081 commit 1d55095
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
59 changes: 59 additions & 0 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,65 @@ constructor_pushforward(
Kokkos::View<DataType, ViewParams...>(
"_diff_" + name, idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7)};
}

/// View indexing
template <typename View, typename Idx>
inline clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, const View* _d_v,
Idx /*_d_i0*/) {
return {(*v)(i0), (*_d_v)(i0)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, const View* _d_v,
Idx /*_d_i0*/, Idx /*_d_i1*/) {
return {(*v)(i0, i1), (*_d_v)(i0, i1)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2,
const View* _d_v, Idx /*_d_i0*/, Idx /*_d_i1*/,
Idx /*_d_i2*/) {
return {(*v)(i0, i1, i2), (*_d_v)(i0, i1, i2)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3,
const View* _d_v, Idx /*_d_i0*/, Idx /*_d_i1*/,
Idx /*_d_i2*/, Idx /*_d_i3*/) {
return {(*v)(i0, i1, i2, i3), (*_d_v)(i0, i1, i2, i3)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
const View* _d_v, Idx /*_d_i0*/, Idx /*_d_i1*/,
Idx /*_d_i2*/, Idx /*_d_i3*/, Idx /*_d_i4*/) {
return {(*v)(i0, i1, i2, i3, i4), (*_d_v)(i0, i1, i2, i3, i4)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
Idx i5, const View* _d_v, Idx /*_d_i0*/,
Idx /*_d_i1*/, Idx /*_d_i2*/, Idx /*_d_i3*/,
Idx /*_d_i4*/, Idx /*_d_i5*/) {
return {(*v)(i0, i1, i2, i3, i4, i5), (*_d_v)(i0, i1, i2, i3, i4, i5)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
Idx i5, Idx i6, const View* _d_v, Idx /*_d_i0*/,
Idx /*_d_i1*/, Idx /*_d_i2*/, Idx /*_d_i3*/,
Idx /*_d_i4*/, Idx /*_d_i5*/, Idx /*_d_i6*/) {
return {(*v)(i0, i1, i2, i3, i4, i5, i6),
(*_d_v)(i0, i1, i2, i3, i4, i5, i6)};
}
} // namespace class_functions

/// Kokkos functions (view utils)
Expand Down
29 changes: 29 additions & 0 deletions unittests/Kokkos/ViewBasics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,33 @@ TEST(ViewBasics, TestResize4) {
for (double x = 3; x <= 5; x += 1)
for (double y = 3; y <= 5; y += 1)
EXPECT_NEAR(df.execute(x, y), df_true(x, y), eps);
}

template <typename View> struct FooModifier {
double x;

FooModifier(View& v, double x) : x(x) {}

void operator()(View& v) { v(1, 0, 1, 0, 1, 0, 1) += x; }
};

double f_basics_call(double x) {
Kokkos::View<double[2][2][2][2][2][2][2], Kokkos::LayoutLeft,
Kokkos::HostSpace>
a("a");
Kokkos::deep_copy(a, 3 * x);

FooModifier f(a, x);

f(a);

return a(1, 0, 1, 0, 1, 0, 1);
}

TEST(ViewBasics, FunctorCall4) {
const double eps = 1e-8;

auto df = clad::differentiate(f_basics_call, 0);
for (double x = 3; x <= 5; x += 1)
EXPECT_NEAR(df.execute(x), 4, eps);
}

0 comments on commit 1d55095

Please sign in to comment.