diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index 481782856..73d4d659c 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -16,69 +16,27 @@ namespace clad { class TBRAnalyzer : public clang::ConstStmtVisitor { private: - /// Used to provide a hash function for an unordered_map with llvm::APInt - /// type keys. - struct APIntHash { - size_t operator()(const llvm::APInt& x) const { - return llvm::hash_value(x); - } - }; - static bool eqAPInt(const llvm::APInt& x, const llvm::APInt& y) { - if (x.getBitWidth() != y.getBitWidth()) - return false; - return x == y; + /// ProfileID is the key type for ArrMap used to represent array indices + /// and object fields. + using ProfileID = clad_compat::FoldingSetNodeID; + + ProfileID getProfileID(const Expr* E) const{ + ProfileID profID; + E->Profile(profID, m_Context, /* Canonical */ true); + return profID; } - struct APIntComp { - bool operator()(const llvm::APInt& x, const llvm::APInt& y) const { - return eqAPInt(x, y); - } - }; + static ProfileID getProfileID(const FieldDecl* FD) { + ProfileID profID; + profID.AddPointer(FD); + return profID; + } - /// Just a helper struct serving as a wrapper for IdxOrMemberValue union. - /// Used to unwrap expressions like a[6].x.t[3]. Only used in - /// TBRAnalyzer::overlay(). - struct IdxOrMember { - enum IdxOrMemberType { FIELD, INDEX }; - union IdxOrMemberValue { - const clang::FieldDecl* field; - llvm::APInt index; - IdxOrMemberValue() : field(nullptr) {} - ~IdxOrMemberValue() {} - IdxOrMemberValue(const IdxOrMemberValue&) = delete; - IdxOrMemberValue& operator=(const IdxOrMemberValue&) = delete; - IdxOrMemberValue(const IdxOrMemberValue&&) = delete; - IdxOrMemberValue& operator=(const IdxOrMemberValue&&) = delete; - }; - IdxOrMemberType type; - IdxOrMemberValue val; - IdxOrMember(const clang::FieldDecl* field) : type(IdxOrMemberType::FIELD) { - val.field = field; + struct ProfileIDHash { + size_t operator()(const ProfileID& x) const { + return x.ComputeHash(); } - IdxOrMember(llvm::APInt&& index) : type(IdxOrMemberType::INDEX) { - new (&val.index) llvm::APInt(index); - } - IdxOrMember(const IdxOrMember& other) { - new (&val.index) llvm::APInt(); - *this = other; - } - IdxOrMember(const IdxOrMember&& other) noexcept { - new (&val.index) llvm::APInt(); - *this = other; - } - IdxOrMember& operator=(const IdxOrMember& other) { - type = other.type; - if (type == IdxOrMemberType::FIELD) - val.field = other.val.field; - else - val.index = other.val.index; - return *this; - } - IdxOrMember& operator=(const IdxOrMember&& other) noexcept { - return *this = other; - } - ~IdxOrMember() = default; }; /// Stores all the necessary information about one variable. Fundamental type @@ -93,17 +51,16 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// 'double& x = f(b);' is not supported. struct VarData; - using ObjMap = std::unordered_map; using ArrMap = - std::unordered_map; + std::unordered_map; struct VarData { enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE }; union VarDataValue { bool m_FundData; - /// m_ObjData, m_ArrData are stored as pointers for VarDataValue to take + /// m_ArrData is stored as pointers for VarDataValue to take /// less space. - ObjMap* m_ObjData; + /// Both arrays and and objects are modelled using m_ArrData; ArrMap* m_ArrData; Expr* m_RefData; VarDataValue() : m_FundData(false) {} @@ -117,12 +74,8 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { VarData(const QualType QT); /// Erases all children VarData's of this VarData. - void erase() { - if (type == OBJ_TYPE) { - for (auto& pair : *val.m_ObjData) - pair.second.erase(); - delete val.m_ObjData; - } else if (type == ARR_TYPE) { + void erase() const { + if (type == OBJ_TYPE || type == ARR_TYPE) { for (auto& pair : *val.m_ArrData) pair.second.erase(); delete val.m_ArrData; @@ -138,7 +91,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// indices/members of the expression being overlaid and the index of of the /// current index/member. void overlay(VarData& targetData, - llvm::SmallVector& IdxAndMemberSequence, + llvm::SmallVector& IDSequence, size_t i); /// Returns true if there is at least one required to store node among /// child nodes. @@ -190,8 +143,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { std::unordered_map(); VarsData* prev = nullptr; - VarsData() {} - VarsData(VarsData& other) : data(other.data), prev(other.prev) {} + VarsData() = default; ~VarsData() { for (auto& pair : data) @@ -264,7 +216,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { std::vector blockPassCounter; /// ID of the CFG block being visited. - unsigned curBlockID; + unsigned curBlockID{}; /// The set of IDs of the CFG blocks that should be visited. std::set CFGQueue; @@ -290,7 +242,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { //// Modes Setters /// Sets the mode manually - void setMode(int mode) { modeStack.push_back(mode); } + void setMode(short mode) { modeStack.push_back(mode); } /// Sets nonLinearMode but leaves markingMode just as it was. void startNonLinearMode() { modeStack.push_back(modeStack.back() | Mode::nonLinearMode); @@ -310,7 +262,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// Destructor ~TBRAnalyzer() { - for (auto varsData : blockData) { + for (auto* varsData : blockData) { if (varsData) { delete varsData; } diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index a8837d819..817b8c1bb 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -7,10 +7,7 @@ namespace clad { void TBRAnalyzer::setIsRequired(VarData& varData, bool isReq) { if (varData.type == VarData::FUND_TYPE) varData.val.m_FundData = isReq; - else if (varData.type == VarData::OBJ_TYPE) - for (auto& pair : *varData.val.m_ObjData) - setIsRequired(pair.second, isReq); - else if (varData.type == VarData::ARR_TYPE) + else if (varData.type == VarData::OBJ_TYPE || varData.type == VarData::ARR_TYPE) for (auto& pair : *varData.val.m_ArrData) setIsRequired(pair.second, isReq); else if (varData.type == VarData::REF_TYPE && varData.val.m_RefData) @@ -24,8 +21,8 @@ void TBRAnalyzer::merge(VarData& targetData, VarData& mergeData) { targetData.val.m_FundData = targetData.val.m_FundData || mergeData.val.m_FundData; } else if (targetData.type == VarData::OBJ_TYPE) { - for (auto& pair : *targetData.val.m_ObjData) - merge(pair.second, (*mergeData.val.m_ObjData)[pair.first]); + for (auto& pair : *targetData.val.m_ArrData) + merge(pair.second, (*mergeData.val.m_ArrData)[pair.first]); } else if (targetData.type == VarData::ARR_TYPE) { /// FIXME: Currently non-constant indices are not supported in merging. for (auto& pair : *targetData.val.m_ArrData) { @@ -49,11 +46,7 @@ TBRAnalyzer::VarData TBRAnalyzer::copy(VarData& copyData) { res.type = copyData.type; if (copyData.type == VarData::FUND_TYPE) { res.val.m_FundData = copyData.val.m_FundData; - } else if (copyData.type == VarData::OBJ_TYPE) { - res.val.m_ObjData = new ObjMap(); - for (auto& pair : *copyData.val.m_ObjData) - (*res.val.m_ObjData)[pair.first] = copy(pair.second); - } else if (copyData.type == VarData::ARR_TYPE) { + } else if (copyData.type == VarData::OBJ_TYPE || copyData.type == VarData::ARR_TYPE) { res.val.m_ArrData = new ArrMap(); for (auto& pair : *copyData.val.m_ArrData) (*res.val.m_ArrData)[pair.first] = copy(pair.second); @@ -66,11 +59,7 @@ TBRAnalyzer::VarData TBRAnalyzer::copy(VarData& copyData) { bool TBRAnalyzer::findReq(const VarData& varData) { if (varData.type == VarData::FUND_TYPE) return varData.val.m_FundData; - if (varData.type == VarData::OBJ_TYPE) { - for (auto& pair : *varData.val.m_ObjData) - if (findReq(pair.second)) - return true; - } else if (varData.type == VarData::ARR_TYPE) { + if (varData.type == VarData::OBJ_TYPE || varData.type == VarData::ARR_TYPE) { for (auto& pair : *varData.val.m_ArrData) if (findReq(pair.second)) return true; @@ -86,23 +75,20 @@ bool TBRAnalyzer::findReq(const VarData& varData) { void TBRAnalyzer::overlay( VarData& targetData, - llvm::SmallVector& IdxAndMemberSequence, size_t i) { + llvm::SmallVector& IDSequence, size_t i) { if (i == 0) { setIsRequired(targetData); return; } --i; - IdxOrMember& curIdxOrMember = IdxAndMemberSequence[i]; - if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::FIELD) { - overlay((*targetData.val.m_ObjData)[curIdxOrMember.val.field], - IdxAndMemberSequence, i); - } else if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::INDEX) { - auto idx = curIdxOrMember.val.index; - if (eqAPInt(idx, llvm::APInt(2, -1, true))) - for (auto& pair : *targetData.val.m_ArrData) - overlay(pair.second, IdxAndMemberSequence, i); - else - overlay((*targetData.val.m_ArrData)[idx], IdxAndMemberSequence, i); + ProfileID& curID = IDSequence[i]; + // non-constant indices are represented with default ID. + ProfileID nonConstIdxID; + if (curID == nonConstIdxID) { + for (auto& pair : *targetData.val.m_ArrData) + overlay(pair.second, IDSequence, i); + } else { + overlay((*targetData.val.m_ArrData)[curID], IDSequence, i); } } @@ -120,7 +106,7 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, if (nonConstIndexFound && !addNonConstIdx) return baseData; - return &(*baseData->val.m_ObjData)[FD]; + return &(*baseData->val.m_ArrData)[getProfileID(FD)]; } return nullptr; } @@ -129,13 +115,12 @@ TBRAnalyzer::VarData* TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, bool addNonConstIdx) { const auto* idxExpr = ASE->getIdx(); - llvm::APInt idx; + ProfileID idxID; if (const auto* IL = dyn_cast(idxExpr)) { - idx = IL->getValue(); + idxID = getProfileID(IL); } else { nonConstIndexFound = true; - /// Non-const indices are represented with -1. - idx = llvm::APInt(2, -1, true); + /// Non-const indices are represented with default FoldingSetNodeID. } const auto* base = ASE->getBase()->IgnoreImpCasts(); @@ -150,15 +135,16 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, return baseData; auto* baseArrMap = baseData->val.m_ArrData; - auto it = baseArrMap->find(idx); + auto it = baseArrMap->find(idxID); /// Add the current index if it was not added previously if (it == baseArrMap->end()) { - auto& idxData = (*baseArrMap)[idx]; - /// Since -1 represents non-const indices, whenever we add a new index we - /// have to copy the VarData of -1's element (if an element with undefined - /// index was used this might be our current element). - idxData = copy((*baseArrMap)[llvm::APInt(2, -1, true)]); + auto& idxData = (*baseArrMap)[idxID]; + /// Since default ID represents non-const indices, whenever we add a new + /// index we have to copy the VarData of default ID's element (if an element + /// with undefined index was used this might be our current element). + ProfileID nonConstIdxID; + idxData = copy((*baseArrMap)[nonConstIdxID]); return &idxData; } @@ -209,7 +195,8 @@ TBRAnalyzer::VarData::VarData(const QualType QT) { elemType = pointerType->getPointeeType().getTypePtrOrNull(); else elemType = QT->getArrayElementTypeNoTypeQual(); - auto& idxData = (*val.m_ArrData)[llvm::APInt(2, -1, true)]; + ProfileID nonConstIdxID; + auto& idxData = (*val.m_ArrData)[nonConstIdxID]; idxData = VarData (QualType::getFromOpaquePtr(elemType)); } else if (QT->isBuiltinType()) { type = VarData::FUND_TYPE; @@ -217,18 +204,18 @@ TBRAnalyzer::VarData::VarData(const QualType QT) { } else if (QT->isRecordType()) { type = VarData::OBJ_TYPE; const auto* recordDecl = QT->getAs()->getDecl(); - auto& newObjMap = val.m_ObjData; - newObjMap = new ObjMap(); + auto& newArrMap = val.m_ArrData; + newArrMap = new ArrMap(); for (const auto* field : recordDecl->fields()) { const auto varType = field->getType(); - (*newObjMap)[field] = VarData(varType); + (*newArrMap)[getProfileID(field)] = VarData(varType); } } } void TBRAnalyzer::overlay(const clang::Expr* E) { nonConstIndexFound = false; - llvm::SmallVector IdxAndMemberSequence; + llvm::SmallVector IDSequence; const clang::DeclRefExpr* innermostDRE; bool cond = true; /// Unwrap the given expression to a vector of indices and fields. @@ -236,13 +223,13 @@ void TBRAnalyzer::overlay(const clang::Expr* E) { E = E->IgnoreImplicit(); if (const auto* ASE = dyn_cast(E)) { if (const auto* IL = dyn_cast(ASE->getIdx())) - IdxAndMemberSequence.push_back(IdxOrMember(IL->getValue())); + IDSequence.push_back(getProfileID(IL)); else - IdxAndMemberSequence.push_back(IdxOrMember(llvm::APInt(2, -1, true))); + IDSequence.push_back(ProfileID()); E = ASE->getBase(); } else if (const auto* ME = dyn_cast(E)) { if (const auto* FD = dyn_cast(ME->getMemberDecl())) - IdxAndMemberSequence.push_back(IdxOrMember(FD)); + IDSequence.push_back(getProfileID(FD)); E = ME->getBase(); } else if (isa(E)) { innermostDRE = dyn_cast(E); @@ -253,8 +240,7 @@ void TBRAnalyzer::overlay(const clang::Expr* E) { /// Overlay on all the VarData's recursively. if (const auto* VD = dyn_cast(innermostDRE->getDecl())) { - overlay(getCurBranch()[VD], IdxAndMemberSequence, - IdxAndMemberSequence.size()); + overlay(getCurBranch()[VD], IDSequence, IDSequence.size()); } }