diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index ae2d26a99..a784f7e21 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -27,6 +27,10 @@ namespace clad { /// function `FD`. std::string ComputeEffectiveFnName(const clang::FunctionDecl* FD); + // Unwraps S to a single statement if it's a compound statement only + // containing 1 statement. + clang::Stmt* unwrapIfSingleStmt(clang::Stmt* S); + /// Creates and returns a compound statement having statements as follows: /// {`S`, all the statement of `initial` in sequence} clang::CompoundStmt* PrependAndCreateCompoundStmt(clang::ASTContext& C, diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index b7e7706b1..3d16cd8e4 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -764,11 +764,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { StmtDiff bodyVisited = Visit(body); for (Stmt* S : bodyVisited.getBothStmts()) addToCurrentBlock(S); - CompoundStmt* bodyResultCmpd = endBlock(); - if (bodyResultCmpd->size() == 1) - bodyResult = bodyResultCmpd->body_front(); - else - bodyResult = bodyResultCmpd; + bodyResult = utils::unwrapIfSingleStmt(endBlock()); endScope(); Stmt* forStmtDiff = new (m_Context) diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index c180d0b1a..831bbcb08 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -102,6 +102,19 @@ namespace clad { } } + Stmt* unwrapIfSingleStmt(Stmt* S) { + if (!S) + return nullptr; + if (!isa(S)) + return S; + auto* CS = cast(S); + if (CS->size() == 0) + return nullptr; + if (CS->size() == 1) + return CS->body_front(); + return CS; + } + CompoundStmt* PrependAndCreateCompoundStmt(ASTContext& C, Stmt* initial, Stmt* S) { llvm::SmallVector block; diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index f9ec442d8..4aba32332 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -783,19 +783,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Forward, Reverse); } - static Stmt* unwrapIfSingleStmt(Stmt* S) { - if (!S) - return nullptr; - if (!isa(S)) - return S; - auto* CS = cast(S); - if (CS->size() == 0) - return nullptr; - if (CS->size() == 1) - return CS->body_front(); - return CS; - } - StmtDiff ReverseModeVisitor::VisitIfStmt(const clang::IfStmt* If) { // Control scope of the IfStmt. E.g., in if (double x = ...) {...}, x goes // to this scope. @@ -888,8 +875,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource ->ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt(); - Stmt* Forward = unwrapIfSingleStmt(endBlock(direction::forward)); - Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); + Stmt* Forward = utils::unwrapIfSingleStmt(endBlock(direction::forward)); + Stmt* Reverse = utils::unwrapIfSingleStmt(BranchDiff.getStmt_dx()); return StmtDiff(Forward, Reverse); }; @@ -914,8 +901,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CompoundStmt* ForwardBlock = endBlock(direction::forward); CompoundStmt* ReverseBlock = endBlock(direction::reverse); endScope(); - return StmtDiff(unwrapIfSingleStmt(ForwardBlock), - unwrapIfSingleStmt(ReverseBlock)); + return StmtDiff(utils::unwrapIfSingleStmt(ForwardBlock), + utils::unwrapIfSingleStmt(ReverseBlock)); } StmtDiff ReverseModeVisitor::VisitConditionalOperator( @@ -941,8 +928,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto Result = DifferentiateSingleExpr(Branch, dfdx); StmtDiff BranchDiff = Result.first; StmtDiff ExprDiff = Result.second; - Stmt* Forward = unwrapIfSingleStmt(BranchDiff.getStmt()); - Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); + Stmt* Forward = utils::unwrapIfSingleStmt(BranchDiff.getStmt()); + Stmt* Reverse = utils::unwrapIfSingleStmt(BranchDiff.getStmt_dx()); return {StmtDiff(Forward, Reverse), ExprDiff}; }; @@ -1128,7 +1115,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Reverse = endBlock(direction::reverse); endScope(); - return {unwrapIfSingleStmt(Forward), unwrapIfSingleStmt(Reverse)}; + return {utils::unwrapIfSingleStmt(Forward), + utils::unwrapIfSingleStmt(Reverse)}; } StmtDiff @@ -2766,7 +2754,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse); CompoundStmt* RCS = endBlock(direction::reverse); std::reverse(RCS->body_begin(), RCS->body_end()); - Stmt* ReverseResult = unwrapIfSingleStmt(RCS); + Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS); return StmtDiff(SDiff.getStmt(), ReverseResult); } @@ -2780,7 +2768,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CompoundStmt* RCS = endBlock(direction::reverse); Stmt* ForwardResult = endBlock(direction::forward); std::reverse(RCS->body_begin(), RCS->body_end()); - Stmt* ReverseResult = unwrapIfSingleStmt(RCS); + Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS); return {StmtDiff(ForwardResult, ReverseResult), EDiff}; } @@ -2929,7 +2917,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (Decl* decl : decls) addToBlock(BuildDeclStmt(decl), m_Globals); Stmt* initAssignments = MakeCompoundStmt(inits); - initAssignments = unwrapIfSingleStmt(initAssignments); + initAssignments = utils::unwrapIfSingleStmt(initAssignments); return StmtDiff(initAssignments); } @@ -3574,7 +3562,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterProcessingSingleStmtBodyInVisitForLoop(); - Stmt* reverseBlock = unwrapIfSingleStmt(bodyDiff.getStmt_dx()); + Stmt* reverseBlock = utils::unwrapIfSingleStmt(bodyDiff.getStmt_dx()); bodyDiff = {endBlock(direction::forward), reverseBlock}; // for forward-pass loop statement body endScope(); @@ -3619,7 +3607,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(condVarDiff, direction::reverse); addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse); bodyDiff = {bodyDiff.getStmt(), - unwrapIfSingleStmt(endBlock(direction::reverse))}; + utils::unwrapIfSingleStmt(endBlock(direction::reverse))}; return bodyDiff; }