Skip to content

Commit

Permalink
More
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Feb 16, 2024
1 parent 2bc5302 commit 55c80ca
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 75 deletions.
20 changes: 20 additions & 0 deletions include/clad/Differentiator/Sins.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef CLAD_DIFFERENTIATOR_SINS
#define CLAD_DIFFERENTIATOR_SINS

#include <type_traits>

/// Standard-protected facility allowing access into private members in C++.
/// Use with caution!
#define CONCATE_(X, Y) X##Y
#define CONCATE(X, Y) CONCATE_(X, Y)
#define ALLOW_ACCESS(CLASS, MEMBER, ...) \
template<typename Only, __VA_ARGS__ CLASS::*Member> \
struct CONCATE(MEMBER, __LINE__) { friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } }; \
template<typename> struct Only_##MEMBER; \
template<> struct Only_##MEMBER<CLASS> { friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER<CLASS>*); }; \
template struct CONCATE(MEMBER, __LINE__)<Only_##MEMBER<CLASS>, &CLASS::MEMBER>

#define ACCESS(OBJECT, MEMBER) \
(OBJECT).*Access((Only_##MEMBER<std::remove_reference<decltype(OBJECT)>::type>*)nullptr)

#endif // CLAD_DIFFERENTIATOR_SINS
37 changes: 5 additions & 32 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

#include "ConstantFolder.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/CladUtils.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
Expand Down Expand Up @@ -59,42 +60,14 @@ namespace clad {
return true;
}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};

// Tag used to access Sema::CurScope.
using namespace clang;
struct Sema_CurScope : TagBase<Sema_CurScope, Scope * Sema::*> {};
template struct Rob<Sema_CurScope, &Sema::CurScope>;
} // namespace
ALLOW_ACCESS(Sema, CurScope, Scope *);

clang::Scope*& VisitorBase::getCurrentScope() {
return m_Sema.*get(Sema_CurScope());
return ACCESS(m_Sema, CurScope);
}

void VisitorBase::setCurrentScope(clang::Scope* S) {
m_Sema.*get(Sema_CurScope()) = S;
getCurrentScope() = S;
assert(getEnclosingNamespaceOrTUScope() && "Lost path to base.");
}

Expand Down
47 changes: 5 additions & 42 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "clad/Differentiator/DerivativeBuilder.h"
#include "clad/Differentiator/EstimationModel.h"

#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/Version.h"

#include "clang/AST/ASTConsumer.h"
Expand Down Expand Up @@ -121,39 +122,7 @@ namespace clad {

CladPlugin::~CladPlugin() {}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};
// Tag used to access MultiplexConsumer::Consumers.
using namespace clang;
struct MultiplexConsumer_Consumers
: TagBase<
MultiplexConsumer_Consumers,
std::vector<std::unique_ptr<ASTConsumer>> MultiplexConsumer::*> {
};
template struct Rob<MultiplexConsumer_Consumers,
&MultiplexConsumer::Consumers>;
} // namespace
ALLOW_ACCESS(MultiplexConsumer, Consumers, std::vector<std::unique_ptr<ASTConsumer>>);

void CladPlugin::Initialize(clang::ASTContext& C) {
// We know we have a multiplexer. We commit a sin here by stealing it and
Expand All @@ -163,7 +132,8 @@ namespace clad {
using namespace clang;

auto& MultiplexC = static_cast<MultiplexConsumer&>(m_CI.getASTConsumer());
auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers());
//auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers());
auto& RobbedCs = ACCESS(MultiplexC, Consumers);
assert(RobbedCs.back().get() == this && "Clad is not the last consumer");
std::vector<std::unique_ptr<ASTConsumer>> StolenConsumers;

Expand Down Expand Up @@ -191,11 +161,6 @@ namespace clad {
if (!m_DerivativeBuilder)
m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this));

// if HandleTopLevelDecl was called through clad we don't need to process
// it for diff requests
if (m_HandleTopLevelDeclInternal)
return m_Multiplexer->HandleTopLevelDecl(DGR); // true;

DiffSchedule requests{};
DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema());

Expand Down Expand Up @@ -226,9 +191,7 @@ namespace clad {
}

void CladPlugin::ProcessTopLevelDecl(Decl* D) {
m_HandleTopLevelDeclInternal = true;
m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D));
m_HandleTopLevelDeclInternal = false;
m_Multiplexer->HandleTopLevelDecl(DeclGroupRef(D));
}

FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) {
Expand Down
38 changes: 37 additions & 1 deletion tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ namespace clad {
std::unique_ptr<DerivativeBuilder> m_DerivativeBuilder;
bool m_HasRuntime = false;
bool m_PendingInstantiationsInFlight = false;
bool m_HandleTopLevelDeclInternal = false;
DerivedFnCollector m_DFC;
enum class CallKind {
HandleCXXStaticMemberVarInstantiation,
Expand Down Expand Up @@ -182,6 +181,43 @@ namespace clad {

// Not delayed.
void HandleTranslationUnit(clang::ASTContext& C) override {
/*for (const DelayedCallInfo& DCI : m_DelayedCalls) {
switch (DCI.m_Kind) {
case CallKind::HandleCXXStaticMemberVarInstantiation:
llvm::errs() << "HandleCXXStaticMemberVarInstantiation"; break;
case CallKind::HandleTopLevelDecl:
llvm::errs() << "HandleTopLevelDecl"; break;
case CallKind::HandleInlineFunctionDefinition:
llvm::errs() << "HandleInlineFunctionDefinition"; break;
case CallKind::HandleInterestingDecl:
llvm::errs() << "HandleInterestingDecl"; break;
case CallKind::HandleTagDeclDefinition:
llvm::errs() << "HandleTagDeclDefinition"; break;
case CallKind::HandleTagDeclRequiredDefinition:
llvm::errs() << "HandleTagDeclRequiredDefinition"; break;
case CallKind::HandleCXXImplicitFunctionInstantiation:
llvm::errs() << "HandleCXXImplicitFunctionInstantiation"; break;
case CallKind::HandleTopLevelDeclInObjCContainer:
llvm::errs() << "HandleTopLevelDeclInObjCContainer"; break;
case CallKind::HandleImplicitImportDecl:
llvm::errs() << "HandleImplicitImportDecl"; break;
case CallKind::CompleteTentativeDefinition:
llvm::errs() << "CompleteTentativeDefinition"; break;
case CallKind::CompleteExternalDeclaration:
llvm::errs() << "CompleteExternalDeclaration"; break;
case CallKind::AssignInheritanceModel:
llvm::errs() << "AssignInheritanceModel"; break;
case CallKind::HandleVTable:
llvm::errs() << "HandleVTable"; break;
case CallKind::InitializeSema:
llvm::errs() << "InitializeSema"; break;
case CallKind::ForgetSema:
llvm::errs() << "ForgetSema"; break;
};
for (const clang::Decl* D : DCI.m_DGR)
llvm::errs() << " " << D;
llvm::errs() << "\n";
}*/
m_Multiplexer->HandleTranslationUnit(C);
}
// No need to handle the listeners, they will be handled at non-delayed by
Expand Down

0 comments on commit 55c80ca

Please sign in to comment.