diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index e35be3325..648b8d7a4 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -104,9 +104,9 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// whole structures instead (just one VarData for the whole array). struct VarData; - using ObjMap = std::unordered_map; + using ObjMap = std::unordered_map; using ArrMap = - std::unordered_map; + std::unordered_map; struct VarData { enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE }; @@ -119,57 +119,57 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { Expr* refData; VarDataValue() : fundData(false) {} }; - VarDataType type; + VarDataType type = UNDEFINED; VarDataValue val; VarData() = default; - VarData(const VarData&) = delete; - VarData& operator=(const VarData&) = delete; - VarData(const VarData&&) = delete; - VarData& operator=(const VarData&&) = delete; /// Builds a VarData object (and its children) based on the provided type. VarData(const QualType QT); - ~VarData() { - if (type == OBJ_TYPE) + /// Erases all children VarData's of this VarData. + void erase() { + if (type == OBJ_TYPE) { for (auto& pair : *val.objData) - delete pair.second; - else if (type == ARR_TYPE) + pair.second.erase(); + delete val.objData; + } else if (type == ARR_TYPE) { for (auto& pair : *val.arrData) - delete pair.second; + pair.second.erase(); + delete val.arrData; + } } }; /// Recursively sets all the leaves' bools to isReq. - void setIsRequired(VarData* varData, bool isReq = true); + void setIsRequired(VarData& varData, bool isReq = true); /// Whenever an array element with a non-constant index is set to required /// this function is used to set to required all the array elements that /// could match that element (e.g. set 'a[1].y' and 'a[6].y' to required /// when 'a[k].y' is set to required). Takes unwrapped sequence of /// indices/members of the expression being overlaid and the index of of the /// current index/member. - void overlay(VarData* targetData, + void overlay(VarData& targetData, llvm::SmallVector& IdxAndMemberSequence, size_t i); /// Returns true if there is at least one required to store node among /// child nodes. - bool findReq(const VarData* varData); + bool findReq(const VarData& varData); /// Used to merge together VarData for one variable from two branches /// (e.g. after an if-else statements). Look at the Control Flow section for /// more information. - void merge(VarData* targetData, VarData* mergeData); + void merge(VarData& targetData, VarData& mergeData); /// Used to recursively copy VarData when separating into different branches /// (e.g. when entering an if-else statements). Look at the Control Flow /// section for more information. - VarData* copy(VarData* copyData); + VarData copy(VarData& copyData); clang::CFGBlock* getCFGBlockByID(unsigned ID); /// Given a MemberExpr*/ArraySubscriptExpr* return a pointer to its /// corresponding VarData. If the given element of an array does not have a - /// VarData* yet it will be added automatically. If addNonConstIdx==false this - /// will return the last VarData* before the non-constant index - /// (e.g. for 'x.arr[k+1].y' the return value will be the VarData* of x.arr). + /// VarData yet it will be added automatically. If addNonConstIdx==false this + /// will return the last VarData before the non-constant index + /// (e.g. for 'x.arr[k+1].y' the return value will be the VarData of x.arr). /// Otherwise, non-const indices will be represented as index -1. VarData* getMemberVarData(const clang::MemberExpr* ME, bool addNonConstIdx = false); @@ -197,23 +197,28 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// Note: 'this' pointer does not have a declaration so nullptr is used as /// its key instead. struct VarsData { - std::unordered_map data = - std::unordered_map(); + std::unordered_map data = + std::unordered_map(); VarsData* prev = nullptr; VarsData() {} VarsData(VarsData& other) : data(other.data), prev(other.prev) {} + ~VarsData() { + for (auto& pair : data) + pair.second.erase(); + } + using iterator = - std::unordered_map::iterator; + std::unordered_map::iterator; iterator begin() { return data.begin(); } iterator end() { return data.end(); } - VarData*& operator[](const clang::VarDecl* VD) { return data[VD]; } + VarData& operator[](const clang::VarDecl* VD) { return data[VD]; } iterator find(const clang::VarDecl* VD) { return data.find(VD); } - void emplace(const clang::VarDecl* VD, VarData* varsData) { + void emplace(const clang::VarDecl* VD, VarData varsData) { data.emplace(VD, varsData); } - void emplace(std::pair pair) { + void emplace(std::pair pair) { data.emplace(pair); } void clear() { @@ -223,12 +228,12 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// Collects the data from 'varsData' and its predecessors until - /// 'limit' into one VarsData ('limit' VarsData is not included). + /// 'limit' into one map ('limit' VarsData is not included). /// If 'limit' is 'nullptr', data is collected starting with /// the entry CFG block. /// Note: the returned VarsData contains original data from /// the predecessors (NOT copies). It should not be modified. - std::unique_ptr + std::unordered_map collectDataFromPredecessors(VarsData* varsData, VarsData* limit = nullptr); /// Finds the lowest common ancestor of two VarsData @@ -318,8 +323,6 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { ~TBRAnalyzer() { for (auto varsData : blockData) { if (varsData) { - for (auto pair : *varsData) - delete pair.second; delete varsData; } } diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 252b33350..c28c5dfa9 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -4,83 +4,88 @@ using namespace clang; namespace clad { -void TBRAnalyzer::setIsRequired(VarData* varData, bool isReq) { - if (varData->type == VarData::FUND_TYPE) - varData->val.fundData = isReq; - else if (varData->type == VarData::OBJ_TYPE) - for (auto& pair : *varData->val.objData) +void TBRAnalyzer::setIsRequired(VarData& varData, bool isReq) { + if (varData.type == VarData::FUND_TYPE) + varData.val.fundData = isReq; + else if (varData.type == VarData::OBJ_TYPE) + for (auto& pair : *varData.val.objData) setIsRequired(pair.second, isReq); - else if (varData->type == VarData::ARR_TYPE) - for (auto& pair : *varData->val.arrData) + else if (varData.type == VarData::ARR_TYPE) + for (auto& pair : *varData.val.arrData) setIsRequired(pair.second, isReq); - else if (varData->type == VarData::REF_TYPE && varData->val.refData) - setIsRequired(getExprVarData(varData->val.refData), isReq); -} - -void TBRAnalyzer::merge(VarData* targetData, VarData* mergeData) { - if (targetData->type == VarData::FUND_TYPE) { - targetData->val.fundData = - targetData->val.fundData || mergeData->val.fundData; - } else if (targetData->type == VarData::OBJ_TYPE) { - for (auto& pair : *targetData->val.objData) - merge(pair.second, (*mergeData->val.objData)[pair.first]); - } else if (targetData->type == VarData::ARR_TYPE) { + else if (varData.type == VarData::REF_TYPE && varData.val.refData) + if (auto data = getExprVarData(varData.val.refData)) { + setIsRequired(*data, isReq); + } +} + +void TBRAnalyzer::merge(VarData& targetData, VarData& mergeData) { + if (targetData.type == VarData::FUND_TYPE) { + targetData.val.fundData = + targetData.val.fundData || mergeData.val.fundData; + } else if (targetData.type == VarData::OBJ_TYPE) { + for (auto& pair : *targetData.val.objData) + merge(pair.second, (*mergeData.val.objData)[pair.first]); + } else if (targetData.type == VarData::ARR_TYPE) { /// FIXME: Currently non-constant indices are not supported in merging. - for (auto& pair : *targetData->val.arrData) { - auto it = mergeData->val.arrData->find(pair.first); - if (it != mergeData->val.arrData->end()) + for (auto& pair : *targetData.val.arrData) { + auto it = mergeData.val.arrData->find(pair.first); + if (it != mergeData.val.arrData->end()) merge(pair.second, it->second); } - for (auto& pair : *mergeData->val.arrData) { - auto it = targetData->val.arrData->find(pair.first); - if (it == mergeData->val.arrData->end()) - (*targetData->val.arrData)[pair.first] = copy(pair.second); + for (auto& pair : *mergeData.val.arrData) { + auto it = targetData.val.arrData->find(pair.first); + if (it == mergeData.val.arrData->end()) + (*targetData.val.arrData)[pair.first] = copy(pair.second); } } /// This might be useful in future if used to analyse pointers. However, for /// now it's only used for references for which merging doesn't make sense. - // else if (this->type == VarData::REF_TYPE) {} -} - -TBRAnalyzer::VarData* TBRAnalyzer::copy(VarData* copyData) { - auto* res = new VarData(); - res->type = copyData->type; - if (copyData->type == VarData::FUND_TYPE) { - res->val.fundData = copyData->val.fundData; - } else if (copyData->type == VarData::OBJ_TYPE) { - res->val.objData = new ObjMap(); - for (auto& pair : *copyData->val.objData) - (*res->val.objData)[pair.first] = copy(pair.second); - } else if (copyData->type == VarData::ARR_TYPE) { - res->val.arrData = new ArrMap(); - for (auto& pair : *copyData->val.arrData) - (*res->val.arrData)[pair.first] = copy(pair.second); - } else if (copyData->type == VarData::REF_TYPE && copyData->val.refData) { - res->val.refData = copyData->val.refData; + // else if (this.type == VarData::REF_TYPE) {} +} + +TBRAnalyzer::VarData TBRAnalyzer::copy(VarData& copyData) { + VarData res; + res.type = copyData.type; + if (copyData.type == VarData::FUND_TYPE) { + res.val.fundData = copyData.val.fundData; + } else if (copyData.type == VarData::OBJ_TYPE) { + res.val.objData = new ObjMap(); + for (auto& pair : *copyData.val.objData) + (*res.val.objData)[pair.first] = copy(pair.second); + } else if (copyData.type == VarData::ARR_TYPE) { + res.val.arrData = new ArrMap(); + for (auto& pair : *copyData.val.arrData) + (*res.val.arrData)[pair.first] = copy(pair.second); + } else if (copyData.type == VarData::REF_TYPE && copyData.val.refData) { + res.val.refData = copyData.val.refData; } return res; } -bool TBRAnalyzer::findReq(const VarData* varData) { - if (varData->type == VarData::FUND_TYPE) - return varData->val.fundData; - if (varData->type == VarData::OBJ_TYPE) { - for (auto& pair : *varData->val.objData) +bool TBRAnalyzer::findReq(const VarData& varData) { + if (varData.type == VarData::FUND_TYPE) + return varData.val.fundData; + if (varData.type == VarData::OBJ_TYPE) { + for (auto& pair : *varData.val.objData) if (findReq(pair.second)) return true; - } else if (varData->type == VarData::ARR_TYPE) { - for (auto& pair : *varData->val.arrData) + } else if (varData.type == VarData::ARR_TYPE) { + for (auto& pair : *varData.val.arrData) if (findReq(pair.second)) return true; - } else if (varData->type == VarData::REF_TYPE && varData->val.refData) { - if (findReq(getExprVarData(varData->val.refData))) - return true; + } else if (varData.type == VarData::REF_TYPE && varData.val.refData) { + if (auto* data = getExprVarData(varData.val.refData)) { + if (findReq(*data)) { + return true; + } + } } return false; } void TBRAnalyzer::overlay( - VarData* targetData, + VarData& targetData, llvm::SmallVector& IdxAndMemberSequence, size_t i) { if (i == 0) { setIsRequired(targetData); @@ -89,15 +94,15 @@ void TBRAnalyzer::overlay( --i; IdxOrMember& curIdxOrMember = IdxAndMemberSequence[i]; if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::FIELD) { - overlay((*targetData->val.objData)[curIdxOrMember.val.field], + overlay((*targetData.val.objData)[curIdxOrMember.val.field], IdxAndMemberSequence, i); } else if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::INDEX) { auto idx = curIdxOrMember.val.index; if (eqAPInt(idx, llvm::APInt(2, -1, true))) - for (auto& pair : *targetData->val.arrData) + for (auto& pair : *targetData.val.arrData) overlay(pair.second, IdxAndMemberSequence, i); else - overlay((*targetData->val.arrData)[idx], IdxAndMemberSequence, i); + overlay((*targetData.val.arrData)[idx], IdxAndMemberSequence, i); } } @@ -115,7 +120,7 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, if (nonConstIndexFound && !addNonConstIdx) return baseData; - return (*baseData->val.objData)[FD]; + return &(*baseData->val.objData)[FD]; } return nullptr; } @@ -149,15 +154,15 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, /// Add the current index if it was not added previously if (it == baseArrMap->end()) { - auto*& idxData = (*baseArrMap)[idx]; + auto& idxData = (*baseArrMap)[idx]; /// Since -1 represents non-const indices, whenever we add a new index we /// have to copy the VarData of -1's element (if an element with undefined /// index was used this might be our current element). idxData = copy((*baseArrMap)[llvm::APInt(2, -1, true)]); - return idxData; + return &idxData; } - return it->second; + return &it->second; } TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, @@ -165,7 +170,7 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, /// This line is necessary for pointer member expressions (in 'x->y' /// x would be implicitly casted with the * operator). E = E->IgnoreImpCasts(); - VarData* EData; + VarData* EData = nullptr; if (isa(E) || isa(E)) { const VarDecl* VD = nullptr; /// ``this`` does not have a declaration so it is represented with nullptr. @@ -175,7 +180,7 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, while (branch) { auto it = branch->find(VD); if (it != branch->end()) { - EData = it->second; + EData = &it->second; break; } branch = branch->prev; @@ -205,7 +210,7 @@ TBRAnalyzer::VarData::VarData(const QualType QT) { else elemType = QT->getArrayElementTypeNoTypeQual(); auto& idxData = (*val.arrData)[llvm::APInt(2, -1, true)]; - idxData = new VarData(QualType::getFromOpaquePtr(elemType)); + idxData = VarData (QualType::getFromOpaquePtr(elemType)); } else if (QT->isBuiltinType()) { type = VarData::FUND_TYPE; val.fundData = false; @@ -216,7 +221,7 @@ TBRAnalyzer::VarData::VarData(const QualType QT) { newObjMap = new ObjMap(); for (const auto* field : recordDecl->fields()) { const auto varType = field->getType(); - (*newObjMap)[field] = new VarData(varType); + (*newObjMap)[field] = VarData(varType); } } } @@ -271,7 +276,7 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { } QualType varType; - if (const ParmVarDecl* arrayParam = dyn_cast(VD)) + if (const auto* arrayParam = dyn_cast(VD)) varType = arrayParam->getOriginalType(); else varType = VD->getType(); @@ -280,22 +285,22 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { varType = VD->getInit()->getType(); /// FIXME: If the pointer points to an object we represent it with a - /// OBJ_TYPE VarData*. This is done for '_d_this' pointer to be processed + /// OBJ_TYPE VarData. This is done for '_d_this' pointer to be processed /// correctly in hessian mode. This should be removed once full support for /// pointers in analysis is introduced. if (const auto pointerType = dyn_cast(varType)) { const auto* elemType = pointerType->getPointeeType().getTypePtrOrNull(); if (elemType && elemType->isRecordType()) { - curBranch[VD] = new VarData(QualType::getFromOpaquePtr(elemType)); + curBranch[VD] = VarData(QualType::getFromOpaquePtr(elemType)); return; } } - curBranch[VD] = new VarData(varType); + curBranch[VD] = VarData(varType); } void TBRAnalyzer::markLocation(const clang::Expr* E) { VarData* data = getExprVarData(E); - if (!data || findReq(data)) { + if (!data || findReq(*data)) { /// FIXME: If any of the data's child nodes are required to store then data /// itself is stored. We might add an option to store separate fields. /// FIXME: Sometimes one location might correspond to multiple stores. @@ -311,8 +316,8 @@ void TBRAnalyzer::setIsRequired(const clang::Expr* E, bool isReq) { if (!isReq || (modeStack.back() == (Mode::markingMode | Mode::nonLinearMode))) { VarData* data = getExprVarData(E, /*addNonConstIdx=*/isReq); - if (isReq || !nonConstIndexFound) - setIsRequired(data, isReq); + if (data && (isReq || !nonConstIndexFound)) + setIsRequired(*data, isReq); /// If an array element with a non-const element is set to required /// all the elements of that array should be set to required. if (isReq && nonConstIndexFound) @@ -340,7 +345,7 @@ void TBRAnalyzer::Analyze(const FunctionDecl* FD) { const Type* recordType = dyn_cast(FD->getParent())->getTypeForDecl(); getCurBranch()[nullptr] = - new VarData(QualType::getFromOpaquePtr(recordType)); + VarData(QualType::getFromOpaquePtr(recordType)); } auto paramsRef = FD->parameters(); for (std::size_t i = 0; i < FD->getNumParams(); ++i) @@ -383,7 +388,7 @@ void TBRAnalyzer::VisitCFGBlock(CFGBlock* block) { } /// Traverse successor CFG blocks. - for (auto succ : block->succs()) { + for (const auto succ : block->succs()) { /// Sometimes clang CFG does not create blocks for parts of code that /// are never executed (e.g. 'if (0) {...'). Add this check for safety. if (!succ) @@ -481,19 +486,19 @@ TBRAnalyzer::findLowestCommonAncestor(VarsData* varsData1, return nullptr; } -std::unique_ptr +std::unordered_map TBRAnalyzer::collectDataFromPredecessors(VarsData* varsData, TBRAnalyzer::VarsData* limit) { - auto result = std::unique_ptr(new VarsData(*varsData)); - + std::unordered_map result; if (varsData != limit) { /// Copy data from every predecessor. for (auto pred = varsData->prev; pred != limit; pred = pred->prev) { /// If a variable from 'pred' is not present /// in 'result', place it in there. - for (auto pair : *pred) - if (result->find(pair.first) == result->end()) - result->emplace(pair); + for (auto& pair : *pred) + if (result.find(pair.first) == result.end()) { + result[pair.first] = &pair.second; + } } } @@ -509,7 +514,7 @@ void TBRAnalyzer::merge(VarsData* targetData, VarsData* mergeData) { /// For every variable in 'collectedMergeData', search it in targetData /// and all its predecessors (if found in a predecessor, make a copy to /// targetData). - for (auto& pair : *collectedMergeData) { + for (auto& pair : collectedMergeData) { VarData* found = nullptr; auto elemSearch = targetData->find(pair.first); if (elemSearch == targetData->end()) { @@ -517,22 +522,22 @@ void TBRAnalyzer::merge(VarsData* targetData, VarsData* mergeData) { while (branch) { auto it = branch->find(pair.first); if (it != branch->end()) { - found = copy(it->second); - targetData->emplace(pair.first, found); + (*targetData)[pair.first] = copy(it->second); + found = &(*targetData)[pair.first]; break; } branch = branch->prev; } } else { - found = elemSearch->second; + found = &elemSearch->second; } /// If the variable was found, perform a merge. /// Else, just copy it from collectedMergeData. - if (found) - merge(found, pair.second); - else - targetData->emplace(pair.first, copy(pair.second)); + if (found) { + merge(*found, *pair.second); + } else + (*targetData)[pair.first] = copy(*pair.second); } /// For every variable in collectedTargetData, search it inside @@ -540,14 +545,14 @@ void TBRAnalyzer::merge(VarsData* targetData, VarsData* mergeData) { /// was not used anywhere between LCA and mergeData. /// To correctly merge, we have to take it from LCA's /// predecessors and merge it to targetData. - for (auto& pair : *collectedTargetData) { - auto elemSearch = collectedMergeData->find(pair.first); - if (elemSearch == collectedMergeData->end()) { + for (auto& pair : collectedTargetData) { + auto elemSearch = collectedMergeData.find(pair.first); + if (elemSearch == collectedMergeData.end()) { auto* branch = LCA; while (branch) { auto it = branch->find(pair.first); if (it != branch->end()) { - merge(pair.second, it->second); + merge(*pair.second, it->second); break; } branch = branch->prev; @@ -605,11 +610,11 @@ void TBRAnalyzer::VisitDeclStmt(const DeclStmt* DS) { Visit(init); resetMode(); auto& VDExpr = getCurBranch()[VD]; - /// if the declared variable is ref type attach its VarData* to the - /// VarData* of the RHS variable. + /// if the declared variable is ref type attach its VarData to the + /// VarData of the RHS variable. auto returnExprs = utils::GetInnermostReturnExpr(init); - if (VDExpr->type == VarData::REF_TYPE && !returnExprs.empty()) - VDExpr->val.refData = returnExprs[0]; + if (VDExpr.type == VarData::REF_TYPE && !returnExprs.empty()) + VDExpr.val.refData = returnExprs[0]; } } }