From 764fa88f6241d18f3b2b57406fc6721d9d2e256d Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Mon, 20 Nov 2023 18:28:16 +0000 Subject: [PATCH] Switch to RecursiveASTVisitor. --- include/clad/Differentiator/TBRAnalyzer.h | 49 ++----- lib/Differentiator/TBRAnalyzer.cpp | 149 ++++++++++------------ 2 files changed, 78 insertions(+), 120 deletions(-) diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index 0aa5197ce..f8339811d 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -1,7 +1,7 @@ #ifndef CLAD_DIFFERENTIATOR_TBRANALYZER_H #define CLAD_DIFFERENTIATOR_TBRANALYZER_H -#include "clang/AST/StmtVisitor.h" +#include "clang/AST/RecursiveASTVisitor.h" #include "clang/Analysis/CFG.h" #include "clad/Differentiator/CladUtils.h" @@ -14,9 +14,7 @@ using namespace clang; namespace clad { -class TBRAnalyzer : public clang::ConstStmtVisitor { -private: - +class TBRAnalyzer : public clang::RecursiveASTVisitor { /// ProfileID is the key type for ArrMap used to represent array indices /// and object fields. using ProfileID = clad_compat::FoldingSetNodeID; @@ -295,37 +293,18 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// Visitors void Analyze(const clang::FunctionDecl* FD); - void VisitCFGBlock(const clang::CFGBlock* block); - - void Visit(const clang::Stmt* stmt) { - clang::ConstStmtVisitor::Visit(stmt); - } - - void VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); - void VisitBinaryOperator(const clang::BinaryOperator* BinOp); - void VisitCallExpr(const clang::CallExpr* CE); - void VisitCompoundStmt(const clang::CompoundStmt* CS); - void VisitConditionalOperator(const clang::ConditionalOperator* CO); - void VisitCXXConstructExpr(const clang::CXXConstructExpr* CE); - void VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); - void VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE); - void VisitDeclRefExpr(const clang::DeclRefExpr* DRE); - void VisitDeclStmt(const clang::DeclStmt* DS); - void VisitExprWithCleanups(const clang::ExprWithCleanups* EWC); - void VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE); - void VisitInitListExpr(const clang::InitListExpr* ILE); - void VisitMemberExpr(const clang::MemberExpr* ME); - void VisitParenExpr(const clang::ParenExpr* PE); - void VisitReturnStmt(const clang::ReturnStmt* RS); - void VisitUnaryOperator(const clang::UnaryOperator* UnOp); - - /// FIXME: Make sure these are not necessary - /// Unused Visitors: - // void VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); - // void VisitCXXThisExpr(const clang::CXXThisExpr* TE); - // void VisitFloatingLiteral(const clang::FloatingLiteral* FL); - // void VisitIntegerLiteral(const clang::IntegerLiteral* IL); - // void VisitStmt(const clang::Stmt* S); + void VisitCFGBlock(const clang::CFGBlock& block); + + bool VisitArraySubscriptExpr(clang::ArraySubscriptExpr* ASE); + bool VisitBinaryOperator(clang::BinaryOperator* BinOp); + bool VisitCallExpr(clang::CallExpr* CE); + bool VisitConditionalOperator(clang::ConditionalOperator* CO); + bool VisitCXXConstructExpr(clang::CXXConstructExpr* CE); + bool VisitDeclRefExpr(clang::DeclRefExpr* DRE); + bool VisitDeclStmt(clang::DeclStmt* DS); + bool VisitInitListExpr(clang::InitListExpr* ILE); + bool VisitMemberExpr(clang::MemberExpr* ME); + bool VisitUnaryOperator(clang::UnaryOperator* UnOp); }; } // end namespace clad diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index dfa3fb596..2a634c8a5 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -348,7 +348,7 @@ void TBRAnalyzer::Analyze(const FunctionDecl* FD) { curBlockID = *IDIter; CFGQueue.erase(IDIter); - auto* nextBlock = getCFGBlockByID(curBlockID); + CFGBlock& nextBlock = *getCFGBlockByID(curBlockID); VisitCFGBlock(nextBlock); } // for (int id = curBlockID; id >= 0; --id) { @@ -360,24 +360,24 @@ void TBRAnalyzer::Analyze(const FunctionDecl* FD) { // } } -void TBRAnalyzer::VisitCFGBlock(const CFGBlock* block) { +void TBRAnalyzer::VisitCFGBlock(const CFGBlock& block) { // llvm::errs() << "\n-----BLOCK" << block->getBlockID() << "-----\n"; /// Visiting loop blocks just once is not enough since the end of one /// loop iteration may have an effect on the next one. However, two /// iterations is always enough. Allow a third visit without going to /// successors to correctly analyse loop conditions. - bool notLastPass = ++blockPassCounter[block->getBlockID()] <= 2; + bool notLastPass = ++blockPassCounter[block.getBlockID()] <= 2; /// Visit all the statements inside the block. - for (const clang::CFGElement& Element : *block) { + for (const clang::CFGElement& Element : block) { if (Element.getKind() == clang::CFGElement::Statement) { - const auto* Stmt = Element.castAs().getStmt(); - Visit(Stmt); + const clang::Stmt* S = Element.castAs().getStmt(); + TraverseStmt(const_cast(S)); } } /// Traverse successor CFG blocks. - for (const auto succ : block->succs()) { + for (const auto succ : block.succs()) { /// Sometimes clang CFG does not create blocks for parts of code that /// are never executed (e.g. 'if (0) {...'). Add this check for safety. if (!succ) @@ -389,7 +389,7 @@ void TBRAnalyzer::VisitCFGBlock(const CFGBlock* block) { /// current block as previous. if (!varsData) { varsData = std::unique_ptr(new VarsData()); - varsData->prev = blockData[block->getBlockID()].get(); + varsData->prev = blockData[block.getBlockID()].get(); } /// If this is the third (last) pass of block, it means block represents @@ -401,7 +401,7 @@ void TBRAnalyzer::VisitCFGBlock(const CFGBlock* block) { /// This part is necessary for loops. For other cases, this is not /// supposed to do anything. - if (succ->getBlockID() < block->getBlockID()) { + if (succ->getBlockID() < block.getBlockID()) { /// If there is another loop condition present inside a loop, /// We have to set it's loop pass counter to 0 (it might be 3 /// from the previous outer loop pass). @@ -413,8 +413,8 @@ void TBRAnalyzer::VisitCFGBlock(const CFGBlock* block) { /// If the successor's previous block is not this one, /// perform a merge. - if (varsData->prev != blockData[block->getBlockID()].get()) { - merge(varsData.get(), blockData[block->getBlockID()].get()); + if (varsData->prev != blockData[block.getBlockID()].get()) { + merge(varsData.get(), blockData[block.getBlockID()].get()); } } // llvm::errs() << "----------------\n\n"; @@ -548,12 +548,7 @@ void TBRAnalyzer::merge(VarsData* targetData, VarsData* mergeData) { } } -void TBRAnalyzer::VisitCompoundStmt(const CompoundStmt* CS) { - for (Stmt* S : CS->body()) - Visit(S); -} - -void TBRAnalyzer::VisitDeclRefExpr(const DeclRefExpr* DRE) { +bool TBRAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { if (const auto* VD = dyn_cast(DRE->getDecl())) { auto& curBranch = getCurBlockVarsData(); if (curBranch.find(VD) == curBranch.end()) @@ -562,39 +557,17 @@ void TBRAnalyzer::VisitDeclRefExpr(const DeclRefExpr* DRE) { if (const auto* E = dyn_cast(DRE)) setIsRequired(E); -} - -void TBRAnalyzer::VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE) { - Visit(ICE->getSubExpr()); -} - -void TBRAnalyzer::VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE) { - Visit(DE->getExpr()); -} -void TBRAnalyzer::VisitParenExpr(const clang::ParenExpr* PE) { - Visit(PE->getSubExpr()); + return true; } -void TBRAnalyzer::VisitReturnStmt(const clang::ReturnStmt* RS) { - Visit(RS->getRetValue()); -} - -void TBRAnalyzer::VisitExprWithCleanups(const clang::ExprWithCleanups* EWC) { - Visit(EWC->getSubExpr()); -} - -void TBRAnalyzer::VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE) { - Visit(SCE->getSubExpr()); -} - -void TBRAnalyzer::VisitDeclStmt(const DeclStmt* DS) { - for (const auto* D : DS->decls()) { - if (const auto* VD = dyn_cast(D)) { +bool TBRAnalyzer::VisitDeclStmt(DeclStmt* DS) { + for (auto* D : DS->decls()) { + if (auto* VD = dyn_cast(D)) { addVar(VD); - if (const clang::Expr* init = VD->getInit()) { + if (clang::Expr* init = VD->getInit()) { setMode(Mode::markingMode); - Visit(init); + TraverseStmt(init); resetMode(); auto& VDExpr = getCurBlockVarsData()[VD]; /// if the declared variable is ref type attach its VarData to the @@ -607,38 +580,39 @@ void TBRAnalyzer::VisitDeclStmt(const DeclStmt* DS) { } } } + return true; } -void TBRAnalyzer::VisitConditionalOperator( - const clang::ConditionalOperator* CO) { +bool TBRAnalyzer::VisitConditionalOperator(clang::ConditionalOperator* CO) { setMode(0); - Visit(CO->getCond()); + TraverseStmt(CO->getCond()); resetMode(); auto elseBranch = std::move(blockData[curBlockID]); blockData[curBlockID] = std::unique_ptr(new VarsData()); blockData[curBlockID]->prev = elseBranch.get(); - Visit(CO->getTrueExpr()); + TraverseStmt(CO->getTrueExpr()); auto thenBranch = std::move(blockData[curBlockID]); blockData[curBlockID] = std::move(elseBranch); - Visit(CO->getTrueExpr()); + TraverseStmt(CO->getTrueExpr()); merge(blockData[curBlockID].get(), thenBranch.get()); + return true; } -void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { +bool TBRAnalyzer::VisitBinaryOperator(BinaryOperator* BinOp) { const auto opCode = BinOp->getOpcode(); - const auto* L = BinOp->getLHS(); - const auto* R = BinOp->getRHS(); + Expr* L = BinOp->getLHS(); + Expr* R = BinOp->getRHS(); /// Addition is not able to create any differential influence by itself so /// markingMode should be left as it is. Similarly, addition does not affect /// linearity so nonLinearMode shouldn't be changed as well. The same applies /// to subtraction. if (opCode == BO_Add || opCode == BO_Sub) { - Visit(L); - Visit(R); + TraverseStmt(L); + TraverseStmt(R); } else if (opCode == BO_Mul) { /// Multiplication results in a linear expression if and only if one of the /// factors is constant. @@ -649,8 +623,8 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { if (nonLinear) startNonLinearMode(); - Visit(L); - Visit(R); + TraverseStmt(L); + TraverseStmt(R); if (nonLinear) resetMode(); @@ -663,8 +637,8 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { if (nonLinear) startNonLinearMode(); - Visit(L); - Visit(R); + TraverseStmt(L); + TraverseStmt(R); if (nonLinear) resetMode(); @@ -673,10 +647,10 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { opCode == BO_SubAssign) { /// Since we only care about non-linear usages of variables, there is /// no difference between operators =, -=, += in terms of TBR analysis. - Visit(L); + TraverseStmt(L); startMarkingMode(); - Visit(R); + TraverseStmt(R); resetMode(); } else if (opCode == BO_MulAssign || opCode == BO_DivAssign) { /// *= (/=) normally only performs a linear operation if and only if @@ -688,12 +662,12 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { !clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context); if (RisNotConst) setMode(Mode::markingMode | Mode::nonLinearMode); - Visit(L); + TraverseStmt(L); if (RisNotConst) resetMode(); setMode(Mode::markingMode | Mode::nonLinearMode); - Visit(R); + TraverseStmt(R); resetMode(); } llvm::SmallVector ExprsToStore; @@ -709,20 +683,21 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { } } else if (opCode == BO_Comma) { setMode(0); - Visit(L); + TraverseStmt(L); resetMode(); - Visit(R); + TraverseStmt(R); } // else { // FIXME: add logic/bitwise/comparison operators // } + return true; } -void TBRAnalyzer::VisitUnaryOperator(const clang::UnaryOperator* UnOp) { +bool TBRAnalyzer::VisitUnaryOperator(clang::UnaryOperator* UnOp) { const auto opCode = UnOp->getOpcode(); - const Expr* E = UnOp->getSubExpr(); - Visit(E); + Expr* E = UnOp->getSubExpr(); + TraverseStmt(E); if (opCode == UO_PostInc || opCode == UO_PostDec || opCode == UO_PreInc || opCode == UO_PreDec) { // FIXME: this doesn't support all the possible references @@ -736,24 +711,25 @@ void TBRAnalyzer::VisitUnaryOperator(const clang::UnaryOperator* UnOp) { markLocation(innerExpr); } } + return true; } -void TBRAnalyzer::VisitCallExpr(const clang::CallExpr* CE) { +bool TBRAnalyzer::VisitCallExpr(clang::CallExpr* CE) { /// FIXME: Currently TBR analysis just stops here and assumes that all the /// variables passed by value/reference are used/used and changed. Analysis /// could proceed to the function to analyse data flow inside it. - const auto* FD = CE->getDirectCallee(); + FunctionDecl* FD = CE->getDirectCallee(); bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); setMode(Mode::markingMode | Mode::nonLinearMode); for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { - const clang::Expr* arg = CE->getArg(i); + clang::Expr* arg = CE->getArg(i); bool passByRef = false; if (noHiddenParam) passByRef = FD->getParamDecl(i)->getType()->isReferenceType(); else if (i!=0) passByRef = FD->getParamDecl(i - 1)->getType()->isReferenceType(); setMode(Mode::markingMode | Mode::nonLinearMode); - Visit(arg); + TraverseStmt(arg); resetMode(); const auto* B = arg->IgnoreParenImpCasts(); // FIXME: this supports only DeclRefExpr @@ -766,20 +742,21 @@ void TBRAnalyzer::VisitCallExpr(const clang::CallExpr* CE) { } } resetMode(); + return true; } -void TBRAnalyzer::VisitCXXConstructExpr(const clang::CXXConstructExpr* CE) { +bool TBRAnalyzer::VisitCXXConstructExpr(clang::CXXConstructExpr* CE) { /// FIXME: Currently TBR analysis just stops here and assumes that all the /// variables passed by value/reference are used/used and changed. Analysis /// could proceed to the constructor to analyse data flow inside it. /// FIXME: add support for default values - auto* FD = CE->getConstructor(); + FunctionDecl* FD = CE->getConstructor(); setMode(Mode::markingMode | Mode::nonLinearMode); for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { - const auto* arg = CE->getArg(i); + auto* arg = CE->getArg(i); bool passByRef = FD->getParamDecl(i)->getType()->isReferenceType(); setMode(Mode::markingMode | Mode::nonLinearMode); - Visit(arg); + TraverseStmt(arg); resetMode(); const auto* B = arg->IgnoreParenImpCasts(); // FIXME: this supports only DeclRefExpr @@ -792,29 +769,31 @@ void TBRAnalyzer::VisitCXXConstructExpr(const clang::CXXConstructExpr* CE) { } } resetMode(); + return true; } -void TBRAnalyzer::VisitMemberExpr(const clang::MemberExpr* ME) { +bool TBRAnalyzer::VisitMemberExpr(clang::MemberExpr* ME) { setIsRequired(dyn_cast(ME)); + return true; } -void TBRAnalyzer::VisitArraySubscriptExpr( - const clang::ArraySubscriptExpr* ASE) { +bool TBRAnalyzer::VisitArraySubscriptExpr(clang::ArraySubscriptExpr* ASE) { setMode(0); - Visit(ASE->getBase()); + TraverseStmt(ASE->getBase()); resetMode(); setIsRequired(dyn_cast(ASE)); setMode(Mode::markingMode | Mode::nonLinearMode); - Visit(ASE->getIdx()); + TraverseStmt(ASE->getIdx()); resetMode(); + return true; } -void TBRAnalyzer::VisitInitListExpr(const clang::InitListExpr* ILE) { +bool TBRAnalyzer::VisitInitListExpr(clang::InitListExpr* ILE) { setMode(Mode::markingMode); - for (auto* init : ILE->inits()) { - Visit(init); - } + for (auto* init : ILE->inits()) + TraverseStmt(init); resetMode(); + return true; } } // end namespace clad