diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index cf1820a56..ad1a54c77 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -54,7 +54,7 @@ namespace clad { /// Based on To-Be-Recorded analysis performed before differentiation, /// tells UsefulToStoreGlobal whether a variable with a given /// SourceLocation has to be stored before being changed or not. - std::map m_ToBeRecorded; + std::set m_ToBeRecorded; /// A flag indicating if the Stmt we are currently visiting is inside loop. bool isInsideLoop = false; /// Output variable of vector-valued function diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index b2677e4e7..c3c223f43 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -223,7 +223,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { enum Mode { markingMode = 1, nonLinearMode = 2 }; /// Tells if the variable at a given location is required to store. Basically, /// is the result of analysis. - std::map TBRLocs; + std::set TBRLocs; /// Stores modes in a stack (used to retrieve the old mode after entering /// a new one). @@ -297,7 +297,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { TBRAnalyzer& operator=(const TBRAnalyzer&&) = delete; /// Returns the result of the whole analysis - std::map getResult() { return TBRLocs; } + std::set getResult() { return TBRLocs; } /// Visitors void Analyze(const clang::FunctionDecl* FD); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d8b9e36a1..690472456 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2769,16 +2769,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // current system that should have been decided by the parent expression. // FIXME: Here will be the entry point of the advanced activity analysis. if (isa(B) || isa(B)) { - // auto line = - // m_Context.getSourceManager().getPresumedLoc(B->getBeginLoc()).getLine(); - // auto column = - // m_Context.getSourceManager().getPresumedLoc(B->getBeginLoc()).getColumn(); - // llvm::errs() << line << "|" <getBeginLoc()); - if (it == m_ToBeRecorded.end()) { - return true; - } - return it->second; + auto found = m_ToBeRecorded.find(B->getBeginLoc()); + return found != m_ToBeRecorded.end(); } // FIXME: Attach checkpointing. diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 62f826781..b5f6af7e6 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -295,20 +295,16 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { void TBRAnalyzer::markLocation(const clang::Expr* E) { VarData* data = getExprVarData(E); - if (data) { + if (!data || findReq(data)) { /// FIXME: If any of the data's child nodes are required to store then data /// itself is stored. We might add an option to store separate fields. - bool& ToBeRec = TBRLocs[E->getBeginLoc()]; /// FIXME: Sometimes one location might correspond to multiple stores. /// For example, in ``(x*=y)=u`` x's location will first be marked as /// required to be stored (when passing *= operator) but then marked as not /// required to be stored (when passing = operator). Current method of /// marking locations does not allow to differentiate between these two. - ToBeRec = ToBeRec || findReq(data); - } else - /// If the current branch is going to be deleted then there is not point in - /// storing anything in it. - TBRLocs[E->getBeginLoc()] = true; + TBRLocs.insert(E->getBeginLoc()); + } } void TBRAnalyzer::setIsRequired(const clang::Expr* E, bool isReq) { @@ -723,15 +719,11 @@ void TBRAnalyzer::VisitCallExpr(const clang::CallExpr* CE) { // FIXME: this supports only DeclRefExpr const auto innerExpr = utils::GetInnermostReturnExpr(arg); if (passByRef) { - /// Mark SourceLocation as required for ref-type arguments. + /// Mark SourceLocation as required to store for ref-type arguments. if (isa(B) || isa(B)) { - TBRLocs[arg->getBeginLoc()] = true; + TBRLocs.insert(arg->getBeginLoc()); setIsRequired(arg, /*isReq=*/false); } - } else { - /// Mark SourceLocation as not required for non-ref-type arguments. - if (isa(B) || isa(B)) - TBRLocs[arg->getBeginLoc()] = false; } } resetMode(); @@ -755,13 +747,9 @@ void TBRAnalyzer::VisitCXXConstructExpr(const clang::CXXConstructExpr* CE) { if (passByRef) { /// Mark SourceLocation as required for ref-type arguments. if (isa(B) || isa(B)) { - TBRLocs[arg->getBeginLoc()] = true; + TBRLocs.insert(arg->getBeginLoc()); setIsRequired(arg, /*isReq=*/false); } - } else { - /// Mark SourceLocation as not required for non-ref-type arguments. - if (isa(B) || isa(B)) - TBRLocs[arg->getBeginLoc()] = false; } } resetMode();