Skip to content

Commit

Permalink
More
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Feb 18, 2024
1 parent 6f0fb3a commit 07723ff
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 116 deletions.
29 changes: 29 additions & 0 deletions include/clad/Differentiator/Sins.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef CLAD_DIFFERENTIATOR_SINS_H
#define CLAD_DIFFERENTIATOR_SINS_H

#include <type_traits>

/// Standard-protected facility allowing access into private members in C++.
/// Use with caution!
// NOLINTBEGIN(cppcoreguidelines-macro-usage)
#define CONCATE_(X, Y) X##Y
#define CONCATE(X, Y) CONCATE_(X, Y)
#define ALLOW_ACCESS(CLASS, MEMBER, ...) \
template <typename Only, __VA_ARGS__ CLASS::*Member> \
struct CONCATE(MEMBER, __LINE__) { \
friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \
}; \
template <typename> struct Only_##MEMBER; \
template <> struct Only_##MEMBER<CLASS> { \
friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER<CLASS>*); \
}; \
template struct CONCATE(MEMBER, \
__LINE__)<Only_##MEMBER<CLASS>, &CLASS::MEMBER>

#define ACCESS(OBJECT, MEMBER) \
(OBJECT).*Access((Only_##MEMBER< \
std::remove_reference<decltype(OBJECT)>::type>*)nullptr)

// NOLINTEND(cppcoreguidelines-macro-usage)

#endif // CLAD_DIFFERENTIATOR_SINS_H
37 changes: 5 additions & 32 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

#include "ConstantFolder.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/CladUtils.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
Expand Down Expand Up @@ -59,42 +60,14 @@ namespace clad {
return true;
}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};

// Tag used to access Sema::CurScope.
using namespace clang;
struct Sema_CurScope : TagBase<Sema_CurScope, Scope * Sema::*> {};
template struct Rob<Sema_CurScope, &Sema::CurScope>;
} // namespace
ALLOW_ACCESS(Sema, CurScope, Scope*);

clang::Scope*& VisitorBase::getCurrentScope() {
return m_Sema.*get(Sema_CurScope());
return ACCESS(m_Sema, CurScope);
}

void VisitorBase::setCurrentScope(clang::Scope* S) {
m_Sema.*get(Sema_CurScope()) = S;
getCurrentScope() = S;
assert(getEnclosingNamespaceOrTUScope() && "Lost path to base.");
}

Expand Down
8 changes: 8 additions & 0 deletions test/FirstDerivative/CodeGenSimple.C
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ extern "C" int printf(const char* fmt, ...);

int f_1_darg0(int x);

double sq_defined_later(double);

int main() {
int x = 4;
clad::differentiate(f_1, 0);
auto df = clad::differentiate(sq_defined_later, "x");
printf("Result is = %d\n", f_1_darg0(1)); // CHECK-EXEC: Result is = 2
printf("Result is = %f\n", df.execute(3)); // CHECK-EXEC: Result is = 6
return 0;
}

double sq_defined_later(double x) {
return x * x;
}
8 changes: 8 additions & 0 deletions test/Misc/ClangConsumers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out -Xclang -print-stats 2>&1 | FileCheck %s
// CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"
// CHECK: HandleTopLevelDecl
int main() {

}
210 changes: 149 additions & 61 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "clad/Differentiator/DerivativeBuilder.h"
#include "clad/Differentiator/EstimationModel.h"

#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/Version.h"

#include "clang/AST/ASTConsumer.h"
Expand Down Expand Up @@ -121,39 +122,8 @@ namespace clad {

CladPlugin::~CladPlugin() {}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};
// Tag used to access MultiplexConsumer::Consumers.
using namespace clang;
struct MultiplexConsumer_Consumers
: TagBase<
MultiplexConsumer_Consumers,
std::vector<std::unique_ptr<ASTConsumer>> MultiplexConsumer::*> {
};
template struct Rob<MultiplexConsumer_Consumers,
&MultiplexConsumer::Consumers>;
} // namespace
ALLOW_ACCESS(MultiplexConsumer, Consumers,
std::vector<std::unique_ptr<ASTConsumer>>);

