diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 81b9d6d3d..623fc7e29 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -164,6 +164,9 @@ namespace clad { /// Returns true if `QT` is Array or Pointer Type, otherwise returns false. bool isArrayOrPointerType(const clang::QualType QT); + /// Returns true if `T` is auto or auto* type, otherwise returns false. + bool IsAutoOrAutoPtrType(const clang::Type* T); + clang::DeclarationNameInfo BuildDeclarationNameInfo(clang::Sema& S, llvm::StringRef name); diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index 586041c66..213cf1d78 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -17,8 +17,20 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { /// Used to provide a hash function for an unordered_map with llvm::APInt /// type keys. struct APIntHash { - size_t operator()(const llvm::APInt& apint) const { - return llvm::hash_value(apint); + size_t operator()(const llvm::APInt& x) const { + return llvm::hash_value(x); + } + }; + + static bool eqAPInt(const llvm::APInt& x, const llvm::APInt& y) { + if (x.getBitWidth() != y.getBitWidth()) + return false; + return x == y; + } + + struct APIntComp { + bool operator()(const llvm::APInt& x, const llvm::APInt& y) const { + return eqAPInt(x, y); } }; @@ -91,7 +103,8 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { struct VarData; using ObjMap = std::unordered_map; - using ArrMap = std::unordered_map; + using ArrMap = + std::unordered_map; struct VarData { enum VarDataType { UNDEFINED, FUND_TYPE, OBJ_TYPE, ARR_TYPE, REF_TYPE }; @@ -114,15 +127,16 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { VarData(const VarData&&) = delete; VarData& operator=(const VarData&&) = delete; + /// Builds a VarData object (and its children) based on the provided type. + VarData(const QualType QT); + ~VarData() { if (type == OBJ_TYPE) { - for (auto& pair : *val.objData) { + for (auto& pair : *val.objData) delete pair.second; - } } else if (type == ARR_TYPE) { - for (auto& pair : *val.arrData) { + for (auto& pair : *val.arrData) delete pair.second; - } } } /// Recursively sets all the leaves' bools to isReq. @@ -165,9 +179,6 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { bool addNonConstIdx = false); /// Given an Expr* returns its corresponding VarData. VarData* getExprVarData(const clang::Expr* E, bool addNonConstIdx = false); - /// Adds the field FD to objData. - void addField(std::unordered_map* objData, - const FieldDecl* FD); /// Whenever an array element with a non-constant index is set to required /// this function is used to set to required all the array elements that /// could match that element (e.g. set 'a[1].y' and 'a[6].y' to required when diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 4c0bfff44..432357631 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -590,5 +590,17 @@ namespace clad { Finder finder; return finder.Find(E); } + + bool IsAutoOrAutoPtrType(const clang::Type* T) { + if (isa(T)) + return true; + + if (const auto pointerType = dyn_cast(T)) { + return IsAutoOrAutoPtrType( + pointerType->getPointeeType().getTypePtrOrNull()); + } + + return false; + } } // namespace utils } // namespace clad diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 89868c74c..1bdbf1c3b 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -8,13 +8,11 @@ void TBRAnalyzer::VarData::setIsRequired(bool isReq) { if (type == FUND_TYPE) { val.fundData = isReq; } else if (type == OBJ_TYPE) { - for (auto& pair : *val.objData) { + for (auto& pair : *val.objData) pair.second->setIsRequired(isReq); - } } else if (type == ARR_TYPE) { - for (auto& pair : *val.arrData) { + for (auto& pair : *val.arrData) pair.second->setIsRequired(isReq); - } } else if (type == REF_TYPE && val.refData) { val.refData->setIsRequired(isReq); } @@ -24,16 +22,14 @@ void TBRAnalyzer::VarData::merge(VarData* mergeData) { if (this->type == FUND_TYPE) { this->val.fundData = this->val.fundData || mergeData->val.fundData; } else if (this->type == OBJ_TYPE) { - for (auto& pair : *this->val.objData) { + for (auto& pair : *this->val.objData) pair.second->merge((*mergeData->val.objData)[pair.first]); - } } else if (this->type == ARR_TYPE) { /// FIXME: Currently non-constant indices are not supported in merging. for (auto& pair : *this->val.arrData) { auto it = mergeData->val.arrData->find(pair.first); - if (it != mergeData->val.arrData->end()) { + if (it != mergeData->val.arrData->end()) pair.second->merge(it->second); - } } for (auto& pair : *mergeData->val.arrData) { auto it = this->val.arrData->find(pair.first); @@ -59,9 +55,8 @@ TBRAnalyzer::VarData::copy(std::unordered_map& refVars) { auto* res = new VarData(); /// The child node of a reference node should be copied only once. Hence, /// we use refVars to match original referenced nodes to corresponding copies. - if (isReferenced) { + if (isReferenced) refVars[this] = res; - } res->type = this->type; if (this->type == FUND_TYPE) { res->val.fundData = this->val.fundData; @@ -71,9 +66,8 @@ TBRAnalyzer::VarData::copy(std::unordered_map& refVars) { (*res->val.objData)[pair.first] = pair.second->copy(refVars); } else if (this->type == ARR_TYPE) { res->val.arrData = new ArrMap(); - for (auto& pair : *this->val.arrData) { + for (auto& pair : *this->val.arrData) (*res->val.arrData)[pair.first] = pair.second->copy(refVars); - } } else if (this->type == REF_TYPE && this->val.refData) { res->val.refData = this->val.refData; } @@ -82,38 +76,30 @@ TBRAnalyzer::VarData::copy(std::unordered_map& refVars) { void TBRAnalyzer::VarData::restoreRefs( std::unordered_map& refVars) { - if (this->type == OBJ_TYPE) { + if (this->type == OBJ_TYPE) for (auto& pair : *val.objData) pair.second->restoreRefs(refVars); - } else if (this->type == ARR_TYPE) { - for (auto& pair : *this->val.arrData) { + else if (this->type == ARR_TYPE) + for (auto& pair : *this->val.arrData) pair.second->restoreRefs(refVars); - } - } else if (this->type == REF_TYPE && this->val.refData) { + else if (this->type == REF_TYPE && this->val.refData) this->val.refData = refVars[this->val.refData]; - } } bool TBRAnalyzer::VarData::findReq() const { - if (type == FUND_TYPE) { + if (type == FUND_TYPE) return val.fundData; - } if (type == OBJ_TYPE) { - for (auto& pair : *val.objData) { - if (pair.second->findReq()) { + for (auto& pair : *val.objData) + if (pair.second->findReq()) return true; - } - } } else if (type == ARR_TYPE) { - for (auto& pair : *val.arrData) { - if (pair.second->findReq()) { + for (auto& pair : *val.arrData) + if (pair.second->findReq()) return true; - } - } } else if (type == REF_TYPE && val.refData) { - if (val.refData->findReq()) { + if (val.refData->findReq()) return true; - } } return false; } @@ -130,13 +116,11 @@ void TBRAnalyzer::VarData::overlay( (*val.objData)[curIdxOrMember.val.field]->overlay(IdxAndMemberSequence, i); } else if (curIdxOrMember.type == IdxOrMember::IdxOrMemberType::INDEX) { auto idx = curIdxOrMember.val.index; - if (idx == llvm::APInt(2, -1, true)) { - for (auto& pair : *val.arrData) { + if (eqAPInt(idx, llvm::APInt(2, -1, true))) + for (auto& pair : *val.arrData) pair.second->overlay(IdxAndMemberSequence, i); - } - } else { + else (*val.arrData)[idx]->overlay(IdxAndMemberSequence, i); - } } } @@ -152,26 +136,13 @@ TBRAnalyzer::VarData* TBRAnalyzer::getMemberVarData(const clang::MemberExpr* ME, if (!baseData) return nullptr; - /// FUND_TYPE might be set by default earlier. - if (baseData->type == VarData::VarDataType::FUND_TYPE) { - baseData->type = VarData::VarDataType::OBJ_TYPE; - baseData->val.objData = new ObjMap(); - } + /// if non-const index was found and it is not supposed to be added just /// return the current VarData*. if (nonConstIndexFound && !addNonConstIdx) return baseData; - auto& baseObjData = baseData->val.objData; - /// Add the current field if it was not added previously - if (baseObjData->find(FD) == baseObjData->end()) { - (*baseObjData)[FD] = new VarData(); - auto* FDData = (*baseObjData)[FD]; - FDData->type = VarData::VarDataType::UNDEFINED; - return FDData; - } - - return (*baseObjData)[FD]; + return (*baseData->val.objData)[FD]; } return nullptr; } @@ -198,38 +169,27 @@ TBRAnalyzer::getArrSubVarData(const clang::ArraySubscriptExpr* ASE, if (!baseData) return nullptr; - /// FUND_TYPE might be set by default earlier. - if (baseData->type == VarData::VarDataType::FUND_TYPE) { - baseData->type = VarData::VarDataType::ARR_TYPE; - baseData->val.arrData = new ArrMap(); - } /// if non-const index was found and it is not supposed to be added just /// return the current VarData*. if (nonConstIndexFound && !addNonConstIdx) return baseData; - auto& baseArrData = baseData->val.arrData; - auto itEnd = baseArrData->end(); + auto* baseArrMap = baseData->val.arrData; + auto it = baseArrMap->find(idx); /// Add the current index if it was not added previously - if (baseArrData->find(idx) == itEnd) { - (*baseArrData)[idx] = new VarData(); - auto* idxData = (*baseArrData)[idx]; + if (it == baseArrMap->end()) { + auto*& idxData = (*baseArrMap)[idx]; /// Since -1 represents non-const indices, whenever we add a new index we /// have to copy the VarData of -1's element (if an element with undefined /// index was used this might be our current element). - auto it = baseArrData->find(llvm::APInt(2, -1, true)); - if (it != itEnd) { - std::unordered_map dummy; - idxData = it->second->copy(dummy); - } else { - idxData->type = VarData::VarDataType::UNDEFINED; - } + std::unordered_map dummy; + idxData = (*baseArrMap)[llvm::APInt(2, -1, true)]->copy(dummy); return idxData; } - return (*baseArrData)[idx]; + return it->second; } TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, @@ -254,44 +214,39 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, } } } - if (const auto* ME = dyn_cast(E)) { + if (const auto* ME = dyn_cast(E)) EData = getMemberVarData(ME, addNonConstIdx); - } - if (const auto* ASE = dyn_cast(E)) { + if (const auto* ASE = dyn_cast(E)) EData = getArrSubVarData(ASE, addNonConstIdx); - } - /// If the type of this VarData was not defined previously set it to - /// FUND_TYPE. - /// FIXME: this assumes that we only assign fundamental values and not - /// objects or pointers. - if (EData && EData->type == VarData::VarDataType::UNDEFINED) { - EData->type = VarData::VarDataType::FUND_TYPE; - EData->val.fundData = false; - } return EData; } -void TBRAnalyzer::addField(ObjMap* objData, const FieldDecl* FD) { - const auto varType = FD->getType(); - (*objData)[FD] = new VarData(); - VarData* data = (*objData)[FD]; - - if (varType->isReferenceType()) { - data->type = VarData::VarDataType::REF_TYPE; - data->val.refData = nullptr; - } else if (utils::isArrayOrPointerType(varType)) { - data->type = VarData::VarDataType::ARR_TYPE; - data->val.arrData = new ArrMap(); - } else if (varType->isBuiltinType()) { - data->type = VarData::VarDataType::FUND_TYPE; - data->val.fundData = false; - } else if (varType->isRecordType()) { - data->type = VarData::VarDataType::OBJ_TYPE; - const auto* recordDecl = varType->getAs()->getDecl(); - auto& newObjData = data->val.objData; +TBRAnalyzer::VarData::VarData(const QualType QT) { + if (QT->isReferenceType()) { + type = VarData::VarDataType::REF_TYPE; + val.refData = nullptr; + } else if (utils::isArrayOrPointerType(QT)) { + type = VarData::VarDataType::ARR_TYPE; + val.arrData = new ArrMap(); + const Type* elemType; + if (const auto pointerType = llvm::dyn_cast(QT)) + elemType = pointerType->getPointeeType().getTypePtrOrNull(); + else + elemType = QT->getArrayElementTypeNoTypeQual(); + auto& idxData = (*val.arrData)[llvm::APInt(2, -1, true)]; + idxData = new VarData(QualType::getFromOpaquePtr(elemType)); + } else if (QT->isBuiltinType()) { + type = VarData::VarDataType::FUND_TYPE; + val.fundData = false; + } else if (QT->isRecordType()) { + type = VarData::VarDataType::OBJ_TYPE; + const auto* recordDecl = QT->getAs()->getDecl(); + auto& newObjMap = val.objData; + newObjMap = new ObjMap(); for (const auto* field : recordDecl->fields()) { - addField(newObjData, field); + const auto varType = field->getType(); + (*newObjMap)[field] = new VarData(varType); } } } @@ -305,18 +260,17 @@ void TBRAnalyzer::overlay(const clang::Expr* E) { while (cond) { E = E->IgnoreImplicit(); if (const auto* ASE = dyn_cast(E)) { - if (const auto* IL = dyn_cast(ASE->getIdx())) { + if (const auto* IL = dyn_cast(ASE->getIdx())) IdxAndMemberSequence.push_back(IdxOrMember(IL->getValue())); - } else { + else IdxAndMemberSequence.push_back(IdxOrMember(llvm::APInt(2, -1, true))); - } E = ASE->getBase(); } else if (const auto* ME = dyn_cast(E)) { - if (const auto* FD = dyn_cast(ME->getMemberDecl())) { + if (const auto* FD = dyn_cast(ME->getMemberDecl())) IdxAndMemberSequence.push_back(IdxOrMember(FD)); - } E = ME->getBase(); - } else if ((innermostDRE = dyn_cast(E))) { + } else if (isa(E)) { + innermostDRE = dyn_cast(E); cond = false; } else return; @@ -352,57 +306,41 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { break; } } - if (!localVarsStack.empty()) { localVarsStack.back().push_back(VD); } - const auto varType = VD->getType(); - curBranch[VD] = new VarData(); - VarData* data = curBranch[VD]; - - if (varType->isReferenceType()) { - data->type = VarData::VarDataType::REF_TYPE; - data->val.refData = nullptr; - } else if (utils::isArrayOrPointerType(varType)) { - if (const auto pointerType = llvm::dyn_cast(varType)) { - /// FIXME: If the pointer points to an object we represent it with a - /// OBJ_TYPE VarData*. - const auto pointeeType = pointerType->getPointeeType().getTypePtrOrNull(); - if (pointeeType && pointeeType->isRecordType()) { - data->type = VarData::VarDataType::OBJ_TYPE; - const auto* recordDecl = pointeeType->getAs()->getDecl(); - auto& objData = data->val.objData; - objData = new ObjMap(); - for (const auto* field : recordDecl->fields()) { - addField(objData, field); - } - return; - } - } - data->type = VarData::VarDataType::ARR_TYPE; - data->val.arrData = new ArrMap(); - } else if (varType->isBuiltinType()) { - data->type = VarData::VarDataType::FUND_TYPE; - data->val.fundData = false; - } else if (varType->isRecordType()) { - data->type = VarData::VarDataType::OBJ_TYPE; - const auto* recordDecl = varType->getAs()->getDecl(); - auto& objData = data->val.objData; - objData = new ObjMap(); - for (const auto* field : recordDecl->fields()) { - addField(objData, field); + QualType varType; + if (const ParmVarDecl* arrayParam = dyn_cast(VD)) + varType = arrayParam->getOriginalType(); + else + varType = VD->getType(); + /// If varType represents auto or auto*, get the type of init. + if (utils::IsAutoOrAutoPtrType(varType.getTypePtr())) + varType = VD->getInit()->getType(); + + /// FIXME: If the pointer points to an object we represent it with a + /// OBJ_TYPE VarData*. This is done for '_d_this' pointer to be processed + /// correctly in hessian mode. This should be removed once full support for + /// pointers in analysis is introduced. + if (const auto pointerType = dyn_cast(varType)) { + const auto* elemType = pointerType->getPointeeType().getTypePtrOrNull(); + if (elemType && elemType->isRecordType()) { + curBranch[VD] = new VarData(QualType::getFromOpaquePtr(elemType)); + return; } } + + curBranch[VD] = new VarData(varType); } void TBRAnalyzer::markLocation(const clang::Expr* E) { VarData* data = getExprVarData(E); if (data) { /// FIXME: If any of the data's child nodes are required to store then data - /// itselt is stored. We might add an option to store separate fields. + /// itself is stored. We might add an option to store separate fields. bool& ToBeRec = TBRLocs[E->getBeginLoc()]; - /// FIXME: Sometimes one location might correspong to multiple stores. + /// FIXME: Sometimes one location might correspond to multiple stores. /// For example, in ``(x*=y)=u`` x's location will first be marked as /// required to be stored (when passing *= operator) but then marked as not /// required to be stored (when passing = operator). Current method of @@ -422,9 +360,8 @@ void TBRAnalyzer::mergeLayer() { for (auto& removedBranch : removedLayer) { for (auto& pair : curBranch) { auto it = removedBranch.find(pair.first); - if (it != removedBranch.end()) { + if (it != removedBranch.end()) pair.second->merge(it->second); - } } for (auto& pair : removedBranch) { auto it = curBranch.find(pair.first); @@ -457,9 +394,8 @@ void TBRAnalyzer::mergeLayerOnTop() { while ((++branchIter) != branchIterEnd) { for (auto& pair : firstBranch) { auto elemIter = branchIter->find(pair.first); - if (elemIter != branchIter->end()) { + if (elemIter != branchIter->end()) pair.second->merge(elemIter->second); - } } for (auto& pair : *branchIter) { auto elemIter = firstBranch.find(pair.first); @@ -519,15 +455,10 @@ void TBRAnalyzer::Analyze(const FunctionDecl* FD) { /// represented with nullptr). if (isa(FD)) { - auto& thisData = getCurBranch()[nullptr]; - thisData = new VarData(); - thisData->type = VarData::VarDataType::OBJ_TYPE; - auto recordDecl = dyn_cast(FD->getParent()); - auto& objData = thisData->val.objData; - objData = new ObjMap(); - for (const auto* field : recordDecl->fields()) { - addField(objData, field); - } + const Type* recordType = + dyn_cast(FD->getParent())->getTypeForDecl(); + getCurBranch()[nullptr] = + new VarData(QualType::getFromOpaquePtr(recordType)); } auto paramsRef = FD->parameters(); @@ -551,9 +482,8 @@ void TBRAnalyzer::VisitDeclRefExpr(const DeclRefExpr* DRE) { addVar(VD); } - if (const auto* E = dyn_cast(DRE)) { + if (const auto* E = dyn_cast(DRE)) setIsRequired(E); - } } void TBRAnalyzer::VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE) { @@ -592,9 +522,10 @@ void TBRAnalyzer::VisitDeclStmt(const DeclStmt* DS) { auto& VDExpr = getCurBranch()[VD]; /// if the declared variable is ref type attach its VarData* to the /// VarData* of the RHS variable. - if (VDExpr->type == VarData::VarDataType::REF_TYPE) { - auto* RHSExpr = - getExprVarData(utils::GetInnermostReturnExpr(init)[0]); + auto returnExprs = utils::GetInnermostReturnExpr(init); + if (VDExpr->type == VarData::VarDataType::REF_TYPE && + !returnExprs.empty()) { + auto* RHSExpr = getExprVarData(returnExprs[0]); VDExpr->val.refData = RHSExpr; RHSExpr->isReferenced = true; } @@ -658,11 +589,19 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { if (nonLinear) resetMode(); } else if (BinOp->isAssignmentOp()) { - if (opCode == BO_Assign || opCode == BO_AddAssign || opCode == BO_SubAssign) { /// Since we only care about non-linear usages of variables, there is /// no difference between operators =, -=, += in terms of TBR analysis. + // llvm::errs() << "before assignment:\n"; + // for(auto& pair : getCurBranch()) { + // if (pair.second->type == VarData::ARR_TYPE) { + // for (auto& pair2 : *pair.second->val.arrData) { + // llvm::errs() << pair.first->getNameAsString() << "." << + // pair2.first << ": " << pair2.second->val.fundData << "\n"; + // } + // } + // } Visit(L); startMarkingMode(); @@ -686,8 +625,8 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { Visit(R); resetMode(); } - const auto return_exprs = utils::GetInnermostReturnExpr(L); - for (const auto* innerExpr : return_exprs) { + const auto returnExprs = utils::GetInnermostReturnExpr(L); + for (const auto* innerExpr : returnExprs) { /// Mark corresponding SourceLocation as required/not required to be /// stored for all expressions that could be used changed. markLocation(innerExpr); @@ -696,7 +635,15 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { /// already not required to store). setIsRequired(innerExpr, /*isReq=*/false); } - + // llvm::errs() << "after assignment:\n"; + // for(auto& pair : getCurBranch()) { + // if (pair.second->type == VarData::ARR_TYPE) { + // for (auto& pair2 : *pair.second->val.arrData) { + // llvm::errs() << pair.first->getNameAsString() << "." << + // pair2.first << ": " << pair2.second->val.fundData << "\n"; + // } + // } + // } } else if (opCode == BO_Comma) { setMode(0); Visit(L); @@ -751,7 +698,7 @@ void TBRAnalyzer::VisitIfStmt(const clang::IfStmt* If) { const auto* thenBranch = If->getThen(); const auto* elseBranch = If->getElse(); - localVarsStack.emplace_back(); + // localVarsStack.emplace_back(); addLayer(); if (thenBranch) { @@ -792,9 +739,13 @@ void TBRAnalyzer::VisitIfStmt(const clang::IfStmt* If) { } } - mergeLayerOnTop(); - removeLocalVars(); - localVarsStack.pop_back(); + if (elseBranch) + mergeLayerOnTop(); + else + mergeLayer(); + + // removeLocalVars(); + // localVarsStack.pop_back(); } void TBRAnalyzer::VisitWhileStmt(const clang::WhileStmt* WS) { @@ -1006,9 +957,8 @@ void TBRAnalyzer::VisitDoStmt(const clang::DoStmt* DS) { Visit(body); if (deleteCurBranch) deleteBranch(); - else { + else Visit(cond); - } mergeLayer(); innermostLoopLayer = backupILB; @@ -1137,9 +1087,8 @@ void TBRAnalyzer::VisitArraySubscriptExpr( void TBRAnalyzer::VisitInitListExpr(const clang::InitListExpr* ILE) { setMode(0); - for (auto* init : ILE->inits()) { + for (auto* init : ILE->inits()) Visit(init); - } resetMode(); }