diff --git a/include/clad/Differentiator/Sins.h b/include/clad/Differentiator/Sins.h new file mode 100644 index 000000000..28983d626 --- /dev/null +++ b/include/clad/Differentiator/Sins.h @@ -0,0 +1,29 @@ +#ifndef CLAD_DIFFERENTIATOR_SINS_H +#define CLAD_DIFFERENTIATOR_SINS_H + +#include + +/// Standard-protected facility allowing access into private members in C++. +/// Use with caution! +// NOLINTBEGIN(cppcoreguidelines-macro-usage) +#define CONCATE_(X, Y) X##Y +#define CONCATE(X, Y) CONCATE_(X, Y) +#define ALLOW_ACCESS(CLASS, MEMBER, ...) \ + template \ + struct CONCATE(MEMBER, __LINE__) { \ + friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \ + }; \ + template struct Only_##MEMBER; \ + template <> struct Only_##MEMBER { \ + friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER*); \ + }; \ + template struct CONCATE(MEMBER, \ + __LINE__), &CLASS::MEMBER> + +#define ACCESS(OBJECT, MEMBER) \ + (OBJECT).*Access((Only_##MEMBER< \ + std::remove_reference::type>*)nullptr) + +// NOLINTEND(cppcoreguidelines-macro-usage) + +#endif // CLAD_DIFFERENTIATOR_SINS_H diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index eef3e2353..32ab9f161 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -8,10 +8,11 @@ #include "ConstantFolder.h" +#include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/StmtClone.h" -#include "clad/Differentiator/CladUtils.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" @@ -59,42 +60,14 @@ namespace clad { return true; } - // A facility allowing us to access the private member CurScope of the Sema - // object using standard-conforming C++. - namespace { - template struct Rob { - friend typename Tag::type get(Tag) { return M; } - }; - - template struct TagBase { - using type = Member; -#ifdef MSVC -#pragma warning(push, 0) -#endif // MSVC -#pragma GCC diagnostic push -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wunknown-warning-option" -#endif // __clang__ -#pragma GCC diagnostic ignored "-Wnon-template-friend" - friend type get(Tag); -#pragma GCC diagnostic pop -#ifdef MSVC -#pragma warning(pop) -#endif // MSVC - }; - - // Tag used to access Sema::CurScope. - using namespace clang; - struct Sema_CurScope : TagBase {}; - template struct Rob; - } // namespace + ALLOW_ACCESS(Sema, CurScope, Scope*); clang::Scope*& VisitorBase::getCurrentScope() { - return m_Sema.*get(Sema_CurScope()); + return ACCESS(m_Sema, CurScope); } void VisitorBase::setCurrentScope(clang::Scope* S) { - m_Sema.*get(Sema_CurScope()) = S; + getCurrentScope() = S; assert(getEnclosingNamespaceOrTUScope() && "Lost path to base."); } diff --git a/test/FirstDerivative/CodeGenSimple.C b/test/FirstDerivative/CodeGenSimple.C index 02a815c92..4ff77e806 100644 --- a/test/FirstDerivative/CodeGenSimple.C +++ b/test/FirstDerivative/CodeGenSimple.C @@ -33,9 +33,17 @@ extern "C" int printf(const char* fmt, ...); int f_1_darg0(int x); +double sq_defined_later(double); + int main() { int x = 4; clad::differentiate(f_1, 0); + auto df = clad::differentiate(sq_defined_later, "x"); printf("Result is = %d\n", f_1_darg0(1)); // CHECK-EXEC: Result is = 2 + printf("Result is = %f\n", df.execute(3)); // CHECK-EXEC: Result is = 6 return 0; } + +double sq_defined_later(double x) { + return x * x; +} diff --git a/test/Misc/ClangConsumers.cpp b/test/Misc/ClangConsumers.cpp new file mode 100644 index 000000000..16e7c9692 --- /dev/null +++ b/test/Misc/ClangConsumers.cpp @@ -0,0 +1,28 @@ +// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out \ +// RUN: -fms-compatibility -std=c++14 -fmodules -Xclang \ +// RUN: -print-stats 2>&1 | FileCheck %s +// CHECK-NOT: {{.*error|warning|note:.*}} + + + +#pragma clang module build N + module N {} + #pragma clang module contents + #pragma clang module begin N + struct f { void operator()() const {} }; + template auto vtemplate = f{}; + #pragma clang module end +#pragma clang module endbuild + +#pragma clang module import N + +class __single_inheritance IncSingle; +// CHECK: HandleImplicitImportDecl +// CHECK: AssignInheritanceModel +// CHECK: HandleTopLevelDecl +// CHECK: HandleInterestingDecl +// CHECK: HandleCXXStaticMemberVarInstantiation + +int main() { + vtemplate(); +} diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 0f55cf065..64c5d44d3 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -9,6 +9,7 @@ #include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/EstimationModel.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/Version.h" #include "clang/AST/ASTConsumer.h" @@ -91,60 +92,52 @@ namespace clad { CladPlugin::~CladPlugin() {} + ALLOW_ACCESS(MultiplexConsumer, Consumers, + std::vector>); + + void CladPlugin::Initialize(clang::ASTContext& C) { + // We know we have a multiplexer. We commit a sin here by stealing it and + // making the consumer pass-through so that we can delay all operations + // until clad is happy. + + using namespace clang; + + auto& MultiplexC = static_cast(m_CI.getASTConsumer()); + auto& RobbedCs = ACCESS(MultiplexC, Consumers); + assert(RobbedCs.back().get() == this && "Clad is not the last consumer"); + std::vector> StolenConsumers; + + // The range-based for loop in MultiplexConsumer::Initialize has + // dispatched this call. Generally, it is unsafe to delete elements while + // iterating but we know we are in the end of the loop and ::end() won't + // be invalidated. + for (auto& RC : RobbedCs) + if (RC.get() == this) + RobbedCs.erase(RobbedCs.begin(), RobbedCs.end() - 1); + else + StolenConsumers.push_back(std::move(RC)); + m_Multiplexer.reset(new MultiplexConsumer(std::move(StolenConsumers))); + } + // We cannot use HandleTranslationUnit because codegen already emits code on // HandleTopLevelDecl calls and makes updateCall with no effect. - bool CladPlugin::HandleTopLevelDecl(DeclGroupRef DGR) { + void CladPlugin::HandleTopLevelDeclForClad(DeclGroupRef DGR) { if (!CheckBuiltins()) - return true; + return; Sema& S = m_CI.getSema(); if (!m_DerivativeBuilder) - m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this)); - - // if HandleTopLevelDecl was called through clad we don't need to process - // it for diff requests - if (m_HandleTopLevelDeclInternal) - return true; - - DiffSchedule requests{}; - DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema()); - - if (requests.empty()) - return true; + m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this)); - // FIXME: flags have to be set manually since DiffCollector's constructor - // does not have access to m_DO. - if (m_DO.EnableTBRAnalysis) - for (DiffRequest& request : requests) - request.EnableTBRAnalysis = true; - - // FIXME: Remove the PerformPendingInstantiations altogether. We should - // somehow make the relevant functions referenced. - // Instantiate all pending for instantiations templates, because we will - // need the full bodies to produce derivatives. - // FIXME: Confirm if we really need `m_PendingInstantiationsInFlight`? - if (!m_PendingInstantiationsInFlight) { - m_PendingInstantiationsInFlight = true; - S.PerformPendingInstantiations(); - m_PendingInstantiationsInFlight = false; - } - - for (DiffRequest& request : requests) - ProcessDiffRequest(request); - return true; // Happiness - } - - void CladPlugin::ProcessTopLevelDecl(Decl* D) { - m_HandleTopLevelDeclInternal = true; - m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D)); - m_HandleTopLevelDeclInternal = false; + DiffCollector collector(DGR, CladEnabledRange, m_DiffSchedule, S); } FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { Sema& S = m_CI.getSema(); // Required due to custom derivatives function templates that might be // used in the function that we need to derive. + // FIXME: Remove the call to PerformPendingInstantiations(). S.PerformPendingInstantiations(); if (request.Function->getDefinition()) request.Function = request.Function->getDefinition(); @@ -267,6 +260,8 @@ namespace clad { // Call CodeGen only if the produced Decl is a top-most // decl or is contained in a namespace decl. + // FIXME: We could get rid of this by prepending the produced + // derivatives in CladPlugin::HandleTranslationUnitDecl DeclContext* derivativeDC = DerivativeDecl->getDeclContext(); bool isTUorND = derivativeDC->isTranslationUnit() || derivativeDC->isNamespace(); @@ -296,6 +291,70 @@ namespace clad { return nullptr; } + void CladPlugin::SendToMultiplexer() { + for (auto DelayedCall : m_DelayedCalls) { + DeclGroupRef& D = DelayedCall.m_DGR; + switch (DelayedCall.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + m_Multiplexer->HandleCXXStaticMemberVarInstantiation( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTopLevelDecl: + m_Multiplexer->HandleTopLevelDecl(D); + break; + case CallKind::HandleInlineFunctionDefinition: + m_Multiplexer->HandleInlineFunctionDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleInterestingDecl: + m_Multiplexer->HandleInterestingDecl(D); + break; + case CallKind::HandleTagDeclDefinition: + m_Multiplexer->HandleTagDeclDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTagDeclRequiredDefinition: + m_Multiplexer->HandleTagDeclRequiredDefinition( + cast(D.getSingleDecl())); + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + m_Multiplexer->HandleCXXImplicitFunctionInstantiation( + cast(D.getSingleDecl())); + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + m_Multiplexer->HandleTopLevelDeclInObjCContainer(D); + break; + case CallKind::HandleImplicitImportDecl: + m_Multiplexer->HandleImplicitImportDecl( + cast(D.getSingleDecl())); + break; + case CallKind::CompleteTentativeDefinition: + m_Multiplexer->CompleteTentativeDefinition( + cast(D.getSingleDecl())); + break; +#if CLANG_VERSION_MAJOR > 9 + case CallKind::CompleteExternalDeclaration: + m_Multiplexer->CompleteExternalDeclaration( + cast(D.getSingleDecl())); + break; +#endif + case CallKind::AssignInheritanceModel: + m_Multiplexer->AssignInheritanceModel( + cast(D.getSingleDecl())); + break; + case CallKind::HandleVTable: + m_Multiplexer->HandleVTable(cast(D.getSingleDecl())); + break; + case CallKind::InitializeSema: + m_Multiplexer->InitializeSema(m_CI.getSema()); + break; + case CallKind::ForgetSema: + m_Multiplexer->ForgetSema(); + break; + }; + } + } + bool CladPlugin::CheckBuiltins() { // If we have included "clad/Differentiator/Differentiator.h" return. if (m_HasRuntime) @@ -318,6 +377,95 @@ namespace clad { m_HasRuntime = !R.empty(); return m_HasRuntime; } + + void CladPlugin::HandleTranslationUnit(ASTContext& C) { + Sema& S = m_CI.getSema(); + // Restore the TUScope that became a 0 in Sema::ActOnEndOfTranslationUnit. + S.TUScope = m_StoredTUScope; + constexpr bool Enabled = true; + Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled); + Sema::LocalEagerInstantiationScope LocalInstantiations(S); + + for (DiffRequest& request : m_DiffSchedule) { + // FIXME: flags have to be set manually since DiffCollector's + // constructor does not have access to m_DO. + request.EnableTBRAnalysis = m_DO.EnableTBRAnalysis; + ProcessDiffRequest(request); + } + // Put the TUScope in a consistent state after clad is done. + S.TUScope = nullptr; + // Force emission of the produced pending template instantiations. + LocalInstantiations.perform(); + GlobalInstantiations.perform(); + + SendToMultiplexer(); + m_Multiplexer->HandleTranslationUnit(C); + } + + void CladPlugin::PrintStats() { + llvm::errs() << "*** INFORMATION ABOUT THE DELAYED CALLS\n"; + for (const DelayedCallInfo& DCI : m_DelayedCalls) { + llvm::errs() << " "; + switch (DCI.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + llvm::errs() << "HandleCXXStaticMemberVarInstantiation"; + break; + case CallKind::HandleTopLevelDecl: + llvm::errs() << "HandleTopLevelDecl"; + break; + case CallKind::HandleInlineFunctionDefinition: + llvm::errs() << "HandleInlineFunctionDefinition"; + break; + case CallKind::HandleInterestingDecl: + llvm::errs() << "HandleInterestingDecl"; + break; + case CallKind::HandleTagDeclDefinition: + llvm::errs() << "HandleTagDeclDefinition"; + break; + case CallKind::HandleTagDeclRequiredDefinition: + llvm::errs() << "HandleTagDeclRequiredDefinition"; + break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + llvm::errs() << "HandleCXXImplicitFunctionInstantiation"; + break; + case CallKind::HandleTopLevelDeclInObjCContainer: + llvm::errs() << "HandleTopLevelDeclInObjCContainer"; + break; + case CallKind::HandleImplicitImportDecl: + llvm::errs() << "HandleImplicitImportDecl"; + break; + case CallKind::CompleteTentativeDefinition: + llvm::errs() << "CompleteTentativeDefinition"; + break; +#if CLANG_VERSION_MAJOR > 9 + case CallKind::CompleteExternalDeclaration: + llvm::errs() << "CompleteExternalDeclaration"; + break; +#endif + case CallKind::AssignInheritanceModel: + llvm::errs() << "AssignInheritanceModel"; + break; + case CallKind::HandleVTable: + llvm::errs() << "HandleVTable"; + break; + case CallKind::InitializeSema: + llvm::errs() << "InitializeSema"; + break; + case CallKind::ForgetSema: + llvm::errs() << "ForgetSema"; + break; + }; + for (const clang::Decl* D : DCI.m_DGR) { + llvm::errs() << " " << D; + if (const auto* ND = dyn_cast(D)) + llvm::errs() << " " << ND->getNameAsString(); + } + llvm::errs() << "\n"; + } + + m_Multiplexer->PrintStats(); + } + } // end namespace plugin clad::CladTimerGroup::CladTimerGroup() diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 808443e49..86be47260 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -13,11 +13,12 @@ #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/Version.h" -#include "clang/AST/ASTConsumer.h" #include "clang/AST/Decl.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/Version.h" #include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Frontend/MultiplexConsumer.h" +#include "clang/Sema/SemaConsumer.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" @@ -92,24 +93,177 @@ namespace clad { std::string CustomModelName; }; - class CladPlugin : public clang::ASTConsumer { + class CladExternalSource : public clang::ExternalSemaSource { + // ExternalSemaSource + void ReadUndefinedButUsed( + llvm::MapVector& Undefined) + override { + // namespace { double f_darg0(double x); } will issue a warning that + // f_darg0 has internal linkage but is not defined. This is because we + // have not yet started to differentiate it. The warning is triggered by + // Sema::ActOnEndOfTranslationUnit before Clad is given control. + // To avoid the warning we should remove the entry from here. + using namespace clang; + Undefined.remove_if([](std::pair P) { + NamedDecl* ND = P.first; + + if (!ND->getDeclName().isIdentifier()) + return false; + + // FIXME: We should replace this comparison with the canonical decl + // from the differentiation plan... + return ND->getName().contains("_darg"); + }); + } + }; + class CladPlugin : public clang::SemaConsumer { clang::CompilerInstance& m_CI; DifferentiationOptions m_DO; std::unique_ptr m_DerivativeBuilder; bool m_HasRuntime = false; - bool m_PendingInstantiationsInFlight = false; - bool m_HandleTopLevelDeclInternal = false; CladTimerGroup m_CTG; DerivedFnCollector m_DFC; + DiffSchedule m_DiffSchedule; + enum class CallKind { + HandleCXXStaticMemberVarInstantiation, + HandleTopLevelDecl, + HandleInlineFunctionDefinition, + HandleInterestingDecl, + HandleTagDeclDefinition, + HandleTagDeclRequiredDefinition, + HandleCXXImplicitFunctionInstantiation, + HandleTopLevelDeclInObjCContainer, + HandleImplicitImportDecl, + CompleteTentativeDefinition, +#if CLANG_VERSION_MAJOR > 9 + CompleteExternalDeclaration, +#endif + AssignInheritanceModel, + HandleVTable, + InitializeSema, + ForgetSema + }; + struct DelayedCallInfo { + CallKind m_Kind; + clang::DeclGroupRef m_DGR; + DelayedCallInfo(CallKind K, clang::DeclGroupRef DGR) + : m_Kind(K), m_DGR(DGR) {} + DelayedCallInfo(CallKind K, const clang::Decl* D) + : m_Kind(K), m_DGR(const_cast(D)) {} + bool operator==(const DelayedCallInfo& other) const { + if (m_Kind != other.m_Kind) + return false; + + clang::Decl* const* first1 = m_DGR.begin(); + clang::Decl* const* first2 = other.m_DGR.begin(); + clang::Decl* const* last1 = m_DGR.end(); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + for (; first1 != last1; ++first1, ++first2) + if (!(*first1 == *first2)) + return false; + return true; + } + }; + /// The calls to the main action which clad delayed and will dispatch at + /// then end of the translation unit. + std::vector m_DelayedCalls; + /// The default clang consumers which are called after clad is done. + std::unique_ptr m_Multiplexer; + + /// The Sema::TUScope to restore in CladPlugin::HandleTranslationUnit. + clang::Scope* m_StoredTUScope = nullptr; + public: CladPlugin(clang::CompilerInstance& CI, DifferentiationOptions& DO); ~CladPlugin(); - bool HandleTopLevelDecl(clang::DeclGroupRef DGR) override; + // ASTConsumer + void Initialize(clang::ASTContext& Context) override; + void HandleCXXStaticMemberVarInstantiation(clang::VarDecl* D) override { + AppendDelayed({CallKind::HandleCXXStaticMemberVarInstantiation, D}); + } + bool HandleTopLevelDecl(clang::DeclGroupRef D) override { + HandleTopLevelDeclForClad(D); + AppendDelayed({CallKind::HandleTopLevelDecl, D}); + return true; // happyness, continue parsing + } + void HandleInlineFunctionDefinition(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleInlineFunctionDefinition, D}); + } + void HandleInterestingDecl(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleInterestingDecl, D}); + } + void HandleTagDeclDefinition(clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclDefinition, D}); + } + void HandleTagDeclRequiredDefinition(const clang::TagDecl* D) override { + AppendDelayed({CallKind::HandleTagDeclRequiredDefinition, D}); + } + void + HandleCXXImplicitFunctionInstantiation(clang::FunctionDecl* D) override { + AppendDelayed({CallKind::HandleCXXImplicitFunctionInstantiation, D}); + } + void HandleTopLevelDeclInObjCContainer(clang::DeclGroupRef D) override { + AppendDelayed({CallKind::HandleTopLevelDeclInObjCContainer, D}); + } + void HandleImplicitImportDecl(clang::ImportDecl* D) override { + AppendDelayed({CallKind::HandleImplicitImportDecl, D}); + } + void CompleteTentativeDefinition(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteTentativeDefinition, D}); + } +#if CLANG_VERSION_MAJOR > 9 + void CompleteExternalDeclaration(clang::VarDecl* D) override { + AppendDelayed({CallKind::CompleteExternalDeclaration, D}); + } +#endif + void AssignInheritanceModel(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::AssignInheritanceModel, D}); + } + void HandleVTable(clang::CXXRecordDecl* D) override { + AppendDelayed({CallKind::HandleVTable, D}); + } + + // Not delayed. + void HandleTranslationUnit(clang::ASTContext& C) override; + + // No need to handle the listeners, they will be handled at non-delayed by + // the parent multiplexer. + // + // clang::ASTMutationListener *GetASTMutationListener() override; + // clang::ASTDeserializationListener *GetASTDeserializationListener() + // override; + void PrintStats() override; + + bool shouldSkipFunctionBody(clang::Decl* D) override { + return m_Multiplexer->shouldSkipFunctionBody(D); + } + + // SemaConsumer + void InitializeSema(clang::Sema& S) override { + // We are also a ExternalSemaSource. + // NOLINTNEXTLINE(cppcoreguidelines-owning-memory) + S.addExternalSource(new CladExternalSource()); // Owned by Sema. + m_StoredTUScope = S.TUScope; + AppendDelayed({CallKind::InitializeSema, nullptr}); + } + void ForgetSema() override { + AppendDelayed({CallKind::ForgetSema, nullptr}); + } + + // FIXME: We should hide ProcessDiffRequest when we implement proper + // handling of the differentiation plans. clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); private: + void AppendDelayed(DelayedCallInfo DCI) { m_DelayedCalls.push_back(DCI); } + void SendToMultiplexer(); bool CheckBuiltins(); - void ProcessTopLevelDecl(clang::Decl* D); + void ProcessTopLevelDecl(clang::Decl* D) { + DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D}; + assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!"); + AppendDelayed(DCI); + } + void HandleTopLevelDeclForClad(clang::DeclGroupRef DGR); }; clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P, @@ -190,7 +344,7 @@ namespace clad { } PluginASTAction::ActionType getActionType() override { - return AddBeforeMainAction; + return AddAfterMainAction; } }; } // end namespace plugin