diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index c4879cf27..115104f10 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -43,6 +43,9 @@ namespace clad { std::vector m_Reverse; /// Accumulates local variables for all visited blocks. std::vector> m_Locals; + /// Stores all expressions used as placeholders which have to be + /// reset later. + std::set m_Placeholders; /// Stack is used to pass the arguments (dfdx) to further nodes /// in the Visit method. std::stack m_Stack; @@ -180,6 +183,16 @@ namespace clad { return blk; } + clang::Expr* Clone(const clang::Expr* E) { + if (m_Placeholders.find(E)!=m_Placeholders.end()) + return const_cast(E); + return VisitorBase::Clone(E); + } + + clang::Stmt* Clone(const clang::Stmt* E) { + return VisitorBase::Clone(E); + } + /// Add more comments. void EmitRevSweepDecls(); @@ -282,15 +295,15 @@ namespace clad { struct DelayedStoreResult { ReverseModeVisitor& V; StmtDiff Result; - bool isConstant; bool isInsideLoop; bool needsUpdate; + clang::Expr* Placeholder; DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult, - bool pIsConstant, bool pIsInsideLoop, - bool pNeedsUpdate = false) - : V(pV), Result(pResult), isConstant(pIsConstant), - isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {} - void Finalize(clang::Expr* New); + bool pIsInsideLoop, bool pNeedsUpdate = false, + clang::Expr* pPlaceholder = nullptr) + : V(pV), Result(pResult), isInsideLoop(pIsInsideLoop), + needsUpdate(pNeedsUpdate), Placeholder(pPlaceholder) {} + void Finalize(StmtDiff New); }; /// Sometimes (e.g. when visiting multiplication/division operator), we @@ -302,7 +315,8 @@ namespace clad { /// This is what DelayedGlobalStoreAndRef does. E is expected to be the /// original (uncloned) expression. DelayedStoreResult DelayedGlobalStoreAndRef(clang::Expr* E, - llvm::StringRef prefix = "_t"); + llvm::StringRef prefix = "_t", + bool forceNoRecompute = false); struct CladTapeResult { ReverseModeVisitor& V; diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 7b33c9ee5..e5ed047bc 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2204,17 +2204,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // to reduce cloning complexity and only clones once. Storing it in a // global variable allows to save current result and make it accessible // in the reverse pass. - std::unique_ptr RDelayed; - StmtDiff RResult; - // If R has no side effects, it can be just cloned - // (no need to store it). - if (!ShouldRecompute(R)) { - RDelayed = std::unique_ptr( - new DelayedStoreResult(DelayedGlobalStoreAndRef(R))); - RResult = RDelayed->Result; - } else { - RResult = StmtDiff(Clone(R)); - } + DelayedStoreResult RDelayed = DelayedGlobalStoreAndRef(R); + StmtDiff& RResult = RDelayed.Result; Expr* dl = nullptr; if (dfdx()) @@ -2228,15 +2219,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, LStored = GlobalStoreAndRef(LStored.getExpr(), /*prefix=*/"_t", /*force=*/true); Expr::EvalResult dummy; - if (RDelayed || - !clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) { + if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) { Expr* dr = nullptr; if (dfdx()) dr = BuildOp(BO_Mul, LStored.getRevSweepAsExpr(), dfdx()); Rdiff = Visit(R, dr); // Assign right multiplier's variable with R. - if (RDelayed) - RDelayed->Finalize(Rdiff.getExpr()); + RDelayed.Finalize(Rdiff); } std::tie(Ldiff, Rdiff) = std::make_pair(LStored.getExpr(), RResult.getExpr()); @@ -2244,8 +2233,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // xi = xl / xr // dxi/xl = 1 / xr // df/dxl += df/dxi * dxi/xl = df/dxi * (1/xr) - auto RDelayed = DelayedGlobalStoreAndRef(R); - StmtDiff RResult = RDelayed.Result; + auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t", /*forceNoRecompute=*/false); + StmtDiff& RResult = RDelayed.Result; Expr* RStored = StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse); Expr* dl = nullptr; @@ -2260,7 +2249,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr)) // Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is // produced instead of 1 / (R * R). - if (!RDelayed.isConstant) { + Expr::EvalResult dummy; + if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) { Expr* dr = nullptr; if (dfdx()) { Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored)); @@ -2271,7 +2261,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, dr = StoreAndRef(dr, direction::reverse); } Rdiff = Visit(R, dr); - RDelayed.Finalize(Rdiff.getExpr()); + RDelayed.Finalize(Rdiff); } std::tie(Ldiff, Rdiff) = std::make_pair(LStored.getExpr(), RResult.getExpr()); @@ -2445,14 +2435,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Ldiff.getRevSweepAsExpr()); std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, Rdiff.getExpr()); } else if (opCode == BO_DivAssign) { - auto RDelayed = DelayedGlobalStoreAndRef(R); - StmtDiff RResult = RDelayed.Result; + auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t", /*forceNoRecompute=*/true); + StmtDiff& RResult = RDelayed.Result; Expr* RStored = StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse); addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, BuildOp(BO_Div, oldValue, RStored)), direction::reverse); - if (!RDelayed.isConstant) { + Expr::EvalResult dummy; + if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) { if (isInsideLoop) addToCurrentBlock(LCloned, direction::forward); Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored)); @@ -2460,7 +2451,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(UO_Minus, BuildOp(BO_Div, LCloned, RxR))); dr = StoreAndRef(dr, direction::reverse); Rdiff = Visit(R, dr); - RDelayed.Finalize(Rdiff.getExpr()); + RDelayed.Finalize(Rdiff); } valueForRevPass = BuildOp(BO_Div, Rdiff.getRevSweepAsExpr(), Ldiff.getRevSweepAsExpr()); @@ -2967,47 +2958,85 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {}; } - void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) { - if (isConstant || !needsUpdate) + void ReverseModeVisitor::DelayedStoreResult::Finalize(StmtDiff New) { + class PlaceholderReplacer : public RecursiveASTVisitor { + public: + const Expr* placeholder; + Expr* newExpr; + PlaceholderReplacer(const Expr* Placeholder) + : placeholder(Placeholder), newExpr(nullptr) {} + + bool VisitExpr(Expr *E) { + for (auto iter = E->child_begin(), e = E->child_end(); + iter!=e; ++iter) { + if (*iter == placeholder) + *iter = newExpr; + else + TraverseStmt(*iter); + } + return true; + } + }; + + if (!needsUpdate) + return; + + if (Placeholder!=nullptr) { + PlaceholderReplacer repl(Placeholder); + repl.newExpr = New.getExpr(); + for (Stmt* S : V.getCurrentBlock(direction::forward)) + repl.TraverseStmt(S); + repl.newExpr = New.getRevSweepAsExpr(); + for (Stmt* S : V.getCurrentBlock(direction::reverse)) + repl.TraverseStmt(S); + Result = New; + V.m_Placeholders.erase(Placeholder); return; + } + if (isInsideLoop) { auto* Push = cast(Result.getExpr()); unsigned lastArg = Push->getNumArgs() - 1; - Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New).get()); + Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New.getExpr()).get()); } else { - V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New), + V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New.getExpr()), direction::forward); } } ReverseModeVisitor::DelayedStoreResult ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E, - llvm::StringRef prefix) { + llvm::StringRef prefix, + bool forceNoRecompute) { assert(E && "must be provided"); if (!UsefulToStore(E)) { StmtDiff Ediff = Visit(E); Expr::EvalResult evalRes; - bool isConst = - clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context); return DelayedStoreResult{*this, Ediff, - /*isConstant*/ isConst, /*isInsideLoop*/ false, /*pNeedsUpdate=*/false}; } + if (!forceNoRecompute && ShouldRecompute(E)) { + Expr* PH + = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); + m_Placeholders.insert(PH); + return DelayedStoreResult{*this, StmtDiff{PH, nullptr, nullptr, PH}, + /*isInsideLoop*/ false, + /*pNeedsUpdate=*/true, + /*pPlaceholder=*/PH}; + } if (isInsideLoop) { Expr* dummy = E; auto CladTape = MakeCladTapeFor(dummy); Expr* Push = CladTape.Push; Expr* Pop = CladTape.Pop; return DelayedStoreResult{*this, StmtDiff{Push, nullptr, nullptr, Pop}, - /*isConstant*/ false, /*isInsideLoop*/ true, /*pNeedsUpdate=*/true}; } Expr* Ref = BuildDeclRef(GlobalStoreImpl( getNonConstType(E->getType(), m_Context, m_Sema), prefix)); // Return reference to the declaration instead of original expression. return DelayedStoreResult{*this, StmtDiff{Ref, nullptr, nullptr, Ref}, - /*isConstant*/ false, /*isInsideLoop*/ false, /*pNeedsUpdate=*/true}; }