diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6357d78fe..895361ee6 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2159,6 +2159,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff Ldiff{}; StmtDiff Rdiff{}; StmtDiff Lstored{}; + Expr* valueForRevPass = nullptr; auto L = BinOp->getLHS(); auto R = BinOp->getRHS(); // If it is an assignment operator, its result is a reference to LHS and @@ -2260,8 +2261,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!LDRE && !RDRE) return Clone(BinOp); - Expr* LExpr = LDRE ? Visit(L).getExpr() : L; - Expr* RExpr = RDRE ? Visit(R).getExpr() : R; + Expr* LExpr = LDRE ? Visit(L).getRevSweepExpr() : L; + Expr* RExpr = RDRE ? Visit(R).getRevSweepExpr() : R; return BuildOp(opCode, LExpr, RExpr); } @@ -2380,14 +2381,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (opCode == BO_Assign) { Rdiff = Visit(R, oldValue); + valueForRevPass = Rdiff.getRevSweepExpr(); } else if (opCode == BO_AddAssign) { addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), direction::reverse); Rdiff = Visit(R, oldValue); + valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepExpr(), Ldiff.getRevSweepExpr()); } else if (opCode == BO_SubAssign) { addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), direction::reverse); Rdiff = Visit(R, BuildOp(UO_Minus, oldValue)); + valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepExpr(), Ldiff.getRevSweepExpr()); } else if (opCode == BO_MulAssign) { auto RDelayed = DelayedGlobalStoreAndRef(R); StmtDiff RResult = RDelayed.Result; @@ -2418,6 +2422,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Rdiff = Visit(R, dr); RDelayed.Finalize(Rdiff.getExpr()); } + valueForRevPass = BuildOp(BO_Mul, Rdiff.getRevSweepExpr(), Ldiff.getRevSweepExpr()); std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr()); } else if (opCode == BO_DivAssign) { auto RDelayed = DelayedGlobalStoreAndRef(R); @@ -2439,6 +2444,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Rdiff = Visit(R, dr); RDelayed.Finalize(Rdiff.getExpr()); } + valueForRevPass = BuildOp(BO_Div, Rdiff.getRevSweepExpr(), Ldiff.getRevSweepExpr()); std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr()); } else llvm_unreachable("unknown assignment opCode"); @@ -2455,6 +2461,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); Ldiff = Visit(L, zero); Rdiff = Visit(R, dfdx()); + valueForRevPass = Ldiff.getRevSweepExpr(); ResultRef = Ldiff.getExpr(); } else { // We should not output any warning on visiting boolean conditions