From cc57837897a180b4a665164683716cc03e526e4d Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 4 Sep 2024 13:12:39 +0200 Subject: [PATCH] Add support for `std::array` in the rvs mode Fixes: #1000 --- include/clad/Differentiator/STLBuiltins.h | 101 ++++++++++ test/Gradient/STLCustomDerivatives.C | 234 +++++++++++++++++++++- 2 files changed, 334 insertions(+), 1 deletion(-) diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index 147938080..462d8355d 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -253,6 +253,107 @@ void constructor_pullback(::std::vector* v, S count, U val, d_v->clear(); } +template +clad::ValueAndAdjoint operator_subscript_reverse_forw( + ::std::array* arr, typename ::std::array::size_type idx, + ::std::array* d_arr, typename ::std::array::size_type d_idx) { + return {(*arr)[idx], (*d_arr)[idx]}; +} +template +void operator_subscript_pullback( + ::std::array* arr, typename ::std::array::size_type idx, P d_y, + ::std::array* d_arr, typename ::std::array::size_type* d_idx) { + (*d_arr)[idx] += d_y; +} +template +clad::ValueAndAdjoint at_reverse_forw( + ::std::array* arr, typename ::std::array::size_type idx, + ::std::array* d_arr, typename ::std::array::size_type d_idx) { + return {(*arr)[idx], (*d_arr)[idx]}; +} +template +void at_pullback(::std::array* arr, + typename ::std::array::size_type idx, P d_y, + ::std::array* d_arr, + typename ::std::array::size_type* d_idx) { + (*d_arr)[idx] += d_y; +} +template +void fill_reverse_forw(::std::array* a, + const typename ::std::array::value_type& u, + ::std::array* d_a, + const typename ::std::array::value_type& d_u) { + a->fill(u); + d_a->fill(0); +} +template +void fill_pullback(::std::array* arr, + const typename ::std::array::value_type& u, + ::std::array* d_arr, + typename ::std::array::value_type* d_u) { + size_t _d_i = 0; + size_t i = 0; + clad::tape::value_type> _t1 = {}; + size_t _t0 = 0; + for (i = 0;; ++i) { + if (!(i < N)) + break; + _t0++; + clad::push(_t1, (*arr)[i]); + (*arr)[i] = u; + } + for (;; _t0--) { + if (!_t0) + break; + --i; + (*arr)[i] = clad::pop(_t1); + typename ::std::array::value_type _r_d0 = (*d_arr)[i]; + (*d_arr)[i] = 0; + *d_u += _r_d0; + } +} +template +clad::ValueAndAdjoint +back_reverse_forw(::std::array* arr, ::std::array* d_arr) noexcept { + return {arr->back(), d_arr->back()}; +} +template +void back_pullback(::std::array* arr, + typename ::std::array::value_type d_u, + ::std::array* d_arr) noexcept { + (*d_arr)[d_arr->size() - 1] += d_u; +} +template +clad::ValueAndAdjoint +front_reverse_forw(::std::array* arr, + ::std::array* d_arr) noexcept { + return {arr->front(), d_arr->front()}; +} +template +void front_pullback(::std::array* arr, + typename ::std::array::value_type d_u, + ::std::array* d_arr) { + (*d_arr)[0] += d_u; +} +template +void size_pullback(::std::array* a, ::std::array* d_a) noexcept {} +template +::clad::ValueAndAdjoint<::std::array, ::std::array> +constructor_reverse_forw(::clad::ConstructorReverseForwTag<::std::array>, + const ::std::array& arr, + const ::std::array& d_arr) { + ::std::array a = arr; + ::std::array d_a = d_arr; + return {a, d_a}; +} +template +void constructor_pullback(::std::array* a, const ::std::array& arr, + ::std::array* d_a, ::std::array* d_arr) { + for (size_t i = 0; i < N; ++i) + (*d_arr)[i] += (*d_a)[i]; + // d_a->fill(0); +} + } // namespace class_functions } // namespace custom_derivatives } // namespace clad diff --git a/test/Gradient/STLCustomDerivatives.C b/test/Gradient/STLCustomDerivatives.C index 3bcbd903a..f3ac774b1 100644 --- a/test/Gradient/STLCustomDerivatives.C +++ b/test/Gradient/STLCustomDerivatives.C @@ -10,6 +10,7 @@ #include "../TestUtils.h" #include "../PrintOverloads.h" +#include #include double fn10(double u, double v) { @@ -111,6 +112,42 @@ double fn14(double x, double y) { return a[1]; } +double fn15(double x, double y) { + std::array a; + a.fill(x); + + double res = 0; + for (size_t i = 0; i < a.size(); ++i) { + res += a.at(i); + } + + return res; +} + +double fn16(double x, double y) { + std::array a; + a[0] = 5; + a[1] = y; + std::array _b; + _b[0] = x; + _b[1] = 0; + _b[2] = x*x; + const std::array b = _b; + return a.back() * b.front() * b.at(2) + b[1]; +} + +double fn17(double x, double y) { + std::array a; + a.fill(y+x+x); + return a[49]+a[3]; +} + +double fn18(double x, double y) { + std::array a; + a[1] = 2*x; + return a[1]; +} + int main() { double d_i, d_j; INIT_GRADIENT(fn10); @@ -118,12 +155,20 @@ int main() { INIT_GRADIENT(fn12); INIT_GRADIENT(fn13); INIT_GRADIENT(fn14); + INIT_GRADIENT(fn15); + INIT_GRADIENT(fn16); + INIT_GRADIENT(fn17); + INIT_GRADIENT(fn18); TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00} TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00} TEST_GRADIENT(fn12, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {4.00, 2.00} TEST_GRADIENT(fn13, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {3.00, 0.00} TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {6.00, 0.00} + TEST_GRADIENT(fn15, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {3.00, 0.00} + TEST_GRADIENT(fn16, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {108.00, 27.00} + TEST_GRADIENT(fn17, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {4.00, 2.00} + TEST_GRADIENT(fn18, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {2.00, 0.00} } // CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) { @@ -428,4 +473,191 @@ int main() { // CHECK-NEXT: x = _t0; // CHECK-NEXT: {{.*}}push_back_pullback(&_t1, _t0, &_d_a, &*_d_x); // CHECK-NEXT: } -// CHECK-NEXT: } \ No newline at end of file +// CHECK-NEXT: } + +// CHECK: void fn15_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: size_t _d_i = 0UL; +// CHECK-NEXT: size_t i = 0UL; +// CHECK-NEXT: clad::tape > _t3 = {}; +// CHECK-NEXT: clad::tape _t4 = {}; +// CHECK-NEXT: clad::tape > _t5 = {}; +// CHECK-NEXT: std::array _d_a({}); +// CHECK-NEXT: std::array a; +// CHECK-NEXT: double _t0 = x; +// CHECK-NEXT: std::array _t1 = a; +// CHECK-NEXT: {{.*}}fill_reverse_forw(&a, x, &_d_a, *_d_x); +// CHECK-NEXT: double _d_res = 0.; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: unsigned long _t2 = 0UL; +// CHECK-NEXT: for (i = 0; ; ++i) { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: clad::push(_t3, a); +// CHECK-NEXT: } +// CHECK-NEXT: if (!(i < a.size())) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: _t2++; +// CHECK-NEXT: clad::push(_t4, res); +// CHECK-NEXT: clad::push(_t5, a); +// CHECK-NEXT: clad::ValueAndAdjoint _t6 = {{.*}}at_reverse_forw(&a, i, &_d_a, _r0); +// CHECK-NEXT: res += _t6.value; +// CHECK-NEXT: } +// CHECK-NEXT: _d_res += 1; +// CHECK-NEXT: for (;; _t2--) { +// CHECK-NEXT: { +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}}size_pullback(&clad::back(_t3), &_d_a); +// CHECK-NEXT: clad::pop(_t3); +// CHECK-NEXT: } +// CHECK-NEXT: if (!_t2) +// CHECK-NEXT: break; +// CHECK-NEXT: } +// CHECK-NEXT: --i; +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t4); +// CHECK-NEXT: double _r_d0 = _d_res; +// CHECK-NEXT: size_t _r0 = 0UL; +// CHECK-NEXT: {{.*}}at_pullback(&clad::back(_t5), i, _r_d0, &_d_a, &_r0); +// CHECK-NEXT: _d_i += _r0; +// CHECK-NEXT: clad::pop(_t5); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: x = _t0; +// CHECK-NEXT: {{.*}}fill_pullback(&_t1, _t0, &_d_a, &*_d_x); +// CHECK-NEXT: } +//} + +// CHECK: void fn16_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: std::array _d_a({}); +// CHECK-NEXT: std::array a; +// CHECK-NEXT: std::array _t0 = a; +// CHECK-NEXT: clad::ValueAndAdjoint _t1 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r0); +// CHECK-NEXT: double _t2 = _t1.value; +// CHECK-NEXT: _t1.value = 5; +// CHECK-NEXT: std::array _t3 = a; +// CHECK-NEXT: clad::ValueAndAdjoint _t4 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r1); +// CHECK-NEXT: double _t5 = _t4.value; +// CHECK-NEXT: _t4.value = y; +// CHECK-NEXT: std::array _d__b({}); +// CHECK-NEXT: std::array _b0; +// CHECK-NEXT: std::array _t6 = _b0; +// CHECK-NEXT: clad::ValueAndAdjoint _t7 = {{.*}}operator_subscript_reverse_forw(&_b0, 0, &_d__b, _r2); +// CHECK-NEXT: double _t8 = _t7.value; +// CHECK-NEXT: _t7.value = x; +// CHECK-NEXT: std::array _t9 = _b0; +// CHECK-NEXT: clad::ValueAndAdjoint _t10 = {{.*}}operator_subscript_reverse_forw(&_b0, 1, &_d__b, _r3); +// CHECK-NEXT: double _t11 = _t10.value; +// CHECK-NEXT: _t10.value = 0; +// CHECK-NEXT: std::array _t12 = _b0; +// CHECK-NEXT: clad::ValueAndAdjoint _t13 = {{.*}}operator_subscript_reverse_forw(&_b0, 2, &_d__b, _r4); +// CHECK-NEXT: double _t14 = _t13.value; +// CHECK-NEXT: _t13.value = x * x; +// CHECK-NEXT: ::clad::ValueAndAdjoint< ::std::array, ::std::array > _t15 = {{.*}}constructor_reverse_forw(clad::ConstructorReverseForwTag >(), _b0, _d__b); +// CHECK-NEXT: std::array _d_b = _t15.adjoint; +// CHECK-NEXT: const std::array b = _t15.value; +// CHECK-NEXT: std::array _t18 = a; +// CHECK-NEXT: clad::ValueAndAdjoint _t19 = {{.*}}back_reverse_forw(&a, &_d_a); +// CHECK-NEXT: std::array _t20 = b; +// CHECK-NEXT: std::array::value_type _t17 = b.front(); +// CHECK-NEXT: std::array _t21 = b; +// CHECK-NEXT: std::array::value_type _t16 = b.at(2); +// CHECK-NEXT: std::array _t22 = b; +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}}back_pullback(&_t18, 1 * _t16 * _t17, &_d_a); +// CHECK-NEXT: {{.*}}front_pullback(&_t20, _t19.value * 1 * _t16, &_d_b); +// CHECK-NEXT: {{.*}}size_type _r5 = 0UL; +// CHECK-NEXT: {{.*}}at_pullback(&_t21, 2, _t19.value * _t17 * 1, &_d_b, &_r5); +// CHECK-NEXT: {{.*}}size_type _r6 = 0UL; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t22, 1, 1, &_d_b, &_r6); +// CHECK-NEXT: } +// CHECK-NEXT: {{.*}}constructor_pullback(&b, _b0, &_d_b, &_d__b); +// CHECK-NEXT: { +// CHECK-NEXT: _t13.value = _t14; +// CHECK-NEXT: double _r_d4 = _t13.adjoint; +// CHECK-NEXT: _t13.adjoint = 0.; +// CHECK-NEXT: *_d_x += _r_d4 * x; +// CHECK-NEXT: *_d_x += x * _r_d4; +// CHECK-NEXT: {{.*}}size_type _r4 = 0UL; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t12, 2, 0., &_d__b, &_r4); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _t10.value = _t11; +// CHECK-NEXT: double _r_d3 = _t10.adjoint; +// CHECK-NEXT: _t10.adjoint = 0.; +// CHECK-NEXT: {{.*}}size_type _r3 = 0UL; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t9, 1, 0., &_d__b, &_r3); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _t7.value = _t8; +// CHECK-NEXT: double _r_d2 = _t7.adjoint; +// CHECK-NEXT: _t7.adjoint = 0.; +// CHECK-NEXT: *_d_x += _r_d2; +// CHECK-NEXT: {{.*}}size_type _r2 = 0UL; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t6, 0, 0., &_d__b, &_r2); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _t4.value = _t5; +// CHECK-NEXT: double _r_d1 = _t4.adjoint; +// CHECK-NEXT: _t4.adjoint = 0.; +// CHECK-NEXT: *_d_y += _r_d1; +// CHECK-NEXT: {{.*}}size_type _r1 = 0UL; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t3, 1, 0., &_d_a, &_r1); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _t1.value = _t2; +// CHECK-NEXT: double _r_d0 = _t1.adjoint; +// CHECK-NEXT: _t1.adjoint = 0.; +// CHECK-NEXT: {{.*}}size_type _r0 = 0UL; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t0, 0, 0., &_d_a, &_r0); +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void fn17_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: std::array _d_a({}); +// CHECK-NEXT: std::array a; +// CHECK-NEXT: std::array _t0 = a; +// CHECK-NEXT: {{.*}}fill_reverse_forw(&a, y + x + x, &_d_a, _r0); +// CHECK-NEXT: std::array _t1 = a; +// CHECK-NEXT: clad::ValueAndAdjoint _t2 = {{.*}}operator_subscript_reverse_forw(&a, 49, &_d_a, _r1); +// CHECK-NEXT: std::array _t3 = a; +// CHECK-NEXT: clad::ValueAndAdjoint _t4 = {{.*}}operator_subscript_reverse_forw(&a, 3, &_d_a, _r2); +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}}size_type _r1 = 0UL; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 49, 1, &_d_a, &_r1); +// CHECK-NEXT: {{.*}}size_type _r2 = 0UL; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t3, 3, 1, &_d_a, &_r2); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r0 = 0.; +// CHECK-NEXT: {{.*}}fill_pullback(&_t0, y + x + x, &_d_a, &_r0); +// CHECK-NEXT: *_d_y += _r0; +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void fn18_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: std::array _d_a({}); +// CHECK-NEXT: std::array a; +// CHECK-NEXT: std::array _t0 = a; +// CHECK-NEXT: clad::ValueAndAdjoint _t1 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r0); +// CHECK-NEXT: {{.*}} _t2 = _t1.value; +// CHECK-NEXT: _t1.value = 2 * x; +// CHECK-NEXT: std::array _t3 = a; +// CHECK-NEXT: clad::ValueAndAdjoint _t4 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r1); +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}}size_type _r1 = 0UL; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t3, 1, 1, &_d_a, &_r1); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _t1.value = _t2; +// CHECK-NEXT: {{.*}} _r_d0 = _t1.adjoint; +// CHECK-NEXT: _t1.adjoint = 0.; +// CHECK-NEXT: *_d_x += 2 * _r_d0; +// CHECK-NEXT: {{.*}}size_type _r0 = 0UL; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t0, 1, 0., &_d_a, &_r0); +// CHECK-NEXT: } +// CHECK-NEXT: } +