diff --git a/.gitignore b/.gitignore index 046c078b7..28702db7e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /build /.vscode +/inst \ No newline at end of file diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 6c7d1fe71..15e3dda9c 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -58,6 +58,8 @@ namespace clad { std::set m_ToBeRecorded; /// A flag indicating if the Stmt we are currently visiting is inside loop. bool isInsideLoop = false; + /// A flag indicating if the Stmt we are currently visiting is inside loop. + bool hasContStmt = false; /// Output variable of vector-valued function std::string outputArrayStr; std::vector m_LoopBlock; @@ -501,6 +503,15 @@ namespace clad { } }; + /// Helper function to bring the cases created by a continue or break stmt + /// foward to the loop's body and append them correctly. + /// The statements that belong to the main body of the loop are added directly + /// to the current block, while the cases followed by with their corresponding stmts + /// are stored in a separate vector. + void AppendCaseStmts(llvm::SmallVectorImpl& curBlock, + llvm::SmallVectorImpl& cases, clang::Stmt* S, + bool& afterCase); + /// Helper function to differentiate a loop body. /// ///\param[in] body body of the loop @@ -542,12 +553,6 @@ namespace clad { /// PopBreakContStmtHandler(); /// ``` class BreakContStmtHandler { - /// Keeps track of all the created switch cases. It is required - /// because we need to register all the switch cases later with the - /// switch statement that will be used to manage the control flow in - /// the reverse block. - llvm::SmallVector m_SwitchCases; - /// `m_ControlFlowTape` tape keeps track of which `break`/`continue` /// statement was hit in which iteration. /// \note `m_ControlFlowTape` is only initialized if the body contains @@ -560,8 +565,6 @@ namespace clad { /// `break`/`continue` statement. std::size_t m_CaseCounter = 0; - ReverseModeVisitor& m_RMV; - const bool m_IsInvokedBySwitchStmt = false; /// Builds and returns a literal expression of type `std::size_t` with /// `value` as value. @@ -577,6 +580,14 @@ namespace clad { clang::Expr* CreateCFTapePushExpr(std::size_t value); public: + /// Keeps track of all the created switch cases. It is required + /// because we need to register all the switch cases later with the + /// switch statement that will be used to manage the control flow in + /// the reverse block. + llvm::SmallVector m_SwitchCases; + + ReverseModeVisitor& m_RMV; + BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false) : m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {} diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ed8683483..0521b1395 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -897,9 +897,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Forward, Reverse); }; + llvm::SaveAndRestore SaveHasContStmt(hasContStmt); + hasContStmt = false; StmtDiff thenDiff = VisitBranch(If->getThen()); + llvm::SaveAndRestore SaveHasContStmtThen(hasContStmt); + hasContStmt = false; StmtDiff elseDiff = VisitBranch(If->getElse()); + // It is problematic to specify both condVarDecl and cond thorugh // Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor. Stmt* Forward = clad_compat::IfStmt_Create(m_Context, @@ -920,18 +925,48 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(PushCond, direction::forward); reverseCond = PopCond; } - Stmt* Reverse = clad_compat::IfStmt_Create(m_Context, - noLoc, - If->isConstexpr(), - initResult.getStmt_dx(), - condVarClone, - reverseCond, - noLoc, - noLoc, - thenDiff.getStmt_dx(), - noLoc, - elseDiff.getStmt_dx()); - addToCurrentBlock(Reverse, direction::reverse); + + // if neither then nor else block contains a continue statement, + // we can add the reverse block to the current block. + if (!SaveHasContStmtThen.get() && !hasContStmt){ + Stmt* Reverse = clad_compat::IfStmt_Create( + m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), + condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, + elseDiff.getStmt_dx()); + + addToCurrentBlock(Reverse, direction::reverse); + } + // if both then and else block contain a continue statement, + // we need to add their cases to the current block. + else if (SaveHasContStmtThen.get() && hasContStmt){ + addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse); + addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse); + } + // if only then block contains a continue statement, we need to add + // the then block to the current block and create an if stmt for the else block + // afterwards to ensure that in the reverse pass it will be included in the prior case + else if (SaveHasContStmtThen.get()) { + addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse); + if (elseDiff.getStmt_dx()){ + Stmt* Reverse = clad_compat::IfStmt_Create( + m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), + condVarClone, reverseCond, noLoc, noLoc, + m_Sema.ActOnNullStmt(noLoc).get(), noLoc, elseDiff.getStmt_dx()); + addToCurrentBlock(Reverse, direction::reverse); + } + } + // if only else block contains a continue statement, we need to add + // the else block to the current block and create an if stmt for the then block + // afterwards to ensure that in the reverse pass it will be included in the prior case + else if (hasContStmt) { + addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse); + Stmt* Reverse = clad_compat::IfStmt_Create( + m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), + condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, + nullptr); + addToCurrentBlock(Reverse, direction::reverse); + } + CompoundStmt* ForwardBlock = endBlock(direction::forward); CompoundStmt* ReverseBlock = endBlock(direction::reverse); endScope(); @@ -3487,6 +3522,88 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {endBlock(direction::forward), endBlock(direction::reverse)}; } + void ReverseModeVisitor::AppendCaseStmts(llvm::SmallVectorImpl& curBlock, + llvm::SmallVectorImpl& cases, + Stmt* S, bool& afterCase) { + if (auto CS = dyn_cast_or_null(S)) { + // create a new list to store the nested stmts + Stmts newBlock; + // This stmts is a compound and not a case + // so its nested stmts do not come immediately after a case. + // The whole compound though may belong to a case stmt, + // hence, we store the original flag's value + SaveAndRestore SaveAfterCase(afterCase); + afterCase = false; + for (auto stmt : CS->body()) + AppendCaseStmts(newBlock, cases, stmt, afterCase); + if (!newBlock.empty()){ + auto Stmts_ref = clad_compat::makeArrayRef(newBlock.data(), newBlock.size()); + auto newCS = clad_compat::CompoundStmt_Create( + m_Context, Stmts_ref /**/ + CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2( + FPOptionsOverride()), + noLoc, noLoc); + // if the compound belongs to a case, add it to the `cases` vector + // else add it to the main body of the for loop + if (SaveAfterCase.get()) + cases.push_back(newCS); + else{ + curBlock.push_back(newCS); + } + } + } else if (isa(S)) { + afterCase = true; + cases.push_back(S); + } else if (auto If = dyn_cast_or_null(S)) { + if (auto IfThenCS = dyn_cast_or_null(If->getThen())) { + Stmts thenBlock; + SaveAndRestore SaveAfterCase(afterCase); + afterCase = false; + for (auto stmt : IfThenCS->body()) + AppendCaseStmts(thenBlock, cases, stmt, afterCase); + auto Stmts_ref = + clad_compat::makeArrayRef(thenBlock.data(), + thenBlock.size()); + auto newThenCS = clad_compat::CompoundStmt_Create( + m_Context, + Stmts_ref /**/ + CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2( + FPOptionsOverride()), + noLoc, noLoc); + If->setThen(newThenCS); + } + if (auto IfElseCS = dyn_cast_or_null(If->getElse())) { + Stmts elseBlock; + SaveAndRestore SaveAfterCase(afterCase); + afterCase = false; + for (auto stmt : IfElseCS->body()) + AppendCaseStmts(elseBlock, cases, stmt, afterCase); + auto Stmts_ref = + clad_compat::makeArrayRef(elseBlock.data(), + elseBlock.size()); + auto newElseCS = clad_compat::CompoundStmt_Create( + m_Context, + Stmts_ref /**/ + CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2( + FPOptionsOverride()), + noLoc, noLoc); + If->setElse(newElseCS); + } + if (afterCase) + cases.push_back(If); + else + curBlock.push_back(If); + } else if (S) { + if (afterCase) + cases.push_back(S); + else + curBlock.push_back(S); + } + // No need to check fo other stmts that have a body, + // as while and for loops as well as do stmts have their own switch. + // Functions and class objects are also independent. + } + StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body, LoopCounter& loopCounter, Stmt* condVarDiff, @@ -3523,15 +3640,69 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // for forward-pass loop statement body endScope(); } + Stmts revLoopBlock = m_LoopBlock.back(); - utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx()); - if (!revLoopBlock.empty()) + + if (!activeBreakContHandler->m_SwitchCases.empty()) { + // Add case statement in the beginning of the reverse block + // and corresponding push expression for this case statement + // at the end of the forward block to cover the case when no + // `break`/`continue` statements are hit. + auto* lastSC = activeBreakContHandler->GetNextCFCaseStmt(); + auto* pushExprToCurrentCase = + activeBreakContHandler->CreateCFTapePushExprToCurrentCase(); + + Stmt* forwBlock = nullptr; + forwBlock = utils::AppendAndCreateCompoundStmt( + activeBreakContHandler->m_RMV.m_Context, bodyDiff.getStmt(), + pushExprToCurrentCase); + bodyDiff.updateStmt(forwBlock); + + bool afterCase = false; + Stmts cases; + AppendCaseStmts(revLoopBlock, cases, bodyDiff.getStmt_dx(), afterCase); + revLoopBlock.append(cases.begin(), cases.end()); + revLoopBlock.insert(revLoopBlock.begin(), lastSC); + Stmts revLoopBlockIndexed; // stores the correctly indexed version of the loop's body + bool betweenCase = false; + Stmts curBlockStmts; + + // Add the Stmts between cases as SubStmt of the first CaseStmt + if (!revLoopBlock.empty()) { + CaseStmt* curCaseStmt = nullptr; + for (auto revLoopStmt : revLoopBlock) { + if (auto caseStmt = dyn_cast_or_null(revLoopStmt)) { + if (!betweenCase) { + betweenCase = true; + } else { + curBlockStmts.push_back(new (m_Context) + BreakStmt(Stmt::EmptyShell())); // compatible with all clang versions + curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts)); + curBlockStmts.clear(); + } + curCaseStmt = caseStmt; + revLoopBlockIndexed.push_back(caseStmt); + } else { + curBlockStmts.push_back(revLoopStmt); + } + } + curBlockStmts.push_back(new (m_Context) BreakStmt(Stmt::EmptyShell())); + curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts)); + bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlockIndexed)); + } + } + else{ + utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx()); bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock)); + } m_LoopBlock.pop_back(); - // Increment statement in the for-loop is only executed if the iteration - // did not end with a break/continue statement. Therefore, forLoopIncDiff - // should be inside the last switch case in the reverse pass. + activeBreakContHandler->EndCFSwitchStmtScope(); + activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + PopBreakContStmtHandler(); + + // Increment statement in the for-loop should be executed in the beginning for + // every case, hence it should be added prior to the switch statement. if (forLoopIncDiff) { if (bodyDiff.getStmt_dx()) { bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( @@ -3541,10 +3712,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - activeBreakContHandler->EndCFSwitchStmtScope(); - activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); - PopBreakContStmtHandler(); - Expr* counterDecrement = loopCounter.getCounterDecrement(); // Create reverse pass loop body statements by arranging various @@ -3568,6 +3735,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitContinueStmt(const ContinueStmt* CS) { + hasContStmt = true; beginBlock(direction::forward); Stmt* newCS = m_Sema.ActOnContinueStmt(noLoc, getCurrentScope()).get(); auto* activeBreakContHandler = GetActiveBreakContStmtHandler(); @@ -3659,25 +3827,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt) return; - // Add case statement in the beginning of the reverse block - // and corresponding push expression for this case statement - // at the end of the forward block to cover the case when no - // `break`/`continue` statements are hit. - auto* lastSC = GetNextCFCaseStmt(); - auto* pushExprToCurrentCase = CreateCFTapePushExprToCurrentCase(); - - Stmt* forwBlock = nullptr; - Stmt* revBlock = nullptr; - - forwBlock = utils::AppendAndCreateCompoundStmt(m_RMV.m_Context, - bodyDiff.getStmt(), - pushExprToCurrentCase); - revBlock = utils::PrependAndCreateCompoundStmt(m_RMV.m_Context, - bodyDiff.getStmt_dx(), - lastSC); - - bodyDiff = {forwBlock, revBlock}; - auto condResult = m_RMV.m_Sema.ActOnCondition(m_RMV.getCurrentScope(), noLoc, m_ControlFlowTape->Pop, Sema::ConditionKind::Switch);