diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 0473684d8..fd74add6e 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2335,8 +2335,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (L->HasSideEffects(m_Context)) { Expr* E = Ldiff.getExpr(); - auto* storeE = GlobalStoreAndRef(BuildOp(UO_AddrOf, E)); - Ldiff.updateStmt(BuildOp(UO_Deref, storeE)); + llvm::SmallVector returnExprs; + utils::GetInnermostReturnExpr(E, returnExprs); + if (returnExprs.size() == 1) { + addToCurrentBlock(E, direction::forward); + Ldiff.updateStmt(returnExprs[0]); + } else { + auto* storeE = GlobalStoreAndRef(BuildOp(UO_AddrOf, E)); + Ldiff.updateStmt(BuildOp(UO_Deref, storeE)); + } } Stmts Lblock = EndBlockWithoutCreatingCS(direction::reverse); diff --git a/test/Gradient/Assignments.C b/test/Gradient/Assignments.C index 7a010d838..2913840c9 100644 --- a/test/Gradient/Assignments.C +++ b/test/Gradient/Assignments.C @@ -414,20 +414,19 @@ double f9(double x, double y) { //CHECK: void f9_grad(double x, double y, double *_d_x, double *_d_y) { //CHECK-NEXT: double _d_t = 0; //CHECK-NEXT: double _t0; -//CHECK-NEXT: double *_t1; -//CHECK-NEXT: double _t2; +//CHECK-NEXT: double _t1; //CHECK-NEXT: double t = x; //CHECK-NEXT: _t0 = t; -//CHECK-NEXT: _t1 = &(t *= x); -//CHECK-NEXT: _t2 = *_t1; -//CHECK-NEXT: *_t1 *= y; +//CHECK-NEXT: (t *= x); +//CHECK-NEXT: _t1 = t; +//CHECK-NEXT: t *= y; //CHECK-NEXT: _d_t += 1; //CHECK-NEXT: { -//CHECK-NEXT: *_t1 = _t2; +//CHECK-NEXT: t = _t1; //CHECK-NEXT: double _r_d1 = _d_t; //CHECK-NEXT: _d_t = 0; //CHECK-NEXT: _d_t += _r_d1 * y; -//CHECK-NEXT: *_d_y += *_t1 * _r_d1; +//CHECK-NEXT: *_d_y += t * _r_d1; //CHECK-NEXT: t = _t0; //CHECK-NEXT: double _r_d0 = _d_t; //CHECK-NEXT: _d_t = 0; @@ -474,16 +473,15 @@ double f11(double x, double y) { //CHECK: void f11_grad(double x, double y, double *_d_x, double *_d_y) { //CHECK-NEXT: double _d_t = 0; //CHECK-NEXT: double _t0; -//CHECK-NEXT: double *_t1; -//CHECK-NEXT: double _t2; +//CHECK-NEXT: double _t1; //CHECK-NEXT: double t = x; //CHECK-NEXT: _t0 = t; -//CHECK-NEXT: _t1 = &(t = x); -//CHECK-NEXT: _t2 = *_t1; -//CHECK-NEXT: *_t1 = y; +//CHECK-NEXT: (t = x); +//CHECK-NEXT: _t1 = t; +//CHECK-NEXT: t = y; //CHECK-NEXT: _d_t += 1; //CHECK-NEXT: { -//CHECK-NEXT: *_t1 = _t2; +//CHECK-NEXT: t = _t1; //CHECK-NEXT: double _r_d1 = _d_t; //CHECK-NEXT: _d_t = 0; //CHECK-NEXT: *_d_y += _r_d1;