Skip to content

Commit

Permalink
Move unwrapIfSingleStmt to the clad::utils namespace
Browse files Browse the repository at this point in the history
This commit moves the `unwrapIfSingleStmt` function that
was previously defined as a helper function for the reverse
mode differentiation only.
  • Loading branch information
gojakuch committed May 30, 2024
1 parent b1a2c36 commit a0ff8a8
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 30 deletions.
4 changes: 4 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ namespace clad {
/// function `FD`.
std::string ComputeEffectiveFnName(const clang::FunctionDecl* FD);

// Unwraps S to a single statement if it's a compound statement only
// containing 1 statement.
clang::Stmt* unwrapIfSingleStmt(clang::Stmt* S);

/// Creates and returns a compound statement having statements as follows:
/// {`S`, all the statement of `initial` in sequence}
clang::CompoundStmt* PrependAndCreateCompoundStmt(clang::ASTContext& C,
Expand Down
6 changes: 1 addition & 5 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,11 +763,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
StmtDiff bodyVisited = Visit(body);
for (Stmt* S : bodyVisited.getBothStmts())
addToCurrentBlock(S);
CompoundStmt* bodyResultCmpd = endBlock();
if (bodyResultCmpd->size() == 1)
bodyResult = bodyResultCmpd->body_front();
else
bodyResult = bodyResultCmpd;
bodyResult = utils::unwrapIfSingleStmt(endBlock());
endScope();

Stmt* forStmtDiff = new (m_Context)
Expand Down
13 changes: 13 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ namespace clad {
}
}

Stmt* unwrapIfSingleStmt(Stmt* S) {
if (!S)
return nullptr;
if (!isa<CompoundStmt>(S))
return S;
auto* CS = cast<CompoundStmt>(S);
if (CS->size() == 0)
return nullptr;
if (CS->size() == 1)
return CS->body_front();
return CS;
}

CompoundStmt* PrependAndCreateCompoundStmt(ASTContext& C, Stmt* initial,
Stmt* S) {
llvm::SmallVector<Stmt*, 16> block;
Expand Down
38 changes: 13 additions & 25 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,19 +783,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Forward, Reverse);
}

static Stmt* unwrapIfSingleStmt(Stmt* S) {
if (!S)
return nullptr;
if (!isa<CompoundStmt>(S))
return S;
auto* CS = cast<CompoundStmt>(S);
if (CS->size() == 0)
return nullptr;
if (CS->size() == 1)
return CS->body_front();
return CS;
}

StmtDiff ReverseModeVisitor::VisitIfStmt(const clang::IfStmt* If) {
// Control scope of the IfStmt. E.g., in if (double x = ...) {...}, x goes
// to this scope.
Expand Down Expand Up @@ -888,8 +875,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource
->ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt();

Stmt* Forward = unwrapIfSingleStmt(endBlock(direction::forward));
Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx());
Stmt* Forward = utils::unwrapIfSingleStmt(endBlock(direction::forward));
Stmt* Reverse = utils::unwrapIfSingleStmt(BranchDiff.getStmt_dx());
return StmtDiff(Forward, Reverse);
};

Expand All @@ -914,8 +901,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CompoundStmt* ForwardBlock = endBlock(direction::forward);
CompoundStmt* ReverseBlock = endBlock(direction::reverse);
endScope();
return StmtDiff(unwrapIfSingleStmt(ForwardBlock),
unwrapIfSingleStmt(ReverseBlock));
return StmtDiff(utils::unwrapIfSingleStmt(ForwardBlock),
utils::unwrapIfSingleStmt(ReverseBlock));
}

StmtDiff ReverseModeVisitor::VisitConditionalOperator(
Expand All @@ -941,8 +928,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto Result = DifferentiateSingleExpr(Branch, dfdx);
StmtDiff BranchDiff = Result.first;
StmtDiff ExprDiff = Result.second;
Stmt* Forward = unwrapIfSingleStmt(BranchDiff.getStmt());
Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx());
Stmt* Forward = utils::unwrapIfSingleStmt(BranchDiff.getStmt());
Stmt* Reverse = utils::unwrapIfSingleStmt(BranchDiff.getStmt_dx());
return {StmtDiff(Forward, Reverse), ExprDiff};
};

Expand Down Expand Up @@ -1128,7 +1115,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Reverse = endBlock(direction::reverse);
endScope();

return {unwrapIfSingleStmt(Forward), unwrapIfSingleStmt(Reverse)};
return {utils::unwrapIfSingleStmt(Forward),
utils::unwrapIfSingleStmt(Reverse)};
}

StmtDiff
Expand Down Expand Up @@ -2766,7 +2754,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse);
CompoundStmt* RCS = endBlock(direction::reverse);
std::reverse(RCS->body_begin(), RCS->body_end());
Stmt* ReverseResult = unwrapIfSingleStmt(RCS);
Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS);
return StmtDiff(SDiff.getStmt(), ReverseResult);
}

Expand All @@ -2780,7 +2768,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CompoundStmt* RCS = endBlock(direction::reverse);
Stmt* ForwardResult = endBlock(direction::forward);
std::reverse(RCS->body_begin(), RCS->body_end());
Stmt* ReverseResult = unwrapIfSingleStmt(RCS);
Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS);
return {StmtDiff(ForwardResult, ReverseResult), EDiff};
}

Expand Down Expand Up @@ -2929,7 +2917,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (Decl* decl : decls)
addToBlock(BuildDeclStmt(decl), m_Globals);
Stmt* initAssignments = MakeCompoundStmt(inits);
initAssignments = unwrapIfSingleStmt(initAssignments);
initAssignments = utils::unwrapIfSingleStmt(initAssignments);
return StmtDiff(initAssignments);
}

Expand Down Expand Up @@ -3574,7 +3562,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (m_ExternalSource)
m_ExternalSource->ActAfterProcessingSingleStmtBodyInVisitForLoop();

Stmt* reverseBlock = unwrapIfSingleStmt(bodyDiff.getStmt_dx());
Stmt* reverseBlock = utils::unwrapIfSingleStmt(bodyDiff.getStmt_dx());
bodyDiff = {endBlock(direction::forward), reverseBlock};
// for forward-pass loop statement body
endScope();
Expand Down Expand Up @@ -3619,7 +3607,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(condVarDiff, direction::reverse);
addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse);
bodyDiff = {bodyDiff.getStmt(),
unwrapIfSingleStmt(endBlock(direction::reverse))};
utils::unwrapIfSingleStmt(endBlock(direction::reverse))};
return bodyDiff;
}

Expand Down

0 comments on commit a0ff8a8

Please sign in to comment.