From 0375658b1b4137b2baf90983ed98b00d6e3ef520 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Thu, 15 Feb 2024 15:48:34 +0000 Subject: [PATCH] Delay the differentiation process until the end of TU. Before this patch clad attaches itself as a first consumer and applies AD before code generation. However, that is limited since clang sends every top-level declaration to codegen which limits the amount of flexibility clad has. For example, we have to instantiate all pending templates at every HandleTopLevelDecl calls; we cannot really differentiate virtual functions whose classes have sent their key function to CodeGen; and in general we perform actions which are semantically useful for the end of the translation unit. This patch makes clad a single consumer of clang which dispatches to the others. That's done by delaying all calls to the consumers until the end of the TU where clad can replay the exact sequence of calls to the other consumers as if they were directly connected to the frontend. Fixes #248 --- include/clad/Differentiator/Sins.h | 29 ++++ lib/Differentiator/VisitorBase.cpp | 37 +---- test/FirstDerivative/CodeGenSimple.C | 8 + test/Misc/ClangConsumers.cpp | 44 ++++++ tools/ClangPlugin.cpp | 228 ++++++++++++++++++++++----- tools/ClangPlugin.h | 168 +++++++++++++++++++- 6 files changed, 435 insertions(+), 79 deletions(-) create mode 100644 include/clad/Differentiator/Sins.h create mode 100644 test/Misc/ClangConsumers.cpp diff --git a/include/clad/Differentiator/Sins.h b/include/clad/Differentiator/Sins.h new file mode 100644 index 000000000..28983d626 --- /dev/null +++ b/include/clad/Differentiator/Sins.h @@ -0,0 +1,29 @@ +#ifndef CLAD_DIFFERENTIATOR_SINS_H +#define CLAD_DIFFERENTIATOR_SINS_H + +#include + +/// Standard-protected facility allowing access into private members in C++. +/// Use with caution! +// NOLINTBEGIN(cppcoreguidelines-macro-usage) +#define CONCATE_(X, Y) X##Y +#define CONCATE(X, Y) CONCATE_(X, Y) +#define ALLOW_ACCESS(CLASS, MEMBER, ...) \ + template \ + struct CONCATE(MEMBER, __LINE__) { \ + friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \ + }; \ + template struct Only_##MEMBER; \ + template <> struct Only_##MEMBER { \ + friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER*); \ + }; \ + template struct CONCATE(MEMBER, \ + __LINE__), &CLASS::MEMBER> + +#define ACCESS(OBJECT, MEMBER) \ + (OBJECT).*Access((Only_##MEMBER< \ + std::remove_reference::type>*)nullptr) + +// NOLINTEND(cppcoreguidelines-macro-usage) + +#endif // CLAD_DIFFERENTIATOR_SINS_H diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index eef3e2353..32ab9f161 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -8,10 +8,11 @@ #include "ConstantFolder.h" +#include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/StmtClone.h" -#include "clad/Differentiator/CladUtils.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" @@ -59,42 +60,14 @@ namespace clad { return true; } - // A facility allowing us to access the private member CurScope of the Sema - // object using standard-conforming C++. - namespace { - template struct Rob { - friend typename Tag::type get(Tag) { return M; } - }; - - template struct TagBase { - using type = Member; -#ifdef MSVC -#pragma warning(push, 0) -#endif // MSVC -#pragma GCC diagnostic push -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wunknown-warning-option" -#endif // __clang__ -#pragma GCC diagnostic ignored "-Wnon-template-friend" - friend type get(Tag); -#pragma GCC diagnostic pop -#ifdef MSVC -#pragma warning(pop) -#endif // MSVC - }; - - // Tag used to access Sema::CurScope. - using namespace clang; - struct Sema_CurScope : TagBase {}; - template struct Rob; - } // namespace + ALLOW_ACCESS(Sema, CurScope, Scope*); clang::Scope*& VisitorBase::getCurrentScope() { - return m_Sema.*get(Sema_CurScope()); + return ACCESS(m_Sema, CurScope); } void VisitorBase::setCurrentScope(clang::Scope* S) { - m_Sema.*get(Sema_CurScope()) = S; + getCurrentScope() = S; assert(getEnclosingNamespaceOrTUScope() && "Lost path to base."); } diff --git a/test/FirstDerivative/CodeGenSimple.C b/test/FirstDerivative/CodeGenSimple.C index 02a815c92..4ff77e806 100644 --- a/test/FirstDerivative/CodeGenSimple.C +++ b/test/FirstDerivative/CodeGenSimple.C @@ -33,9 +33,17 @@ extern "C" int printf(const char* fmt, ...); int f_1_darg0(int x); +double sq_defined_later(double); + int main() { int x = 4; clad::differentiate(f_1, 0); + auto df = clad::differentiate(sq_defined_later, "x"); printf("Result is = %d\n", f_1_darg0(1)); // CHECK-EXEC: Result is = 2 + printf("Result is = %f\n", df.execute(3)); // CHECK-EXEC: Result is = 6 return 0; } + +double sq_defined_later(double x) { + return x * x; +} diff --git a/test/Misc/ClangConsumers.cpp b/test/Misc/ClangConsumers.cpp new file mode 100644 index 000000000..7b0c47cbf --- /dev/null +++ b/test/Misc/ClangConsumers.cpp @@ -0,0 +1,44 @@ +// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out \ +// RUN: -fms-compatibility -std=c++14 -fmodules -Xclang \ +// RUN: -print-stats 2>&1 | FileCheck %s +// CHECK-NOT: {{.*error|warning|note:.*}} +// +// RUN: clang -xc -Xclang -add-plugin -Xclang clad -Xclang -load \ +// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \ +// RUN: -Xclang -print-stats 2>&1 | \ +// RUN: FileCheck -check-prefix=CHECK_C %s +// CHECK_C-NOT: {{.*error|warning|note:.*}} + +#ifdef __cplusplus + +#pragma clang module build N + module N {} + #pragma clang module contents + #pragma clang module begin N + struct f { void operator()() const {} }; + template auto vtemplate = f{}; + #pragma clang module end +#pragma clang module endbuild + +#pragma clang module import N + +class __single_inheritance IncSingle; + +// CHECK: HandleImplicitImportDecl +// CHECK: AssignInheritanceModel +// CHECK: HandleTopLevelDecl +// CHECK: HandleInterestingDecl +// CHECK: HandleCXXStaticMemberVarInstantiation + +#endif // __cplusplus + +#ifdef __STDC_VERSION__ // C mode +int i; +// CHECK_C: CompleteTentativeDefinition +#endif // __STDC_VERSION__ + +int main() { +#ifdef __cplusplus + vtemplate(); +#endif // __cplusplus +} diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 0f55cf065..64c5d44d3 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -9,6 +9,7 @@ #include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/EstimationModel.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/Version.h" #include "clang/AST/ASTConsumer.h" @@ -91,60 +92,52 @@ namespace clad { CladPlugin::~CladPlugin() {} + ALLOW_ACCESS(MultiplexConsumer, Consumers, + std::vector>); + + void CladPlugin::Initialize(clang::ASTContext& C) { + // We know we have a multiplexer. We commit a sin here by stealing it and + // making the consumer pass-through so that we can delay all operations + // until clad is happy. + + using namespace clang; + + auto& MultiplexC = static_cast(m_CI.getASTConsumer()); + auto& RobbedCs = ACCESS(MultiplexC, Consumers); + assert(RobbedCs.back().get() == this && "Clad is not the last consumer"); + std::vector> StolenConsumers; + + // The range-based for loop in MultiplexConsumer::Initialize has + // dispatched this call. Generally, it is unsafe to delete elements while + // iterating but we know we are in the end of the loop and ::end() won't + // be invalidated. + for (auto& RC : RobbedCs) + if (RC.get() == this) + RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1); + else + StolenConsumers.push_back(std::move(RC)); + m_Multiplexer.reset(new MultiplexConsumer(std::move(StolenConsumers))); + } + // We cannot use HandleTranslationUnit because codegen already emits code on // HandleTopLevelDecl calls and makes updateCall with no effect. - bool CladPlugin::HandleTopLevelDecl(DeclGroupRef DGR) { + void CladPlugin::HandleTopLevelDeclForClad(DeclGroupRef DGR) { if (!CheckBuiltins()) - return true; + return; Sema& S = m_CI.getSema(); if (!m_DerivativeBuilder) - m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this)); - - // if HandleTopLevelDecl was called through clad we don't need to process - // it for diff requests - if (m_HandleTopLevelDeclInternal) - return true; - - DiffSchedule requests{}; - DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema()); - - if (requests.empty()) - return true; + m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this)); - // FIXME: flags have to be set manually since DiffCollector's constructor - // does not have access to m_DO. - if (m_DO.EnableTBRAnalysis) - for (DiffRequest& request : requests) - request.EnableTBRAnalysis = true; - - // FIXME: Remove the PerformPendingInstantiations altogether. We should - // somehow make the relevant functions referenced. - // Instantiate all pending for instantiations templates, because we will - // need the full bodies to produce derivatives. - // FIXME: Confirm if we really need `m_PendingInstantiationsInFlight`? - if (!m_PendingInstantiationsInFlight) { - m_PendingInstantiationsInFlight = true; - S.PerformPendingInstantiations(); - m_PendingInstantiationsInFlight = false; - } - - for (DiffRequest& request : requests) - ProcessDiffRequest(request); - return true; // Happiness - } - - void CladPlugin::ProcessTopLevelDecl(Decl* D) { - m_HandleTopLevelDeclInternal = true; - m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D)); - m_HandleTopLevelDeclInternal = false; + DiffCollector collector(DGR, CladEnabledRange, m_DiffSchedule, S); } FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { Sema& S = m_CI.getSema(); // Required due to custom derivatives function templates that might be // used in the function that we need to derive. + // FIXME: Remove the call to PerformPendingInstantiations(). S.PerformPendingInstantiations(); if (request.Function->getDefinition()) request.Function = request.Function->getDefinition(); @@ -267,6 +260,8 @@ namespace clad { // Call CodeGen only if the produced Decl is a top-most // decl or is contained in a namespace decl. + // FIXME: We could get rid of this by prepending the produced + // derivatives in CladPlugin::HandleTranslationUnitDecl DeclContext* derivativeDC = DerivativeDecl->getDeclContext(); bool isTUorND = derivativeDC->isTranslationUnit() || derivativeDC->isNamespace(); @@ -296,6 +291,70 @@ namespace clad { return nullptr; } + void CladPlugin::SendToMultiplexer() { + for (auto DelayedCall : m_DelayedCalls) { + DeclGroupRef& D = DelayedCall.m_DGR; + switch (DelayedCall.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + m_Multiplexer->HandleCXXStaticMemberVarInstantiation( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTopLevelDecl: + m_Multiplexer->HandleTopLevelDecl(D); + break; + case CallKind::HandleInlineFunctionDefinition: + m_Multiplexer->HandleInlineFunctionDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleInterestingDecl: + m_Multiplexer->HandleInterestingDecl(D); + break; + case CallKind::HandleTagDeclDefinition: + m_Multiplexer->HandleTagDeclDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTagDeclRequiredDefinition: + m_Multiplexer->HandleTagDeclRequiredDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + m_Multiplexer->HandleCXXImplicitFunctionInstantiation( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + m_Multiplexer->HandleTopLevelDeclInObjCContainer(D); + break; + case CallKind::HandleImplicitImportDecl: + m_Multiplexer->HandleImplicitImportDecl( + cast(D.getSingleDecl())); + break; + case CallKind::CompleteTentativeDefinition: + m_Multiplexer->CompleteTentativeDefinition( + cast(D.getSingleDecl())); + break; +#if CLANG_VERSION_MAJOR > 9 + case CallKind::CompleteExternalDeclaration: + m_Multiplexer->CompleteExternalDeclaration( + cast(D.getSingleDecl())); + break; +#endif + case CallKind::AssignInheritanceModel: + m_Multiplexer->AssignInheritanceModel( + cast(D.getSingleDecl())); + break; + case CallKind::HandleVTable: + m_Multiplexer->HandleVTable(cast(D.getSingleDecl())); + break; + case CallKind::InitializeSema: + m_Multiplexer->InitializeSema(m_CI.getSema()); + break; + case CallKind::ForgetSema: + m_Multiplexer->ForgetSema(); + break; + }; + } + } + bool CladPlugin::CheckBuiltins() { // If we have included "clad/Differentiator/Differentiator.h" return. if (m_HasRuntime) @@ -318,6 +377,95 @@ namespace clad { m_HasRuntime = !R.empty(); return m_HasRuntime; } + + void CladPlugin::HandleTranslationUnit(ASTContext& C) { + Sema& S = m_CI.getSema(); + // Restore the TUScope that became a 0 in Sema::ActOnEndOfTranslationUnit. + S.TUScope = m_StoredTUScope; + constexpr bool Enabled = true; + Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled); + Sema::LocalEagerInstantiationScope LocalInstantiations(S); + + for (DiffRequest& request : m_DiffSchedule) { + // FIXME: flags have to be set manually since DiffCollector's + // constructor does not have access to m_DO. + request.EnableTBRAnalysis = m_DO.EnableTBRAnalysis; + ProcessDiffRequest(request); + } + // Put the TUScope in a consistent state after clad is done. + S.TUScope = nullptr; + // Force emission of the produced pending template instantiations. + LocalInstantiations.perform(); + GlobalInstantiations.perform(); + + SendToMultiplexer(); + m_Multiplexer->HandleTranslationUnit(C); + } + + void CladPlugin::PrintStats() { + llvm::errs() << "*** INFORMATION ABOUT THE DELAYED CALLS\n"; + for (const DelayedCallInfo& DCI : m_DelayedCalls) { + llvm::errs() << " "; + switch (DCI.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + llvm::errs() << "HandleCXXStaticMemberVarInstantiation"; + break; + case CallKind::HandleTopLevelDecl: + llvm::errs() << "HandleTopLevelDecl"; + break; + case CallKind::HandleInlineFunctionDefinition: + llvm::errs() << "HandleInlineFunctionDefinition"; + break; + case CallKind::HandleInterestingDecl: + llvm::errs() << "HandleInterestingDecl"; + break; + case CallKind::HandleTagDeclDefinition: + llvm::errs() << "HandleTagDeclDefinition"; + break; + case CallKind::HandleTagDeclRequiredDefinition: + llvm::errs() << "HandleTagDeclRequiredDefinition"; + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + llvm::errs() << "HandleCXXImplicitFunctionInstantiation"; + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + llvm::errs() << "HandleTopLevelDeclInObjCContainer"; + break; + case CallKind::HandleImplicitImportDecl: + llvm::errs() << "HandleImplicitImportDecl"; + break; + case CallKind::CompleteTentativeDefinition: + llvm::errs() << "CompleteTentativeDefinition"; + break; +#if CLANG_VERSION_MAJOR > 9 + case CallKind::CompleteExternalDeclaration: + llvm::errs() << "CompleteExternalDeclaration"; + break; +#endif + case CallKind::AssignInheritanceModel: + llvm::errs() << "AssignInheritanceModel"; + break; + case CallKind::HandleVTable: + llvm::errs() << "HandleVTable"; + break; + case CallKind::InitializeSema: + llvm::errs() << "InitializeSema"; + break; + case CallKind::ForgetSema: + llvm::errs() << "ForgetSema"; + break; + }; + for (const clang::Decl* D : DCI.m_DGR) { + llvm::errs() << " " << D; + if (const auto* ND = dyn_cast(D)) + llvm::errs() << " " << ND->getNameAsString(); + } + llvm::errs() << "\n"; + } + + m_Multiplexer->PrintStats(); + } + } // end namespace plugin clad::CladTimerGroup::CladTimerGroup() diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 808443e49..86be47260 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -13,11 +13,12 @@ #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/Version.h" -#include "clang/AST/ASTConsumer.h" #include "clang/AST/Decl.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/Version.h" #include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Frontend/MultiplexConsumer.h" +#include "clang/Sema/SemaConsumer.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" @@ -92,24 +93,177 @@ namespace clad { std::string CustomModelName; }; - class CladPlugin : public clang::ASTConsumer { + class CladExternalSource : public clang::ExternalSemaSource { + // ExternalSemaSource + void ReadUndefinedButUsed( + llvm::MapVector& Undefined) + override { + // namespace { double f_darg0(double x); } will issue a warning that + // f_darg0 has internal linkage but is not defined. This is because we + // have not yet started to differentiate it. The warning is triggered by + // Sema::ActOnEndOfTranslationUnit before Clad is given control. + // To avoid the warning we should remove the entry from here. + using namespace clang; + Undefined.remove_if([](std::pair P) { + NamedDecl* ND = P.first; + + if (!ND->getDeclName().isIdentifier()) + return false; + + // FIXME: We should replace this comparison with the canonical decl + // from the differentiation plan... + return ND->getName().contains("_darg"); + }); + } + }; + class CladPlugin : public clang::SemaConsumer { clang::CompilerInstance& m_CI; DifferentiationOptions m_DO; std::unique_ptr m_DerivativeBuilder; bool m_HasRuntime = false; - bool m_PendingInstantiationsInFlight = false; - bool m_HandleTopLevelDeclInternal = false; CladTimerGroup m_CTG; DerivedFnCollector m_DFC; + DiffSchedule m_DiffSchedule; + enum class CallKind { + HandleCXXStaticMemberVarInstantiation, + HandleTopLevelDecl, + HandleInlineFunctionDefinition, + HandleInterestingDecl, + HandleTagDeclDefinition, + HandleTagDeclRequiredDefinition, + HandleCXXImplicitFunctionInstantiation, + HandleTopLevelDeclInObjCContainer, + HandleImplicitImportDecl, + CompleteTentativeDefinition, +#if CLANG_VERSION_MAJOR > 9 + CompleteExternalDeclaration, +#endif + AssignInheritanceModel, + HandleVTable, + InitializeSema, + ForgetSema + }; + struct DelayedCallInfo { + CallKind m_Kind; + clang::DeclGroupRef m_DGR; + DelayedCallInfo(CallKind K, clang::DeclGroupRef DGR) + : m_Kind(K), m_DGR(DGR) {} + DelayedCallInfo(CallKind K, const clang::Decl* D) + : m_Kind(K), m_DGR(const_cast(D)) {} + bool operator==(const DelayedCallInfo& other) const { + if (m_Kind != other.m_Kind) + return false; + + clang::Decl* const* first1 = m_DGR.begin(); + clang::Decl* const* first2 = other.m_DGR.begin(); + clang::Decl* const* last1 = m_DGR.end(); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + for (; first1 != last1; ++first1, ++first2) + if (!(*first1 == *first2)) + return false; + return true; + } + }; + /// The calls to the main action which clad delayed and will dispatch at + /// then end of the translation unit. + std::vector m_DelayedCalls; + /// The default clang consumers which are called after clad is done. + std::unique_ptr m_Multiplexer; + + /// The Sema::TUScope to restore in CladPlugin::HandleTranslationUnit. + clang::Scope* m_StoredTUScope = nullptr; + public: CladPlugin(clang::CompilerInstance& CI, DifferentiationOptions& DO); ~CladPlugin(); - bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override; + // ASTConsumer + void Initialize(clang::ASTContext& Context) override; + void HandleCXXStaticMemberVarInstantiation(clang::VarDecl* D) override { + AppendDelayed({CallKind::HandleCXXStaticMemberVarInstantiation, D}); + } + bool HandleTopLevelDecl(clang::DeclGroupRef D) override { + HandleTopLevelDeclForClad(D); + AppendDelayed({CallKind::HandleTopLevelDecl, D}); + return true; // happyness, continue parsing + } + void HandleInlineFunctionDefinition(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleInlineFunctionDefinition, D}); + } + void HandleInterestingDecl(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleInterestingDecl, D}); + } + void HandleTagDeclDefinition(clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclDefinition, D}); + } + void HandleTagDeclRequiredDefinition(const clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclRequiredDefinition, D}); + } + void + HandleCXXImplicitFunctionInstantiation(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleCXXImplicitFunctionInstantiation, D}); + } + void HandleTopLevelDeclInObjCContainer(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleTopLevelDeclInObjCContainer, D}); + } + void HandleImplicitImportDecl(clang::ImportDecl* D) override { + AppendDelayed({CallKind::HandleImplicitImportDecl, D}); + } + void CompleteTentativeDefinition(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteTentativeDefinition, D}); + } +#if CLANG_VERSION_MAJOR > 9 + void CompleteExternalDeclaration(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteExternalDeclaration, D}); + } +#endif + void AssignInheritanceModel(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::AssignInheritanceModel, D}); + } + void HandleVTable(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::HandleVTable, D}); + } + + // Not delayed. + void HandleTranslationUnit(clang::ASTContext& C) override; + + // No need to handle the listeners, they will be handled at non-delayed by + // the parent multiplexer. + // + // clang::ASTMutationListener *GetASTMutationListener() override; + // clang::ASTDeserializationListener *GetASTDeserializationListener() + // override; + void PrintStats() override; + + bool shouldSkipFunctionBody(clang::Decl* D) override { + return m_Multiplexer->shouldSkipFunctionBody(D); + } + + // SemaConsumer + void InitializeSema(clang::Sema& S) override { + // We are also a ExternalSemaSource. + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + S.addExternalSource(new CladExternalSource()); // Owned by Sema. + m_StoredTUScope = S.TUScope; + AppendDelayed({CallKind::InitializeSema, nullptr}); + } + void ForgetSema() override { + AppendDelayed({CallKind::ForgetSema, nullptr}); + } + + // FIXME: We should hide ProcessDiffRequest when we implement proper + // handling of the differentiation plans. clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); private: + void AppendDelayed(DelayedCallInfo DCI) { m_DelayedCalls.push_back(DCI); } + void SendToMultiplexer(); bool CheckBuiltins(); - void ProcessTopLevelDecl(clang::Decl* D); + void ProcessTopLevelDecl(clang::Decl* D) { + DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D}; + assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!"); + AppendDelayed(DCI); + } + void HandleTopLevelDeclForClad(clang::DeclGroupRef DGR); }; clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P, @@ -190,7 +344,7 @@ namespace clad { } PluginASTAction::ActionType getActionType() override { - return AddBeforeMainAction; + return AddAfterMainAction; } }; } // end namespace plugin