diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index fd74add6e..da6f53b42 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2349,12 +2349,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Stmts Lblock = EndBlockWithoutCreatingCS(direction::reverse); Expr* LCloned = Ldiff.getExpr(); - // For x, AssignedDiff is _d_x, for x[i] its _d_x[i], for reference exprs + // For x, ResultRef is _d_x, for x[i] its _d_x[i], for reference exprs // like (x = y) it propagates recursively, so _d_x is also returned. - Expr* AssignedDiff = Ldiff.getExpr_dx(); - if (!AssignedDiff) + ResultRef = Ldiff.getExpr_dx(); + if (!ResultRef) return Clone(BinOp); - ResultRef = AssignedDiff; // If assigned expr is dependent, first update its derivative; if (dfdx() && !Lblock.empty()) { addToCurrentBlock(*Lblock.begin(), direction::reverse); @@ -2386,13 +2385,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // For pointer types, no need to store old derivatives. if (!isPointerOp) - oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d", + oldValue = StoreAndRef(ResultRef, direction::reverse, "_r_d", /*forceDeclCreation=*/true); if (opCode == BO_Assign) { if (!isPointerOp) { // Add the statement `dl = 0;` - Expr* zero = getZeroInit(AssignedDiff->getType()); - addToCurrentBlock(BuildOp(BO_Assign, AssignedDiff, zero), + Expr* zero = getZeroInit(ResultRef->getType()); + addToCurrentBlock(BuildOp(BO_Assign, ResultRef, zero), direction::reverse); } Rdiff = Visit(R, oldValue); @@ -2424,8 +2423,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (isInsideLoop) addToCurrentBlock(LCloned, direction::forward); // Add the statement `dl = 0;` - Expr* zero = getZeroInit(AssignedDiff->getType()); - addToCurrentBlock(BuildOp(BO_Assign, AssignedDiff, zero), + Expr* zero = getZeroInit(ResultRef->getType()); + addToCurrentBlock(BuildOp(BO_Assign, ResultRef, zero), direction::reverse); /// Capture all the emitted statements while visiting R /// and insert them after `dl += dl * R` @@ -2434,7 +2433,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Rdiff = Visit(R, dr); Stmts RBlock = EndBlockWithoutCreatingCS(direction::reverse); addToCurrentBlock( - BuildOp(BO_AddAssign, AssignedDiff, + BuildOp(BO_AddAssign, ResultRef, BuildOp(BO_Mul, oldValue, Rdiff.getRevSweepAsExpr())), direction::reverse); for (auto& S : RBlock) @@ -2444,14 +2443,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, Rdiff.getExpr()); } else if (opCode == BO_DivAssign) { // Add the statement `dl = 0;` - Expr* zero = getZeroInit(AssignedDiff->getType()); - addToCurrentBlock(BuildOp(BO_Assign, AssignedDiff, zero), + Expr* zero = getZeroInit(ResultRef->getType()); + addToCurrentBlock(BuildOp(BO_Assign, ResultRef, zero), direction::reverse); auto RDelayed = DelayedGlobalStoreAndRef(R); StmtDiff RResult = RDelayed.Result; Expr* RStored = StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse); - addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, + addToCurrentBlock(BuildOp(BO_AddAssign, ResultRef, BuildOp(BO_Div, oldValue, RStored)), direction::reverse); if (!RDelayed.isConstant) {