diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index 213cf1d78..f2943ab23 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -2,6 +2,8 @@ #define CLAD_DIFFERENTIATOR_TBRANALYZER_H #include "clang/AST/StmtVisitor.h" +#include "clang/Analysis/CFG.h" + #include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/Compatibility.h" @@ -131,14 +133,14 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { VarData(const QualType QT); ~VarData() { - if (type == OBJ_TYPE) { + if (type == OBJ_TYPE) for (auto& pair : *val.objData) delete pair.second; - } else if (type == ARR_TYPE) { + else if (type == ARR_TYPE) for (auto& pair : *val.arrData) delete pair.second; - } } + /// Recursively sets all the leaves' bools to isReq. void setIsRequired(bool isReq = true); /// Returns true if there is at least one required to store node among @@ -167,6 +169,8 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { void restoreRefs(std::unordered_map& refVars); }; + clang::CFGBlock* getCFGBlockByID(unsigned ID); + /// Given a MemberExpr*/ArraySubscriptExpr* return a pointer to its /// corresponding VarData. If the given element of an array does not have a /// VarData* yet it will be added automatically. If addNonConstIdx==false this @@ -191,7 +195,33 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// particular moment. /// Note: 'this' pointer does not have a declaration so nullptr is used as /// its key instead. - using VarsData = std::unordered_map; + struct VarsData { + std::unordered_map data = + std::unordered_map(); + VarsData* prev = nullptr; + + VarsData() {} + VarsData(VarsData& other) : data(other.data), prev(other.prev) {} + + using iterator = + std::unordered_map::iterator; + iterator begin() { return data.begin(); } + iterator end() { return data.end(); } + VarData*& operator[](const clang::VarDecl* VD) { return data[VD]; } + iterator find(const clang::VarDecl* VD) { return data.find(VD); } + void emplace(const clang::VarDecl* VD, VarData* varsData) { + data.emplace(VD, varsData); + } + void emplace(std::pair pair) { + data.emplace(pair); + } + + std::unique_ptr + collectDataFromPredecessors(VarsData* limit = nullptr); + VarsData* findLowestCommonAncestor(VarsData* other); + void merge(VarsData* mergeData); + }; + /// Used to find DeclRefExpr's that will be used in the backwards pass. /// In order to be marked as required, a variables has to appear in a place /// where it would have a differential influence and will appear non-linearly @@ -201,31 +231,23 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// Tells if the variable at a given location is required to store. Basically, /// is the result of analysis. std::map TBRLocs; - /// Stores VarsData for every branch in control flow (e.g. if-else statements, - /// loops). - std::vector> reqStack; + /// Stores modes in a stack (used to retrieve the old mode after entering /// a new one). - std::vector modeStack; - /// Stores local variables to delete them after exiting the corresponding - /// scope. - /// Note: This is not used every time a new scope is entered. This is only - /// used when merging an if-else statement to get rid of local variables in - /// the then-branch. - std::vector> localVarsStack; + std::vector modeStack; ASTContext* m_Context; - /// The index of the innermost branch corresponding to a loop (used to handle - /// break/continue statements). - size_t innermostLoopLayer = 0; - /// Tells if the current branch should be deleted instead of merged with - /// others. This happens when the branch has a break/continue statement or a - /// return expression in it. - bool deleteCurBranch = false; - /// Loop bodies have to be passed twice. This tells us what pass is currently - /// happening. - bool firstLoopPass = false; + std::unique_ptr m_CFG; + + std::vector blockData; + + std::vector blockPassCounter; + + unsigned curBlockID; + + std::set CFGQueue; + /// Set to true when a non-const index is found while analysing an /// array subscript expression. bool nonConstIndexFound = false; @@ -242,32 +264,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// ArraySubscriptExpr* or MemberExpr*. void setIsRequired(const clang::Expr* E, bool isReq = true); - //// Control Flow - /// Returns the current branch. - VarsData& getCurBranch() { return reqStack.back().back(); } - /// Adds a new layer. - void addLayer() { reqStack.emplace_back(); } - /// Creates a new empty branch. - void addBranch() { reqStack.back().emplace_back(); } - /// Deletes the last branch. - void deleteBranch() { - for (auto& pair : getCurBranch()) - delete pair.second; - reqStack.back().pop_back(); - } - /// Merges the last layer into the one last branch on the previous layer - /// right and deletes the last layer. - void mergeLayer(); - /// Merges the last layer but, unlike the previous method, basically replaces - /// the last branch on the previous layer with the result of merging. After - /// that, removes the last layer. - void mergeLayerOnTop(); - /// Merges the branch with index targetBranch into a sourceBranchNum. - /// No branches are deleted. - void mergeBranchTo(size_t sourceBranchNum, VarsData& targetBranch); - /// Removes local variables from the current branch (uses localVarsStack). - /// This is necessary when merging if-else branches. - void removeLocalVars(); + VarsData& getCurBranch() { return *blockData[curBlockID]; } //// Modes Setters /// Sets the mode manually @@ -287,16 +284,17 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// Constructor TBRAnalyzer(ASTContext* m_Context) : m_Context(m_Context) { modeStack.push_back(0); - addLayer(); - addBranch(); } /// Destructor ~TBRAnalyzer() { - for (auto& layer : reqStack) - for (auto& branch : layer) - for (auto& pair : branch) + for (auto varsData : blockData) { + if (varsData) { + for (auto pair : *varsData) delete pair.second; + delete varsData; + } + } } /// Delete copy/move operators and constructors. @@ -311,33 +309,29 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// Visitors void Analyze(const clang::FunctionDecl* FD); + void VisitCFGBlock(clang::CFGBlock* block); + void Visit(const clang::Stmt* stmt) { clang::ConstStmtVisitor::Visit(stmt); } void VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); void VisitBinaryOperator(const clang::BinaryOperator* BinOp); - void VisitBreakStmt(const clang::BreakStmt* BS); void VisitCallExpr(const clang::CallExpr* CE); void VisitCompoundStmt(const clang::CompoundStmt* CS); void VisitConditionalOperator(const clang::ConditionalOperator* CO); - void VisitContinueStmt(const clang::ContinueStmt* CS); void VisitCXXConstructExpr(const clang::CXXConstructExpr* CE); void VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); void VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE); void VisitDeclRefExpr(const clang::DeclRefExpr* DRE); void VisitDeclStmt(const clang::DeclStmt* DS); - void VisitDoStmt(const clang::DoStmt* DS); void VisitExprWithCleanups(const clang::ExprWithCleanups* EWC); - void VisitForStmt(const clang::ForStmt* FS); - void VisitIfStmt(const clang::IfStmt* If); void VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE); void VisitInitListExpr(const clang::InitListExpr* ILE); void VisitMemberExpr(const clang::MemberExpr* ME); void VisitParenExpr(const clang::ParenExpr* PE); void VisitReturnStmt(const clang::ReturnStmt* RS); void VisitUnaryOperator(const clang::UnaryOperator* UnOp); - void VisitWhileStmt(const clang::WhileStmt* WS); /// FIXME: Make sure these are not necessary /// Unused Visitors: diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 1bdbf1c3b..7045dac4d 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -5,17 +5,16 @@ using namespace clang; namespace clad { void TBRAnalyzer::VarData::setIsRequired(bool isReq) { - if (type == FUND_TYPE) { + if (type == FUND_TYPE) val.fundData = isReq; - } else if (type == OBJ_TYPE) { + else if (type == OBJ_TYPE) for (auto& pair : *val.objData) pair.second->setIsRequired(isReq); - } else if (type == ARR_TYPE) { + else if (type == ARR_TYPE) for (auto& pair : *val.arrData) pair.second->setIsRequired(isReq); - } else if (type == REF_TYPE && val.refData) { + else if (type == REF_TYPE && val.refData) val.refData->setIsRequired(isReq); - } } void TBRAnalyzer::VarData::merge(VarData* mergeData) { @@ -39,7 +38,8 @@ void TBRAnalyzer::VarData::merge(VarData* mergeData) { } } } else if (this->type == REF_TYPE && this->val.refData) { - this->val.refData->merge(mergeData->val.refData); + /// FIXME: add support for merging references. + // this->val.refData->merge(mergeData->val.refData); } } @@ -130,9 +130,8 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, const auto* base = ME->getBase(); VarData* baseData = getExprVarData(base); /// If the VarData is ref type just go to the VarData being referenced. - if (baseData && baseData->type == VarData::VarDataType::REF_TYPE) { + if (baseData && baseData->type == VarData::VarDataType::REF_TYPE) baseData = baseData->val.refData; - } if (!baseData) return nullptr; @@ -163,9 +162,8 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, const auto* base = ASE->getBase()->IgnoreImpCasts(); VarData* baseData = getExprVarData(base); /// If the VarData is ref type just go to the VarData being referenced. - if (baseData && baseData->type == VarData::VarDataType::REF_TYPE) { + if (baseData && baseData->type == VarData::VarDataType::REF_TYPE) baseData = baseData->val.refData; - } if (!baseData) return nullptr; @@ -203,15 +201,15 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, /// ``this`` does not have a declaration so it is represented with nullptr. if (const auto* DRE = dyn_cast(E)) VD = dyn_cast(DRE->getDecl()); - /// The index i is shifted since otherwise the last value would be i=-1 - /// and size_t can only take positive values. - for (size_t i = reqStack.size(); i > 0; --i) { - auto& branch = reqStack[i - 1].back(); - const auto it = branch.find(VD); - if (it != branch.end()) { + + auto* branch = &getCurBranch(); + while (branch) { + auto it = branch->find(VD); + if (it != branch->end()) { EData = it->second; break; } + branch = branch->prev; } } if (const auto* ME = dyn_cast(E)) @@ -288,26 +286,16 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { /// treat it as an assigment operation. /// FIXME: this marks the SourceLocation of DeclStmt which doesn't work for /// declarations with multiple VarDecls. - size_t len = reqStack.size(); - auto& curBranch = reqStack[len - 1].back(); - if (curBranch.find(VD) != curBranch.end()) { - auto& VDData = curBranch[VD]; - if (VDData->type == VarData::VarDataType::FUND_TYPE) { - TBRLocs[VD->getBeginLoc()] = - (deleteCurBranch ? false : VDData->findReq()); - } - } - /// The index here is shifted by one since otherwise the loop would end with - /// i=-1 and size_t is positive only. - for (size_t i = len - 1; i > 0; --i) { - auto& branch = reqStack[i - 1].back(); - if (branch.find(VD) != branch.end()) { - curBranch[VD] = branch[VD]->copy(); - break; + auto& curBranch = getCurBranch(); + + auto* branch = curBranch.prev; + while (branch) { + auto it = branch->find(VD); + if (it != branch->end()) { + curBranch[VD] = it->second->copy(); + return; } - } - if (!localVarsStack.empty()) { - localVarsStack.back().push_back(VD); + branch = branch->prev; } QualType varType; @@ -330,7 +318,6 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { return; } } - curBranch[VD] = new VarData(varType); } @@ -345,115 +332,38 @@ void TBRAnalyzer::markLocation(const clang::Expr* E) { /// required to be stored (when passing *= operator) but then marked as not /// required to be stored (when passing = operator). Current method of /// marking locations does not allow to differentiate between these two. - ToBeRec = (deleteCurBranch ? false : ToBeRec || data->findReq()); + ToBeRec = ToBeRec || data->findReq(); } else /// If the current branch is going to be deleted then there is not point in /// storing anything in it. - TBRLocs[E->getBeginLoc()] = !deleteCurBranch; -} - -void TBRAnalyzer::mergeLayer() { - size_t len = reqStack.size(); - auto& removedLayer = reqStack[len - 1]; - auto& curBranch = reqStack[len - 2].back(); - - for (auto& removedBranch : removedLayer) { - for (auto& pair : curBranch) { - auto it = removedBranch.find(pair.first); - if (it != removedBranch.end()) - pair.second->merge(it->second); - } - for (auto& pair : removedBranch) { - auto it = curBranch.find(pair.first); - if (it == curBranch.end()) { - delete curBranch[pair.first]; - curBranch[pair.first] = pair.second; - } else { - delete pair.second; - } - } - } - reqStack.pop_back(); -} - -void TBRAnalyzer::mergeLayerOnTop() { - size_t len = reqStack.size(); - auto& removedLayer = reqStack[len - 1]; - - if (removedLayer.empty()) { - reqStack.pop_back(); - return; - } - - auto& curBranch = reqStack[len - 2].back(); - - /// First, we merge every branch on the layer with the first one there. - auto branchIter = removedLayer.begin(); - auto branchIterEnd = removedLayer.end(); - auto& firstBranch = *branchIter; - while ((++branchIter) != branchIterEnd) { - for (auto& pair : firstBranch) { - auto elemIter = branchIter->find(pair.first); - if (elemIter != branchIter->end()) - pair.second->merge(elemIter->second); - } - for (auto& pair : *branchIter) { - auto elemIter = firstBranch.find(pair.first); - if (elemIter == firstBranch.end()) { - delete firstBranch[pair.first]; - firstBranch[pair.first] = pair.second; - } else { - delete pair.second; - } - } - } - - /// Second, we place it on top of the branch on the previous layer's last - /// branch. - for (auto& pair : firstBranch) { - delete curBranch[pair.first]; - curBranch[pair.first] = pair.second; - } - - reqStack.pop_back(); -} - -void TBRAnalyzer::mergeBranchTo(size_t sourceBranchNum, - VarsData& targetBranch) { - for (auto& pair : targetBranch) { - /// Index i is shifted by one since otherwise its last value could be -1 - /// and size_t is only positive. - for (size_t i = sourceBranchNum + 1; i > 0; --i) { - auto& sourceBranch = reqStack[i - 1].back(); - auto it = sourceBranch.find(pair.first); - if (it != sourceBranch.end()) { - pair.second->merge(it->second); - break; - } - } - } + TBRLocs[E->getBeginLoc()] = true; } void TBRAnalyzer::setIsRequired(const clang::Expr* E, bool isReq) { if (!isReq || (modeStack.back() == (Mode::markingMode | Mode::nonLinearMode))) { VarData* data = getExprVarData(E, /*addNonConstIdx=*/isReq); - if (isReq || !nonConstIndexFound) { + if (isReq || !nonConstIndexFound) data->setIsRequired(isReq); - } /// If an array element with a non-const element is set to required /// all the elements of that array should be set to required. - if (isReq && nonConstIndexFound) { + if (isReq && nonConstIndexFound) overlay(E); - } nonConstIndexFound = false; } } void TBRAnalyzer::Analyze(const FunctionDecl* FD) { - /// If we are analysing a method add a VarData for 'this' pointer (it is + clang::CFG::BuildOptions Options; + m_CFG = clang::CFG::buildCFG(FD, FD->getBody(), m_Context, Options); + blockData.resize(m_CFG->size(), nullptr); + blockPassCounter.resize(m_CFG->size(), 0); + auto* entry = &m_CFG->getEntry(); + curBlockID = entry->getBlockID(); + blockData[curBlockID] = new VarsData(); + + /// If we are analysing a method, add a VarData for 'this' pointer (it is /// represented with nullptr). - if (isa(FD)) { const Type* recordType = dyn_cast(FD->getParent())->getTypeForDecl(); @@ -461,10 +371,171 @@ void TBRAnalyzer::Analyze(const FunctionDecl* FD) { new VarData(QualType::getFromOpaquePtr(recordType)); } auto paramsRef = FD->parameters(); - for (std::size_t i = 0; i < FD->getNumParams(); ++i) addVar(paramsRef[i]); - Visit(FD->getBody()); + CFGQueue.insert(curBlockID); + while (!CFGQueue.empty()) { + auto IDIter = std::prev(CFGQueue.end()); + curBlockID = *IDIter; + CFGQueue.erase(IDIter); + + auto* nextBlock = getCFGBlockByID(curBlockID); + VisitCFGBlock(nextBlock); + } + // for (int id = curBlockID; id >= 0; --id) { + // llvm::errs() << "\n-----BLOCK" << id << "-----\n\n"; + // for (auto succ : getCFGBlockByID(id)->succs()) { + // if (succ) + // llvm::errs() << "successor: " << succ->getBlockID() << "\n"; + // } + // } +} + +void TBRAnalyzer::VisitCFGBlock(CFGBlock* block) { + // llvm::errs() << "\n-----BLOCK" << block->getBlockID() << "-----\n"; + bool notLastPass = ++blockPassCounter[block->getBlockID()] <= 2; + for (clang::CFGElement& Element : *block) { + if (Element.getKind() == clang::CFGElement::Statement) { + auto* Stmt = Element.castAs().getStmt(); + // llvm::errs() << "stmt:\n"; // + // Stmt->dump(); // + Visit(Stmt); + } + } + for (auto succ : block->succs()) { + if (!succ) + continue; + auto*& varsData = blockData[succ->getBlockID()]; + if (!varsData) { + varsData = new VarsData(); + varsData->prev = blockData[block->getBlockID()]; + } else if (varsData->prev != blockData[block->getBlockID()]) { + varsData->merge(blockData[block->getBlockID()]); + } + if (notLastPass) { + CFGQueue.insert(succ->getBlockID()); + if (succ->getBlockID() < block->getBlockID()) + blockPassCounter[succ->getBlockID()] = 0; + } + } + // llvm::errs() << "----------------\n\n"; +} + +CFGBlock* TBRAnalyzer::getCFGBlockByID(unsigned ID) { + return *(m_CFG->begin() + ID); +} + +TBRAnalyzer::VarsData* +TBRAnalyzer::VarsData::findLowestCommonAncestor(TBRAnalyzer::VarsData* other) { + VarsData* pred1 = this; + VarsData* pred2 = other; + while (true) { + if (pred1 == pred2) + return pred1; + + auto branch = this; + while (branch != pred1) { + if (branch == pred2) + return branch; + branch = branch->prev; + } + + branch = other; + while (branch != pred2) { + if (branch == pred1) + return branch; + branch = branch->prev; + } + + if (pred1->prev) { + pred1 = pred1->prev; + /// This ensures we don't get an infinite loop because of VarsData being + /// connected in a loop themselves. + if (pred1 == this) + return nullptr; + } else { + /// pred1 not having a predecessor means it is corresponds to the entry + /// block and, therefore it is the lowest common ancestor. + return pred1; + } + + if (pred2->prev) { + pred2 = pred2->prev; + /// This ensures we don't get an infinite loop because of VarsData being + /// connected in a loop themselves. + if (pred2 == other) + return nullptr; + } else { + /// pred2 not having a predecessor means it is corresponds to the entry + /// block and, therefore it is the lowest common ancestor. + return pred2; + } + } + /// This is not supposed to ever happen. + return nullptr; +} + +std::unique_ptr +TBRAnalyzer::VarsData::collectDataFromPredecessors( + TBRAnalyzer::VarsData* limit) { + auto result = std::unique_ptr(new VarsData(*this)); + if (this != limit) { + auto pred = this->prev; + while (pred != limit) { + for (auto pair : *pred) + if (result->find(pair.first) == result->end()) + result->emplace(pair); + pred = pred->prev; + } + } + + return result; +} + +void TBRAnalyzer::VarsData::merge(TBRAnalyzer::VarsData* mergeData) { + auto* LCA = this->findLowestCommonAncestor(mergeData); + auto collectedMergeData = + mergeData->collectDataFromPredecessors(/*limit=*/LCA); + + for (auto& pair : *collectedMergeData) { + VarData* found = nullptr; + auto elemSearch = this->find(pair.first); + if (elemSearch == this->end()) { + auto* branch = this->prev; + while (branch) { + auto it = branch->find(pair.first); + if (it != branch->end()) { + found = it->second->copy(); + this->emplace(pair.first, found); + break; + } + branch = branch->prev; + } + } else { + found = elemSearch->second; + } + + if (found) + found->merge(pair.second); + else + this->emplace(pair.first, pair.second->copy()); + } + + auto collectedThis = this->collectDataFromPredecessors(/*limit=*/LCA); + for (auto& pair : *collectedThis) { + auto elemSearch = mergeData->find(pair.first); + if (elemSearch == this->end()) { + auto* branch = LCA; + while (branch) { + auto it = branch->find(pair.first); + if (it != branch->end()) { + pair.second->merge(it->second); + break; + } + branch = branch->prev; + } + } + } } void TBRAnalyzer::VisitCompoundStmt(const CompoundStmt* CS) { @@ -475,9 +546,6 @@ void TBRAnalyzer::VisitCompoundStmt(const CompoundStmt* CS) { void TBRAnalyzer::VisitDeclRefExpr(const DeclRefExpr* DRE) { if (const auto* VD = dyn_cast(DRE->getDecl())) { auto& curBranch = getCurBranch(); - // FIXME: this is only necessary to ensure global variables are added. - // It doesn't make any sense to first add variables when visiting DeclStmt - // and then checking if they were added while visiting DeclRefExpr. if (curBranch.find(VD) == curBranch.end()) addVar(VD); } @@ -500,7 +568,6 @@ void TBRAnalyzer::VisitParenExpr(const clang::ParenExpr* PE) { void TBRAnalyzer::VisitReturnStmt(const clang::ReturnStmt* RS) { Visit(RS->getRetValue()); - deleteCurBranch = true; } void TBRAnalyzer::VisitExprWithCleanups(const clang::ExprWithCleanups* EWC) { @@ -540,12 +607,18 @@ void TBRAnalyzer::VisitConditionalOperator( Visit(CO->getCond()); resetMode(); - addLayer(); - addBranch(); + auto* elseBranch = blockData[curBlockID]; + auto* thenBranch = new VarsData(); + thenBranch->prev = elseBranch; + + blockData[curBlockID] = thenBranch; Visit(CO->getTrueExpr()); - addBranch(); + + blockData[curBlockID] = elseBranch; Visit(CO->getFalseExpr()); - mergeLayerOnTop(); + + elseBranch->merge(thenBranch); + delete thenBranch; } void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { @@ -593,15 +666,6 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { opCode == BO_SubAssign) { /// Since we only care about non-linear usages of variables, there is /// no difference between operators =, -=, += in terms of TBR analysis. - // llvm::errs() << "before assignment:\n"; - // for(auto& pair : getCurBranch()) { - // if (pair.second->type == VarData::ARR_TYPE) { - // for (auto& pair2 : *pair.second->val.arrData) { - // llvm::errs() << pair.first->getNameAsString() << "." << - // pair2.first << ": " << pair2.second->val.fundData << "\n"; - // } - // } - // } Visit(L); startMarkingMode(); @@ -635,15 +699,6 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { /// already not required to store). setIsRequired(innerExpr, /*isReq=*/false); } - // llvm::errs() << "after assignment:\n"; - // for(auto& pair : getCurBranch()) { - // if (pair.second->type == VarData::ARR_TYPE) { - // for (auto& pair2 : *pair.second->val.arrData) { - // llvm::errs() << pair.first->getNameAsString() << "." << - // pair2.first << ": " << pair2.second->val.fundData << "\n"; - // } - // } - // } } else if (opCode == BO_Comma) { setMode(0); Visit(L); @@ -678,347 +733,20 @@ void TBRAnalyzer::VisitUnaryOperator(const clang::UnaryOperator* UnOp) { } } -void TBRAnalyzer::VisitIfStmt(const clang::IfStmt* If) { - const auto* cond = If->getCond(); - const auto* condVarDecl = If->getConditionVariable(); - const auto* condInit = If->getInit(); - - /// We have to separated analyse then-block and else-block and then merge - /// them together. First, we make a copy of the current branch and analyse - /// then-block on it. Then swap last two branches and analyse the else-block - /// on the last branch. Finally, we merge them together. This diagram explains - /// the transformations performed on the reqStack: - /// ... - - /// ... - - - /// ... - - - /// ... - - - /// ... - - - /// ... - - - const auto* thenBranch = If->getThen(); - const auto* elseBranch = If->getElse(); - - // localVarsStack.emplace_back(); - addLayer(); - - if (thenBranch) { - addBranch(); - Visit(cond); - if (condVarDecl) - addVar(condVarDecl); - if (condInit) { - setMode(Mode::markingMode); - Visit(condInit); - resetMode(); - } - Visit(thenBranch); - if (deleteCurBranch) { - /// This section is performed if this branch had break/continue/return - /// and, therefore, shouldn't be merged. - deleteBranch(); - deleteCurBranch = false; - } - } - - if (elseBranch) { - addBranch(); - Visit(cond); - if (condVarDecl) - addVar(condVarDecl); - if (condInit) { - setMode(Mode::markingMode); - Visit(condInit); - resetMode(); - } - Visit(elseBranch); - if (deleteCurBranch) { - /// This section is performed if this branch had break/continue/return - /// and, therefore, shouldn't be merged. - deleteBranch(); - deleteCurBranch = false; - } - } - - if (elseBranch) - mergeLayerOnTop(); - else - mergeLayer(); - - // removeLocalVars(); - // localVarsStack.pop_back(); -} - -void TBRAnalyzer::VisitWhileStmt(const clang::WhileStmt* WS) { - const auto* body = WS->getBody(); - const auto* cond = WS->getCond(); - size_t backupILB = innermostLoopLayer; - bool backupFLP = firstLoopPass; - bool backupDCB = deleteCurBranch; - /// Let's assume we have a section of code structured like this - /// (A, B, C represent blocks): - /// ``` - /// A - /// while (cond) B - /// C - /// ``` - /// Depending on cond, this could give us 3 types of scenarios: 'AC', 'ABC', - /// 'AB...BC'. We must notice two things: 1) C comes either after B or A, - /// 2) B comes either after A or B itself. So first, we have to merge original - /// state with after-first-iteration state and analyse B a second time on top - /// to get the state that represents arbitrary non-zero number of iterations. - /// Finally, we have to merge it with the original state once again to account - /// for the fact that the loop block may not be executed at all. - /// This diagram explains the transformations performed on the reqStack: - /// ... - - /// ... - - - /// ... - - - - /// ... - - - - /// ... - - - /// ... - - - /// ... - - - Visit(cond); - addLayer(); - addBranch(); - addBranch(); - - addLayer(); - addBranch(); - addBranch(); - /// First pass - innermostLoopLayer = reqStack.size() - 1; - firstLoopPass = true; - if (body) - Visit(body); - if (deleteCurBranch) { - deleteBranch(); - deleteCurBranch = backupDCB; - } else { - Visit(cond); - } - --innermostLoopLayer; - mergeLayer(); - mergeBranchTo(innermostLoopLayer - 1, reqStack[innermostLoopLayer].back()); - - /// Second pass - firstLoopPass = false; - if (body) - Visit(body); - if (deleteCurBranch) - deleteBranch(); - else { - Visit(cond); - } - mergeLayer(); - - innermostLoopLayer = backupILB; - firstLoopPass = backupFLP; - deleteCurBranch = backupDCB; -} - -void TBRAnalyzer::VisitForStmt(const clang::ForStmt* FS) { - const auto* body = FS->getBody(); - const auto* condVar = FS->getConditionVariable(); - auto* init = FS->getInit(); - const auto* cond = FS->getCond(); - const auto* incr = FS->getInc(); - size_t backupILB = innermostLoopLayer; - bool backupFLP = firstLoopPass; - bool backupDCB = deleteCurBranch; - /// The logic here is virtually the same as with while-loop. Take a look at - /// TBRAnalyzer::VisitWhileStmt for more details. - if (init) { - setMode(Mode::markingMode); - Visit(init); - resetMode(); - } - if (cond) - Visit(cond); - addLayer(); - addBranch(); - addBranch(); - if (condVar) - addVar(condVar); - addLayer(); - addBranch(); - addBranch(); - /// First pass - innermostLoopLayer = reqStack.size() - 1; - firstLoopPass = true; - if (body) - Visit(body); - if (deleteCurBranch) { - deleteBranch(); - deleteCurBranch = backupDCB; - } else { - if (incr) - Visit(incr); - if (cond) - Visit(cond); - } - --innermostLoopLayer; - mergeLayer(); - mergeBranchTo(innermostLoopLayer - 1, reqStack[innermostLoopLayer].back()); - - /// Second pass - firstLoopPass = false; - if (body) - Visit(body); - if (incr) - Visit(incr); - if (deleteCurBranch) - deleteBranch(); - else { - if (cond) - Visit(cond); - } - mergeLayer(); - - innermostLoopLayer = backupILB; - firstLoopPass = backupFLP; - deleteCurBranch = backupDCB; -} - -void TBRAnalyzer::VisitDoStmt(const clang::DoStmt* DS) { - const auto* body = DS->getBody(); - const auto* cond = DS->getCond(); - size_t backupILB = innermostLoopLayer; - bool backupFLP = firstLoopPass; - bool backupDCB = deleteCurBranch; - - /// The logic used here is virtually the same as with while-loop. Take a look - /// at TBRAnalyzer::VisitWhileStmt for more details. - /// FIXME: do-while-block is performed at least once and so we don't have to - /// account for the possibility of it not being performed at all. However, - /// having two loop branches is necessary for handling continue statements - /// so we can't just remove one of them. - - addLayer(); - addBranch(); - addBranch(); - addLayer(); - addBranch(); - addBranch(); - /// First pass - innermostLoopLayer = reqStack.size() - 2; - firstLoopPass = true; - if (body) - Visit(body); - if (deleteCurBranch) { - reqStack.pop_back(); - deleteCurBranch = backupDCB; - } else { - Visit(cond); - mergeLayer(); - } - - /// Second pass - --innermostLoopLayer; - firstLoopPass = false; - if (body) - Visit(body); - Visit(cond); - if (deleteCurBranch) { - reqStack.pop_back(); - mergeLayer(); - } - - innermostLoopLayer = backupILB; - firstLoopPass = backupFLP; - deleteCurBranch = backupDCB; - - addLayer(); - addBranch(); - addBranch(); - - addLayer(); - addBranch(); - addBranch(); - /// First pass - innermostLoopLayer = reqStack.size() - 1; - firstLoopPass = true; - if (body) - Visit(body); - if (deleteCurBranch) { - deleteBranch(); - deleteCurBranch = backupDCB; - } else { - Visit(cond); - } - --innermostLoopLayer; - mergeLayer(); - mergeBranchTo(innermostLoopLayer - 1, reqStack[innermostLoopLayer].back()); - - /// Second pass - firstLoopPass = false; - if (body) - Visit(body); - if (deleteCurBranch) - deleteBranch(); - else - Visit(cond); - mergeLayer(); - - innermostLoopLayer = backupILB; - firstLoopPass = backupFLP; - deleteCurBranch = backupDCB; -} - -void TBRAnalyzer::VisitContinueStmt(const clang::ContinueStmt* CS) { - /// If this is the first loop pass, the reqStack will look like this: - /// ... - - - - /// And so continue might be the end of this loop as well as the the end of - /// the first iteration. So we have to merge the current branch into first - /// two branches on the diagram. - /// If this is the second loop pass, the reqStack will look like this: - /// ... - - - /// And so this continue could be the end of this loop. So we have to merge - /// the current branch into the first branch on the diagram. - /// FIXME: If this is the second pass, this continue statement could still be - /// followed by another iteration. We have to either add an additional branch - /// or find a better solution. (However, this bug will matter only in really - /// rare cases) - - auto& targetLayer1 = reqStack[innermostLoopLayer]; - auto& targetBranch1 = targetLayer1[targetLayer1.size() - 2]; - size_t sourceBranchNum = reqStack.size() - 1; - mergeBranchTo(sourceBranchNum, targetBranch1); - /// After the continue statement, this branch cannot be followed by any other - /// code so we can delete it. - if (firstLoopPass) { - auto& targetLayer2 = reqStack[innermostLoopLayer - 1]; - auto& targetBranch2 = targetLayer2[targetLayer2.size() - 2]; - mergeBranchTo(sourceBranchNum, targetBranch2); - } - deleteCurBranch = true; -} - -void TBRAnalyzer::VisitBreakStmt(const clang::BreakStmt* BS) { - /// If this is the second loop pass, the reqStack will look like this: - /// ... - - - /// And so this break could be the end of this loop. So we have to merge - /// the current branch into the first branch on the diagram. - if (!firstLoopPass) { - auto& targetLayer = reqStack[innermostLoopLayer]; - auto& targetBranch = targetLayer[targetLayer.size() - 2]; - mergeBranchTo(reqStack.size() - 1, targetBranch); - } - /// After the break statement, this branch cannot be followed by any other - /// code so we can delete it. - deleteCurBranch = true; -} - void TBRAnalyzer::VisitCallExpr(const clang::CallExpr* CE) { /// FIXME: Currently TBR analysis just stops here and assumes that all the /// variables passed by value/reference are used/used and changed. Analysis /// could proceed to the function to analyse data flow inside it. auto* FD = CE->getDirectCallee(); + bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); setMode(Mode::markingMode | Mode::nonLinearMode); for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { const clang::Expr* arg = CE->getArg(i); - bool passByRef = FD->getParamDecl(i)->getType()->isReferenceType(); + bool passByRef = false; + if (noHiddenParam) + passByRef = FD->getParamDecl(i)->getType()->isReferenceType(); + else if (i) + passByRef = FD->getParamDecl(i - 1)->getType()->isReferenceType(); setMode(Mode::markingMode | Mode::nonLinearMode); Visit(arg); resetMode(); @@ -1092,10 +820,4 @@ void TBRAnalyzer::VisitInitListExpr(const clang::InitListExpr* ILE) { resetMode(); } -void TBRAnalyzer::removeLocalVars() { - auto& curBranch = getCurBranch(); - for (const auto* VD : localVarsStack.back()) - curBranch.erase(VD); -} - } // end namespace clad