From 4735b41e51a42a28e3c2bbd5c127134c95092eea Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Tue, 12 Mar 2024 19:20:12 +0000 Subject: [PATCH] Address review comments --- tools/ClangPlugin.cpp | 14 ++++---------- tools/ClangPlugin.h | 8 +++++++- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 64c5d44d3..266250f73 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -100,9 +100,7 @@ namespace clad { // making the consumer pass-through so that we can delay all operations // until clad is happy. - using namespace clang; - - auto& MultiplexC = static_cast(m_CI.getASTConsumer()); + auto& MultiplexC = cast(m_CI.getASTConsumer()); auto& RobbedCs = ACCESS(MultiplexC, Consumers); assert(RobbedCs.back().get() == this && "Clad is not the last consumer"); std::vector> StolenConsumers; @@ -111,16 +109,12 @@ namespace clad { // dispatched this call. Generally, it is unsafe to delete elements while // iterating but we know we are in the end of the loop and ::end() won't // be invalidated. - for (auto& RC : RobbedCs) - if (RC.get() == this) - RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1); - else - StolenConsumers.push_back(std::move(RC)); + std::move(RobbedCs.begin(), RobbedCs.end() - 1, + std::back_inserter(StolenConsumers)); + RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1); m_Multiplexer.reset(new MultiplexConsumer(std::move(StolenConsumers))); } - // We cannot use HandleTranslationUnit because codegen already emits code on - // HandleTopLevelDecl calls and makes updateCall with no effect. void CladPlugin::HandleTopLevelDeclForClad(DeclGroupRef DGR) { if (!CheckBuiltins()) return; diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index a0fd56cbf..35f20945d 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -112,7 +112,9 @@ namespace clad { // FIXME: We should replace this comparison with the canonical decl // from the differentiation plan... - return ND->getName().contains("_darg"); + llvm::StringRef Name = ND->getName(); + return Name.contains("_darg") || Name.contains("_grad") || + Name.contains("_hessian") || Name.contains("_jacobian"); }); } }; @@ -154,6 +156,10 @@ namespace clad { if (m_Kind != other.m_Kind) return false; + if (std::distance(m_DGR.begin(), m_DGR.end()) != + std::distance(other.m_DGR.begin(), other.m_DGR.end())) + return false; + clang::Decl* const* first1 = m_DGR.begin(); clang::Decl* const* first2 = other.m_DGR.begin(); clang::Decl* const* last1 = m_DGR.end();