Skip to content

Commit

Permalink
Add support for std::array in the rvs mode
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Sep 4, 2024
1 parent 685bcbf commit cc57837
Show file tree
Hide file tree
Showing 2 changed files with 334 additions and 1 deletion.
101 changes: 101 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,107 @@ void constructor_pullback(::std::vector<T>* v, S count, U val,
d_v->clear();
}

template <typename T, ::std::size_t N>
clad::ValueAndAdjoint<T&, T&> operator_subscript_reverse_forw(
::std::array<T, N>* arr, typename ::std::array<T, N>::size_type idx,
::std::array<T, N>* d_arr, typename ::std::array<T, N>::size_type d_idx) {
return {(*arr)[idx], (*d_arr)[idx]};
}
template <typename T, ::std::size_t N, typename P>
void operator_subscript_pullback(
::std::array<T, N>* arr, typename ::std::array<T, N>::size_type idx, P d_y,
::std::array<T, N>* d_arr, typename ::std::array<T, N>::size_type* d_idx) {
(*d_arr)[idx] += d_y;
}
template <typename T, ::std::size_t N>
clad::ValueAndAdjoint<T&, T&> at_reverse_forw(
::std::array<T, N>* arr, typename ::std::array<T, N>::size_type idx,
::std::array<T, N>* d_arr, typename ::std::array<T, N>::size_type d_idx) {
return {(*arr)[idx], (*d_arr)[idx]};
}
template <typename T, ::std::size_t N, typename P>
void at_pullback(::std::array<T, N>* arr,
typename ::std::array<T, N>::size_type idx, P d_y,
::std::array<T, N>* d_arr,
typename ::std::array<T, N>::size_type* d_idx) {
(*d_arr)[idx] += d_y;
}
template <typename T, ::std::size_t N>
void fill_reverse_forw(::std::array<T, N>* a,
const typename ::std::array<T, N>::value_type& u,
::std::array<T, N>* d_a,
const typename ::std::array<T, N>::value_type& d_u) {
a->fill(u);
d_a->fill(0);
}
template <typename T, ::std::size_t N>
void fill_pullback(::std::array<T, N>* arr,
const typename ::std::array<T, N>::value_type& u,
::std::array<T, N>* d_arr,
typename ::std::array<T, N>::value_type* d_u) {
size_t _d_i = 0;
size_t i = 0;
clad::tape<typename ::std::array<T, N>::value_type> _t1 = {};
size_t _t0 = 0;
for (i = 0;; ++i) {
if (!(i < N))
break;
_t0++;
clad::push(_t1, (*arr)[i]);
(*arr)[i] = u;
}
for (;; _t0--) {
if (!_t0)
break;
--i;
(*arr)[i] = clad::pop(_t1);
typename ::std::array<T, N>::value_type _r_d0 = (*d_arr)[i];
(*d_arr)[i] = 0;
*d_u += _r_d0;
}
}
template <typename T, ::std::size_t N>
clad::ValueAndAdjoint<T&, T&>
back_reverse_forw(::std::array<T, N>* arr, ::std::array<T, N>* d_arr) noexcept {
return {arr->back(), d_arr->back()};
}
template <typename T, ::std::size_t N>
void back_pullback(::std::array<T, N>* arr,
typename ::std::array<T, N>::value_type d_u,
::std::array<T, N>* d_arr) noexcept {
(*d_arr)[d_arr->size() - 1] += d_u;
}
template <typename T, ::std::size_t N>
clad::ValueAndAdjoint<T&, T&>
front_reverse_forw(::std::array<T, N>* arr,
::std::array<T, N>* d_arr) noexcept {
return {arr->front(), d_arr->front()};
}
template <typename T, ::std::size_t N>
void front_pullback(::std::array<T, N>* arr,
typename ::std::array<T, N>::value_type d_u,
::std::array<T, N>* d_arr) {
(*d_arr)[0] += d_u;
}
template <typename T, ::std::size_t N>
void size_pullback(::std::array<T, N>* a, ::std::array<T, N>* d_a) noexcept {}
template <typename T, ::std::size_t N>
::clad::ValueAndAdjoint<::std::array<T, N>, ::std::array<T, N>>
constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::array<T, N>>,
const ::std::array<T, N>& arr,
const ::std::array<T, N>& d_arr) {
::std::array<T, N> a = arr;
::std::array<T, N> d_a = d_arr;
return {a, d_a};
}
template <typename T, ::std::size_t N>
void constructor_pullback(::std::array<T, N>* a, const ::std::array<T, N>& arr,
::std::array<T, N>* d_a, ::std::array<T, N>* d_arr) {
for (size_t i = 0; i < N; ++i)
(*d_arr)[i] += (*d_a)[i];
// d_a->fill(0);
}

} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad
Expand Down
234 changes: 233 additions & 1 deletion test/Gradient/STLCustomDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "../TestUtils.h"
#include "../PrintOverloads.h"

