Skip to content

Commit

Permalink
Add LLVM17 support.
Browse files Browse the repository at this point in the history
This patch replaces the artificial handling of the scopes and accesses the
private CurScope member of Sema. This allows us to build precise lexical scope
setup for when building AST nodes. That's essential for lambda support in
clang17 because it tightly couples both Scopes and DeclContexts.

Fixes #669
  • Loading branch information
vgvassilev committed Dec 28, 2023
1 parent d5e8dc2 commit 25e2f3f
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 34 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ jobs:
compiler: clang
clang-runtime: '16'

- name: osx-clang-runtime17
os: macos-latest
compiler: clang
clang-runtime: '17'

- name: win-msvc-runtime14
os: windows-latest
compiler: msvc
Expand Down Expand Up @@ -394,6 +399,10 @@ jobs:
clang-runtime: '16'
doc_build: true

- name: ubu22-clang16-runtime17
os: ubuntu-22.04
compiler: clang-16
clang-runtime: '17'
steps:
- uses: actions/checkout@v3
with:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ Clad also provides certain flags to save and print the generated derivative code
- To print the Clad generated derivative: `-Xclang -plugin-arg-clad -Xclang -fdump-derived-fn`

## How to install
At the moment, LLVM/Clang 5.0.x - 16.0.x are supported.
At the moment, LLVM/Clang 5.0.x - 17.0.x are supported.

### Conda Installation

Expand Down
2 changes: 1 addition & 1 deletion docs/internalDocs/ReleaseNotes.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ described first.
External Dependencies
---------------------

* Clad now works with clang-5.0 to clang-16
* Clad now works with clang-5.0 to clang-17


