From 8185afc460f9fc0185f25f087c25514106b57802 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Mon, 26 Aug 2024 19:23:20 +0200 Subject: [PATCH] Provide pushforward methods for `Kokkos::View` indexing Previously, we relied on automatically generated pushforwards for these operator calls, but this solution is way safer and should work for more machines and Kokkos versions. --- include/clad/Differentiator/KokkosBuiltins.h | 59 ++++++++++++++++++++ unittests/Kokkos/ViewBasics.cpp | 31 ++++++++++ 2 files changed, 90 insertions(+) diff --git a/include/clad/Differentiator/KokkosBuiltins.h b/include/clad/Differentiator/KokkosBuiltins.h index d21710705..d5d2a5b7c 100644 --- a/include/clad/Differentiator/KokkosBuiltins.h +++ b/include/clad/Differentiator/KokkosBuiltins.h @@ -28,6 +28,65 @@ constructor_pushforward( Kokkos::View( "_diff_" + name, idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7)}; } + +/// View indexing +template +inline clad::ValueAndPushforward +operator_call_pushforward(const View* v, Idx i0, const View* _d_v, + Idx /*_d_i0*/) { + return {(*v)(i0), (*_d_v)(i0)}; +} +template +clad::ValueAndPushforward +operator_call_pushforward(const View* v, Idx i0, Idx i1, const View* _d_v, + Idx /*_d_i0*/, Idx /*_d_i1*/) { + return {(*v)(i0, i1), (*_d_v)(i0, i1)}; +} +template +clad::ValueAndPushforward +operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, + const View* _d_v, Idx /*_d_i0*/, Idx /*_d_i1*/, + Idx /*_d_i2*/) { + return {(*v)(i0, i1, i2), (*_d_v)(i0, i1, i2)}; +} +template +clad::ValueAndPushforward +operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, + const View* _d_v, Idx /*_d_i0*/, Idx /*_d_i1*/, + Idx /*_d_i2*/, Idx /*_d_i3*/) { + return {(*v)(i0, i1, i2, i3), (*_d_v)(i0, i1, i2, i3)}; +} +template +clad::ValueAndPushforward +operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4, + const View* _d_v, Idx /*_d_i0*/, Idx /*_d_i1*/, + Idx /*_d_i2*/, Idx /*_d_i3*/, Idx /*_d_i4*/) { + return {(*v)(i0, i1, i2, i3, i4), (*_d_v)(i0, i1, i2, i3, i4)}; +} +template +clad::ValueAndPushforward +operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4, + Idx i5, const View* _d_v, Idx /*_d_i0*/, + Idx /*_d_i1*/, Idx /*_d_i2*/, Idx /*_d_i3*/, + Idx /*_d_i4*/, Idx /*_d_i5*/) { + return {(*v)(i0, i1, i2, i3, i4, i5), (*_d_v)(i0, i1, i2, i3, i4, i5)}; +} +template +clad::ValueAndPushforward +operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4, + Idx i5, Idx i6, const View* _d_v, Idx /*_d_i0*/, + Idx /*_d_i1*/, Idx /*_d_i2*/, Idx /*_d_i3*/, + Idx /*_d_i4*/, Idx /*_d_i5*/, Idx /*_d_i6*/) { + return {(*v)(i0, i1, i2, i3, i4, i5, i6), + (*_d_v)(i0, i1, i2, i3, i4, i5, i6)}; +} } // namespace class_functions /// Kokkos functions (view utils) diff --git a/unittests/Kokkos/ViewBasics.cpp b/unittests/Kokkos/ViewBasics.cpp index b16e59f90..4f996d921 100644 --- a/unittests/Kokkos/ViewBasics.cpp +++ b/unittests/Kokkos/ViewBasics.cpp @@ -248,4 +248,35 @@ TEST(ViewBasics, TestResize4) { for (double x = 3; x <= 5; x += 1) for (double y = 3; y <= 5; y += 1) EXPECT_NEAR(df.execute(x, y), df_true(x, y), eps); +} + +template struct FooModifier { + double x; + + FooModifier(View& v, double x) : x(x) {} + + void operator()(View& v) { v(1, 0, 1, 0, 1, 0, 1) += x; } +}; + +double f_basics_call(double x) { + Kokkos::View + a("a"); + Kokkos::deep_copy(a, 3 * x); + + FooModifier> + f(a, x); + + f(a); + + return a(1, 0, 1, 0, 1, 0, 1); +} + +TEST(ViewBasics, FunctorCall4) { + const double eps = 1e-8; + + auto df = clad::differentiate(f_basics_call, 0); + for (double x = 3; x <= 5; x += 1) + EXPECT_NEAR(df.execute(x), 4, eps); } \ No newline at end of file