From 7eef009811b91cb4e41c3ecf65ebb4da97bcbfd9 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Fri, 19 Jul 2024 18:01:28 +0200 Subject: [PATCH] Add support for `std::array` in the fwd mode Fixes: #829 --- include/clad/Differentiator/STLBuiltins.h | 74 +++++++++++++++++++++++ test/ForwardMode/STLCustomDerivatives.C | 72 +++++++++++++++++++++- 2 files changed, 145 insertions(+), 1 deletion(-) diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index 12a8048f7..9e5b39963 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -128,6 +128,80 @@ operator_subscript_pushforward(const ::std::vector* v, unsigned idx, return {(*v)[idx], (*d_v)[idx]}; } +template +constexpr clad::ValueAndPushforward +operator_subscript_pushforward(::std::array* a, ::std::size_t i, + ::std::array* _d_a, + ::std::size_t _d_i) noexcept { + return {(*a)[i], (*_d_a)[i]}; +} + +template +constexpr clad::ValueAndPushforward +at_pushforward(::std::array* a, ::std::size_t i, ::std::array* _d_a, + ::std::size_t _d_i) noexcept { + return {(*a)[i], (*_d_a)[i]}; +} + +template +constexpr clad::ValueAndPushforward +operator_subscript_pushforward(const ::std::array* a, ::std::size_t i, + const ::std::array* _d_a, + ::std::size_t _d_i) noexcept { + return {(*a)[i], (*_d_a)[i]}; +} + +template +constexpr clad::ValueAndPushforward +at_pushforward(const ::std::array* a, ::std::size_t i, + const ::std::array* _d_a, ::std::size_t _d_i) noexcept { + return {(*a)[i], (*_d_a)[i]}; +} + +template +constexpr clad::ValueAndPushforward<::std::array&, ::std::array&> +operator_equal_pushforward(::std::array* a, + const ::std::array& param, + ::std::array* _d_a, + const ::std::array& _d_param) noexcept { + (*a) = param; + (*_d_a) = _d_param; + return {*a, *_d_a}; +} + +template +inline constexpr clad::ValueAndPushforward +front_pushforward(const ::std::array* a, + const ::std::array* _d_a) noexcept { + return {a->front(), _d_a->front()}; +} + +template +inline constexpr clad::ValueAndPushforward +front_pushforward(::std::array* a, ::std::array* _d_a) noexcept { + return {a->front(), _d_a->front()}; +} + +template +inline constexpr clad::ValueAndPushforward +back_pushforward(const ::std::array* a, + const ::std::array* _d_a) noexcept { + return {a->back(), _d_a->back()}; +} + +template +inline constexpr clad::ValueAndPushforward +back_pushforward(::std::array* a, ::std::array* _d_a) noexcept { + return {a->back(), _d_a->back()}; +} + +template +void fill_pushforward(::std::array* a, const T& u, + ::std::array* _d_a, const T& _d_u) { + a->fill(u); + _d_a->fill(_d_u); +} + } // namespace class_functions } // namespace custom_derivatives } // namespace clad diff --git a/test/ForwardMode/STLCustomDerivatives.C b/test/ForwardMode/STLCustomDerivatives.C index 5943227a0..537e1626d 100644 --- a/test/ForwardMode/STLCustomDerivatives.C +++ b/test/ForwardMode/STLCustomDerivatives.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oSTLCustomDerivatives.out | %filecheck %s +// RUN: %cladclang %s -Wc++14-extensions -I%S/../../include -oSTLCustomDerivatives.out | %filecheck %s // RUN: ./STLCustomDerivatives.out | %filecheck_exec %s // CHECK-NOT: {{.*error|warning|note:.*}} @@ -116,14 +116,84 @@ double fnVec4(double u, double v) { // CHECK-NEXT: return _t1.pushforward * _t4 + _t3 * _t2.pushforward; // CHECK-NEXT: } +double fnArr1(double x) { + std::array a; + a.fill(x); + + for (size_t i = 0; i < a.size(); ++i) { + a[i] *= i; + } + + double res = 0; + for (size_t i = 0; i < a.size(); ++i) { + res += a[i]; + } + + return res; +} + +//CHECK: double fnArr1_darg0(double x) { +//CHECK-NEXT: double _d_x = 1; +//CHECK-NEXT: std::array _d_a; +//CHECK-NEXT: std::array a; +//CHECK-NEXT: clad::custom_derivatives::class_functions::fill_pushforward(&a, x, &_d_a, _d_x); +//CHECK-NEXT: { +//CHECK-NEXT: size_t _d_i = 0; +//CHECK-NEXT: for (size_t i = 0; i < a.size(); ++i) { +//CHECK-NEXT: clad::ValueAndPushforward _t0 = clad::custom_derivatives::class_functions::operator_subscript_pushforward(&a, i, &_d_a, _d_i); +//CHECK-NEXT: double &_t1 = _t0.pushforward; +//CHECK-NEXT: double &_t2 = _t0.value; +//CHECK-NEXT: _t1 = _t1 * i + _t2 * _d_i; +//CHECK-NEXT: _t2 *= i; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: double _d_res = 0; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: { +//CHECK-NEXT: size_t _d_i = 0; +//CHECK-NEXT: for (size_t i = 0; i < a.size(); ++i) { +//CHECK-NEXT: clad::ValueAndPushforward _t3 = clad::custom_derivatives::class_functions::operator_subscript_pushforward(&a, i, &_d_a, _d_i); +//CHECK-NEXT: _d_res += _t3.pushforward; +//CHECK-NEXT: res += _t3.value; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: return _d_res; +//CHECK-NEXT: } + +double fnArr2(double x) { + std::array a{5, x}; + const std::array b{x, 0, x*x}; + return a.back() * b.front() * b.at(2); +} + +//CHECK: double fnArr2_darg0(double x) { +//CHECK-NEXT: double _d_x = 1; +//CHECK-NEXT: std::array _d_a{{[{][{]0, _d_x[}][}]}}; +//CHECK-NEXT: std::array a{{[{][{]5, x[}][}]}}; +//CHECK-NEXT: const std::array _d_b{{[{][{]_d_x, 0, _d_x \* x \+ x \* _d_x[}][}]}}; +//CHECK-NEXT: const std::array b{{[{][{]x, 0, x \* x[}][}]}}; +//CHECK-NEXT: clad::ValueAndPushforward _t0 = clad::custom_derivatives::class_functions::back_pushforward(&a, &_d_a); +//CHECK-NEXT: clad::ValueAndPushforward _t1 = clad::custom_derivatives::class_functions::front_pushforward(&b, &_d_b); +//CHECK-NEXT: double &_t2 = _t0.value; +//CHECK-NEXT: const double _t3 = _t1.value; +//CHECK-NEXT: clad::ValueAndPushforward _t4 = clad::custom_derivatives::class_functions::at_pushforward(&b, 2, &_d_b, 0); +//CHECK-NEXT: double _t5 = _t2 * _t3; +//CHECK-NEXT: const double _t6 = _t4.value; +//CHECK-NEXT: return (_t0.pushforward * _t3 + _t2 * _t1.pushforward) * _t6 + _t5 * _t4.pushforward; +//CHECK-NEXT: } + int main() { INIT_DIFFERENTIATE(fnVec1, "u"); INIT_DIFFERENTIATE(fnVec2, "u"); INIT_DIFFERENTIATE(fnVec3, "u"); INIT_DIFFERENTIATE(fnVec4, "u"); + INIT_DIFFERENTIATE(fnArr1, "x"); + INIT_DIFFERENTIATE(fnArr2, "x"); TEST_DIFFERENTIATE(fnVec1, 3, 5); // CHECK-EXEC: {10.00} TEST_DIFFERENTIATE(fnVec2, 3, 5); // CHECK-EXEC: {5.00} TEST_DIFFERENTIATE(fnVec3, 3, 5); // CHECK-EXEC: {2.00} TEST_DIFFERENTIATE(fnVec4, 3, 5); // CHECK-EXEC: {30.00} + TEST_DIFFERENTIATE(fnArr1, 3); // CHECK-EXEC: {3.00} + TEST_DIFFERENTIATE(fnArr2, 3); // CHECK-EXEC: {108.00} } \ No newline at end of file