From 71ca82a3383cbef1fc6e6edeaee1e61632c225ce Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Sat, 21 Sep 2024 22:32:27 +0200 Subject: [PATCH] Fix some cases of `std::vector::push_back` in the rvs mode Fixes: #1071 --- include/clad/Differentiator/STLBuiltins.h | 16 +++++++++- test/Gradient/STLCustomDerivatives.C | 39 +++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/include/clad/Differentiator/STLBuiltins.h b/include/clad/Differentiator/STLBuiltins.h index 744bdff38..40b562dc4 100644 --- a/include/clad/Differentiator/STLBuiltins.h +++ b/include/clad/Differentiator/STLBuiltins.h @@ -392,13 +392,27 @@ size_pushforward(const ::std::array* a, // vector reverse mode // more can be found in tests: test/Gradient/STLCustomDerivatives.C +template +void push_back_reverse_forw(::std::vector* v, U val, ::std::vector* d_v, + pU /*d_val*/) { + v->push_back(val); + d_v->push_back(0); +} + template void push_back_reverse_forw(::std::vector* v, U val, ::std::vector* d_v, - U d_val) { + U /*d_val*/) { v->push_back(val); d_v->push_back(0); } +template +void push_back_pullback(::std::vector* v, U val, ::std::vector* d_v, + pU* d_val) { + *d_val += d_v->back(); + d_v->pop_back(); +} + template void push_back_pullback(::std::vector* v, U val, ::std::vector* d_v, U* d_val) { diff --git a/test/Gradient/STLCustomDerivatives.C b/test/Gradient/STLCustomDerivatives.C index f5739109f..adb1981cd 100644 --- a/test/Gradient/STLCustomDerivatives.C +++ b/test/Gradient/STLCustomDerivatives.C @@ -177,6 +177,13 @@ double fn20(double x, double y) { return res; // 11x+y } +double fn21(double x, double y) { + std::vector a; + a.push_back(0); + a[0] = x*x; + return a[0]; +} + int main() { double d_i, d_j; INIT_GRADIENT(fn10); @@ -190,6 +197,7 @@ int main() { INIT_GRADIENT(fn18); INIT_GRADIENT(fn19); INIT_GRADIENT(fn20); + INIT_GRADIENT(fn21); TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00} TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00} @@ -202,6 +210,7 @@ int main() { TEST_GRADIENT(fn18, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {2.00, 0.00} TEST_GRADIENT(fn19, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {3.00, 2.00} TEST_GRADIENT(fn20, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {11.00, 1.00} + TEST_GRADIENT(fn21, /*numOfDerivativeArgs=*/2, 3, 4, &d_i, &d_j); // CHECK-EXEC: {6.00, 0.00} } // CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) { @@ -841,3 +850,33 @@ int main() { // CHECK-NEXT: {{.*}}reserve_pullback(&_t0, 10, &_d_v, &_r0); // CHECK-NEXT: } // CHECK-NEXT: } + +// CHECK: void fn21_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: std::vector _d_a({}); +// CHECK-NEXT: std::vector a; +// CHECK-NEXT: std::vector _t0 = a; +// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, 0{{.*}}, &_d_a, _r0); +// CHECK-NEXT: std::vector _t1 = a; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t2 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r1); +// CHECK-NEXT: double _t3 = _t2.value; +// CHECK-NEXT: _t2.value = x * x; +// CHECK-NEXT: std::vector _t4 = a; +// CHECK-NEXT: {{.*}}ValueAndAdjoint _t5 = {{.*}}operator_subscript_reverse_forw(&a, 0, &_d_a, _r2); +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}}size_type _r2 = 0{{.*}}; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 0, 1, &_d_a, &_r2); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _t2.value = _t3; +// CHECK-NEXT: double _r_d0 = _t2.adjoint; +// CHECK-NEXT: _t2.adjoint = 0{{.*}}; +// CHECK-NEXT: *_d_x += _r_d0 * x; +// CHECK-NEXT: *_d_x += x * _r_d0; +// CHECK-NEXT: {{.*}}size_type _r1 = 0{{.*}}; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t1, 0, 0{{.*}}, &_d_a, &_r1); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}}value_type _r0 = 0.; +// CHECK-NEXT: {{.*}}push_back_pullback(&_t0, 0{{.*}}, &_d_a, &_r0); +// CHECK-NEXT: } +// CHECK-NEXT: }