Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign of loop's body in reverse pass #835

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Fix AST indexing of loop body in reverse mode and improve VisitIfStmt…
… node creation when there's a continue stmt
kchristin22 committed Mar 19, 2024
commit 88ef7e80921f14cd073a010734ef64b67050a708
16 changes: 8 additions & 8 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
@@ -522,12 +522,6 @@ namespace clad {
/// PopBreakContStmtHandler();
/// ```
class BreakContStmtHandler {
/// Keeps track of all the created switch cases. It is required
/// because we need to register all the switch cases later with the
/// switch statement that will be used to manage the control flow in
/// the reverse block.
llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases;

/// `m_ControlFlowTape` tape keeps track of which `break`/`continue`
/// statement was hit in which iteration.
/// \note `m_ControlFlowTape` is only initialized if the body contains
@@ -540,8 +534,6 @@ namespace clad {
/// `break`/`continue` statement.
std::size_t m_CaseCounter = 0;

ReverseModeVisitor& m_RMV;

const bool m_IsInvokedBySwitchStmt = false;
/// Builds and returns a literal expression of type `std::size_t` with
/// `value` as value.
@@ -557,6 +549,14 @@ namespace clad {
clang::Expr* CreateCFTapePushExpr(std::size_t value);

public:
/// Keeps track of all the created switch cases. It is required
/// because we need to register all the switch cases later with the
/// switch statement that will be used to manage the control flow in
/// the reverse block.
llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'm_SwitchCases' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes]

      llvm::SmallVector<clang::SwitchCase*, 4> m_SwitchCases;
                                               ^


ReverseModeVisitor& m_RMV;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'm_RMV' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes]

      ReverseModeVisitor& m_RMV;
                          ^


BreakContStmtHandler(ReverseModeVisitor& RMV, bool forSwitchStmt = false)
: m_RMV(RMV), m_IsInvokedBySwitchStmt(forSwitchStmt) {}

8 changes: 4 additions & 4 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
@@ -439,12 +439,12 @@ namespace clad {

void AppendIndividualStmts(llvm::SmallVectorImpl<clang::Stmt*>& block,
clang::Stmt* S) {
if (auto CS = dyn_cast_or_null<CompoundStmt>(S))
if (auto CS = dyn_cast_or_null<CompoundStmt>(S)) {
for (auto stmt : CS->body())
block.push_back(stmt);
else if (S)
AppendIndividualStmts(block, stmt);
} else if (S)
block.push_back(S);
}
}

MemberExpr*
BuildMemberExpr(clang::Sema& semaRef, clang::Scope* S, clang::Expr* base,
136 changes: 97 additions & 39 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
@@ -925,23 +925,46 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(PushCond, direction::forward);
reverseCond = PopCond;
}
Stmt* Reverse = clad_compat::IfStmt_Create(m_Context,
noLoc,
If->isConstexpr(),
initResult.getStmt_dx(),
condVarClone,
reverseCond,
noLoc,
noLoc,
thenDiff.getStmt_dx(),
noLoc,
elseDiff.getStmt_dx());
if (!SaveHasContStmtThen.get())

// if neither then nor else block contains a continue statement,
// we can add the reverse block to the current block.
if (!SaveHasContStmtThen.get() && !hasContStmt){
Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(),
condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc,
elseDiff.getStmt_dx());

addToCurrentBlock(Reverse, direction::reverse);
else{
}
// if both then and else block contain a continue statement,
// we need to add their cases to the current block.
else if (SaveHasContStmtThen.get() && hasContStmt){
addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse);
addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse);
}
// if only then block contains a continue statement, we need to add
// the then block to the current block and create an if stmt for the else block
else if (SaveHasContStmtThen.get()) {
addToCurrentBlock(thenDiff.getStmt_dx(), direction::reverse);
if (elseDiff.getStmt_dx()){
Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(),
condVarClone, reverseCond, noLoc, noLoc,
m_Sema.ActOnNullStmt(noLoc).get(), noLoc, elseDiff.getStmt_dx());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we invert the condition instead of using a null stmt?
e.g.

 if(!cond)
   *else stmt*

instead of

 if(cond)
   ;
 else
   *else stmt*

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this approach before but the way I believe this would turn out is checking for each binary condition operator (<=, >=, <, >, ==, !=) and replacing it accordingly. I will look into it more though, thanks for the comment.

addToCurrentBlock(Reverse, direction::reverse);
}
}
// if only else block contains a continue statement, we need to add
// the else block to the current block and create an if stmt for the then block
else if (hasContStmt) {
addToCurrentBlock(elseDiff.getStmt_dx(), direction::reverse);
Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(),
condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc,
nullptr);
addToCurrentBlock(Reverse, direction::reverse);
}

