Skip to content

Commit

Permalink
More
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Feb 16, 2024
1 parent 2bc5302 commit b479740
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 75 deletions.
37 changes: 5 additions & 32 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

#include "ConstantFolder.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/CladUtils.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
Expand Down Expand Up @@ -59,42 +60,14 @@ namespace clad {
return true;
}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};

// Tag used to access Sema::CurScope.
using namespace clang;
struct Sema_CurScope : TagBase<Sema_CurScope, Scope * Sema::*> {};
template struct Rob<Sema_CurScope, &Sema::CurScope>;
} // namespace
ALLOW_ACCESS(Sema, CurScope, Scope *);

clang::Scope*& VisitorBase::getCurrentScope() {
return m_Sema.*get(Sema_CurScope());
return ACCESS(m_Sema, CurScope);
}

void VisitorBase::setCurrentScope(clang::Scope* S) {
m_Sema.*get(Sema_CurScope()) = S;
getCurrentScope() = S;
assert(getEnclosingNamespaceOrTUScope() && "Lost path to base.");
}

Expand Down
47 changes: 5 additions & 42 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "clad/Differentiator/DerivativeBuilder.h"
#include "clad/Differentiator/EstimationModel.h"

#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/Version.h"

#include "clang/AST/ASTConsumer.h"
Expand Down Expand Up @@ -121,39 +122,7 @@ namespace clad {

CladPlugin::~CladPlugin() {}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};
// Tag used to access MultiplexConsumer::Consumers.
using namespace clang;
struct MultiplexConsumer_Consumers
: TagBase<
MultiplexConsumer_Consumers,
std::vector<std::unique_ptr<ASTConsumer>> MultiplexConsumer::*> {
};
template struct Rob<MultiplexConsumer_Consumers,
&MultiplexConsumer::Consumers>;
} // namespace
ALLOW_ACCESS(MultiplexConsumer, Consumers, std::vector<std::unique_ptr<ASTConsumer>>);

void CladPlugin::Initialize(clang::ASTContext& C) {
// We know we have a multiplexer. We commit a sin here by stealing it and
Expand All @@ -163,7 +132,8 @@ namespace clad {
using namespace clang;

auto& MultiplexC = static_cast<MultiplexConsumer&>(m_CI.getASTConsumer());
auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers());
//auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers());
auto& RobbedCs = ACCESS(MultiplexC, Consumers);
assert(RobbedCs.back().get() == this && "Clad is not the last consumer");
std::vector<std::unique_ptr<ASTConsumer>> StolenConsumers;

Expand Down Expand Up @@ -191,11 +161,6 @@ namespace clad {
if (!m_DerivativeBuilder)
m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this));

// if HandleTopLevelDecl was called through clad we don't need to process
// it for diff requests
if (m_HandleTopLevelDeclInternal)
return m_Multiplexer->HandleTopLevelDecl(DGR); // true;

DiffSchedule requests{};
DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema());

Expand Down Expand Up @@ -226,9 +191,7 @@ namespace clad {
}

void CladPlugin::ProcessTopLevelDecl(Decl* D) {
m_HandleTopLevelDeclInternal = true;
m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D));
m_HandleTopLevelDeclInternal = false;
m_Multiplexer->HandleTopLevelDecl(DeclGroupRef(D));
}

FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) {
Expand Down
38 changes: 37 additions & 1 deletion tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ namespace clad {
std::unique_ptr<DerivativeBuilder> m_DerivativeBuilder;
bool m_HasRuntime = false;
bool m_PendingInstantiationsInFlight = false;
bool m_HandleTopLevelDeclInternal = false;
DerivedFnCollector m_DFC;
enum class CallKind {
HandleCXXStaticMemberVarInstantiation,
Expand Down Expand Up @@ -182,6 +181,43 @@ namespace clad {

// Not delayed.
void HandleTranslationUnit(clang::ASTContext& C) override {
/*for (const DelayedCallInfo& DCI : m_DelayedCalls) {
switch (DCI.m_Kind) {
case CallKind::HandleCXXStaticMemberVarInstantiation:
llvm::errs() << "HandleCXXStaticMemberVarInstantiation"; break;
case CallKind::HandleTopLevelDecl:
llvm::errs() << "HandleTopLevelDecl"; break;
case CallKind::HandleInlineFunctionDefinition:
llvm::errs() << "HandleInlineFunctionDefinition"; break;
case CallKind::HandleInterestingDecl:
llvm::errs() << "HandleInterestingDecl"; break;
case CallKind::HandleTagDeclDefinition:
llvm::errs() << "HandleTagDeclDefinition"; break;
case CallKind::HandleTagDeclRequiredDefinition:
llvm::errs() << "HandleTagDeclRequiredDefinition"; break;
case CallKind::HandleCXXImplicitFunctionInstantiation:
llvm::errs() << "HandleCXXImplicitFunctionInstantiation"; break;
case CallKind::HandleTopLevelDeclInObjCContainer:
llvm::errs() << "HandleTopLevelDeclInObjCContainer"; break;
case CallKind::HandleImplicitImportDecl:
llvm::errs() << "HandleImplicitImportDecl"; break;
case CallKind::CompleteTentativeDefinition:
llvm::errs() << "CompleteTentativeDefinition"; break;
case CallKind::CompleteExternalDeclaration:
llvm::errs() << "CompleteExternalDeclaration"; break;
case CallKind::AssignInheritanceModel:
llvm::errs() << "AssignInheritanceModel"; break;
case CallKind::HandleVTable:
llvm::errs() << "HandleVTable"; break;
case CallKind::InitializeSema:
llvm::errs() << "InitializeSema"; break;
case CallKind::ForgetSema:
llvm::errs() << "ForgetSema"; break;
};
for (const clang::Decl* D : DCI.m_DGR)
llvm::errs() << " " << D;
llvm::errs() << "\n";
}*/
m_Multiplexer->HandleTranslationUnit(C);
}
// No need to handle the listeners, they will be handled at non-delayed by
Expand Down

0 comments on commit b479740

Please sign in to comment.