diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6af50291e..cf5d0a00f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,6 +54,11 @@ jobs: compiler: clang clang-runtime: '16' + - name: osx-clang-runtime17 + os: macos-latest + compiler: clang + clang-runtime: '17' + - name: win-msvc-runtime14 os: windows-latest compiler: msvc @@ -394,6 +399,10 @@ jobs: clang-runtime: '16' doc_build: true + - name: ubu22-clang16-runtime17 + os: ubuntu-22.04 + compiler: clang-16 + clang-runtime: '17' steps: - uses: actions/checkout@v3 with: diff --git a/README.md b/README.md index 03efa06ac..bd13047f4 100644 --- a/README.md +++ b/README.md @@ -231,7 +231,7 @@ Clad also provides certain flags to save and print the generated derivative code - To print the Clad generated derivative: `-Xclang -plugin-arg-clad -Xclang -fdump-derived-fn` ## How to install -At the moment, LLVM/Clang 5.0.x - 16.0.x are supported. +At the moment, LLVM/Clang 5.0.x - 17.0.x are supported. ### Conda Installation diff --git a/docs/internalDocs/ReleaseNotes.md b/docs/internalDocs/ReleaseNotes.md index 4d2b1aa0a..c210c56ec 100644 --- a/docs/internalDocs/ReleaseNotes.md +++ b/docs/internalDocs/ReleaseNotes.md @@ -21,7 +21,7 @@ described first. External Dependencies --------------------- -* Clad now works with clang-5.0 to clang-16 +* Clad now works with clang-5.0 to clang-17 Forward Mode & Reverse Mode diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 09a7af293..f46addf2a 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -103,9 +103,8 @@ namespace clad { VisitorBase(DerivativeBuilder& builder) : m_Builder(builder), m_Sema(builder.m_Sema), m_CladPlugin(builder.m_CladPlugin), m_Context(builder.m_Context), - m_CurScope(m_Sema.TUScope), m_DerivativeFnScope(nullptr), - m_DerivativeInFlight(false), m_Derivative(nullptr), - m_Function(nullptr) {} + m_DerivativeFnScope(nullptr), m_DerivativeInFlight(false), + m_Derivative(nullptr), m_Function(nullptr) {} using Stmts = llvm::SmallVector; @@ -114,7 +113,6 @@ namespace clad { plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; /// Current Scope at the point of visiting. - clang::Scope* m_CurScope; /// Pointer to the topmost Scope in the created derivative function. clang::Scope* m_DerivativeFnScope; bool m_DerivativeInFlight; @@ -148,10 +146,12 @@ namespace clad { // FIXME: Fix this inconsistency, by making `this` pointer derivative // expression to be of object type in the reverse mode as well. clang::Expr* m_ThisExprDerivative = nullptr; + /// A function used to wrap result of visiting E in a lambda. Returns a call /// to the built lambda. Func is a functor that will be invoked inside /// lambda scope and block. Statements inside lambda are expected to be /// added by addToCurrentBlock from func invocation. + // FIXME: This will become problematic when we try to support C. template static clang::Expr* wrapInLambda(VisitorBase& V, clang::Sema& S, const clang::Expr* E, F&& func) { @@ -168,11 +168,26 @@ namespace clad { clang::Declarator D(DS, CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam CLAD_COMPAT_CLANG12_Declarator_LambdaExpr); +#if CLANG_VERSION_MAJOR > 16 + V.beginScope(clang::Scope::LambdaScope | clang::Scope::DeclScope | + clang::Scope::FunctionDeclarationScope | + clang::Scope::FunctionPrototypeScope); +#endif // CLANG_VERSION_MAJOR S.PushLambdaScope(); +#if CLANG_VERSION_MAJOR > 16 + S.ActOnLambdaExpressionAfterIntroducer(Intro, V.getCurrentScope()); + + S.ActOnLambdaClosureParameters(V.getCurrentScope(), /*ParamInfo=*/{}); +#endif // CLANG_VERSION_MAJOR + V.beginScope(clang::Scope::BlockScope | clang::Scope::FnScope | - clang::Scope::DeclScope); + clang::Scope::DeclScope | clang::Scope::CompoundStmtScope); S.ActOnStartOfLambdaDefinition(Intro, D, clad_compat::Sema_ActOnStartOfLambdaDefinition_ScopeOrDeclSpec(V.getCurrentScope(), DS)); +#if CLANG_VERSION_MAJOR > 16 + V.endScope(); +#endif // CLANG_VERSION_MAJOR + V.beginBlock(); func(); clang::CompoundStmt* body = V.endBlock(); @@ -211,22 +226,17 @@ namespace clad { bool addToBlock(clang::Stmt* S, Stmts& block); /// Get a current scope. - clang::Scope* getCurrentScope() { return m_CurScope; } + /// FIXME: Remove the pointer-ref + // clang::Scope* getCurrentScope() { return m_Sema.getCurScope(); } + clang::Scope*& getCurrentScope(); + void setCurrentScope(clang::Scope* S); + /// Returns the innermost enclosing file context which can be either a + /// namespace or the TU scope. + clang::Scope* getEnclosingNamespaceOrTUScope(); + /// Enters a new scope. - void beginScope(unsigned ScopeFlags) { - // FIXME: since Sema::CurScope is private, we cannot access it and have - // to use separate member variable m_CurScope. The only options to set - // CurScope of Sema seemt to be through Parser or ContextAndScopeRAII. - m_CurScope = - new clang::Scope(getCurrentScope(), ScopeFlags, m_Sema.Diags); - } - void endScope() { - // This will remove all the decls in the scope from the IdResolver. - m_Sema.ActOnPopScope(noLoc, m_CurScope); - auto oldScope = m_CurScope; - m_CurScope = oldScope->getParent(); - delete oldScope; - } + void beginScope(unsigned ScopeFlags); + void endScope(); /// A shorthand to simplify syntax for creation of new expressions. /// This function uses m_Sema.BuildUnOp internally to build unary diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 20722eed2..d65fae998 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -165,7 +165,7 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, SourceLocation loc{m_Function->getLocation()}; DeclarationNameInfo name(II, loc); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); - llvm::SaveAndRestore SaveScope(m_CurScope); + llvm::SaveAndRestore SaveScope(getCurrentScope()); DeclContext* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = @@ -426,7 +426,8 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, QualType derivedFnType = m_Context.getFunctionType( returnType, paramTypes, originalFnType->getExtProtoInfo()); llvm::SaveAndRestore saveContext(m_Sema.CurContext); - llvm::SaveAndRestore saveScope(m_CurScope); + llvm::SaveAndRestore saveScope(getCurrentScope(), + getEnclosingNamespaceOrTUScope()); auto* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index a85bfd49a..b44131a16 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -228,7 +228,8 @@ namespace clad { // Create the gradient function declaration. DeclContext* DC = const_cast(m_Function->getDeclContext()); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); - llvm::SaveAndRestore SaveScope(m_CurScope); + llvm::SaveAndRestore SaveScope(getCurrentScope(), + getEnclosingNamespaceOrTUScope()); m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction( diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index 606664301..3e8fb1f11 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -39,7 +39,8 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, sourceFnType->getExtProtoInfo()); llvm::SaveAndRestore saveContext(m_Sema.CurContext); - llvm::SaveAndRestore saveScope(m_CurScope); + llvm::SaveAndRestore saveScope(getCurrentScope(), + getEnclosingNamespaceOrTUScope()); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) m_Sema.CurContext = const_cast(m_Function->getDeclContext()); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 609ccdd20..542f80a26 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -346,7 +346,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Create the gradient function declaration. llvm::SaveAndRestore SaveContext(m_Sema.CurContext); - llvm::SaveAndRestore SaveScope(m_CurScope); + llvm::SaveAndRestore SaveScope(getCurrentScope(), + getEnclosingNamespaceOrTUScope()); auto* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = m_Builder.cloneFunction( @@ -488,7 +489,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo()); llvm::SaveAndRestore saveContext(m_Sema.CurContext); - llvm::SaveAndRestore saveScope(m_CurScope); + llvm::SaveAndRestore saveScope(getCurrentScope(), + getEnclosingNamespaceOrTUScope()); m_Sema.CurContext = const_cast(m_Function->getDeclContext()); DeclWithContext fnBuildRes = m_Builder.cloneFunction( @@ -2838,9 +2840,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // to let Sema::LookupName see the whole scope. auto* identifier = CreateUniqueIdentifier(prefix); // Save current scope and temporarily go to topmost function scope. - llvm::SaveAndRestore SaveScope(m_CurScope); + llvm::SaveAndRestore SaveScope(getCurrentScope()); assert(m_DerivativeFnScope && "must be set"); - m_CurScope = m_DerivativeFnScope; + setCurrentScope(m_DerivativeFnScope); VarDecl* Var = nullptr; if (isa(Type)) { diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index 87f268176..f07458814 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -117,7 +117,7 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, // Function declaration scope llvm::SaveAndRestore SaveContext(m_Sema.CurContext); - llvm::SaveAndRestore SaveScope(m_CurScope); + llvm::SaveAndRestore SaveScope(getCurrentScope()); beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | Scope::DeclScope); m_Sema.PushFunctionScope(); diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 872977d9a..0a18c981f 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -44,6 +44,10 @@ namespace clad { ignoreExpr, ignoreLoc, ignoreRange, ignoreRange, m_Context); } + bool VisitorBase::addToCurrentBlock(Stmt* S) { + return addToBlock(S, getCurrentBlock()); + } + bool VisitorBase::addToBlock(Stmt* S, Stmts& block) { if (!S) return false; @@ -55,8 +59,67 @@ namespace clad { return true; } - bool VisitorBase::addToCurrentBlock(Stmt* S) { - return addToBlock(S, getCurrentBlock()); + // A facility allowing us to access the private member CurScope of the Sema + // object using standard-conforming C++. + namespace { + template struct Rob { + friend typename Tag::type get(Tag) { return M; } + }; + + template struct TagBase { + typedef Member type; +#ifdef MSVC +#pragma warning(push, 0) +#endif // MSVC +#pragma GCC diagnostic push +#pragma clang diagnostic ignored "-Wunknown-warning-option" +#pragma GCC diagnostic ignored "-Wnon-template-friend" + friend type get(Tag); +#pragma GCC diagnostic pop +#ifdef MSVC +#pragma warning(pop) +#endif // MSVC + }; + + // Tag used to access Sema::CurScope. + using namespace clang; + struct Sema_CurScope : TagBase {}; + template struct Rob; + } // namespace + + clang::Scope*& VisitorBase::getCurrentScope() { + return m_Sema.*get(Sema_CurScope()); + } + + void VisitorBase::setCurrentScope(clang::Scope* S) { + m_Sema.*get(Sema_CurScope()) = S; + assert(getEnclosingNamespaceOrTUScope() && "Lost path to base."); + } + + clang::Scope* VisitorBase::getEnclosingNamespaceOrTUScope() { + auto isNamespaceOrTUScope = [](const clang::Scope* S) { + if (clang::DeclContext* DC = S->getEntity()) + return DC->isFileContext(); + return false; + }; + clang::Scope* InnermostFileScope = getCurrentScope(); + while (InnermostFileScope && !isNamespaceOrTUScope(InnermostFileScope)) + InnermostFileScope = InnermostFileScope->getParent(); + + return InnermostFileScope; + } + + void VisitorBase::beginScope(unsigned ScopeFlags) { + auto* S = new clang::Scope(getCurrentScope(), ScopeFlags, m_Sema.Diags); + setCurrentScope(S); + } + + void VisitorBase::endScope() { + // This will remove all the decls in the scope from the IdResolver. + m_Sema.ActOnPopScope(noLoc, getCurrentScope()); + clang::Scope* oldScope = getCurrentScope(); + setCurrentScope(oldScope->getParent()); + delete oldScope; } VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier, diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 4fef1c66b..f66323330 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -378,7 +378,7 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: { // CHECK-NEXT: size_t _d_count = 0; // CHECK-NEXT: size_t _d_max_count = _d_n; -// CHECK-NEXT: for (size_t count = 0; max_count; ++count) { +// CHECK-NEXT: for (size_t count = 0; size_t max_count = n; ++count) { // CHECK-NEXT: if (count >= max_count) // CHECK-NEXT: break; // CHECK-NEXT: {