#include <array>
#include <vector>

double fn10(double u, double v) {
Expand Down Expand Up @@ -111,19 +112,63 @@ double fn14(double x, double y) {
return a[1];
}

double fn15(double x, double y) {
std::array<double, 3> a;
a.fill(x);

double res = 0;
for (size_t i = 0; i < a.size(); ++i) {
res += a.at(i);
}

return res;
}

double fn16(double x, double y) {
std::array<double, 2> a;
a[0] = 5;
a[1] = y;
std::array<double, 3> _b;
_b[0] = x;
_b[1] = 0;
_b[2] = x*x;
const std::array<double, 3> b = _b;
return a.back() * b.front() * b.at(2) + b[1];
}

double fn17(double x, double y) {
std::array<double, 50> a;
a.fill(y+x+x);
return a[49]+a[3];
}

double fn18(double x, double y) {
std::array<double, 2> a;
a[1] = 2*x;
return a[1];
}

int main() {
double d_i, d_j;
INIT_GRADIENT(fn10);
INIT_GRADIENT(fn11);
INIT_GRADIENT(fn12);
INIT_GRADIENT(fn13);
INIT_GRADIENT(fn14);
INIT_GRADIENT(fn15);
INIT_GRADIENT(fn16);
INIT_GRADIENT(fn17);
INIT_GRADIENT(fn18);

TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00}
TEST_GRADIENT(fn12, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {4.00, 2.00}
TEST_GRADIENT(fn13, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {3.00, 0.00}
TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {6.00, 0.00}
TEST_GRADIENT(fn15, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {3.00, 0.00}
TEST_GRADIENT(fn16, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {108.00, 27.00}
TEST_GRADIENT(fn17, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {4.00, 2.00}
TEST_GRADIENT(fn18, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {2.00, 0.00}
}

// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) {
Expand Down Expand Up @@ -428,4 +473,191 @@ int main() {
// CHECK-NEXT: x = _t0;
// CHECK-NEXT: {{.*}}push_back_pullback(&_t1, _t0, &_d_a, &*_d_x);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn15_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: size_t _d_i = 0UL;
// CHECK-NEXT: size_t i = 0UL;
// CHECK-NEXT: clad::tape<std::array<double, 3> > _t3 = {};
// CHECK-NEXT: clad::tape<double> _t4 = {};
// CHECK-NEXT: clad::tape<std::array<double, 3> > _t5 = {};
// CHECK-NEXT: std::array<double, 3> _d_a({});
// CHECK-NEXT: std::array<double, 3> a;
// CHECK-NEXT: double _t0 = x;
// CHECK-NEXT: std::array<double, 3> _t1 = a;
// CHECK-NEXT: {{.*}}fill_reverse_forw(&a, x, &_d_a, *_d_x);
// CHECK-NEXT: double _d_res = 0.;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: unsigned long _t2 = 0UL;
// CHECK-NEXT: for (i = 0; ; ++i) {
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: clad::push(_t3, a);
// CHECK-NEXT: }
// CHECK-NEXT: if (!(i < a.size()))
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: _t2++;
// CHECK-NEXT: clad::push(_t4, res);
// CHECK-NEXT: clad::push(_t5, a);
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t6 = {{.*}}at_reverse_forw(&a, i, &_d_a, _r0);
// CHECK-NEXT: res += _t6.value;
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (;; _t2--) {
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_pullback(&clad::back(_t3), &_d_a);
// CHECK-NEXT: clad::pop(_t3);
// CHECK-NEXT: }
// CHECK-NEXT: if (!_t2)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: --i;
// CHECK-NEXT: {
// CHECK-NEXT: res = clad::pop(_t4);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: size_t _r0 = 0UL;
// CHECK-NEXT: {{.*}}at_pullback(&clad::back(_t5), i, _r_d0, &_d_a, &_r0);
// CHECK-NEXT: _d_i += _r0;
// CHECK-NEXT: clad::pop(_t5);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: x = _t0;
// CHECK-NEXT: {{.*}}fill_pullback(&_t1, _t0, &_d_a, &*_d_x);
// CHECK-NEXT: }
//}

// CHECK: void fn16_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::array<double, 2> _d_a({});
// CHECK-NEXT: std::array<double, 2> a;
// CHECK-NEXT: std::array<double, 2> _t0 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r0);
// CHECK-NEXT: double _t2 = _t1.value;
// CHECK-NEXT: _t1.value = 5;
// CHECK-NEXT: std::array<double, 2> _t3 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t4 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r1);
// CHECK-NEXT: double _t5 = _t4.value;
// CHECK-NEXT: _t4.value = y;
// CHECK-NEXT: std::array<double, 3> _d__b({});
// CHECK-NEXT: std::array<double, 3> _b0;
// CHECK-NEXT: std::array<double, 3> _t6 = _b0;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t7 = {{.*}}operator_subscript_reverse_forw(&_b0, 0, &_d__b, _r2);
// CHECK-NEXT: double _t8 = _t7.value;
// CHECK-NEXT: _t7.value = x;
// CHECK-NEXT: std::array<double, 3> _t9 = _b0;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t10 = {{.*}}operator_subscript_reverse_forw(&_b0, 1, &_d__b, _r3);
// CHECK-NEXT: double _t11 = _t10.value;
// CHECK-NEXT: _t10.value = 0;
// CHECK-NEXT: std::array<double, 3> _t12 = _b0;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t13 = {{.*}}operator_subscript_reverse_forw(&_b0, 2, &_d__b, _r4);
// CHECK-NEXT: double _t14 = _t13.value;
// CHECK-NEXT: _t13.value = x * x;
// CHECK-NEXT: ::clad::ValueAndAdjoint< ::std::array<double, 3UL>, ::std::array<double, 3UL> > _t15 = {{.*}}constructor_reverse_forw(clad::ConstructorReverseForwTag<array<double, 3> >(), _b0, _d__b);
// CHECK-NEXT: std::array<double, 3> _d_b = _t15.adjoint;
// CHECK-NEXT: const std::array<double, 3> b = _t15.value;
// CHECK-NEXT: std::array<double, 2> _t18 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t19 = {{.*}}back_reverse_forw(&a, &_d_a);
// CHECK-NEXT: std::array<double, 3> _t20 = b;
// CHECK-NEXT: std::array<double, 3>::value_type _t17 = b.front();
// CHECK-NEXT: std::array<double, 3> _t21 = b;
// CHECK-NEXT: std::array<double, 3>::value_type _t16 = b.at(2);
// CHECK-NEXT: std::array<double, 3> _t22 = b;
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}back_pullback(&_t18, 1 * _t16 * _t17, &_d_a);
// CHECK-NEXT: {{.*}}front_pullback(&_t20, _t19.value * 1 * _t16, &_d_b);
// CHECK-NEXT: {{.*}}size_type _r5 = 0UL;
// CHECK-NEXT: {{.*}}at_pullback(&_t21, 2, _t19.value * _t17 * 1, &_d_b, &_r5);
// CHECK-NEXT: {{.*}}size_type _r6 = 0UL;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t22, 1, 1, &_d_b, &_r6);
// CHECK-NEXT: }
// CHECK-NEXT: {{.*}}constructor_pullback(&b, _b0, &_d_b, &_d__b);
// CHECK-NEXT: {
// CHECK-NEXT: _t13.value = _t14;
// CHECK-NEXT: double _r_d4 = _t13.adjoint;
// CHECK-NEXT: _t13.adjoint = 0.;
// CHECK-NEXT: *_d_x += _r_d4 * x;
// CHECK-NEXT: *_d_x += x * _r_d4;
// CHECK-NEXT: {{.*}}size_type _r4 = 0UL;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t12, 2, 0., &_d__b, &_r4);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t10.value = _t11;
// CHECK-NEXT: double _r_d3 = _t10.adjoint;
// CHECK-NEXT: _t10.adjoint = 0.;
// CHECK-NEXT: {{.*}}size_type _r3 = 0UL;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t9, 1, 0., &_d__b, &_r3);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t7.value = _t8;
// CHECK-NEXT: double _r_d2 = _t7.adjoint;
// CHECK-NEXT: _t7.adjoint = 0.;
// CHECK-NEXT: *_d_x += _r_d2;
// CHECK-NEXT: {{.*}}size_type _r2 = 0UL;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t6, 0, 0., &_d__b, &_r2);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t4.value = _t5;
// CHECK-NEXT: double _r_d1 = _t4.adjoint;
// CHECK-NEXT: _t4.adjoint = 0.;
// CHECK-NEXT: *_d_y += _r_d1;
// CHECK-NEXT: {{.*}}size_type _r1 = 0UL;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t3, 1, 0., &_d_a, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t1.value = _t2;
// CHECK-NEXT: double _r_d0 = _t1.adjoint;
// CHECK-NEXT: _t1.adjoint = 0.;
// CHECK-NEXT: {{.*}}size_type _r0 = 0UL;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t0, 0, 0., &_d_a, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn17_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::array<double, 50> _d_a({});
// CHECK-NEXT: std::array<double, 50> a;
// CHECK-NEXT: std::array<double, 50> _t0 = a;
// CHECK-NEXT: {{.*}}fill_reverse_forw(&a, y + x + x, &_d_a, _r0);
// CHECK-NEXT: std::array<double, 50> _t1 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t2 = {{.*}}operator_subscript_reverse_forw(&a, 49, &_d_a, _r1);
// CHECK-NEXT: std::array<double, 50> _t3 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t4 = {{.*}}operator_subscript_reverse_forw(&a, 3, &_d_a, _r2);
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_type _r1 = 0UL;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 49, 1, &_d_a, &_r1);
// CHECK-NEXT: {{.*}}size_type _r2 = 0UL;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t3, 3, 1, &_d_a, &_r2);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = 0.;
// CHECK-NEXT: {{.*}}fill_pullback(&_t0, y + x + x, &_d_a, &_r0);
// CHECK-NEXT: *_d_y += _r0;
// CHECK-NEXT: *_d_x += _r0;
// CHECK-NEXT: *_d_x += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void fn18_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: std::array<double, 2> _d_a({});
// CHECK-NEXT: std::array<double, 2> a;
// CHECK-NEXT: std::array<double, 2> _t0 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r0);
// CHECK-NEXT: {{.*}} _t2 = _t1.value;
// CHECK-NEXT: _t1.value = 2 * x;
// CHECK-NEXT: std::array<double, 2> _t3 = a;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t4 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r1);
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}}size_type _r1 = 0UL;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t3, 1, 1, &_d_a, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: _t1.value = _t2;
// CHECK-NEXT: {{.*}} _r_d0 = _t1.adjoint;
// CHECK-NEXT: _t1.adjoint = 0.;
// CHECK-NEXT: *_d_x += 2 * _r_d0;
// CHECK-NEXT: {{.*}}size_type _r0 = 0UL;
// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t0, 1, 0., &_d_a, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit cc57837

Please sign in to comment.