Forward Mode & Reverse Mode
Expand Down
50 changes: 30 additions & 20 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,8 @@ namespace clad {
VisitorBase(DerivativeBuilder& builder)
: m_Builder(builder), m_Sema(builder.m_Sema),
m_CladPlugin(builder.m_CladPlugin), m_Context(builder.m_Context),
m_CurScope(m_Sema.TUScope), m_DerivativeFnScope(nullptr),
m_DerivativeInFlight(false), m_Derivative(nullptr),
m_Function(nullptr) {}
m_DerivativeFnScope(nullptr), m_DerivativeInFlight(false),
m_Derivative(nullptr), m_Function(nullptr) {}

using Stmts = llvm::SmallVector<clang::Stmt*, 16>;

Expand All @@ -114,7 +113,6 @@ namespace clad {
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
/// Current Scope at the point of visiting.
clang::Scope* m_CurScope;
/// Pointer to the topmost Scope in the created derivative function.
clang::Scope* m_DerivativeFnScope;
bool m_DerivativeInFlight;
Expand Down Expand Up @@ -148,10 +146,12 @@ namespace clad {
// FIXME: Fix this inconsistency, by making `this` pointer derivative
// expression to be of object type in the reverse mode as well.
clang::Expr* m_ThisExprDerivative = nullptr;

/// A function used to wrap result of visiting E in a lambda. Returns a call
/// to the built lambda. Func is a functor that will be invoked inside
/// lambda scope and block. Statements inside lambda are expected to be
/// added by addToCurrentBlock from func invocation.
// FIXME: This will become problematic when we try to support C.
template <typename F>
static clang::Expr* wrapInLambda(VisitorBase& V, clang::Sema& S,
const clang::Expr* E, F&& func) {
Expand All @@ -168,11 +168,26 @@ namespace clad {
clang::Declarator D(DS,
CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam
CLAD_COMPAT_CLANG12_Declarator_LambdaExpr);
#if CLANG_VERSION_MAJOR > 16
V.beginScope(clang::Scope::LambdaScope | clang::Scope::DeclScope |
clang::Scope::FunctionDeclarationScope |
clang::Scope::FunctionPrototypeScope);
#endif // CLANG_VERSION_MAJOR
S.PushLambdaScope();
#if CLANG_VERSION_MAJOR > 16
S.ActOnLambdaExpressionAfterIntroducer(Intro, V.getCurrentScope());

S.ActOnLambdaClosureParameters(V.getCurrentScope(), /*ParamInfo=*/{});
#endif // CLANG_VERSION_MAJOR

V.beginScope(clang::Scope::BlockScope | clang::Scope::FnScope |
clang::Scope::DeclScope);
clang::Scope::DeclScope | clang::Scope::CompoundStmtScope);
S.ActOnStartOfLambdaDefinition(Intro, D,
clad_compat::Sema_ActOnStartOfLambdaDefinition_ScopeOrDeclSpec(V.getCurrentScope(), DS));
#if CLANG_VERSION_MAJOR > 16
V.endScope();
#endif // CLANG_VERSION_MAJOR

V.beginBlock();
func();
clang::CompoundStmt* body = V.endBlock();
Expand Down Expand Up @@ -211,22 +226,17 @@ namespace clad {
bool addToBlock(clang::Stmt* S, Stmts& block);

/// Get a current scope.
clang::Scope* getCurrentScope() { return m_CurScope; }
/// FIXME: Remove the pointer-ref
// clang::Scope* getCurrentScope() { return m_Sema.getCurScope(); }
clang::Scope*& getCurrentScope();
void setCurrentScope(clang::Scope* S);
/// Returns the innermost enclosing file context which can be either a
/// namespace or the TU scope.
clang::Scope* getEnclosingNamespaceOrTUScope();

/// Enters a new scope.
void beginScope(unsigned ScopeFlags) {
// FIXME: since Sema::CurScope is private, we cannot access it and have
// to use separate member variable m_CurScope. The only options to set
// CurScope of Sema seemt to be through Parser or ContextAndScopeRAII.
m_CurScope =
new clang::Scope(getCurrentScope(), ScopeFlags, m_Sema.Diags);
}
void endScope() {
// This will remove all the decls in the scope from the IdResolver.
m_Sema.ActOnPopScope(noLoc, m_CurScope);
auto oldScope = m_CurScope;
m_CurScope = oldScope->getParent();
delete oldScope;
}
void beginScope(unsigned ScopeFlags);
void endScope();

/// A shorthand to simplify syntax for creation of new expressions.
/// This function uses m_Sema.BuildUnOp internally to build unary
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
SourceLocation loc{m_Function->getLocation()};
DeclarationNameInfo name(II, loc);
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(m_CurScope);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope());
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result =
Expand Down Expand Up @@ -426,7 +426,8 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
QualType derivedFnType = m_Context.getFunctionType(
returnType, paramTypes, originalFnType->getExtProtoInfo());
llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(m_CurScope);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;

Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ namespace clad {
// Create the gradient function declaration.
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(m_CurScope);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
m_Sema.CurContext = DC;

DeclWithContext result = m_Builder.cloneFunction(
Expand Down
3 changes: 2 additions & 1 deletion lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
sourceFnType->getExtProtoInfo());

llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(m_CurScope);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
m_Sema.CurContext = const_cast<DeclContext*>(m_Function->getDeclContext());

Expand Down
10 changes: 6 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Create the gradient function declaration.
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(m_CurScope);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result = m_Builder.cloneFunction(
Expand Down Expand Up @@ -488,7 +489,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo());

llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(m_CurScope);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
m_Sema.CurContext = const_cast<DeclContext*>(m_Function->getDeclContext());

DeclWithContext fnBuildRes = m_Builder.cloneFunction(
Expand Down Expand Up @@ -2838,9 +2840,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to let Sema::LookupName see the whole scope.
auto* identifier = CreateUniqueIdentifier(prefix);
// Save current scope and temporarily go to topmost function scope.
llvm::SaveAndRestore<Scope*> SaveScope(m_CurScope);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope());
assert(m_DerivativeFnScope && "must be set");
m_CurScope = m_DerivativeFnScope;
setCurrentScope(m_DerivativeFnScope);

VarDecl* Var = nullptr;
if (isa<ArrayType>(Type)) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/VectorForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD,

// Function declaration scope
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(m_CurScope);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope());
beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
Expand Down
67 changes: 65 additions & 2 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ namespace clad {
ignoreExpr, ignoreLoc, ignoreRange, ignoreRange, m_Context);
}

bool VisitorBase::addToCurrentBlock(Stmt* S) {
return addToBlock(S, getCurrentBlock());
}

bool VisitorBase::addToBlock(Stmt* S, Stmts& block) {
if (!S)
return false;
Expand All @@ -55,8 +59,67 @@ namespace clad {
return true;
}

bool VisitorBase::addToCurrentBlock(Stmt* S) {
return addToBlock(S, getCurrentBlock());
// A facility allowing us to access the private member CurScope of the Sema
// object using standard-conforming C++.
namespace {
template <typename Tag, typename Tag::type M> struct Rob {
friend typename Tag::type get(Tag) { return M; }
};

template <typename Tag, typename Member> struct TagBase {
typedef Member type;
#ifdef MSVC
#pragma warning(push, 0)
#endif // MSVC
#pragma GCC diagnostic push
#pragma clang diagnostic ignored "-Wunknown-warning-option"
#pragma GCC diagnostic ignored "-Wnon-template-friend"
friend type get(Tag);
#pragma GCC diagnostic pop
#ifdef MSVC
#pragma warning(pop)
#endif // MSVC
};

// Tag used to access Sema::CurScope.
using namespace clang;
struct Sema_CurScope : TagBase<Sema_CurScope, Scope * Sema::*> {};
template struct Rob<Sema_CurScope, &Sema::CurScope>;
} // namespace

clang::Scope*& VisitorBase::getCurrentScope() {
return m_Sema.*get(Sema_CurScope());
}

void VisitorBase::setCurrentScope(clang::Scope* S) {
m_Sema.*get(Sema_CurScope()) = S;
assert(getEnclosingNamespaceOrTUScope() && "Lost path to base.");
}

clang::Scope* VisitorBase::getEnclosingNamespaceOrTUScope() {
auto isNamespaceOrTUScope = [](const clang::Scope* S) {
if (clang::DeclContext* DC = S->getEntity())
return DC->isFileContext();
return false;
};
clang::Scope* InnermostFileScope = getCurrentScope();
while (InnermostFileScope && !isNamespaceOrTUScope(InnermostFileScope))
InnermostFileScope = InnermostFileScope->getParent();

return InnermostFileScope;
}

void VisitorBase::beginScope(unsigned ScopeFlags) {
auto* S = new clang::Scope(getCurrentScope(), ScopeFlags, m_Sema.Diags);
setCurrentScope(S);
}

void VisitorBase::endScope() {
// This will remove all the decls in the scope from the IdResolver.
m_Sema.ActOnPopScope(noLoc, getCurrentScope());
clang::Scope* oldScope = getCurrentScope();
setCurrentScope(oldScope->getParent());
delete oldScope;
}

VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier,
Expand Down
2 changes: 1 addition & 1 deletion test/FirstDerivative/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ double fn10_darg0(double x, size_t n);
// CHECK-NEXT: {
// CHECK-NEXT: size_t _d_count = 0;
// CHECK-NEXT: size_t _d_max_count = _d_n;
// CHECK-NEXT: for (size_t count = 0; max_count; ++count) {
// CHECK-NEXT: for (size_t count = 0; size_t max_count = n; ++count) {
// CHECK-NEXT: if (count >= max_count)
// CHECK-NEXT: break;
// CHECK-NEXT: {
Expand Down

0 comments on commit 25e2f3f

Please sign in to comment.