Skip to content

Commit

Permalink
Add support for Kokkos::resize in the forward mode
Browse files Browse the repository at this point in the history
This commit extends the Kokkos support in Clad by
providing custom pushforwards for different overloads
of the `Kokkos::resize` function.
  • Loading branch information
gojakuch committed Jul 23, 2024
1 parent f03fc98 commit c10d76a
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
23 changes: 23 additions & 0 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,29 @@ inline void deep_copy_pushforward(const View1& dst, const View2& src, T param,
deep_copy(d_dst, d_src);
}

template <class View>
inline void
resize_pushforward(View& v, const size_t n0, const size_t n1, const size_t n2,
const size_t n3, const size_t n4, const size_t n5,
const size_t n6, const size_t n7, View& d_v,
const size_t d_n0, const size_t d_n1, const size_t d_n2,
const size_t d_n3, const size_t d_n4, const size_t d_n5,
const size_t d_n6, const size_t d_n7) {
::Kokkos::resize(v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}

template <class I, class dI, class View>
inline void resize_pushforward(
const I& arg, View& v, const size_t n0, const size_t n1, const size_t n2,
const size_t n3, const size_t n4, const size_t n5, const size_t n6,
const size_t n7, const dI& d_arg, View& d_v, const size_t d_n0,
const size_t d_n1, const size_t d_n2, const size_t d_n3, const size_t d_n4,
const size_t d_n5, const size_t d_n6, const size_t d_n7) {
::Kokkos::resize(arg, v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(arg, d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}

template <class ExecPolicy, class FunctorType>
inline void
parallel_for_pushforward(const ::std::string& str, const ExecPolicy& policy,
Expand Down
46 changes: 45 additions & 1 deletion unittests/Kokkos/ViewBasics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
// https://github.com/kliegeois/clad/blob/kokkos-PR/unittests/Kokkos/view_access.cpp
// it has been modified to match gtest guidelines and improve readability

#include "ParallelAdd.h"
#include <Kokkos_Core.hpp>
#include "clad/Differentiator/Differentiator.h"
#include "clad/Differentiator/KokkosBuiltins.h"
#include "gtest/gtest.h"

double f_basics(double x, double y) {
Expand Down Expand Up @@ -160,4 +160,48 @@ TEST(ViewBasics, TestDeepCopy2Reverse) {
// EXPECT_NEAR(dy_f_FD, dy, abs(tau*dy));
// }
// }
}

double f_basics_resize_1(double x, double y) {
Kokkos::View<double** [3], Kokkos::LayoutLeft, Kokkos::HostSpace> a("a", 3,
2);

Kokkos::deep_copy(a, 3 * x + y);
a(2, 1, 0) = x * y;

Kokkos::resize(Kokkos::WithoutInitializing, a, 5, 5);

return a(2, 1, 0);
}

TEST(ViewBasics, TestResize1) {
const double eps = 1e-8;

auto df = clad::differentiate(f_basics_resize_1, 0);
auto df_true = [](double x, double y) { return y; };
for (double x = 3; x <= 5; x += 1)
for (double y = 3; y <= 5; y += 1)
EXPECT_NEAR(df.execute(x, y), df_true(x, y), eps);
}

double f_basics_resize_2(double x, double y) {
Kokkos::View<double** [3], Kokkos::LayoutLeft, Kokkos::HostSpace> a("a", 3,
2);

Kokkos::deep_copy(a, 3 * x + y);
a(2, 1, 0) = x * y;

Kokkos::resize(a, 5, 5);

return a(2, 1, 0);
}

TEST(ViewBasics, TestResize2) {
const double eps = 1e-8;

auto df = clad::differentiate(f_basics_resize_2, 0);
auto df_true = [](double x, double y) { return y; };
for (double x = 3; x <= 5; x += 1)
for (double y = 3; y <= 5; y += 1)
EXPECT_NEAR(df.execute(x, y), df_true(x, y), eps);
}

0 comments on commit c10d76a

Please sign in to comment.