Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Mar 13, 2024
1 parent 4516ad6 commit 4735b41
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
14 changes: 4 additions & 10 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ namespace clad {
// making the consumer pass-through so that we can delay all operations
// until clad is happy.

using namespace clang;

auto& MultiplexC = static_cast<MultiplexConsumer&>(m_CI.getASTConsumer());
auto& MultiplexC = cast<MultiplexConsumer>(m_CI.getASTConsumer());
auto& RobbedCs = ACCESS(MultiplexC, Consumers);
assert(RobbedCs.back().get() == this && "Clad is not the last consumer");
std::vector<std::unique_ptr<ASTConsumer>> StolenConsumers;
Expand All @@ -111,16 +109,12 @@ namespace clad {
// dispatched this call. Generally, it is unsafe to delete elements while
// iterating but we know we are in the end of the loop and ::end() won't
// be invalidated.
for (auto& RC : RobbedCs)
if (RC.get() == this)
RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1);
else
StolenConsumers.push_back(std::move(RC));
std::move(RobbedCs.begin(), RobbedCs.end() - 1,
std::back_inserter(StolenConsumers));
RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1);
m_Multiplexer.reset(new MultiplexConsumer(std::move(StolenConsumers)));
}

// We cannot use HandleTranslationUnit because codegen already emits code on
// HandleTopLevelDecl calls and makes updateCall with no effect.
void CladPlugin::HandleTopLevelDeclForClad(DeclGroupRef DGR) {
if (!CheckBuiltins())
return;
Expand Down
8 changes: 7 additions & 1 deletion tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ namespace clad {

// FIXME: We should replace this comparison with the canonical decl
// from the differentiation plan...
return ND->getName().contains("_darg");
llvm::StringRef Name = ND->getName();
return Name.contains("_darg") || Name.contains("_grad") ||
Name.contains("_hessian") || Name.contains("_jacobian");
});

Check warning on line 118 in tools/ClangPlugin.h

View check run for this annotation

Codecov / codecov/patch

tools/ClangPlugin.h#L118

Added line #L118 was not covered by tests
}
};
Expand Down Expand Up @@ -154,6 +156,10 @@ namespace clad {
if (m_Kind != other.m_Kind)
return false;

if (std::distance(m_DGR.begin(), m_DGR.end()) !=
std::distance(other.m_DGR.begin(), other.m_DGR.end()))
return false;

clang::Decl* const* first1 = m_DGR.begin();
clang::Decl* const* first2 = other.m_DGR.begin();
clang::Decl* const* last1 = m_DGR.end();
Expand Down

0 comments on commit 4735b41

Please sign in to comment.