From 06c9a9f8240537cf59b6f3985494ad1b806b3688 Mon Sep 17 00:00:00 2001 From: kchristin22 Date: Sun, 17 Mar 2024 19:35:37 +0200 Subject: [PATCH 1/5] Fix switch case in for with continue stmt and get the decrement variable out of switch --- .gitignore | 1 + .../clad/Differentiator/ReverseModeVisitor.h | 2 ++ lib/Differentiator/ReverseModeVisitor.cpp | 31 ++++++++++++------- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 046c078b7..28702db7e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /build /.vscode +/inst \ No newline at end of file diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 10564dbbf..6c7ec78af 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -58,6 +58,8 @@ namespace clad { std::set m_ToBeRecorded; /// A flag indicating if the Stmt we are currently visiting is inside loop. bool isInsideLoop = false; + /// A flag indicating if the Stmt we are currently visiting is inside loop. + bool hasContStmt = false; /// Output variable of vector-valued function std::string outputArrayStr; std::vector m_LoopBlock; diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 2bcd1a7f7..f6603d364 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -897,9 +897,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Forward, Reverse); }; + llvm::SaveAndRestore SaveHasContStmt(hasContStmt); + hasContStmt = false; StmtDiff thenDiff = VisitBranch(If->getThen()); + llvm::SaveAndRestore SaveHasContStmtThen(hasContStmt); + hasContStmt = false; StmtDiff elseDiff = VisitBranch(If->getElse()); + // It is problematic to specify both condVarDecl and cond thorugh // Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor. Stmt* Forward = clad_compat::IfStmt_Create(m_Context, @@ -931,7 +936,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, thenDiff.getStmt_dx(), noLoc, elseDiff.getStmt_dx()); - addToCurrentBlock(Reverse, direction::reverse); + if (!SaveHasContStmtThen.get()) + addToCurrentBlock(Reverse, direction::reverse); + else{ + addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse); + addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse); + } CompoundStmt* ForwardBlock = endBlock(direction::forward); CompoundStmt* ReverseBlock = endBlock(direction::reverse); endScope(); @@ -3644,10 +3654,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!revLoopBlock.empty()) bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock)); m_LoopBlock.pop_back(); + + activeBreakContHandler->EndCFSwitchStmtScope(); + activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + PopBreakContStmtHandler(); - // Increment statement in the for-loop is only executed if the iteration - // did not end with a break/continue statement. Therefore, forLoopIncDiff - // should be inside the last switch case in the reverse pass. + // Increment statement in the for-loop is executed in the beginning for every case if (forLoopIncDiff) { if (bodyDiff.getStmt_dx()) { bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( @@ -3657,10 +3669,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - activeBreakContHandler->EndCFSwitchStmtScope(); - activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); - PopBreakContStmtHandler(); - Expr* counterDecrement = loopCounter.getCounterDecrement(); // Create reverse pass loop body statements by arranging various @@ -3684,15 +3692,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitContinueStmt(const ContinueStmt* CS) { - beginBlock(direction::forward); + hasContStmt = true; + // beginBlock(direction::forward); Stmt* newCS = m_Sema.ActOnContinueStmt(noLoc, getCurrentScope()).get(); auto* activeBreakContHandler = GetActiveBreakContStmtHandler(); Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt(); Stmt* pushExprToCurrentCase = activeBreakContHandler ->CreateCFTapePushExprToCurrentCase(); addToCurrentBlock(pushExprToCurrentCase); - addToCurrentBlock(newCS); - return {endBlock(direction::forward), CFCaseStmt}; + // addToCurrentBlock(newCS); + return {newCS, CFCaseStmt}; } StmtDiff ReverseModeVisitor::VisitBreakStmt(const BreakStmt* BS) { From 88ef7e80921f14cd073a010734ef64b67050a708 Mon Sep 17 00:00:00 2001 From: kchristin22 Date: Tue, 19 Mar 2024 12:49:05 +0200 Subject: [PATCH 2/5] Fix AST indexing of loop body in reverse mode and improve VisitIfStmt node creation when there's a continue stmt --- .../clad/Differentiator/ReverseModeVisitor.h | 16 +-- lib/Differentiator/CladUtils.cpp | 8 +- lib/Differentiator/ReverseModeVisitor.cpp | 136 +++++++++++++----- 3 files changed, 109 insertions(+), 51 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 6c7ec78af..578b900dd 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -522,12 +522,6 @@ namespace clad { /// PopBreakContStmtHandler(); /// ``` class BreakContStmtHandler { - /// Keeps track of all the created switch cases. It is required - /// because we need to register all the switch cases later with the - /// switch statement that will be used to manage the control flow in - /// the reverse block. - llvm::SmallVector m_SwitchCases; - /// `m_ControlFlowTape` tape keeps track of which `break`/`continue` /// statement was hit in which iteration. /// \note `m_ControlFlowTape` is only initialized if the body contains @@ -540,8 +534,6 @@ namespace clad { /// `break`/`continue` statement. std::size_t m_CaseCounter = 0; - ReverseModeVisitor& m_RMV; - const bool m_IsInvokedBySwitchStmt = false; /// Builds and returns a literal expression of type `std::size_t` with /// `value` as value. @@ -557,6 +549,14 @@ namespace clad { clang::Expr* CreateCFTapePushExpr(std::size_t value); public: + /// Keeps track of all the created switch cases. It is required + /// because we need to register all the switch cases later with the + /// switch statement that will be used to manage the control flow in + /// the reverse block. + llvm::SmallVector m_SwitchCases; + + ReverseModeVisitor& m_RMV; + BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false) : m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {} diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index d7d599ce5..bcc3a49b7 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -439,12 +439,12 @@ namespace clad { void AppendIndividualStmts(llvm::SmallVectorImpl& block, clang::Stmt* S) { - if (auto CS = dyn_cast_or_null(S)) + if (auto CS = dyn_cast_or_null(S)) { for (auto stmt : CS->body()) - block.push_back(stmt); - else if (S) + AppendIndividualStmts(block, stmt); + } else if (S) block.push_back(S); - } + } MemberExpr* BuildMemberExpr(clang::Sema& semaRef, clang::Scope* S, clang::Expr* base, diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index f6603d364..820f8514c 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -925,23 +925,46 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(PushCond, direction::forward); reverseCond = PopCond; } - Stmt* Reverse = clad_compat::IfStmt_Create(m_Context, - noLoc, - If->isConstexpr(), - initResult.getStmt_dx(), - condVarClone, - reverseCond, - noLoc, - noLoc, - thenDiff.getStmt_dx(), - noLoc, - elseDiff.getStmt_dx()); - if (!SaveHasContStmtThen.get()) + + // if neither then nor else block contains a continue statement, + // we can add the reverse block to the current block. + if (!SaveHasContStmtThen.get() && !hasContStmt){ + Stmt* Reverse = clad_compat::IfStmt_Create( + m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), + condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, + elseDiff.getStmt_dx()); + addToCurrentBlock(Reverse, direction::reverse); - else{ + } + // if both then and else block contain a continue statement, + // we need to add their cases to the current block. + else if (SaveHasContStmtThen.get() && hasContStmt){ + addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse); + addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse); + } + // if only then block contains a continue statement, we need to add + // the then block to the current block and create an if stmt for the else block + else if (SaveHasContStmtThen.get()) { addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse); + if (elseDiff.getStmt_dx()){ + Stmt* Reverse = clad_compat::IfStmt_Create( + m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), + condVarClone, reverseCond, noLoc, noLoc, + m_Sema.ActOnNullStmt(noLoc).get(), noLoc, elseDiff.getStmt_dx()); + addToCurrentBlock(Reverse, direction::reverse); + } + } + // if only else block contains a continue statement, we need to add + // the else block to the current block and create an if stmt for the then block + else if (hasContStmt) { addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse); + Stmt* Reverse = clad_compat::IfStmt_Create( + m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), + condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, + nullptr); + addToCurrentBlock(Reverse, direction::reverse); } + CompoundStmt* ForwardBlock = endBlock(direction::forward); CompoundStmt* ReverseBlock = endBlock(direction::reverse); endScope(); @@ -3649,17 +3672,71 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // for forward-pass loop statement body endScope(); } + Stmts revLoopBlock = m_LoopBlock.back(); - utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx()); - if (!revLoopBlock.empty()) + + if (!activeBreakContHandler->m_SwitchCases.empty()) { + // Add case statement in the beginning of the reverse block + // and corresponding push expression for this case statement + // at the end of the forward block to cover the case when no + // `break`/`continue` statements are hit. + auto* lastSC = activeBreakContHandler->GetNextCFCaseStmt(); + auto* pushExprToCurrentCase = + activeBreakContHandler->CreateCFTapePushExprToCurrentCase(); + + Stmt* forwBlock = nullptr; + Stmt* revBlock = nullptr; + + forwBlock = utils::AppendAndCreateCompoundStmt( + activeBreakContHandler->m_RMV.m_Context, bodyDiff.getStmt(), + pushExprToCurrentCase); + revBlock = utils::PrependAndCreateCompoundStmt( + activeBreakContHandler->m_RMV.m_Context, bodyDiff.getStmt_dx(), + lastSC); + + bodyDiff = {forwBlock, revBlock}; + + utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx()); + Stmts revLoopBlockIndexed; + bool betweenCase = false; + Stmts curBlockStmts; + + // Add the Stmts between cases as SubStmt of the first CaseStmt + if (!revLoopBlock.empty()) { + CaseStmt* curCaseStmt = nullptr; + for (auto revLoopStmt : revLoopBlock) { + if (auto caseStmt = dyn_cast_or_null(revLoopStmt)) { + if (!betweenCase) { + betweenCase = true; + } else { + curBlockStmts.push_back(new (m_Context) + BreakStmt(Stmt::EmptyShell())); + curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts)); + curBlockStmts.clear(); + } + curCaseStmt = caseStmt; + revLoopBlockIndexed.push_back(caseStmt); + } else { + curBlockStmts.push_back(revLoopStmt); + } + } + curBlockStmts.push_back(new (m_Context) BreakStmt(Stmt::EmptyShell())); + curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts)); + bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlockIndexed)); + } + } + else{ + utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx()); bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock)); + } m_LoopBlock.pop_back(); - + activeBreakContHandler->EndCFSwitchStmtScope(); activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); PopBreakContStmtHandler(); - // Increment statement in the for-loop is executed in the beginning for every case + // Increment statement in the for-loop should be executed in the beginning for + // every case, hence it should be added prior to the switch statement. if (forLoopIncDiff) { if (bodyDiff.getStmt_dx()) { bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( @@ -3693,15 +3770,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitContinueStmt(const ContinueStmt* CS) { hasContStmt = true; - // beginBlock(direction::forward); + beginBlock(direction::forward); Stmt* newCS = m_Sema.ActOnContinueStmt(noLoc, getCurrentScope()).get(); auto* activeBreakContHandler = GetActiveBreakContStmtHandler(); Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt(); Stmt* pushExprToCurrentCase = activeBreakContHandler ->CreateCFTapePushExprToCurrentCase(); addToCurrentBlock(pushExprToCurrentCase); - // addToCurrentBlock(newCS); - return {newCS, CFCaseStmt}; + addToCurrentBlock(newCS); + return {endBlock(direction::forward), CFCaseStmt}; } StmtDiff ReverseModeVisitor::VisitBreakStmt(const BreakStmt* BS) { @@ -3784,25 +3861,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt) return; - // Add case statement in the beginning of the reverse block - // and corresponding push expression for this case statement - // at the end of the forward block to cover the case when no - // `break`/`continue` statements are hit. - auto* lastSC = GetNextCFCaseStmt(); - auto* pushExprToCurrentCase = CreateCFTapePushExprToCurrentCase(); - - Stmt* forwBlock = nullptr; - Stmt* revBlock = nullptr; - - forwBlock = utils::AppendAndCreateCompoundStmt(m_RMV.m_Context, - bodyDiff.getStmt(), - pushExprToCurrentCase); - revBlock = utils::PrependAndCreateCompoundStmt(m_RMV.m_Context, - bodyDiff.getStmt_dx(), - lastSC); - - bodyDiff = {forwBlock, revBlock}; - auto condResult = m_RMV.m_Sema.ActOnCondition(m_RMV.getCurrentScope(), noLoc, m_ControlFlowTape->Pop, Sema::ConditionKind::Switch); From a825e101685b9e026761464dd44b7a9ad96244d8 Mon Sep 17 00:00:00 2001 From: kchristin22 Date: Tue, 19 Mar 2024 13:22:59 +0200 Subject: [PATCH 3/5] Add more comments --- lib/Differentiator/CladUtils.cpp | 2 +- lib/Differentiator/ReverseModeVisitor.cpp | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index bcc3a49b7..4847bad39 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -444,7 +444,7 @@ namespace clad { AppendIndividualStmts(block, stmt); } else if (S) block.push_back(S); - } + } MemberExpr* BuildMemberExpr(clang::Sema& semaRef, clang::Scope* S, clang::Expr* base, diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 820f8514c..e986a79f0 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -944,6 +944,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } // if only then block contains a continue statement, we need to add // the then block to the current block and create an if stmt for the else block + // afterwards to ensure that in the reverse pass it will be included in the prior case else if (SaveHasContStmtThen.get()) { addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse); if (elseDiff.getStmt_dx()){ @@ -956,6 +957,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } // if only else block contains a continue statement, we need to add // the else block to the current block and create an if stmt for the then block + // afterwards to ensure that in the reverse pass it will be included in the prior case else if (hasContStmt) { addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse); Stmt* Reverse = clad_compat::IfStmt_Create( @@ -3710,7 +3712,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, betweenCase = true; } else { curBlockStmts.push_back(new (m_Context) - BreakStmt(Stmt::EmptyShell())); + BreakStmt(Stmt::EmptyShell())); // compatible with all clang versions curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts)); curBlockStmts.clear(); } From 215e7f7be252ab6f55ccdb4cb8aa2bcfbba8653d Mon Sep 17 00:00:00 2001 From: kchristin22 Date: Wed, 20 Mar 2024 22:15:41 +0200 Subject: [PATCH 4/5] Add support of nested ifs with continue stmts --- .../clad/Differentiator/ReverseModeVisitor.h | 9 ++ lib/Differentiator/CladUtils.cpp | 2 +- lib/Differentiator/ReverseModeVisitor.cpp | 100 ++++++++++++++++-- 3 files changed, 100 insertions(+), 11 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 724faeddb..15e3dda9c 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -503,6 +503,15 @@ namespace clad { } }; + /// Helper function to bring the cases created by a continue or break stmt + /// foward to the loop's body and append them correctly. + /// The statements that belong to the main body of the loop are added directly + /// to the current block, while the cases followed by with their corresponding stmts + /// are stored in a separate vector. + void AppendCaseStmts(llvm::SmallVectorImpl& curBlock, + llvm::SmallVectorImpl& cases, clang::Stmt* S, + bool& afterCase); + /// Helper function to differentiate a loop body. /// ///\param[in] body body of the loop diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 7cf59b462..e38a345dd 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -445,7 +445,7 @@ namespace clad { clang::Stmt* S) { if (auto CS = dyn_cast_or_null(S)) { for (auto stmt : CS->body()) - AppendIndividualStmts(block, stmt); + block.push_back(stmt); } else if (S) block.push_back(S); } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 38d0ff9c2..0521b1395 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -3522,6 +3522,88 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return {endBlock(direction::forward), endBlock(direction::reverse)}; } + void ReverseModeVisitor::AppendCaseStmts(llvm::SmallVectorImpl& curBlock, + llvm::SmallVectorImpl& cases, + Stmt* S, bool& afterCase) { + if (auto CS = dyn_cast_or_null(S)) { + // create a new list to store the nested stmts + Stmts newBlock; + // This stmts is a compound and not a case + // so its nested stmts do not come immediately after a case. + // The whole compound though may belong to a case stmt, + // hence, we store the original flag's value + SaveAndRestore SaveAfterCase(afterCase); + afterCase = false; + for (auto stmt : CS->body()) + AppendCaseStmts(newBlock, cases, stmt, afterCase); + if (!newBlock.empty()){ + auto Stmts_ref = clad_compat::makeArrayRef(newBlock.data(), newBlock.size()); + auto newCS = clad_compat::CompoundStmt_Create( + m_Context, Stmts_ref /**/ + CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2( + FPOptionsOverride()), + noLoc, noLoc); + // if the compound belongs to a case, add it to the `cases` vector + // else add it to the main body of the for loop + if (SaveAfterCase.get()) + cases.push_back(newCS); + else{ + curBlock.push_back(newCS); + } + } + } else if (isa(S)) { + afterCase = true; + cases.push_back(S); + } else if (auto If = dyn_cast_or_null(S)) { + if (auto IfThenCS = dyn_cast_or_null(If->getThen())) { + Stmts thenBlock; + SaveAndRestore SaveAfterCase(afterCase); + afterCase = false; + for (auto stmt : IfThenCS->body()) + AppendCaseStmts(thenBlock, cases, stmt, afterCase); + auto Stmts_ref = + clad_compat::makeArrayRef(thenBlock.data(), + thenBlock.size()); + auto newThenCS = clad_compat::CompoundStmt_Create( + m_Context, + Stmts_ref /**/ + CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2( + FPOptionsOverride()), + noLoc, noLoc); + If->setThen(newThenCS); + } + if (auto IfElseCS = dyn_cast_or_null(If->getElse())) { + Stmts elseBlock; + SaveAndRestore SaveAfterCase(afterCase); + afterCase = false; + for (auto stmt : IfElseCS->body()) + AppendCaseStmts(elseBlock, cases, stmt, afterCase); + auto Stmts_ref = + clad_compat::makeArrayRef(elseBlock.data(), + elseBlock.size()); + auto newElseCS = clad_compat::CompoundStmt_Create( + m_Context, + Stmts_ref /**/ + CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2( + FPOptionsOverride()), + noLoc, noLoc); + If->setElse(newElseCS); + } + if (afterCase) + cases.push_back(If); + else + curBlock.push_back(If); + } else if (S) { + if (afterCase) + cases.push_back(S); + else + curBlock.push_back(S); + } + // No need to check fo other stmts that have a body, + // as while and for loops as well as do stmts have their own switch. + // Functions and class objects are also independent. + } + StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body, LoopCounter& loopCounter, Stmt* condVarDiff, @@ -3571,19 +3653,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, activeBreakContHandler->CreateCFTapePushExprToCurrentCase(); Stmt* forwBlock = nullptr; - Stmt* revBlock = nullptr; - forwBlock = utils::AppendAndCreateCompoundStmt( activeBreakContHandler->m_RMV.m_Context, bodyDiff.getStmt(), pushExprToCurrentCase); - revBlock = utils::PrependAndCreateCompoundStmt( - activeBreakContHandler->m_RMV.m_Context, bodyDiff.getStmt_dx(), - lastSC); - - bodyDiff = {forwBlock, revBlock}; - - utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx()); - Stmts revLoopBlockIndexed; + bodyDiff.updateStmt(forwBlock); + + bool afterCase = false; + Stmts cases; + AppendCaseStmts(revLoopBlock, cases, bodyDiff.getStmt_dx(), afterCase); + revLoopBlock.append(cases.begin(), cases.end()); + revLoopBlock.insert(revLoopBlock.begin(), lastSC); + Stmts revLoopBlockIndexed; // stores the correctly indexed version of the loop's body bool betweenCase = false; Stmts curBlockStmts; From 7185b6e3d0450abddf2921dba8957d9084b8e52f Mon Sep 17 00:00:00 2001 From: kchristin22 Date: Wed, 20 Mar 2024 22:25:52 +0200 Subject: [PATCH 5/5] Revert back changes of CladUtils.cpp --- lib/Differentiator/CladUtils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index e38a345dd..6144ae23a 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -443,10 +443,10 @@ namespace clad { void AppendIndividualStmts(llvm::SmallVectorImpl& block, clang::Stmt* S) { - if (auto CS = dyn_cast_or_null(S)) { + if (auto CS = dyn_cast_or_null(S)) for (auto stmt : CS->body()) block.push_back(stmt); - } else if (S) + else if (S) block.push_back(S); }