Skip to content

Commit

Permalink
Delay the differentiation process until the end of TU.
Browse files Browse the repository at this point in the history
Before this patch clad attaches itself as a first consumer and applies AD before
code generation. However, that is limited since clang sends every top-level
declaration to codegen which limits the amount of flexibility clad has. For
example, we have to instantiate all pending templates at every
HandleTopLevelDecl calls; we cannot really differentiate virtual functions
whose classes have sent their key function to CodeGen; and in general we perform
actions which are semantically useful for the end of the translation unit.

This patch makes clad a single consumer of clang which dispatches to the others.
That's done by delaying all calls to the consumers until the end of the TU where
clad can replay the exact sequence of calls to the other consumers as if they
were directly connected to the frontend.

Fixes #248
  • Loading branch information
vgvassilev committed Mar 10, 2024
1 parent e930a4a commit c90b311
Show file tree
Hide file tree
Showing 6 changed files with 452 additions and 79 deletions.
29 changes: 29 additions & 0 deletions include/clad/Differentiator/Sins.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef CLAD_DIFFERENTIATOR_SINS_H
#define CLAD_DIFFERENTIATOR_SINS_H

#include <type_traits>

/// Standard-protected facility allowing access into private members in C++.
/// Use with caution!
// NOLINTBEGIN(cppcoreguidelines-macro-usage)
#define CONCATE_(X, Y) X##Y
#define CONCATE(X, Y) CONCATE_(X, Y)
#define ALLOW_ACCESS(CLASS, MEMBER, ...) \
template <typename Only, __VA_ARGS__ CLASS::*Member> \
struct CONCATE(MEMBER, __LINE__) { \
friend __VA_ARGS__ CLASS::*Access(Only*) { return Member; } \
}; \
template <typename> struct Only_##MEMBER; \
template <> struct Only_##MEMBER<CLASS> { \
friend __VA_ARGS__ CLASS::*Access(Only_##MEMBER<CLASS>*); \
}; \
template struct CONCATE(MEMBER, \
__LINE__)<Only_##MEMBER<CLASS>, &CLASS::MEMBER>

#define ACCESS(OBJECT, MEMBER) \
(OBJECT).*Access((Only_##MEMBER< \
std::remove_reference<decltype(OBJECT)>::type>*)nullptr)

// NOLINTEND(cppcoreguidelines-macro-usage)

#endif // CLAD_DIFFERENTIATOR_SINS_H
37 changes: 5 additions & 32 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

#include "ConstantFolder.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/Sins.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/CladUtils.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
Expand Down Expand Up @@ -59,42 +60,14 @@ namespace clad {
return true;
}

// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
using type = Member;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#endif // __clang__
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};

// Tag used to access Sema::CurScope.
using namespace clang;
struct Sema_CurScope : TagBase<Sema_CurScope, Scope * Sema::*> {};
template struct Rob<Sema_CurScope, &Sema::CurScope>;
} // namespace
ALLOW_ACCESS(Sema, CurScope, Scope*);

clang::Scope*& VisitorBase::getCurrentScope() {
return m_Sema.*get(Sema_CurScope());
return ACCESS(m_Sema, CurScope);
}

void VisitorBase::setCurrentScope(clang::Scope* S) {
m_Sema.*get(Sema_CurScope()) = S;
getCurrentScope() = S;
assert(getEnclosingNamespaceOrTUScope() && "Lost path to base.");
}

Expand Down
8 changes: 8 additions & 0 deletions test/FirstDerivative/CodeGenSimple.C
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,17 @@ extern "C" int printf(const char* fmt, ...);

int f_1_darg0(int x);

double sq_defined_later(double);

int main() {
int x = 4;
clad::differentiate(f_1, 0);
auto df = clad::differentiate(sq_defined_later, "x");
printf("Result is = %d\n", f_1_darg0(1)); // CHECK-EXEC: Result is = 2
printf("Result is = %f\n", df.execute(3)); // CHECK-EXEC: Result is = 6
return 0;
}

double sq_defined_later(double x) {
return x * x;
}
61 changes: 61 additions & 0 deletions test/Misc/ClangConsumers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: %cladclang %s -I%S/../../include -oClangConsumers.out \
// RUN: -fms-compatibility -std=c++14 -fmodules -Xclang \
// RUN: -print-stats 2>&1 | FileCheck %s
// CHECK-NOT: {{.*error|warning|note:.*}}
//
// RUN: clang -xc -Xclang -add-plugin -Xclang clad -Xclang -load \
// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \
// RUN: -Xclang -print-stats 2>&1 | \
// RUN: FileCheck -check-prefix=CHECK_C %s
// CHECK_C-NOT: {{.*error|warning|note:.*}}
//
// RUN: clang -xobjective-c -Xclang -add-plugin -Xclang clad -Xclang -load \
// RUN: -Xclang %cladlib %s -I%S/../../include -oClangConsumers.out \
// RUN: -Xclang -print-stats 2>&1 | \
// RUN: FileCheck -check-prefix=CHECK_OBJC %s
// CHECK_OBJC-NOT: {{.*error|warning|note:.*}}

#ifdef __cplusplus

#pragma clang module build N
module N {}
#pragma clang module contents
#pragma clang module begin N
struct f { void operator()() const {} };
template <typename T> auto vtemplate = f{};
#pragma clang module end
#pragma clang module endbuild

#pragma clang module import N

class __single_inheritance IncSingle;

struct V { virtual int f(); };
int V::f() { return 1; }

// CHECK: HandleImplicitImportDecl
// CHECK: AssignInheritanceModel
// CHECK: HandleTopLevelDecl
// CHECK: HandleInterestingDecl
// CHECK: HandleVTable
// CHECK: HandleCXXStaticMemberVarInstantiation

#endif // __cplusplus

#ifdef __STDC_VERSION__ // C mode
int i;
// CHECK_C: CompleteTentativeDefinition
#endif // __STDC_VERSION__

#ifdef __OBJC__
@interface rdar10902015
void f() {}
@end
// CHECK_OBJC: HandleTopLevelDeclInObjCContainer
#endif // __OBJC__

int main() {
#ifdef __cplusplus
vtemplate<int>();
#endif // __cplusplus
}
Loading

0 comments on commit c90b311

Please sign in to comment.