Skip to content

Commit

Permalink
Make m_ToBeRecorded a set.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Oct 25, 2023
1 parent e46c34d commit 0e458a2
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 31 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ namespace clad {
/// Based on To-Be-Recorded analysis performed before differentiation,
/// tells UsefulToStoreGlobal whether a variable with a given
/// SourceLocation has to be stored before being changed or not.
std::map<clang::SourceLocation, bool> m_ToBeRecorded;
std::set<clang::SourceLocation> m_ToBeRecorded;
/// A flag indicating if the Stmt we are currently visiting is inside loop.
bool isInsideLoop = false;
/// Output variable of vector-valued function
Expand Down
4 changes: 2 additions & 2 deletions include/clad/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
enum Mode { markingMode = 1, nonLinearMode = 2 };
/// Tells if the variable at a given location is required to store. Basically,
/// is the result of analysis.
std::map<clang::SourceLocation, bool> TBRLocs;
std::set<clang::SourceLocation> TBRLocs;

/// Stores modes in a stack (used to retrieve the old mode after entering
/// a new one).
Expand Down Expand Up @@ -297,7 +297,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor<TBRAnalyzer> {
TBRAnalyzer& operator=(const TBRAnalyzer&&) = delete;

/// Returns the result of the whole analysis
std::map<clang::SourceLocation, bool> getResult() { return TBRLocs; }
std::set<clang::SourceLocation> getResult() { return TBRLocs; }

/// Visitors
void Analyze(const clang::FunctionDecl* FD);
Expand Down
12 changes: 2 additions & 10 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2769,16 +2769,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// current system that should have been decided by the parent expression.
// FIXME: Here will be the entry point of the advanced activity analysis.
if (isa<DeclRefExpr>(B) || isa<ArraySubscriptExpr>(B)) {
// auto line =
// m_Context.getSourceManager().getPresumedLoc(B->getBeginLoc()).getLine();
// auto column =
// m_Context.getSourceManager().getPresumedLoc(B->getBeginLoc()).getColumn();
// llvm::errs() << line << "|" <<column << "?\n";
auto it = m_ToBeRecorded.find(B->getBeginLoc());
if (it == m_ToBeRecorded.end()) {
return true;
}
return it->second;
auto found = m_ToBeRecorded.find(B->getBeginLoc());
return found != m_ToBeRecorded.end();
}

// FIXME: Attach checkpointing.
Expand Down
24 changes: 6 additions & 18 deletions lib/Differentiator/TBRAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,20 +295,16 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) {

void TBRAnalyzer::markLocation(const clang::Expr* E) {
VarData* data = getExprVarData(E);
if (data) {
if (!data || findReq(data)) {
/// FIXME: If any of the data's child nodes are required to store then data
/// itself is stored. We might add an option to store separate fields.
bool& ToBeRec = TBRLocs[E->getBeginLoc()];
/// FIXME: Sometimes one location might correspond to multiple stores.
/// For example, in ``(x*=y)=u`` x's location will first be marked as
/// required to be stored (when passing *= operator) but then marked as not
/// required to be stored (when passing = operator). Current method of
/// marking locations does not allow to differentiate between these two.
ToBeRec = ToBeRec || findReq(data);
} else
/// If the current branch is going to be deleted then there is not point in
/// storing anything in it.
TBRLocs[E->getBeginLoc()] = true;
TBRLocs.insert(E->getBeginLoc());
}
}

void TBRAnalyzer::setIsRequired(const clang::Expr* E, bool isReq) {
Expand Down Expand Up @@ -723,15 +719,11 @@ void TBRAnalyzer::VisitCallExpr(const clang::CallExpr* CE) {
// FIXME: this supports only DeclRefExpr
const auto innerExpr = utils::GetInnermostReturnExpr(arg);
if (passByRef) {
/// Mark SourceLocation as required for ref-type arguments.
/// Mark SourceLocation as required to store for ref-type arguments.
if (isa<DeclRefExpr>(B) || isa<MemberExpr>(B)) {
TBRLocs[arg->getBeginLoc()] = true;
TBRLocs.insert(arg->getBeginLoc());
setIsRequired(arg, /*isReq=*/false);
}
} else {
/// Mark SourceLocation as not required for non-ref-type arguments.
if (isa<DeclRefExpr>(B) || isa<MemberExpr>(B))
TBRLocs[arg->getBeginLoc()] = false;
}
}
resetMode();
Expand All @@ -755,13 +747,9 @@ void TBRAnalyzer::VisitCXXConstructExpr(const clang::CXXConstructExpr* CE) {
if (passByRef) {
/// Mark SourceLocation as required for ref-type arguments.
if (isa<DeclRefExpr>(B) || isa<MemberExpr>(B)) {
TBRLocs[arg->getBeginLoc()] = true;
TBRLocs.insert(arg->getBeginLoc());
setIsRequired(arg, /*isReq=*/false);
}
} else {
/// Mark SourceLocation as not required for non-ref-type arguments.
if (isa<DeclRefExpr>(B) || isa<MemberExpr>(B))
TBRLocs[arg->getBeginLoc()] = false;
}
}
resetMode();
Expand Down

0 comments on commit 0e458a2

Please sign in to comment.