From 1bdc579929e1034e9db3328d049a07c652675421 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 21 Sep 2023 14:33:48 +0530 Subject: [PATCH] fix DifferentiateVarDecl for constructors in reverse mode --- lib/Differentiator/ReverseModeVisitor.cpp | 12 ++++++--- test/Gradient/MemberFunctions.C | 33 +++++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 55256330e..6c5d09cc2 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2440,11 +2440,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // need to call `Visit` since non-local variables are not differentiated. if (!isDerivativeOfRefType) { Expr* derivedE = BuildDeclRef(VDDerived); - initDiff = VD->getInit() ? Visit(VD->getInit(), derivedE) : StmtDiff{}; + initDiff = StmtDiff{}; + if (VD->getInit()) { + if (isa(VD->getInit())) + initDiff = Visit(VD->getInit()); + else + initDiff = Visit(VD->getInit(), derivedE); + } // If we are differentiating `VarDecl` corresponding to a local variable // inside a loop, then we need to reset it to 0 at each iteration. - // + // // for example, if defined inside a loop, // ``` // double localVar = i; @@ -2454,7 +2460,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // { // *_d_i += _d_localVar; // _d_localVar = 0; - // } + // } if (isInsideLoop) { Stmt* assignToZero = BuildOp(BinaryOperatorKind::BO_Assign, BuildDeclRef(VDDerived), diff --git a/test/Gradient/MemberFunctions.C b/test/Gradient/MemberFunctions.C index a6ffeac36..03c6b07fe 100644 --- a/test/Gradient/MemberFunctions.C +++ b/test/Gradient/MemberFunctions.C @@ -786,6 +786,10 @@ double fn2(SimpleFunctions& sf, double i) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn3(double x, double y, double i, double j) { + SimpleFunctions sf(x, y); + return sf.mem_fn(i, j); +} int main() { auto d_mem_fn = clad::gradient(&SimpleFunctions::mem_fn); @@ -880,4 +884,33 @@ int main() { // CHECK-NEXT: * _d_j += _r3; // CHECK-NEXT: } // CHECK-NEXT: } + + auto d_fn3 = clad::gradient(fn3, "i,j"); + result[0] = result[1] = 0; + d_fn3.execute(2, 3, 4, 5, &result[0], &result[1]); + printf("%.2f %.2f", result[0], result[1]); // CHECK-EXEC: 10.00 4.00 + + // CHECK: void fn3_grad_2_3(double x, double y, double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _d_x = 0; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: SimpleFunctions _d_sf({}); +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: SimpleFunctions _t2; +// CHECK-NEXT: SimpleFunctions sf(x, y); +// CHECK-NEXT: _t0 = i; +// CHECK-NEXT: _t1 = j; +// CHECK-NEXT: _t2 = sf; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _grad0 = 0.; +// CHECK-NEXT: double _grad1 = 0.; +// CHECK-NEXT: _t2.mem_fn_pullback(_t0, _t1, 1, &_d_sf, &_grad0, &_grad1); +// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: * _d_i += _r0; +// CHECK-NEXT: double _r1 = _grad1; +// CHECK-NEXT: * _d_j += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } }