From c10d76a8da077c5ce3ab94677ad12b266842d8c3 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Tue, 23 Jul 2024 12:17:13 +0200 Subject: [PATCH] Add support for `Kokkos::resize` in the forward mode This commit extends the Kokkos support in Clad by providing custom pushforwards for different overloads of the `Kokkos::resize` function. --- include/clad/Differentiator/KokkosBuiltins.h | 23 ++++++++++ unittests/Kokkos/ViewBasics.cpp | 46 +++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/include/clad/Differentiator/KokkosBuiltins.h b/include/clad/Differentiator/KokkosBuiltins.h index 8d24adcf0..181f85b48 100644 --- a/include/clad/Differentiator/KokkosBuiltins.h +++ b/include/clad/Differentiator/KokkosBuiltins.h @@ -39,6 +39,29 @@ inline void deep_copy_pushforward(const View1& dst, const View2& src, T param, deep_copy(d_dst, d_src); } +template +inline void +resize_pushforward(View& v, const size_t n0, const size_t n1, const size_t n2, + const size_t n3, const size_t n4, const size_t n5, + const size_t n6, const size_t n7, View& d_v, + const size_t d_n0, const size_t d_n1, const size_t d_n2, + const size_t d_n3, const size_t d_n4, const size_t d_n5, + const size_t d_n6, const size_t d_n7) { + ::Kokkos::resize(v, n0, n1, n2, n3, n4, n5, n6, n7); + ::Kokkos::resize(d_v, n0, n1, n2, n3, n4, n5, n6, n7); +} + +template +inline void resize_pushforward( + const I& arg, View& v, const size_t n0, const size_t n1, const size_t n2, + const size_t n3, const size_t n4, const size_t n5, const size_t n6, + const size_t n7, const dI& d_arg, View& d_v, const size_t d_n0, + const size_t d_n1, const size_t d_n2, const size_t d_n3, const size_t d_n4, + const size_t d_n5, const size_t d_n6, const size_t d_n7) { + ::Kokkos::resize(arg, v, n0, n1, n2, n3, n4, n5, n6, n7); + ::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7); +} + template inline void parallel_for_pushforward(const ::std::string& str, const ExecPolicy& policy, diff --git a/unittests/Kokkos/ViewBasics.cpp b/unittests/Kokkos/ViewBasics.cpp index 8773f5451..4a78078db 100644 --- a/unittests/Kokkos/ViewBasics.cpp +++ b/unittests/Kokkos/ViewBasics.cpp @@ -3,9 +3,9 @@ // https://github.com/kliegeois/clad/blob/kokkos-PR/unittests/Kokkos/view_access.cpp // it has been modified to match gtest guidelines and improve readability -#include "ParallelAdd.h" #include #include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/KokkosBuiltins.h" #include "gtest/gtest.h" double f_basics(double x, double y) { @@ -160,4 +160,48 @@ TEST(ViewBasics, TestDeepCopy2Reverse) { // EXPECT_NEAR(dy_f_FD, dy, abs(tau*dy)); // } // } +} + +double f_basics_resize_1(double x, double y) { + Kokkos::View a("a", 3, + 2); + + Kokkos::deep_copy(a, 3 * x + y); + a(2, 1, 0) = x * y; + + Kokkos::resize(Kokkos::WithoutInitializing, a, 5, 5); + + return a(2, 1, 0); +} + +TEST(ViewBasics, TestResize1) { + const double eps = 1e-8; + + auto df = clad::differentiate(f_basics_resize_1, 0); + auto df_true = [](double x, double y) { return y; }; + for (double x = 3; x <= 5; x += 1) + for (double y = 3; y <= 5; y += 1) + EXPECT_NEAR(df.execute(x, y), df_true(x, y), eps); +} + +double f_basics_resize_2(double x, double y) { + Kokkos::View a("a", 3, + 2); + + Kokkos::deep_copy(a, 3 * x + y); + a(2, 1, 0) = x * y; + + Kokkos::resize(a, 5, 5); + + return a(2, 1, 0); +} + +TEST(ViewBasics, TestResize2) { + const double eps = 1e-8; + + auto df = clad::differentiate(f_basics_resize_2, 0); + auto df_true = [](double x, double y) { return y; }; + for (double x = 3; x <= 5; x += 1) + for (double y = 3; y <= 5; y += 1) + EXPECT_NEAR(df.execute(x, y), df_true(x, y), eps); } \ No newline at end of file