From 3b0acb48e28175242ac74cef38c43b56c0936b49 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Thu, 15 Aug 2024 21:51:37 +0200 Subject: [PATCH] Add support for `Kokkos::fence` in the fwd mode Although this function doesn't need to be differentiated and is correctly used by Clad automatically, this custom pushforward prevents Clad from throwing a warning during that. --- include/clad/Differentiator/KokkosBuiltins.h | 5 +++++ unittests/Kokkos/ParallelFor.cpp | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/include/clad/Differentiator/KokkosBuiltins.h b/include/clad/Differentiator/KokkosBuiltins.h index af676091f..d21710705 100644 --- a/include/clad/Differentiator/KokkosBuiltins.h +++ b/include/clad/Differentiator/KokkosBuiltins.h @@ -67,6 +67,11 @@ inline void resize_pushforward(const I& arg, View& v, const size_t n0, ::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7); } +/// Fence +template void fence_pushforward(const S& s, const S& /*d_s*/) { + ::Kokkos::fence(s); +} + /// Parallel for template // range policy void parallel_for_pushforward( diff --git a/unittests/Kokkos/ParallelFor.cpp b/unittests/Kokkos/ParallelFor.cpp index eeee34e7a..aa9d18b38 100644 --- a/unittests/Kokkos/ParallelFor.cpp +++ b/unittests/Kokkos/ParallelFor.cpp @@ -102,6 +102,22 @@ template struct Foo { void operator()(const int i) const { res(i) = x * i; } }; +double parallel_for_functor_simplest_case_fence(double x) { + Kokkos::View res("res"); + + Kokkos::fence("named fence"); + + Foo> f(res, x); + + f(0); // FIXME: this is a workaround to put Foo::operator() into the + // differentiation plan. This needs to be solved in clad. + + Kokkos::parallel_for(5, f); + Kokkos::fence(); + + return res(3); +} + double parallel_for_functor_simplest_case_intpol(double x) { Kokkos::View res("res"); @@ -191,6 +207,10 @@ double parallel_for_functor_simplest_case_mdpol_space_and_anon(double x) { TEST(ParallelFor, FunctorSimplestCases) { const double eps = 1e-8; + auto df0 = clad::differentiate(parallel_for_functor_simplest_case_fence, 0); + for (double x = 3; x <= 5; x += 1) + EXPECT_NEAR(df0.execute(x), 3, eps); + auto df1 = clad::differentiate(parallel_for_functor_simplest_case_intpol, 0); for (double x = 3; x <= 5; x += 1) EXPECT_NEAR(df1.execute(x), 3, eps);