void CladPlugin::Initialize(clang::ASTContext& C) {
// We know we have a multiplexer. We commit a sin here by stealing it and
Expand All @@ -163,7 +133,7 @@ namespace clad {
using namespace clang;

auto& MultiplexC = static_cast<MultiplexConsumer&>(m_CI.getASTConsumer());
auto& RobbedCs = MultiplexC.*get(MultiplexConsumer_Consumers());
auto& RobbedCs = ACCESS(MultiplexC, Consumers);
assert(RobbedCs.back().get() == this && "Clad is not the last consumer");
std::vector<std::unique_ptr<ASTConsumer>> StolenConsumers;

Expand All @@ -181,60 +151,38 @@ namespace clad {

// We cannot use HandleTranslationUnit because codegen already emits code on
// HandleTopLevelDecl calls and makes updateCall with no effect.
bool CladPlugin::HandleTopLevelDecl(DeclGroupRef DGR) {
AppendDelayed({CallKind::HandleTopLevelDecl, DGR});
void CladPlugin::HandleTopLevelDeclForClad(DeclGroupRef DGR) {
if (!CheckBuiltins())
return m_Multiplexer->HandleTopLevelDecl(DGR); // true;

Sema& S = m_CI.getSema();
return;

if (!m_DerivativeBuilder)
m_DerivativeBuilder.reset(new DerivativeBuilder(m_CI.getSema(), *this));

// if HandleTopLevelDecl was called through clad we don't need to process
// it for diff requests
if (m_HandleTopLevelDeclInternal)
return m_Multiplexer->HandleTopLevelDecl(DGR); // true;

DiffSchedule requests{};
DiffCollector collector(DGR, CladEnabledRange, requests, m_CI.getSema());

if (requests.empty())
return m_Multiplexer->HandleTopLevelDecl(DGR); // true;
return;

// FIXME: flags have to be set manually since DiffCollector's constructor
// does not have access to m_DO.
if (m_DO.EnableTBRAnalysis)
for (DiffRequest& request : requests)
request.EnableTBRAnalysis = true;

// FIXME: Remove the PerformPendingInstantiations altogether. We should
// somehow make the relevant functions referenced.
// Instantiate all pending for instantiations templates, because we will
// need the full bodies to produce derivatives.
// FIXME: Confirm if we really need `m_PendingInstantiationsInFlight`?
if (!m_PendingInstantiationsInFlight) {
m_PendingInstantiationsInFlight = true;
S.PerformPendingInstantiations();
m_PendingInstantiationsInFlight = false;
}

for (DiffRequest& request : requests)
ProcessDiffRequest(request);

return m_Multiplexer->HandleTopLevelDecl(DGR); // Happiness
}

void CladPlugin::ProcessTopLevelDecl(Decl* D) {
m_HandleTopLevelDeclInternal = true;
m_CI.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(D));
m_HandleTopLevelDeclInternal = false;
AppendDelayed({CallKind::HandleTopLevelDecl, D});
}

FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) {
Sema& S = m_CI.getSema();
// Required due to custom derivatives function templates that might be
// used in the function that we need to derive.
// FIXME: Remove the call to PerformPendingInstantiations().
S.PerformPendingInstantiations();
if (request.Function->getDefinition())
request.Function = request.Function->getDefinition();
Expand Down Expand Up @@ -345,6 +293,8 @@ namespace clad {

// Call CodeGen only if the produced Decl is a top-most
// decl or is contained in a namespace decl.
// FIXME: We could get rid of this by prepending the produced
// derivatives in CladPlugin::HandleTranslationUnitDecl
DeclContext* derivativeDC = DerivativeDecl->getDeclContext();
bool isTUorND =
derivativeDC->isTranslationUnit() || derivativeDC->isNamespace();
Expand Down Expand Up @@ -396,6 +346,144 @@ namespace clad {
m_HasRuntime = !R.empty();
return m_HasRuntime;
}

void CladPlugin::HandleTranslationUnit(ASTContext& C) {
for (auto DCI = m_DelayedCalls.begin(); DCI != m_DelayedCalls.end();
++DCI) {
DeclGroupRef& D = DCI->m_DGR;
switch (DCI->m_Kind) {
case CallKind::HandleCXXStaticMemberVarInstantiation:
m_Multiplexer->HandleCXXStaticMemberVarInstantiation(
cast<VarDecl>(D.getSingleDecl()));
break;
case CallKind::HandleTopLevelDecl: {
Sema& S = m_CI.getSema();
Sema::GlobalEagerInstantiationScope GlobalInstantiations(
S, /*Recursive=*/true);
Sema::LocalEagerInstantiationScope LocalInstantiations(S);
HandleTopLevelDeclForClad(D);
LocalInstantiations.perform();
GlobalInstantiations.perform();
m_Multiplexer->HandleTopLevelDecl(D);
break;
}
case CallKind::HandleInlineFunctionDefinition:
m_Multiplexer->HandleInlineFunctionDefinition(
cast<FunctionDecl>(D.getSingleDecl()));
break;
case CallKind::HandleInterestingDecl:
m_Multiplexer->HandleInterestingDecl(D);
break;
case CallKind::HandleTagDeclDefinition:
m_Multiplexer->HandleTagDeclDefinition(
cast<TagDecl>(D.getSingleDecl()));
break;
case CallKind::HandleTagDeclRequiredDefinition:
m_Multiplexer->HandleTagDeclRequiredDefinition(
cast<TagDecl>(D.getSingleDecl()));
break;
case CallKind::HandleCXXImplicitFunctionInstantiation:
m_Multiplexer->HandleCXXImplicitFunctionInstantiation(
cast<FunctionDecl>(D.getSingleDecl()));
break;
case CallKind::HandleTopLevelDeclInObjCContainer:
m_Multiplexer->HandleTopLevelDeclInObjCContainer(D);
break;
case CallKind::HandleImplicitImportDecl:
m_Multiplexer->HandleImplicitImportDecl(
cast<ImportDecl>(D.getSingleDecl()));
break;
case CallKind::CompleteTentativeDefinition:
m_Multiplexer->CompleteTentativeDefinition(
cast<VarDecl>(D.getSingleDecl()));
break;
#if CLANG_VERSION_MAJOR > 9
case CallKind::CompleteExternalDeclaration:
m_Multiplexer->CompleteExternalDeclaration(
cast<VarDecl>(D.getSingleDecl()));
break;
#endif
case CallKind::AssignInheritanceModel:
m_Multiplexer->AssignInheritanceModel(
cast<CXXRecordDecl>(D.getSingleDecl()));
break;
case CallKind::HandleVTable:
m_Multiplexer->HandleVTable(cast<CXXRecordDecl>(D.getSingleDecl()));
break;
case CallKind::InitializeSema:
m_Multiplexer->InitializeSema(m_CI.getSema());
break;
case CallKind::ForgetSema:
m_Multiplexer->ForgetSema();
break;
};
}

m_Multiplexer->HandleTranslationUnit(C);
}

void CladPlugin::PrintStats() {
llvm::errs() << "*** INFORMATION ABOUT THE DELAYED CALLS\n";
for (const DelayedCallInfo& DCI : m_DelayedCalls) {
llvm::errs() << " ";
switch (DCI.m_Kind) {
case CallKind::HandleCXXStaticMemberVarInstantiation:
llvm::errs() << "HandleCXXStaticMemberVarInstantiation";
break;
case CallKind::HandleTopLevelDecl:
llvm::errs() << "HandleTopLevelDecl";
break;
case CallKind::HandleInlineFunctionDefinition:
llvm::errs() << "HandleInlineFunctionDefinition";
break;
case CallKind::HandleInterestingDecl:
llvm::errs() << "HandleInterestingDecl";
break;
case CallKind::HandleTagDeclDefinition:
llvm::errs() << "HandleTagDeclDefinition";
break;
case CallKind::HandleTagDeclRequiredDefinition:
llvm::errs() << "HandleTagDeclRequiredDefinition";
break;
case CallKind::HandleCXXImplicitFunctionInstantiation:
llvm::errs() << "HandleCXXImplicitFunctionInstantiation";
break;
case CallKind::HandleTopLevelDeclInObjCContainer:
llvm::errs() << "HandleTopLevelDeclInObjCContainer";
break;
case CallKind::HandleImplicitImportDecl:
llvm::errs() << "HandleImplicitImportDecl";
break;
case CallKind::CompleteTentativeDefinition:
llvm::errs() << "CompleteTentativeDefinition";
break;
case CallKind::CompleteExternalDeclaration:
llvm::errs() << "CompleteExternalDeclaration";
break;
case CallKind::AssignInheritanceModel:
llvm::errs() << "AssignInheritanceModel";
break;
case CallKind::HandleVTable:
llvm::errs() << "HandleVTable";
break;
case CallKind::InitializeSema:
llvm::errs() << "InitializeSema";
break;
case CallKind::ForgetSema:
llvm::errs() << "ForgetSema";
break;
};
for (const clang::Decl* D : DCI.m_DGR) {
llvm::errs() << " " << D;
if (auto* ND = dyn_cast<NamedDecl>(D))
llvm::errs() << " " << ND->getNameAsString();
}
llvm::errs() << "\n";
}

m_Multiplexer->PrintStats();
}

} // end namespace plugin

// Routine to check clang version at runtime against the clang version for
Expand Down
Loading

0 comments on commit 07723ff

Please sign in to comment.