diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index c4ab7039a..d678d12e1 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -121,11 +121,70 @@ namespace clad { CladPlugin::~CladPlugin() {} + // A facility allowing us to access the private member CurScope of the Sema + // object using standard-conforming C++. + namespace { + template struct Rob { + friend typename Tag::type get(Tag) { return M; } + }; + + template struct TagBase { + using type = Member; +#ifdef MSVC +#pragma warning(push, 0) +#endif // MSVC +#pragma GCC diagnostic push +#ifdef __clang__ +#pragma clang diagnostic ignored "-Wunknown-warning-option" +#endif // __clang__ +#pragma GCC diagnostic ignored "-Wnon-template-friend" + friend type get(Tag); +#pragma GCC diagnostic pop +#ifdef MSVC +#pragma warning(pop) +#endif // MSVC + }; + // Tag used to access MultiplexConsumer::Consumers. + using namespace clang; + struct MultiplexConsumer_Consumers + : TagBase< + MultiplexConsumer_Consumers, + std::vector> MultiplexConsumer::*> { + }; + template struct Rob; + } // namespace + + void CladPlugin::Initialize(clang::ASTContext& C) { + // We know we have a multiplexer. We commit a sin here by stealing it and + // making the consumer pass-through so that we can delay all operations + // until clad is happy. + + using namespace clang; + + auto& MultiplexC = static_cast(m_CI.getASTConsumer()); + auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers()); + assert(RobbedCs.back().get() == this && "Clad is not the last consumer"); + std::vector> StolenConsumers; + + // The range-based for loop in MultiplexConsumer::Initialize has + // dispatched this call. Generally, it is unsafe to delete elements while + // iterating but we know we are in the end of the loop and ::end() won't + // be invalidated. + for (auto& RC : RobbedCs) + if (RC.get() == this) + RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1); + else + StolenConsumers.push_back(std::move(RC)); + m_Multiplexer.reset(new MultiplexConsumer(std::move(StolenConsumers))); + } + // We cannot use HandleTranslationUnit because codegen already emits code on // HandleTopLevelDecl calls and makes updateCall with no effect. bool CladPlugin::HandleTopLevelDecl(DeclGroupRef DGR) { + AppendDelayed({CallKind::HandleTopLevelDecl, DGR}); if (!CheckBuiltins()) - return true; + return m_Multiplexer->HandleTopLevelDecl(DGR); // true; Sema& S = m_CI.getSema(); @@ -135,13 +194,13 @@ namespace clad { // if HandleTopLevelDecl was called through clad we don't need to process // it for diff requests if (m_HandleTopLevelDeclInternal) - return true; + return m_Multiplexer->HandleTopLevelDecl(DGR); // true; DiffSchedule requests{}; DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema()); if (requests.empty()) - return true; + return m_Multiplexer->HandleTopLevelDecl(DGR); // true; // FIXME: flags have to be set manually since DiffCollector's constructor // does not have access to m_DO. @@ -162,7 +221,8 @@ namespace clad { for (DiffRequest& request : requests) ProcessDiffRequest(request); - return true; // Happiness + + return m_Multiplexer->HandleTopLevelDecl(DGR); // Happiness } void CladPlugin::ProcessTopLevelDecl(Decl* D) { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 3394a0d38..3e99a93b9 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -13,11 +13,12 @@ #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/Version.h" -#include "clang/AST/ASTConsumer.h" #include "clang/AST/Decl.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/Version.h" #include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Frontend/MultiplexConsumer.h" +#include "clang/Sema/SemaConsumer.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/DenseMap.h" @@ -82,7 +83,7 @@ namespace clad { std::string CustomModelName; }; - class CladPlugin : public clang::ASTConsumer { + class CladPlugin : public clang::SemaConsumer { clang::CompilerInstance& m_CI; DifferentiationOptions m_DO; std::unique_ptr m_DerivativeBuilder; @@ -90,13 +91,125 @@ namespace clad { bool m_PendingInstantiationsInFlight = false; bool m_HandleTopLevelDeclInternal = false; DerivedFnCollector m_DFC; + enum class CallKind { + HandleCXXStaticMemberVarInstantiation, + HandleTopLevelDecl, + HandleInlineFunctionDefinition, + HandleInterestingDecl, + HandleTagDeclDefinition, + HandleTagDeclRequiredDefinition, + HandleCXXImplicitFunctionInstantiation, + HandleTopLevelDeclInObjCContainer, + HandleImplicitImportDecl, + CompleteTentativeDefinition, + CompleteExternalDeclaration, + AssignInheritanceModel, + HandleVTable, + InitializeSema, + ForgetSema + }; + struct DelayedCallInfo { + CallKind m_Kind; + clang::DeclGroupRef m_DGR; + DelayedCallInfo(CallKind K, clang::DeclGroupRef DGR) + : m_Kind(K), m_DGR(DGR) {} + DelayedCallInfo(CallKind K, const clang::Decl* D) + : m_Kind(K), m_DGR(const_cast(D)) {} + }; + std::vector m_DelayedCalls; + std::unique_ptr m_Multiplexer; + public: CladPlugin(clang::CompilerInstance& CI, DifferentiationOptions& DO); ~CladPlugin(); - bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override; + // ASTConsumer + void Initialize(clang::ASTContext& Context) override; + void HandleCXXStaticMemberVarInstantiation(clang::VarDecl* D) override { + AppendDelayed({CallKind::HandleCXXStaticMemberVarInstantiation, D}); + m_Multiplexer->HandleCXXStaticMemberVarInstantiation(D); + } + bool HandleTopLevelDecl(clang::DeclGroupRef D) override; /*{ + AppendDelayed({CallKind::HandleTopLevelDecl, D}); + return true; // happyness, continue parsing + }*/ + void HandleInlineFunctionDefinition(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleInlineFunctionDefinition, D}); + m_Multiplexer->HandleInlineFunctionDefinition(D); + } + void HandleInterestingDecl(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleInterestingDecl, D}); + m_Multiplexer->HandleInterestingDecl(D); + } + void HandleTagDeclDefinition(clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclDefinition, D}); + m_Multiplexer->HandleTagDeclDefinition(D); + } + void HandleTagDeclRequiredDefinition(const clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclRequiredDefinition, D}); + m_Multiplexer->HandleTagDeclRequiredDefinition(D); + } + void + HandleCXXImplicitFunctionInstantiation(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleCXXImplicitFunctionInstantiation, D}); + m_Multiplexer->HandleCXXImplicitFunctionInstantiation(D); + } + void HandleTopLevelDeclInObjCContainer(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleTopLevelDeclInObjCContainer, D}); + m_Multiplexer->HandleTopLevelDeclInObjCContainer(D); + } + void HandleImplicitImportDecl(clang::ImportDecl* D) override { + AppendDelayed({CallKind::HandleImplicitImportDecl, D}); + m_Multiplexer->HandleImplicitImportDecl(D); + } + void CompleteTentativeDefinition(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteTentativeDefinition, D}); + m_Multiplexer->CompleteTentativeDefinition(D); + } +#if CLANG_VERSION_MAJOR > 9 + void CompleteExternalDeclaration(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteExternalDeclaration, D}); + m_Multiplexer->CompleteExternalDeclaration(D); + } +#endif + void AssignInheritanceModel(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::AssignInheritanceModel, D}); + m_Multiplexer->AssignInheritanceModel(D); + } + void HandleVTable(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::HandleVTable, D}); + m_Multiplexer->HandleVTable(D); + } + + // Not delayed. + void HandleTranslationUnit(clang::ASTContext& C) override { + m_Multiplexer->HandleTranslationUnit(C); + } + // No need to handle the listeners, they will be handled at non-delayed by + // the parent multiplexer. + // + // clang::ASTMutationListener *GetASTMutationListener() override; + // clang::ASTDeserializationListener *GetASTDeserializationListener() + // override; + void PrintStats() override { m_Multiplexer->PrintStats(); } + bool shouldSkipFunctionBody(clang::Decl* D) override { + return m_Multiplexer->shouldSkipFunctionBody(D); + } + + // SemaConsumer + void InitializeSema(clang::Sema& S) override { + AppendDelayed({CallKind::InitializeSema, nullptr}); + m_Multiplexer->InitializeSema(S); + } + void ForgetSema() override { + AppendDelayed({CallKind::ForgetSema, nullptr}); + m_Multiplexer->ForgetSema(); + } + + // bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override; clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); private: + void AppendDelayed(DelayedCallInfo DCI) { m_DelayedCalls.push_back(DCI); } bool CheckBuiltins(); void ProcessTopLevelDecl(clang::Decl* D); }; @@ -179,7 +292,7 @@ namespace clad { } PluginASTAction::ActionType getActionType() override { - return AddBeforeMainAction; + return AddAfterMainAction; } }; } // end namespace plugin