diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 888191786..f967badb7 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2254,15 +2254,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // in Lblock beginBlock(direction::reverse); Ldiff = Visit(L, dfdx()); - auto* Lblock = endBlock(direction::reverse); if (L->HasSideEffects(m_Context)) { Expr* E = Ldiff.getExpr(); auto* storeE = - StoreAndRef(E, m_Context.getLValueReferenceType(E->getType())); - Ldiff.updateStmt(storeE); + GlobalStoreAndRef(BuildOp(UO_AddrOf, E)); + Ldiff.updateStmt(BuildOp(UO_Deref, storeE)); } + Stmts Lblock = EndBlockWithoutCreatingCS(direction::reverse); + Expr* LCloned = Ldiff.getExpr(); // For x, AssignedDiff is _d_x, for x[i] its _d_x[i], for reference exprs // like (x = y) it propagates recursively, so _d_x is also returned. @@ -2271,12 +2272,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return Clone(BinOp); ResultRef = AssignedDiff; // If assigned expr is dependent, first update its derivative; - auto Lblock_begin = Lblock->body_rbegin(); - auto Lblock_end = Lblock->body_rend(); - - if (dfdx() && Lblock_begin != Lblock_end) { - addToCurrentBlock(*Lblock_begin, direction::reverse); - Lblock_begin = std::next(Lblock_begin); + if (dfdx() && !Lblock.empty()) { + addToCurrentBlock(*Lblock.begin(), direction::reverse); + Lblock.erase(Lblock.begin()); } // Store the value of the LHS of the assignment in the forward pass @@ -2385,8 +2383,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, opCode); // Output statements from Visit(L). - for (auto it = Lblock_begin; it != Lblock_end; ++it) - addToCurrentBlock(*it, direction::reverse); + for (Stmt* S : Lblock) + addToCurrentBlock(S, direction::reverse); } else if (opCode == BO_Comma) { auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); diff --git a/test/Gradient/Assignments.C b/test/Gradient/Assignments.C index 20e111ccd..e073721d8 100644 --- a/test/Gradient/Assignments.C +++ b/test/Gradient/Assignments.C @@ -414,19 +414,20 @@ double f9(double x, double y) { //CHECK: void f9_grad(double x, double y, double *_d_x, double *_d_y) { //CHECK-NEXT: double _d_t = 0; //CHECK-NEXT: double _t0; +//CHECK-NEXT: double *_t1; //CHECK-NEXT: double _t2; //CHECK-NEXT: double t = x; //CHECK-NEXT: _t0 = t; -//CHECK-NEXT: double &_t1 = (t *= x); -//CHECK-NEXT: _t2 = _t1; -//CHECK-NEXT: _t1 *= y; +//CHECK-NEXT: _t1 = &(t *= x); +//CHECK-NEXT: _t2 = *_t1; +//CHECK-NEXT: *_t1 *= y; //CHECK-NEXT: _d_t += 1; //CHECK-NEXT: { -//CHECK-NEXT: _t1 = _t2; +//CHECK-NEXT: *_t1 = _t2; //CHECK-NEXT: double _r_d1 = _d_t; //CHECK-NEXT: _d_t -= _r_d1; //CHECK-NEXT: _d_t += _r_d1 * y; -//CHECK-NEXT: *_d_y += _t1 * _r_d1; +//CHECK-NEXT: *_d_y += *_t1 * _r_d1; //CHECK-NEXT: t = _t0; //CHECK-NEXT: double _r_d0 = _d_t; //CHECK-NEXT: _d_t -= _r_d0; @@ -473,15 +474,16 @@ double f11(double x, double y) { //CHECK: void f11_grad(double x, double y, double *_d_x, double *_d_y) { //CHECK-NEXT: double _d_t = 0; //CHECK-NEXT: double _t0; +//CHECK-NEXT: double *_t1; //CHECK-NEXT: double _t2; //CHECK-NEXT: double t = x; //CHECK-NEXT: _t0 = t; -//CHECK-NEXT: double &_t1 = (t = x); -//CHECK-NEXT: _t2 = _t1; -//CHECK-NEXT: _t1 = y; +//CHECK-NEXT: _t1 = &(t = x); +//CHECK-NEXT: _t2 = *_t1; +//CHECK-NEXT: *_t1 = y; //CHECK-NEXT: _d_t += 1; //CHECK-NEXT: { -//CHECK-NEXT: _t1 = _t2; +//CHECK-NEXT: *_t1 = _t2; //CHECK-NEXT: double _r_d1 = _d_t; //CHECK-NEXT: _d_t -= _r_d1; //CHECK-NEXT: *_d_y += _r_d1; @@ -504,6 +506,7 @@ double f12(double x, double y) { //CHECK-NEXT: bool _cond0; //CHECK-NEXT: double _t0; //CHECK-NEXT: double _t1; +//CHECK-NEXT: double *_t2; //CHECK-NEXT: double _t3; //CHECK-NEXT: double t; //CHECK-NEXT: _cond0 = x > y; @@ -511,16 +514,16 @@ double f12(double x, double y) { //CHECK-NEXT: _t0 = t; //CHECK-NEXT: else //CHECK-NEXT: _t1 = t; -//CHECK-NEXT: double &_t2 = (_cond0 ? (t = x) : (t = y)); -//CHECK-NEXT: _t3 = _t2; -//CHECK-NEXT: _t2 *= y; +//CHECK-NEXT: _t2 = &(_cond0 ? (t = x) : (t = y)); +//CHECK-NEXT: _t3 = *_t2; +//CHECK-NEXT: *_t2 *= y; //CHECK-NEXT: _d_t += 1; //CHECK-NEXT: { -//CHECK-NEXT: _t2 = _t3; +//CHECK-NEXT: *_t2 = _t3; //CHECK-NEXT: double _r_d2 = (_cond0 ? _d_t : _d_t); //CHECK-NEXT: (_cond0 ? _d_t : _d_t) -= _r_d2; //CHECK-NEXT: (_cond0 ? _d_t : _d_t) += _r_d2 * y; -//CHECK-NEXT: *_d_y += _t2 * _r_d2; +//CHECK-NEXT: *_d_y += *_t2 * _r_d2; //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: t = _t0; //CHECK-NEXT: double _r_d0 = _d_t;