diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 5e9d54ac2..a01facc9c 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -38,6 +38,7 @@ namespace clad { class CladPlugin; clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P, DiffRequest& request); + void ProcessTopLevelDecl(CladPlugin& P, clang::Decl* D); } // namespace plugin } // namespace clad diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index ba0cf7c14..74286c8d4 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -431,6 +431,7 @@ namespace clad { StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS); DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD, bool keepLocal = false); + clang::Expr* DifferentiateGlobalVarDecl(clang::VarDecl* VD); StmtDiff VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 2ab7e35bc..55d5fb4b0 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -290,8 +290,9 @@ namespace clad { /// \returns The newly built variable declaration. clang::VarDecl* BuildVarDecl(clang::QualType Type, clang::IdentifierInfo* Identifier, - clang::Scope* scope, clang::Expr* Init = nullptr, - bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr, + clang::Scope* scope, clang::DeclContext* DeclCtx, + clang::Expr* Init = nullptr, bool DirectInit = false, + clang::TypeSourceInfo* TSI = nullptr, clang::VarDecl::InitializationStyle IS = clang::VarDecl::InitializationStyle::CInit); /// Builds variable declaration to be used inside the derivative @@ -334,7 +335,8 @@ namespace clad { clang::Expr* Init = nullptr, bool DirectInit = false, clang::TypeSourceInfo* TSI = nullptr, clang::VarDecl::InitializationStyle IS = - clang::VarDecl::InitializationStyle::CInit); + clang::VarDecl::InitializationStyle::CInit, + clang::DeclContext* DeclCtx = nullptr); /// Creates a namespace declaration and enters its context. All subsequent /// Stmts are built inside that namespace, until /// m_Sema.PopDeclContextIsUsed. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9ecf043c8..07096ce7d 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1413,8 +1413,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Check DeclRefExpr is a reference to an independent variable. auto it = m_Variables.find(VD); if (it == std::end(m_Variables)) { - // Is not an independent variable, ignored. - return StmtDiff(clonedDRE); + if (VD->isFileVarDecl()) { + Expr* DREDiff = DifferentiateGlobalVarDecl(VD); + it = m_Variables.emplace(VD, DREDiff).first; + } else + // Is not an independent variable, ignored. + return StmtDiff(clonedDRE); } // Create the (_d_param[idx] += dfdx) statement. if (dfdx()) { @@ -1440,6 +1444,42 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(clonedDRE); } + Expr* ReverseModeVisitor::DifferentiateGlobalVarDecl(VarDecl* VD) { + assert(VD->isFileVarDecl() && "Must be a global variable"); + std::string nameDiff_str = "_d_" + VD->getNameAsString(); + DeclarationName nameDiff = &m_Context.Idents.get(nameDiff_str); + DeclContext* DC = VD->getDeclContext(); + + // Attempt to find the adjoint of VD in case it has already been created. + LookupResult result(m_Sema, nameDiff, noLoc, Sema::LookupOrdinaryName); + m_Sema.LookupQualifiedName(result, DC); + if (!result.empty()) { + // Found, return a reference + Expr* foundExpr = m_Sema + .BuildDeclarationNameExpr(CXXScopeSpec{}, result, + /*ADL=*/false) + .get(); + return foundExpr; + } + // Not found, construct the adjoint and register it. + VarDecl* VDDiff = + BuildVarDecl(VD->getType(), CreateUniqueIdentifier(nameDiff_str), + m_DerivativeFnScope->getParent()->getParent(), DC, + getZeroInit(VD->getType())); + + DC->addDecl(VDDiff); + DC->makeDeclVisibleInContext(VDDiff); + plugin::ProcessTopLevelDecl(m_CladPlugin, VDDiff); + // diag(DiagnosticsEngine::Warning, + // VD->getLocation(), + // "The gradient utilizes a global variable '%0' and its adjoint + // '%1'" + // ". Please make sure to properly reset '%0' and '%1' before + // re-running the gradient.", + // {VD->getNameAsString(), nameDiff_str}); + return BuildDeclRef(VDDiff); + } + StmtDiff ReverseModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) { auto* Constant0 = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 87c55b4cd..e75d62083 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -107,18 +107,19 @@ namespace clad { Expr* Init, bool DirectInit, TypeSourceInfo* TSI, VarDecl::InitializationStyle IS) { - return BuildVarDecl(Type, Identifier, getCurrentScope(), Init, DirectInit, - TSI, IS); + return BuildVarDecl(Type, Identifier, getCurrentScope(), m_Sema.CurContext, + Init, DirectInit, TSI, IS); } VarDecl* VisitorBase::BuildVarDecl(QualType Type, IdentifierInfo* Identifier, - Scope* Scope, Expr* Init, bool DirectInit, + Scope* Scope, DeclContext* DeclCtx, + Expr* Init, bool DirectInit, TypeSourceInfo* TSI, VarDecl::InitializationStyle IS) { // add namespace specifier in variable declaration if needed. Type = utils::AddNamespaceSpecifier(m_Sema, m_Context, Type); - auto* VD = VarDecl::Create( - m_Context, m_Sema.CurContext, m_DiffReq->getLocation(), - m_DiffReq->getLocation(), Identifier, Type, TSI, SC_None); + auto* VD = VarDecl::Create(m_Context, DeclCtx, m_DiffReq->getLocation(), + m_DiffReq->getLocation(), Identifier, Type, TSI, + SC_None); if (Init) { m_Sema.AddInitializerToDecl(VD, Init, DirectInit); @@ -149,9 +150,12 @@ namespace clad { VarDecl* VisitorBase::BuildGlobalVarDecl(QualType Type, llvm::StringRef prefix, Expr* Init, bool DirectInit, TypeSourceInfo* TSI, - VarDecl::InitializationStyle IS) { + VarDecl::InitializationStyle IS, + DeclContext* DeclCtx) { + DeclCtx = DeclCtx ? DeclCtx : m_Sema.CurContext; return BuildVarDecl(Type, CreateUniqueIdentifier(prefix), - m_DerivativeFnScope, Init, DirectInit, TSI, IS); + m_DerivativeFnScope, DeclCtx, Init, DirectInit, TSI, + IS); } NamespaceDecl* VisitorBase::BuildNamespaceDecl(IdentifierInfo* II, diff --git a/test/Gradient/Functors.C b/test/Gradient/Functors.C index 07bac5513..334978e87 100644 --- a/test/Gradient/Functors.C +++ b/test/Gradient/Functors.C @@ -168,6 +168,7 @@ int main() { // CHECK: inline void operator_call_grad(double ii, double j, double *_d_ii, double *_d_j) const { // CHECK-NEXT: { + // CHECK-NEXT: _d_x += 1 * j * ii; // CHECK-NEXT: *_d_ii += x * 1 * j; // CHECK-NEXT: *_d_j += x * ii * 1; // CHECK-NEXT: } diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index a23fe2e8b..2d39c112a 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -666,6 +666,8 @@ float running_sum(float* p, int n) { // CHECK-NEXT: } double global = 7; +// CHECK: double _d_global = 0.; +// expected-warning {{The gradient utilizes a global variable 'global' and its adjoint '_d_global'. Please make sure to properly reset 'global' and '_d_global' before re-running the gradient.}} double fn_global_var_use(double i, double j) { double& ref = global; @@ -673,7 +675,7 @@ double fn_global_var_use(double i, double j) { } // CHECK: void fn_global_var_use_grad(double i, double j, double *_d_i, double *_d_j) { -// CHECK-NEXT: double _d_ref = 0.; +// CHECK-NEXT: double &_d_ref = _d_global; // CHECK-NEXT: double &ref = global; // CHECK-NEXT: { // CHECK-NEXT: _d_ref += 1 * i; @@ -1143,6 +1145,49 @@ double f_ref_in_rhs(double x, double y) { //CHECK-NEXT: } //CHECK-NEXT: } +double glob1 = 5; + +double g(double a, double b) { + glob1 = b; + return a; +} + +//CHECK: void g_pullback(double a, double b, double _d_y, double *_d_a, double *_d_b); + +//CHECK: double _d_glob1 = 0.; +// expected-warning {{The gradient utilizes a global variable 'glob1' and its adjoint '_d_glob1'. Please make sure to properly reset 'glob1' and '_d_glob1' before re-running the gradient.}} + +double f_reuse_global(double x, double t) { + t = g(t, x); + glob1 *= t; + return -glob1; +} // -x * t + +//CHECK: void f_reuse_global_grad(double x, double t, double *_d_x, double *_d_t) { +//CHECK-NEXT: double _t0 = t; +//CHECK-NEXT: t = g(t, x); +//CHECK-NEXT: double _t1 = glob1; +//CHECK-NEXT: glob1 *= t; +//CHECK-NEXT: _d_glob1 += -1; +//CHECK-NEXT: { +//CHECK-NEXT: glob1 = _t1; +//CHECK-NEXT: double _r_d1 = _d_glob1; +//CHECK-NEXT: _d_glob1 = 0.; +//CHECK-NEXT: _d_glob1 += _r_d1 * t; +//CHECK-NEXT: *_d_t += glob1 * _r_d1; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: t = _t0; +//CHECK-NEXT: double _r_d0 = *_d_t; +//CHECK-NEXT: *_d_t = 0.; +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: double _r1 = 0.; +//CHECK-NEXT: g_pullback(t, x, _r_d0, &_r0, &_r1); +//CHECK-NEXT: *_d_t += _r0; +//CHECK-NEXT: *_d_x += _r1; +//CHECK-NEXT: } +//CHECK-NEXT: } + #define TEST(F, x, y) \ { \ result[0] = 0; \ @@ -1239,4 +1284,19 @@ int main() { INIT_GRADIENT(f_ref_in_rhs); TEST_GRADIENT(f_ref_in_rhs, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {5.00, 13.00} + + INIT_GRADIENT(f_reuse_global); + TEST_GRADIENT(f_reuse_global, /*numOfDerivativeArgs=*/2, -3, 4, &d_i, &d_j); // CHECK-EXEC: {-4.00, 3.00} } + +//CHECK-NEXT: void g_pullback(double a, double b, double _d_y, double *_d_a, double *_d_b) { +//CHECK-NEXT: double _t0 = glob1; +//CHECK-NEXT: glob1 = b; +//CHECK-NEXT: *_d_a += _d_y; +//CHECK-NEXT: { +//CHECK-NEXT: glob1 = _t0; +//CHECK-NEXT: double _r_d0 = _d_glob1; +//CHECK-NEXT: _d_glob1 = 0.; +//CHECK-NEXT: *_d_b += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } \ No newline at end of file diff --git a/test/ValidCodeGen/ValidCodeGen.C b/test/ValidCodeGen/ValidCodeGen.C index 11f1f6fc7..0e454de51 100644 --- a/test/ValidCodeGen/ValidCodeGen.C +++ b/test/ValidCodeGen/ValidCodeGen.C @@ -58,7 +58,10 @@ int main() { //CHECK-NEXT: } //CHECK: void fn_grad(double x, double *_d_x) { -//CHECK-NEXT: *_d_x += 1 * TN::coefficient; +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += 1 * TN::coefficient; +//CHECK-NEXT: _d_coefficient += x * 1; +//CHECK-NEXT: } //CHECK-NEXT: } //CHECK: void fn2_grad(double x, double y, double *_d_x, double *_d_y) { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 89b62ce8f..90af2f6d3 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -255,6 +255,24 @@ class CladTimerGroup { // handling of the differentiation plans. clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); + void ProcessTopLevelDecl(clang::Decl* D) { + if (llvm::isa(D) && m_DO.DumpDerivedFn) { + clang::LangOptions LangOpts; + LangOpts.CPlusPlus = true; + clang::PrintingPolicy Policy(LangOpts); + Policy.Bool = true; + D->print(llvm::outs(), Policy); + llvm::outs() << ";\n"; + } + DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D}; + assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!"); + AppendDelayed(DCI); + // We could not delay the process due to some strange way of + // initialization, inform the consumers now. + if (!m_Multiplexer) + m_CI.getASTConsumer().HandleTopLevelDecl(DCI.m_DGR); + } + private: void AppendDelayed(DelayedCallInfo DCI) { // Incremental processing handles the translation unit in chunks and it is @@ -268,16 +286,6 @@ class CladTimerGroup { void SendToMultiplexer(); bool CheckBuiltins(); void SetRequestOptions(RequestOptions& opts) const; - - void ProcessTopLevelDecl(clang::Decl* D) { - DelayedCallInfo DCI{CallKind::HandleTopLevelDecl, D}; - assert(!llvm::is_contained(m_DelayedCalls, DCI) && "Already exists!"); - AppendDelayed(DCI); - // We could not delay the process due to some strange way of - // initialization, inform the consumers now. - if (!m_Multiplexer) - m_CI.getASTConsumer().HandleTopLevelDecl(DCI.m_DGR); - } void HandleTopLevelDeclForClad(clang::DeclGroupRef DGR); }; @@ -286,6 +294,10 @@ class CladTimerGroup { return P.ProcessDiffRequest(request); } + void ProcessTopLevelDecl(CladPlugin& P, clang::Decl* D) { + P.ProcessTopLevelDecl(D); + } + template class Action : public clang::PluginASTAction { private: