diff --git a/include/clad/Differentiator/Sins.h b/include/clad/Differentiator/Sins.h new file mode 100644 index 000000000..803d451b2 --- /dev/null +++ b/include/clad/Differentiator/Sins.h @@ -0,0 +1,20 @@ +#ifndef CLAD_DIFFERENTIATOR_SINS +#define CLAD_DIFFERENTIATOR_SINS + +#include + +/// Standard-protected facility allowing access into private members in C++. +/// Use with caution! +#define CONCATE_(X, Y) X##Y +#define CONCATE(X, Y) CONCATE_(X, Y) +#define ALLOW_ACCESS(CLASS, MEMBER, ...) \ + template \ + struct CONCATE(MEMBER, __LINE__) { friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } }; \ + template struct Only_##MEMBER; \ + template<> struct Only_##MEMBER { friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER*); }; \ + template struct CONCATE(MEMBER, __LINE__), &CLASS::MEMBER> + +#define ACCESS(OBJECT, MEMBER) \ + (OBJECT).*Access((Only_##MEMBER::type>*)nullptr) + +#endif // CLAD_DIFFERENTIATOR_SINS diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index eef3e2353..c25f9db72 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -8,10 +8,11 @@ #include "ConstantFolder.h" +#include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/ErrorEstimator.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/StmtClone.h" -#include "clad/Differentiator/CladUtils.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" @@ -59,42 +60,14 @@ namespace clad { return true; } - // A facility allowing us to access the private member CurScope of the Sema - // object using standard-conforming C++. - namespace { - template struct Rob { - friend typename Tag::type get(Tag) { return M; } - }; - - template struct TagBase { - using type = Member; -#ifdef MSVC -#pragma warning(push, 0) -#endif // MSVC -#pragma GCC diagnostic push -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wunknown-warning-option" -#endif // __clang__ -#pragma GCC diagnostic ignored "-Wnon-template-friend" - friend type get(Tag); -#pragma GCC diagnostic pop -#ifdef MSVC -#pragma warning(pop) -#endif // MSVC - }; - - // Tag used to access Sema::CurScope. - using namespace clang; - struct Sema_CurScope : TagBase {}; - template struct Rob; - } // namespace + ALLOW_ACCESS(Sema, CurScope, Scope *); clang::Scope*& VisitorBase::getCurrentScope() { - return m_Sema.*get(Sema_CurScope()); + return ACCESS(m_Sema, CurScope); } void VisitorBase::setCurrentScope(clang::Scope* S) { - m_Sema.*get(Sema_CurScope()) = S; + getCurrentScope() = S; assert(getEnclosingNamespaceOrTUScope() && "Lost path to base."); } diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index d678d12e1..130ca3119 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -9,6 +9,7 @@ #include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/EstimationModel.h" +#include "clad/Differentiator/Sins.h" #include "clad/Differentiator/Version.h" #include "clang/AST/ASTConsumer.h" @@ -121,39 +122,7 @@ namespace clad { CladPlugin::~CladPlugin() {} - // A facility allowing us to access the private member CurScope of the Sema - // object using standard-conforming C++. - namespace { - template struct Rob { - friend typename Tag::type get(Tag) { return M; } - }; - - template struct TagBase { - using type = Member; -#ifdef MSVC -#pragma warning(push, 0) -#endif // MSVC -#pragma GCC diagnostic push -#ifdef __clang__ -#pragma clang diagnostic ignored "-Wunknown-warning-option" -#endif // __clang__ -#pragma GCC diagnostic ignored "-Wnon-template-friend" - friend type get(Tag); -#pragma GCC diagnostic pop -#ifdef MSVC -#pragma warning(pop) -#endif // MSVC - }; - // Tag used to access MultiplexConsumer::Consumers. - using namespace clang; - struct MultiplexConsumer_Consumers - : TagBase< - MultiplexConsumer_Consumers, - std::vector> MultiplexConsumer::*> { - }; - template struct Rob; - } // namespace + ALLOW_ACCESS(MultiplexConsumer, Consumers, std::vector>); void CladPlugin::Initialize(clang::ASTContext& C) { // We know we have a multiplexer. We commit a sin here by stealing it and @@ -163,7 +132,8 @@ namespace clad { using namespace clang; auto& MultiplexC = static_cast(m_CI.getASTConsumer()); - auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers()); + //auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers()); + auto& RobbedCs = ACCESS(MultiplexC, Consumers); assert(RobbedCs.back().get() == this && "Clad is not the last consumer"); std::vector> StolenConsumers; @@ -191,11 +161,6 @@ namespace clad { if (!m_DerivativeBuilder) m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this)); - // if HandleTopLevelDecl was called through clad we don't need to process - // it for diff requests - if (m_HandleTopLevelDeclInternal) - return m_Multiplexer->HandleTopLevelDecl(DGR); // true; - DiffSchedule requests{}; DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema()); @@ -226,9 +191,7 @@ namespace clad { } void CladPlugin::ProcessTopLevelDecl(Decl* D) { - m_HandleTopLevelDeclInternal = true; - m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D)); - m_HandleTopLevelDeclInternal = false; + m_Multiplexer->HandleTopLevelDecl(DeclGroupRef(D)); } FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 3e99a93b9..3503091bd 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -89,7 +89,6 @@ namespace clad { std::unique_ptr m_DerivativeBuilder; bool m_HasRuntime = false; bool m_PendingInstantiationsInFlight = false; - bool m_HandleTopLevelDeclInternal = false; DerivedFnCollector m_DFC; enum class CallKind { HandleCXXStaticMemberVarInstantiation, @@ -182,6 +181,43 @@ namespace clad { // Not delayed. void HandleTranslationUnit(clang::ASTContext& C) override { + /*for (const DelayedCallInfo& DCI : m_DelayedCalls) { + switch (DCI.m_Kind) { + case CallKind::HandleCXXStaticMemberVarInstantiation: + llvm::errs() << "HandleCXXStaticMemberVarInstantiation"; break; + case CallKind::HandleTopLevelDecl: + llvm::errs() << "HandleTopLevelDecl"; break; + case CallKind::HandleInlineFunctionDefinition: + llvm::errs() << "HandleInlineFunctionDefinition"; break; + case CallKind::HandleInterestingDecl: + llvm::errs() << "HandleInterestingDecl"; break; + case CallKind::HandleTagDeclDefinition: + llvm::errs() << "HandleTagDeclDefinition"; break; + case CallKind::HandleTagDeclRequiredDefinition: + llvm::errs() << "HandleTagDeclRequiredDefinition"; break; + case CallKind::HandleCXXImplicitFunctionInstantiation: + llvm::errs() << "HandleCXXImplicitFunctionInstantiation"; break; + case CallKind::HandleTopLevelDeclInObjCContainer: + llvm::errs() << "HandleTopLevelDeclInObjCContainer"; break; + case CallKind::HandleImplicitImportDecl: + llvm::errs() << "HandleImplicitImportDecl"; break; + case CallKind::CompleteTentativeDefinition: + llvm::errs() << "CompleteTentativeDefinition"; break; + case CallKind::CompleteExternalDeclaration: + llvm::errs() << "CompleteExternalDeclaration"; break; + case CallKind::AssignInheritanceModel: + llvm::errs() << "AssignInheritanceModel"; break; + case CallKind::HandleVTable: + llvm::errs() << "HandleVTable"; break; + case CallKind::InitializeSema: + llvm::errs() << "InitializeSema"; break; + case CallKind::ForgetSema: + llvm::errs() << "ForgetSema"; break; + }; + for (const clang::Decl* D : DCI.m_DGR) + llvm::errs() << " " << D; + llvm::errs() << "\n"; + }*/ m_Multiplexer->HandleTranslationUnit(C); } // No need to handle the listeners, they will be handled at non-delayed by