From 2f114580f8e9f14168dfaa49b75a81ed6b31fa52 Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Mon, 25 Nov 2024 01:21:03 +0100 Subject: [PATCH] Improve CallExpr analysis --- include/clad/Differentiator/DiffPlanner.h | 19 +- lib/Differentiator/ActivityAnalyzer.cpp | 35 +-- lib/Differentiator/DiffPlanner.cpp | 293 +++++++++++----------- lib/Differentiator/ReverseModeVisitor.cpp | 4 +- test/Analyses/ActivityReverse.cpp | 47 +++- 5 files changed, 210 insertions(+), 188 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index ccce6288c..ff120cd84 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -34,11 +34,13 @@ struct DiffRequest { } m_TbrRunInfo; mutable struct ActivityRunInfo { - std::set ToBeRecorded; + // std::set ToBeRecorded; bool HasAnalysisRun = false; } m_ActivityRunInfo; public: + static std::set AllVariedDecls; + /// Function to be differentiated. const clang::FunctionDecl* Function = nullptr; /// Name of the base function to be differentiated. Can be different from @@ -145,12 +147,15 @@ struct DiffRequest { bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; - void setToBeRecorded(std::set init) { - this->m_ActivityRunInfo.ToBeRecorded = init; - } - std::set getToBeRecorded() const { - return this->m_ActivityRunInfo.ToBeRecorded; - } + // void setToBeRecorded(std::set init) { + // this->m_ActivityRunInfo.ToBeRecorded = init; + // } + // std::set getToBeRecorded() const { + // for(auto i: m_ActivityRunInfo.ToBeRecorded){ + // AllVariedDecls.insert(i); + // } + // //return this->m_ActivityRunInfo.ToBeRecorded; + // } }; using DiffInterval = std::vector; diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index 7085826b7..d9623bbbf 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -121,6 +121,8 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { FunctionDecl* FD = CE->getDirectCallee(); bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); if (noHiddenParam) { + bool restoreMarking = m_Marking; + bool restoreVaried = m_Varied; MutableArrayRef FDparam = FD->parameters(); for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { clang::Expr* par = CE->getArg(i); @@ -130,25 +132,24 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { while (innermostType->isPointerType()) innermostType = innermostType->getPointeeType(); - if ((parType->isReferenceType() || - utils::isArrayOrPointerType(parType)) && - !innermostType.isConstQualified()) { - m_Marking = true; - m_Varied = true; - } - + m_Varied = false; + m_Marking = false; TraverseStmt(par); - if ((parType->isReferenceType() || - utils::isArrayOrPointerType(parType)) && - !innermostType.isConstQualified()) { - m_Marking = false; //? - m_Varied = false; - } - - if ((m_Varied || !innermostType.isConstQualified())) + if (m_Varied) + m_VariedDecls.insert(FDparam[i]); + else if ((parType->isReferenceType() || + (utils::isArrayOrPointerType(parType) && + !innermostType.isConstQualified()))) { + m_Varied = true; + m_Marking = true; + TraverseStmt(par); m_VariedDecls.insert(FDparam[i]); + } } + m_Varied = restoreVaried; + m_Marking = restoreMarking; } + return true; } @@ -161,10 +162,10 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { innermost = innermost->getPointeeType(); if (VDTy->isPointerType() && !innermost.isConstQualified()) { copyVarToCurBlock(cast(D)); - continue; + m_Varied = true; } else if (VDTy->isArrayType()) { copyVarToCurBlock(cast(D)); - continue; + m_Varied = true; } if (Expr* init = cast(D)->getInit()) { diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index b9bdfcf39..ca0501fa3 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -22,158 +22,149 @@ using namespace clang; namespace clad { - static SourceLocation noLoc; - - /// Returns `DeclRefExpr` node corresponding to the function, method or - /// functor argument which is to be differentiated. - /// - /// \param[in] call A clad differentiation function call expression - /// \param SemaRef Reference to Sema - DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { - struct Finder : - RecursiveASTVisitor { - Sema& m_SemaRef; - SourceLocation m_BeginLoc; - DeclRefExpr* m_FnDRE = nullptr; - Finder(Sema& SemaRef, SourceLocation beginLoc) - : m_SemaRef(SemaRef), m_BeginLoc(beginLoc) {} - - // Required for visiting lambda declarations. - bool shouldVisitImplicitCode() const { return true; } - - bool VisitDeclRefExpr(DeclRefExpr* DRE) { - if (auto VD = dyn_cast(DRE->getDecl())) { - auto varType = VD->getType().getTypePtr(); - // If variable is of class type, set `m_FnDRE` to - // `DeclRefExpr` of overloaded call operator method of - // the class type. - if (varType->isStructureOrClassType()) { - auto RD = varType->getAsCXXRecordDecl(); - TraverseDecl(RD); - } else { - TraverseStmt(VD->getInit()); - } - } else if (isa(DRE->getDecl())) - m_FnDRE = DRE; - return false; +std::set DiffRequest::AllVariedDecls; + +static SourceLocation noLoc; + +/// Returns `DeclRefExpr` node corresponding to the function, method or +/// functor argument which is to be differentiated. +/// +/// \param[in] call A clad differentiation function call expression +/// \param SemaRef Reference to Sema +DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { + struct Finder : RecursiveASTVisitor { + Sema& m_SemaRef; + SourceLocation m_BeginLoc; + DeclRefExpr* m_FnDRE = nullptr; + Finder(Sema& SemaRef, SourceLocation beginLoc) + : m_SemaRef(SemaRef), m_BeginLoc(beginLoc) {} + + // Required for visiting lambda declarations. + bool shouldVisitImplicitCode() const { return true; } + + bool VisitDeclRefExpr(DeclRefExpr* DRE) { + if (auto VD = dyn_cast(DRE->getDecl())) { + auto varType = VD->getType().getTypePtr(); + // If variable is of class type, set `m_FnDRE` to + // `DeclRefExpr` of overloaded call operator method of + // the class type. + if (varType->isStructureOrClassType()) { + auto RD = varType->getAsCXXRecordDecl(); + TraverseDecl(RD); + } else { + TraverseStmt(VD->getInit()); } + } else if (isa(DRE->getDecl())) + m_FnDRE = DRE; + return false; + } - bool VisitCXXRecordDecl(CXXRecordDecl* RD) { - auto callOperatorDeclName = - m_SemaRef.getASTContext().DeclarationNames.getCXXOperatorName( - OverloadedOperatorKind::OO_Call); - LookupResult R(m_SemaRef, - callOperatorDeclName, - noLoc, - Sema::LookupNameKind::LookupMemberName); - // We do not want diagnostics that would fire because of this lookup. - R.suppressDiagnostics(); - m_SemaRef.LookupQualifiedName(R, RD); - - // Emit error diagnostics - if (R.empty()) { - const char diagFmt[] = "'%0' has no defined operator()"; - auto diagId = - m_SemaRef.Diags.getCustomDiagID(DiagnosticsEngine::Level::Error, - diagFmt); - m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); - return false; - } else if (!R.isSingleResult()) { - const char diagFmt[] = - "'%0' has multiple definitions of operator(). " - "Multiple definitions of call operators are not supported."; - auto diagId = - m_SemaRef.Diags.getCustomDiagID(DiagnosticsEngine::Level::Error, - diagFmt); - m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); - - // Emit diagnostics for candidate functions - for (auto oper = R.begin(), operEnd = R.end(); oper != operEnd; - ++oper) { - auto candidateFn = cast(oper.getDecl()); - m_SemaRef.NoteOverloadCandidate(candidateFn, - cast(candidateFn)); - } - return false; - } else if (R.isSingleResult() == 1 && - cast(R.getFoundDecl())->getAccess() != - AccessSpecifier::AS_public) { - const char diagFmt[] = - "'%0' contains %1 call operator. Differentiation of " - "private/protected call operator is not supported."; - - auto diagId = - m_SemaRef.Diags.getCustomDiagID(DiagnosticsEngine::Level::Error, - diagFmt); - // Compute access specifier name so that it can be used in - // diagnostic message. - const char* callOperatorAS = - (cast(R.getFoundDecl())->getAccess() == - AccessSpecifier::AS_private - ? "private" - : "protected"); - m_SemaRef.Diag(m_BeginLoc, diagId) - << RD->getName() << callOperatorAS; - auto callOperator = cast(R.getFoundDecl()); - - bool isImplicit = true; - - // compute if the corresponding access specifier of the found - // call operator is implicit or explicit. - for (auto decl : RD->decls()) { - if (decl == callOperator) - break; - if (isa(decl)) { - isImplicit = false; - break; - } - } + bool VisitCXXRecordDecl(CXXRecordDecl* RD) { + auto callOperatorDeclName = + m_SemaRef.getASTContext().DeclarationNames.getCXXOperatorName( + OverloadedOperatorKind::OO_Call); + LookupResult R(m_SemaRef, callOperatorDeclName, noLoc, + Sema::LookupNameKind::LookupMemberName); + // We do not want diagnostics that would fire because of this lookup. + R.suppressDiagnostics(); + m_SemaRef.LookupQualifiedName(R, RD); + + // Emit error diagnostics + if (R.empty()) { + const char diagFmt[] = "'%0' has no defined operator()"; + auto diagId = m_SemaRef.Diags.getCustomDiagID( + DiagnosticsEngine::Level::Error, diagFmt); + m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); + return false; + } else if (!R.isSingleResult()) { + const char diagFmt[] = + "'%0' has multiple definitions of operator(). " + "Multiple definitions of call operators are not supported."; + auto diagId = m_SemaRef.Diags.getCustomDiagID( + DiagnosticsEngine::Level::Error, diagFmt); + m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); + + // Emit diagnostics for candidate functions + for (auto oper = R.begin(), operEnd = R.end(); oper != operEnd; + ++oper) { + auto candidateFn = cast(oper.getDecl()); + m_SemaRef.NoteOverloadCandidate(candidateFn, + cast(candidateFn)); + } + return false; + } else if (R.isSingleResult() == 1 && + cast(R.getFoundDecl())->getAccess() != + AccessSpecifier::AS_public) { + const char diagFmt[] = + "'%0' contains %1 call operator. Differentiation of " + "private/protected call operator is not supported."; + + auto diagId = m_SemaRef.Diags.getCustomDiagID( + DiagnosticsEngine::Level::Error, diagFmt); + // Compute access specifier name so that it can be used in + // diagnostic message. + const char* callOperatorAS = + (cast(R.getFoundDecl())->getAccess() == + AccessSpecifier::AS_private + ? "private" + : "protected"); + m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName() << callOperatorAS; + auto callOperator = cast(R.getFoundDecl()); + + bool isImplicit = true; + + // compute if the corresponding access specifier of the found + // call operator is implicit or explicit. + for (auto decl : RD->decls()) { + if (decl == callOperator) + break; + if (isa(decl)) { + isImplicit = false; + break; + } + } - // Emit diagnostics for the found call operator - m_SemaRef.Diag(callOperator->getBeginLoc(), - diag::note_access_natural) - << (unsigned)(callOperator->getAccess() == - AccessSpecifier::AS_protected) - << isImplicit; + // Emit diagnostics for the found call operator + m_SemaRef.Diag(callOperator->getBeginLoc(), diag::note_access_natural) + << (unsigned)(callOperator->getAccess() == + AccessSpecifier::AS_protected) + << isImplicit; - return false; - } + return false; + } - assert(R.isSingleResult() && - "Multiple definitions of call operators are not supported"); - assert(R.isSingleResult() == 1 && - cast(R.getFoundDecl())->getAccess() == - AccessSpecifier::AS_public && - "Differentiation of private/protected call operators are " - "not supported"); - auto callOperator = cast(R.getFoundDecl()); - // Creating `DeclRefExpr` of the found overloaded call operator - // method, to maintain consistency with member function - // differentiation. - CXXScopeSpec CSS; - utils::BuildNNS(m_SemaRef, callOperator->getDeclContext(), CSS, - /*addGlobalNS=*/true); - - // `ExprValueKind::VK_RValue` is used because functions are - // decomposed to function pointers and thus a temporary is - // created for the function pointer. - auto newFnDRE = clad_compat::GetResult( - m_SemaRef.BuildDeclRefExpr(callOperator, - callOperator->getType(), - CLAD_COMPAT_ExprValueKind_R_or_PR_Value, - noLoc, - &CSS)); - m_FnDRE = cast(newFnDRE); - return false; - } - } finder(SemaRef, call->getArg(0)->getBeginLoc()); - finder.TraverseStmt(call->getArg(0)); + assert(R.isSingleResult() && + "Multiple definitions of call operators are not supported"); + assert(R.isSingleResult() == 1 && + cast(R.getFoundDecl())->getAccess() == + AccessSpecifier::AS_public && + "Differentiation of private/protected call operators are " + "not supported"); + auto callOperator = cast(R.getFoundDecl()); + // Creating `DeclRefExpr` of the found overloaded call operator + // method, to maintain consistency with member function + // differentiation. + CXXScopeSpec CSS; + utils::BuildNNS(m_SemaRef, callOperator->getDeclContext(), CSS, + /*addGlobalNS=*/true); + + // `ExprValueKind::VK_RValue` is used because functions are + // decomposed to function pointers and thus a temporary is + // created for the function pointer. + auto newFnDRE = clad_compat::GetResult(m_SemaRef.BuildDeclRefExpr( + callOperator, callOperator->getType(), + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noLoc, &CSS)); + m_FnDRE = cast(newFnDRE); + return false; + } + } finder(SemaRef, call->getArg(0)->getBeginLoc()); + finder.TraverseStmt(call->getArg(0)); - assert(cast(call->getDirectCallee()->getDeclContext()) - ->getName() == "clad" && - "Should be called for clad:: special functions!"); - return finder.m_FnDRE; - } + assert(cast(call->getDirectCallee()->getDeclContext()) + ->getName() == "clad" && + "Should be called for clad:: special functions!"); + return finder.m_FnDRE; +} void DiffRequest::updateCall(FunctionDecl* FD, FunctionDecl* OverloadedFD, Sema& SemaRef) { @@ -632,15 +623,13 @@ namespace clad { if (!m_ActivityRunInfo.HasAnalysisRun) { if (Args) for (const auto& dParam : DVI) - m_ActivityRunInfo.ToBeRecorded.insert(cast(dParam.param)); - - VariedAnalyzer analyzer(Function->getASTContext(), - m_ActivityRunInfo.ToBeRecorded); + AllVariedDecls.insert(cast(dParam.param)); + VariedAnalyzer analyzer(Function->getASTContext(), AllVariedDecls); analyzer.Analyze(Function); m_ActivityRunInfo.HasAnalysisRun = true; } - auto found = m_ActivityRunInfo.ToBeRecorded.find(VD); - return found != m_ActivityRunInfo.ToBeRecorded.end(); + auto found = AllVariedDecls.find(VD); + return found != AllVariedDecls.end(); } bool DiffCollector::VisitCallExpr(CallExpr* E) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index fe04e0e9d..ae2e9560a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1677,8 +1677,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. - if (utils::IsReferenceOrPointerArg(arg) || - !m_DiffReq.shouldHaveAdjoint(PVD)) { + if (utils::IsReferenceOrPointerArg(arg)) { argDiff = Visit(arg); CallArgDx.push_back(argDiff.getExpr_dx()); } else { @@ -1981,7 +1980,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.VerboseDiags = false; pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; - pullbackRequest.setToBeRecorded(m_DiffReq.getToBeRecorded()); bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) if (MD && isLambdaCallOperator(MD)) { diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp index 0a93b93cf..490bafede 100644 --- a/test/Analyses/ActivityReverse.cpp +++ b/test/Analyses/ActivityReverse.cpp @@ -373,6 +373,32 @@ double f11(double x){ // CHECK-NEXT: } // CHECK-NEXT: } +double f12_1(double y, const double* obs){ + double nopull = interpolate1d(1.0, 3.0, 5.0, 5, obs); + return nopull; +} +double f12(double x, const double* obs){ + double pull = f12_1(x, obs); + return pull*x; +} + +// CHECK: void f12_1_pullback(double y, const double *obs, double _d_y0, double *_d_y); +// CHECK-NEXT: void f12_grad_0(double x, const double *obs, double *_d_x) { +// CHECK-NEXT: double _d_pull = 0.; +// CHECK-NEXT: double pull = f12_1(x, obs); +// CHECK-NEXT: { +// CHECK-NEXT: _d_pull += 1 * x; +// CHECK-NEXT: *_d_x += pull * 1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: f12_1_pullback(x, obs, _d_pull, &_r0); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + + + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -384,7 +410,6 @@ int main(){ double arr[] = {1,2,3,4,5}; double darr[] = {0,0,0,0,0}; double result[3] = {}; - double dx = 0; TEST(f1, 3);// CHECK-EXEC: {6.00} TEST(f2, 3);// CHECK-EXEC: {6.00} TEST(f3, 3);// CHECK-EXEC: {0.00} @@ -393,24 +418,24 @@ int main(){ TEST(f6, 3);// CHECK-EXEC: {0.00} TEST(f7, 3);// CHECK-EXEC: {1.00} TEST(f8, 3);// CHECK-EXEC: {1.00} + double dx9 = 0; auto grad9 = clad::gradient(f9, "x"); - grad9.execute(3, arr, &dx, darr); - printf("%.2f\n", dx);// CHECK-EXEC: 2.00 + grad9.execute(3, arr, &dx9, darr); + printf("%.2f\n", dx9);// CHECK-EXEC: 2.00 TEST(f10, 3);// CHECK-EXEC: {1.00} TEST(f11, 3);// CHECK-EXEC: {1.00} + double dx12 = 0; + auto grad = clad::gradient(f12, "x"); + grad.execute(3, arr, &dx12, darr); + printf("%.2f\n", dx12);// CHECK-EXEC: 5.00 } // CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { -// CHECK-NEXT: double _d_k = 0.; // CHECK-NEXT: double k = 2 * u; // CHECK-NEXT: double _d_n = 0.; // CHECK-NEXT: double n = 2 * v; -// CHECK-NEXT: { -// CHECK-NEXT: _d_n += _d_y * k; -// CHECK-NEXT: _d_k += n * _d_y; -// CHECK-NEXT: } +// CHECK-NEXT: _d_n += _d_y * k; // CHECK-NEXT: *_d_v += 2 * _d_n; -// CHECK-NEXT: *_d_u += 2 * _d_k; // CHECK-NEXT: } // CHECK: void f8_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { @@ -438,4 +463,8 @@ int main(){ // CHECK-NEXT: *_d_u = 0.; // CHECK-NEXT: *_d_v += _r_d0; // CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void f12_1_pullback(double y, const double *obs, double _d_y0, double *_d_y) { +// CHECK-NEXT: double nopull = interpolate1d(1., 3., 5., 5, obs); // CHECK-NEXT: } \ No newline at end of file