CompoundStmt* ForwardBlock = endBlock(direction::forward);
CompoundStmt* ReverseBlock = endBlock(direction::reverse);
endScope();
@@ -3649,17 +3672,71 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// for forward-pass loop statement body
endScope();
}

Stmts revLoopBlock = m_LoopBlock.back();
utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx());
if (!revLoopBlock.empty())

if (!activeBreakContHandler->m_SwitchCases.empty()) {
// Add case statement in the beginning of the reverse block
// and corresponding push expression for this case statement
// at the end of the forward block to cover the case when no
// `break`/`continue` statements are hit.
auto* lastSC = activeBreakContHandler->GetNextCFCaseStmt();
auto* pushExprToCurrentCase =
activeBreakContHandler->CreateCFTapePushExprToCurrentCase();

Stmt* forwBlock = nullptr;
Stmt* revBlock = nullptr;

forwBlock = utils::AppendAndCreateCompoundStmt(
activeBreakContHandler->m_RMV.m_Context, bodyDiff.getStmt(),
pushExprToCurrentCase);
revBlock = utils::PrependAndCreateCompoundStmt(
activeBreakContHandler->m_RMV.m_Context, bodyDiff.getStmt_dx(),
lastSC);

bodyDiff = {forwBlock, revBlock};

utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx());
Stmts revLoopBlockIndexed;
bool betweenCase = false;
Stmts curBlockStmts;

// Add the Stmts between cases as SubStmt of the first CaseStmt
if (!revLoopBlock.empty()) {
CaseStmt* curCaseStmt = nullptr;
for (auto revLoopStmt : revLoopBlock) {
if (auto caseStmt = dyn_cast_or_null<CaseStmt>(revLoopStmt)) {
if (!betweenCase) {
betweenCase = true;
} else {
curBlockStmts.push_back(new (m_Context)
BreakStmt(Stmt::EmptyShell()));
curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts));
curBlockStmts.clear();
}
curCaseStmt = caseStmt;
revLoopBlockIndexed.push_back(caseStmt);
} else {
curBlockStmts.push_back(revLoopStmt);
}
}
curBlockStmts.push_back(new (m_Context) BreakStmt(Stmt::EmptyShell()));
curCaseStmt->setSubStmt(MakeCompoundStmt(curBlockStmts));
bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlockIndexed));
}
}
else{
utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx());
bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock));
}
m_LoopBlock.pop_back();

activeBreakContHandler->EndCFSwitchStmtScope();
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();

// Increment statement in the for-loop is executed in the beginning for every case
// Increment statement in the for-loop should be executed in the beginning for
// every case, hence it should be added prior to the switch statement.
if (forLoopIncDiff) {
if (bodyDiff.getStmt_dx()) {
bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt(
@@ -3693,15 +3770,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff ReverseModeVisitor::VisitContinueStmt(const ContinueStmt* CS) {
hasContStmt = true;
// beginBlock(direction::forward);
beginBlock(direction::forward);
Stmt* newCS = m_Sema.ActOnContinueStmt(noLoc, getCurrentScope()).get();
auto* activeBreakContHandler = GetActiveBreakContStmtHandler();
Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt();
Stmt* pushExprToCurrentCase = activeBreakContHandler
->CreateCFTapePushExprToCurrentCase();
addToCurrentBlock(pushExprToCurrentCase);
// addToCurrentBlock(newCS);
return {newCS, CFCaseStmt};
addToCurrentBlock(newCS);
return {endBlock(direction::forward), CFCaseStmt};
}

StmtDiff ReverseModeVisitor::VisitBreakStmt(const BreakStmt* BS) {
@@ -3784,25 +3861,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_SwitchCases.empty() && !m_IsInvokedBySwitchStmt)
return;

// Add case statement in the beginning of the reverse block
// and corresponding push expression for this case statement
// at the end of the forward block to cover the case when no
// `break`/`continue` statements are hit.
auto* lastSC = GetNextCFCaseStmt();
auto* pushExprToCurrentCase = CreateCFTapePushExprToCurrentCase();

Stmt* forwBlock = nullptr;
Stmt* revBlock = nullptr;

forwBlock = utils::AppendAndCreateCompoundStmt(m_RMV.m_Context,
bodyDiff.getStmt(),
pushExprToCurrentCase);
revBlock = utils::PrependAndCreateCompoundStmt(m_RMV.m_Context,
bodyDiff.getStmt_dx(),
lastSC);

bodyDiff = {forwBlock, revBlock};

auto condResult = m_RMV.m_Sema.ActOnCondition(m_RMV.getCurrentScope(),
noLoc, m_ControlFlowTape->Pop,
Sema::ConditionKind::Switch);