Skip to content

Commit

Permalink
Move name collision handling to ReferencesUpdater.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Nov 28, 2023
1 parent 9d26a46 commit a3415b8
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 10 deletions.
5 changes: 4 additions & 1 deletion include/clad/Differentiator/StmtClone.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,12 @@ namespace utils {
clang::Sema& m_Sema; // We don't own.
clang::Scope* m_CurScope; // We don't own.
const clang::FunctionDecl* m_Function; // We don't own.
const std::unordered_map<const clang::VarDecl*, clang::VarDecl*>& m_DeclReplacements; // We don't own.
public:
ReferencesUpdater(clang::Sema& SemaRef, clang::Scope* S,
const clang::FunctionDecl* FD);
const clang::FunctionDecl* FD,
const std::unordered_map<const clang::VarDecl*,
clang::VarDecl*>& DeclReplacements);
bool VisitDeclRefExpr(clang::DeclRefExpr* DRE);
bool VisitStmt(clang::Stmt* S);
/// Used to update the size expression of QT
Expand Down
6 changes: 1 addition & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1328,11 +1328,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Check if referenced Decl was "replaced" with another identifier inside
// the derivative
if (const auto* VD = dyn_cast<VarDecl>(DRE->getDecl())) {
auto it = m_DeclReplacements.find(VD);
if (it != std::end(m_DeclReplacements))
clonedDRE = BuildDeclRef(it->second);
else
clonedDRE = cast<DeclRefExpr>(Clone(DRE));
clonedDRE = cast<DeclRefExpr>(Clone(DRE));
// If current context is different than the context of the original
// declaration (e.g. we are inside lambda), rebuild the DeclRefExpr
// with Sema::BuildDeclRefExpr. This is required in some cases, e.g.
Expand Down
15 changes: 13 additions & 2 deletions lib/Differentiator/StmtClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,26 @@ Stmt* StmtClone::VisitStmt(Stmt*) {
}

ReferencesUpdater::ReferencesUpdater(Sema& SemaRef, Scope* S,
const FunctionDecl* FD)
: m_Sema(SemaRef), m_CurScope(S), m_Function(FD) {}
const FunctionDecl* FD,
const std::unordered_map<const clang::VarDecl*,
clang::VarDecl*>& DeclReplacements)
: m_Sema(SemaRef), m_CurScope(S), m_Function(FD), m_DeclReplacements(DeclReplacements) {}

bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) {
// We should only update references of the declarations that were inside
// the original function declaration context.
// Original function = function that we are currently differentiating.
if (!DRE->getDecl()->getDeclContext()->Encloses(m_Function))
return true;

// Replace the declaration if it is present in `m_DeclReplacements`.
if (VarDecl* VD = dyn_cast<VarDecl>(DRE->getDecl())) {
auto it = m_DeclReplacements.find(VD);
if (it != std::end(m_DeclReplacements)) {
DRE->setDecl(it->second);
}
}

DeclarationNameInfo DNI = DRE->getNameInfo();

LookupResult R(m_Sema, DNI, Sema::LookupOrdinaryName);
Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ namespace clad {
}

void VisitorBase::updateReferencesOf(Stmt* InSubtree) {
utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function);
utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function, m_DeclReplacements);
up.TraverseStmt(InSubtree);
}

Expand Down Expand Up @@ -304,7 +304,7 @@ namespace clad {

QualType VisitorBase::CloneType(const QualType QT) {
auto clonedType = m_Builder.m_NodeCloner->CloneType(QT);
utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function);
utils::ReferencesUpdater up(m_Sema, getCurrentScope(), m_Function, m_DeclReplacements);
up.updateType(clonedType);
return clonedType;
}
Expand Down

0 comments on commit a3415b8

Please sign in to comment.