diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index da656a33f..ff5020ef5 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2056,7 +2056,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); auto* resAdjoint = utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); - return StmtDiff(resValue, nullptr, resAdjoint); + return StmtDiff(resValue, resAdjoint, resAdjoint); } if (utils::isNonConstReferenceType(returnType) || returnType->isPointerType()) { diff --git a/test/Gradient/STLCustomDerivatives.C b/test/Gradient/STLCustomDerivatives.C index 03886c7d8..f226d0034 100644 --- a/test/Gradient/STLCustomDerivatives.C +++ b/test/Gradient/STLCustomDerivatives.C @@ -103,17 +103,27 @@ double fn13(double u, double v) { return vec[0] + vec[1] + vec[2]; } +double fn14(double x, double y) { + std::vector a; + a.push_back(x); + a.push_back(x); + a[1] = x*x; + return a[1]; +} + int main() { double d_i, d_j; INIT_GRADIENT(fn10); INIT_GRADIENT(fn11); INIT_GRADIENT(fn12); INIT_GRADIENT(fn13); + INIT_GRADIENT(fn14); TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00} TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00} TEST_GRADIENT(fn12, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {4.00, 2.00} TEST_GRADIENT(fn13, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {3.00, 0.00} + TEST_GRADIENT(fn14, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {6.00, 0.00} } // CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) { @@ -381,3 +391,41 @@ int main() { // CHECK-NEXT: {{.*}}constructor_pullback(&vec, count, u, allocator, &_d_vec, &_d_count, &*_d_u, &_d_allocator); // CHECK-NEXT: *_d_u += _d_res; // CHECK-NEXT: } + +// CHECK: void fn14_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: std::vector _d_a({}); +// CHECK-NEXT: std::vector a; +// CHECK-NEXT: double _t0 = x; +// CHECK-NEXT: std::vector _t1 = a; +// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, x, &_d_a, *_d_x); +// CHECK-NEXT: double _t2 = x; +// CHECK-NEXT: std::vector _t3 = a; +// CHECK-NEXT: {{.*}}push_back_reverse_forw(&a, x, &_d_a, *_d_x); +// CHECK-NEXT: std::vector _t4 = a; +// CHECK-NEXT: clad::ValueAndAdjoint _t5 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r0); +// CHECK-NEXT: double _t6 = _t5.value; +// CHECK-NEXT: _t5.value = x * x; +// CHECK-NEXT: std::vector _t7 = a; +// CHECK-NEXT: clad::ValueAndAdjoint _t8 = {{.*}}operator_subscript_reverse_forw(&a, 1, &_d_a, _r1); +// CHECK-NEXT: { +// CHECK-NEXT: {{.*}} _r1 = 0; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t7, 1, 1, &_d_a, &_r1); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: _t5.value = _t6; +// CHECK-NEXT: double _r_d0 = _t5.adjoint; +// CHECK-NEXT: _t5.adjoint = 0; +// CHECK-NEXT: *_d_x += _r_d0 * x; +// CHECK-NEXT: *_d_x += x * _r_d0; +// CHECK-NEXT: {{.*}} _r0 = 0; +// CHECK-NEXT: {{.*}}operator_subscript_pullback(&_t4, 1, 0, &_d_a, &_r0); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: x = _t2; +// CHECK-NEXT: {{.*}}push_back_pullback(&_t3, _t2, &_d_a, &*_d_x); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: x = _t0; +// CHECK-NEXT: {{.*}}push_back_pullback(&_t1, _t0, &_d_a, &*_d_x); +// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file