From b26af5533e839e4ab84ddd08011de23e4edb9b47 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Tue, 14 May 2024 19:07:58 +0200 Subject: [PATCH] Fix issue #865 and allow conditions in if statements to affect the derivatives --- lib/Differentiator/ReverseModeVisitor.cpp | 106 +++++++++------------- 1 file changed, 42 insertions(+), 64 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 089c421de..6ddfd3220 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -800,37 +800,56 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // to this scope. beginScope(Scope::DeclScope | Scope::ControlScope); - StmtDiff cond = Clone(If->getCond()); // Condition has to be stored as a "global" variable, to take the correct // branch in the reverse pass. // If we are inside loop, the condition has to be stored in a stack after // the if statement. Expr* PushCond = nullptr; Expr* PopCond = nullptr; - auto condExpr = Visit(cond.getExpr()); + // Create a block "around" if statement, e.g: + // { + // ... + // if (...) {...} + // } + beginBlock(direction::forward); + beginBlock(direction::reverse); + StmtDiff condDiff; + // this ensures we can differentiate conditions that affect the derivatives + // as well as declarations inside the condition: + beginBlock(direction::reverse); + if (const auto* condDeclStmt = If->getConditionVariableDeclStmt()) + condDiff = Visit(condDeclStmt); + else + condDiff = Visit(If->getCond()); + auto* RCS = endBlock(direction::reverse); + std::reverse( + RCS->body_begin(), + RCS->body_end()); // it is reversed in the endBlock() but we don't + // actually need this, so we reverse it once again + addToCurrentBlock(RCS, direction::reverse); + if (isInsideLoop) { - // If we are inside for loop, cond will be stored in the following way: - // forward: - // _t = cond; - // if (_t) { ... } - // clad::push(..., _t); - // reverse: + // If we are inside for loop, condDiff will be stored in the following + // way: forward: _t = cond; if (_t) { ... } clad::push(..., _t); reverse: // if (clad::pop(...)) { ... } // Simply doing // if (clad::push(..., _t) { ... } // is incorrect when if contains return statement inside: return will // skip corresponding push. - cond = StoreAndRef(condExpr.getExpr(), direction::forward, "_t", - /*forceDeclCreation=*/true); - StmtDiff condPushPop = GlobalStoreAndRef(cond.getExpr(), "_cond", - /*force=*/true); + condDiff = StoreAndRef(condDiff.getExpr(), m_Context.BoolTy, + direction::forward, "_t", + /*forceDeclCreation=*/true); + StmtDiff condPushPop = + GlobalStoreAndRef(condDiff.getExpr(), m_Context.BoolTy, "_cond", + /*force=*/true); PushCond = condPushPop.getExpr(); PopCond = condPushPop.getExpr_dx(); } else - cond = GlobalStoreAndRef(condExpr.getExpr(), "_cond"); + condDiff = + GlobalStoreAndRef(condDiff.getExpr(), m_Context.BoolTy, "_cond"); // Convert cond to boolean condition. We are modifying each Stmt in // StmtDiff. - for (Stmt*& S : cond.getBothStmts()) + for (Stmt*& S : condDiff.getBothStmts()) if (S) S = m_Sema .ActOnCondition(getCurrentScope(), noLoc, cast(S), @@ -838,13 +857,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, .get() .second; - // Create a block "around" if statement, e.g: - // { - // ... - // if (...) {...} - // } - beginBlock(direction::forward); - beginBlock(direction::reverse); const Stmt* init = If->getInit(); StmtDiff initResult = init ? Visit(init) : StmtDiff{}; // If there is Init, it's derivative will be output in the block before if: @@ -858,24 +870,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // This is done to avoid variable names clashes. addToCurrentBlock(initResult.getStmt_dx()); - VarDecl* condVarClone = nullptr; - if (const VarDecl* condVarDecl = If->getConditionVariable()) { - DeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl); - condVarClone = condVarDeclDiff.getDecl(); - if (condVarDeclDiff.getDecl_dx()) - addToBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()), m_Globals); - } - - // Condition is just cloned as it is, not derived. - // FIXME: if condition changes one of the variables, it may be reasonable - // to derive it, e.g. - // if (x += x) {...} - // should result in: - // { - // _d_y += _d_x - // if (y += x) {...} - // } - auto VisitBranch = [&](const Stmt* Branch) -> StmtDiff { if (!Branch) return {}; @@ -902,37 +896,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff thenDiff = VisitBranch(If->getThen()); StmtDiff elseDiff = VisitBranch(If->getElse()); - // It is problematic to specify both condVarDecl and cond thorugh - // Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor. - Stmt* Forward = clad_compat::IfStmt_Create(m_Context, - noLoc, - If->isConstexpr(), - initResult.getStmt(), - condVarClone, - cond.getExpr(), - noLoc, - noLoc, - thenDiff.getStmt(), - noLoc, - elseDiff.getStmt()); + Stmt* Forward = clad_compat::IfStmt_Create( + m_Context, noLoc, If->isConstexpr(), initResult.getStmt(), nullptr, + condDiff.getExpr(), noLoc, noLoc, thenDiff.getStmt(), noLoc, + elseDiff.getStmt()); addToCurrentBlock(Forward, direction::forward); - Expr* reverseCond = cond.getExpr_dx(); + Expr* reverseCond = condDiff.getExpr_dx(); if (isInsideLoop) { addToCurrentBlock(PushCond, direction::forward); reverseCond = PopCond; } - Stmt* Reverse = clad_compat::IfStmt_Create(m_Context, - noLoc, - If->isConstexpr(), - initResult.getStmt_dx(), - condVarClone, - reverseCond, - noLoc, - noLoc, - thenDiff.getStmt_dx(), - noLoc, - elseDiff.getStmt_dx()); + Stmt* Reverse = clad_compat::IfStmt_Create( + m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), nullptr, + reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, + elseDiff.getStmt_dx()); addToCurrentBlock(Reverse, direction::reverse); CompoundStmt* ForwardBlock = endBlock(direction::forward); CompoundStmt* ReverseBlock = endBlock(direction::reverse);