diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index f967badb7..0cf4e2e9d 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2265,12 +2265,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Stmts Lblock = EndBlockWithoutCreatingCS(direction::reverse); Expr* LCloned = Ldiff.getExpr(); - // For x, AssignedDiff is _d_x, for x[i] its _d_x[i], for reference exprs + // For x, ResultRef is _d_x, for x[i] its _d_x[i], for reference exprs // like (x = y) it propagates recursively, so _d_x is also returned. - Expr* AssignedDiff = Ldiff.getExpr_dx(); - if (!AssignedDiff) + ResultRef = Ldiff.getExpr_dx(); + if (!ResultRef) return Clone(BinOp); - ResultRef = AssignedDiff; // If assigned expr is dependent, first update its derivative; if (dfdx() && !Lblock.empty()) { addToCurrentBlock(*Lblock.begin(), direction::reverse); @@ -2300,11 +2299,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // For pointer types, no need to store old derivatives. if (!isPointerOp) - oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d", + oldValue = StoreAndRef(ResultRef, direction::reverse, "_r_d", /*forceDeclCreation=*/true); if (opCode == BO_Assign) { // Add the statement `dl -= oldValue;` - addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), + addToCurrentBlock(BuildOp(BO_SubAssign, ResultRef, oldValue), direction::reverse); Rdiff = Visit(R, oldValue); valueForRevPass = Rdiff.getRevSweepAsExpr(); @@ -2335,7 +2334,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (isInsideLoop) addToCurrentBlock(LCloned, direction::forward); // Add the statement `dl -= oldValue;` - addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), + addToCurrentBlock(BuildOp(BO_SubAssign, ResultRef, oldValue), direction::reverse); /// Capture all the emitted statements while visiting R /// and insert them after `dl += dl * R` @@ -2344,7 +2343,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Rdiff = Visit(R, dr); Stmts RBlock = EndBlockWithoutCreatingCS(direction::reverse); addToCurrentBlock( - BuildOp(BO_AddAssign, AssignedDiff, + BuildOp(BO_AddAssign, ResultRef, BuildOp(BO_Mul, oldValue, Rdiff.getRevSweepAsExpr())), direction::reverse); for (auto& S : RBlock) @@ -2354,13 +2353,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, Rdiff.getExpr()); } else if (opCode == BO_DivAssign) { // Add the statement `dl -= oldValue;` - addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), + addToCurrentBlock(BuildOp(BO_SubAssign, ResultRef, oldValue), direction::reverse); auto RDelayed = DelayedGlobalStoreAndRef(R); StmtDiff RResult = RDelayed.Result; Expr* RStored = StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse); - addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, + addToCurrentBlock(BuildOp(BO_AddAssign, ResultRef, BuildOp(BO_Div, oldValue, RStored)), direction::reverse); if (!RDelayed.isConstant) {