From fa807221e89f1e6f1a7527d93e3cc39703a17e61 Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Mon, 18 Nov 2024 01:19:58 +0100 Subject: [PATCH 1/5] Mark variables passed by pointer/reference as varied --- lib/Differentiator/ActivityAnalyzer.cpp | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index eb4f2d2a7..f7810baf2 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -124,8 +124,22 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { MutableArrayRef FDparam = FD->parameters(); for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { clang::Expr* par = CE->getArg(i); + + QualType parType = FDparam[i]->getType(); + while (parType->isPointerType()) + parType = parType->getPointeeType(); + if((parType->isReferenceType() || utils::isArrayOrPointerType(parType)) && !parType.isConstQualified()){ + m_Marking = true; + m_Varied = true; + } + TraverseStmt(par); - m_VariedDecls.insert(FDparam[i]); + + m_Marking = false; + m_Varied = false; + + if(!parType.isConstQualified()) + m_VariedDecls.insert(FDparam[i]); } } return true; @@ -133,12 +147,16 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { for (Decl* D : DS->decls()) { + QualType VDTy = cast(D)->getType(); + if(utils::isArrayOrPointerType(VDTy)){ + copyVarToCurBlock(cast(D)); + continue; + } if (Expr* init = cast(D)->getInit()) { m_Varied = false; TraverseStmt(init); m_Marking = true; - QualType VDTy = cast(D)->getType(); - if (m_Varied || utils::isArrayOrPointerType(VDTy)) + if (m_Varied ) copyVarToCurBlock(cast(D)); m_Marking = false; } From 96e9ceff870e8f575de511397e9a7474da754ebb Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Mon, 18 Nov 2024 01:20:21 +0100 Subject: [PATCH 2/5] Add tests --- test/Analyses/ActivityReverse.cpp | 80 +++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 3 deletions(-) diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp index 9979c85ce..4f5d3a2d6 100644 --- a/test/Analyses/ActivityReverse.cpp +++ b/test/Analyses/ActivityReverse.cpp @@ -264,7 +264,7 @@ double f8(double x){ // CHECK-NEXT: } // CHECK-NEXT: } -double fn9(double x, double const *obs) +double f9(double x, double const *obs) { double res = 0.0; for (int loopIdx0 = 0; loopIdx0 < 2; loopIdx0++) { @@ -273,7 +273,7 @@ double fn9(double x, double const *obs) return res; } -// CHECK: void fn9_grad(double x, const double *obs, double *_d_x, double *_d_obs) { +// CHECK: void f9_grad(double x, const double *obs, double *_d_x, double *_d_obs) { // CHECK-NEXT: int loopIdx0 = 0; // CHECK-NEXT: clad::tape _t1 = {}; // CHECK-NEXT: double _d_res = 0.; @@ -304,6 +304,55 @@ double fn9(double x, double const *obs) // CHECK-NEXT: } +void f10_1(double x, double* t){ + t[0] = x; +} + +double f10(double x){ + double t[3]; + f10_1(x, t); + return t[0]; +} +// CHECK: void f10_1_pullback(double x, double *t, double *_d_x, double *_d_t); +// CHECK-NEXT: void f10_grad(double x, double *_d_x) { +// CHECK-NEXT: double _d_t[3] = {0}; +// CHECK-NEXT: double t[3]; +// CHECK-NEXT: f10_1(x, t); +// CHECK-NEXT: _d_t[0] += 1; +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: f10_1_pullback(x, t, &_r0, _d_t); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double f11_1(double v, double& u){ + u = v; + return u; +} + +double f11(double x){ + double y; + double c = f11_1(x, y); + return y; +} + +// CHECK: void f11_1_pullback(double v, double &u, double _d_y, double *_d_v, double *_d_u); +// CHECK-NEXT: void f11_grad(double x, double *_d_x) { +// CHECK-NEXT: double _d_y = 0.; +// CHECK-NEXT: double y; +// CHECK-NEXT: double _t0 = y; +// CHECK-NEXT: double _d_c = 0.; +// CHECK-NEXT: double c = f11_1(x, y); +// CHECK-NEXT: _d_y += 1; +// CHECK-NEXT: { +// CHECK-NEXT: y = _t0; +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: f11_1_pullback(x, _t0, _d_c, &_r0, &_d_y); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -324,9 +373,11 @@ int main(){ TEST(f6, 3);// CHECK-EXEC: {0.00} TEST(f7, 3);// CHECK-EXEC: {1.00} TEST(f8, 3);// CHECK-EXEC: {1.00} - auto grad = clad::gradient(fn9); + auto grad = clad::gradient(f9); grad.execute(3, arr, &dx, darr); printf("%.2f\n", dx);// CHECK-EXEC: 2.00 + TEST(f10, 3);// CHECK-EXEC: {1.00} + TEST(f11, 3);// CHECK-EXEC: {1.00} } // CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { @@ -344,4 +395,27 @@ int main(){ // CHECK: void f8_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { // CHECK-NEXT: *_d_v += _d_y; +// CHECK-NEXT: } + +// CHECK: void f10_1_pullback(double x, double *t, double *_d_x, double *_d_t) { +// CHECK-NEXT: double _t0 = t[0]; +// CHECK-NEXT: t[0] = x; +// CHECK-NEXT: { +// CHECK-NEXT: t[0] = _t0; +// CHECK-NEXT: double _r_d0 = _d_t[0]; +// CHECK-NEXT: _d_t[0] = 0.; +// CHECK-NEXT: *_d_x += _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void f11_1_pullback(double v, double &u, double _d_y, double *_d_v, double *_d_u) { +// CHECK-NEXT: double _t0 = u; +// CHECK-NEXT: u = v; +// CHECK-NEXT: *_d_u += _d_y; +// CHECK-NEXT: { +// CHECK-NEXT: u = _t0; +// CHECK-NEXT: double _r_d0 = *_d_u; +// CHECK-NEXT: *_d_u = 0.; +// CHECK-NEXT: *_d_v += _r_d0; +// CHECK-NEXT: } // CHECK-NEXT: } \ No newline at end of file From f3eeeaf0d1ab43b0ca89c2ad392156a08b34e080 Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Thu, 21 Nov 2024 16:51:19 +0100 Subject: [PATCH 3/5] Don't mark nonvaried constant params --- include/clad/Differentiator/DiffPlanner.h | 7 ++++ lib/Differentiator/ActivityAnalyzer.cpp | 40 ++++++++++++++--------- lib/Differentiator/DiffPlanner.cpp | 17 ++-------- lib/Differentiator/ReverseModeVisitor.cpp | 29 ++++++++-------- test/Analyses/ActivityReverse.cpp | 28 +++++++++++++--- 5 files changed, 74 insertions(+), 47 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index d2b74592b..ccce6288c 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -144,6 +144,13 @@ struct DiffRequest { bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; + + void setToBeRecorded(std::set init) { + this->m_ActivityRunInfo.ToBeRecorded = init; + } + std::set getToBeRecorded() const { + return this->m_ActivityRunInfo.ToBeRecorded; + } }; using DiffInterval = std::vector; diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index f7810baf2..7085826b7 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -126,19 +126,26 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { clang::Expr* par = CE->getArg(i); QualType parType = FDparam[i]->getType(); - while (parType->isPointerType()) - parType = parType->getPointeeType(); - if((parType->isReferenceType() || utils::isArrayOrPointerType(parType)) && !parType.isConstQualified()){ + QualType innermostType = parType; + while (innermostType->isPointerType()) + innermostType = innermostType->getPointeeType(); + + if ((parType->isReferenceType() || + utils::isArrayOrPointerType(parType)) && + !innermostType.isConstQualified()) { m_Marking = true; m_Varied = true; } TraverseStmt(par); + if ((parType->isReferenceType() || + utils::isArrayOrPointerType(parType)) && + !innermostType.isConstQualified()) { + m_Marking = false; //? + m_Varied = false; + } - m_Marking = false; - m_Varied = false; - - if(!parType.isConstQualified()) + if ((m_Varied || !innermostType.isConstQualified())) m_VariedDecls.insert(FDparam[i]); } } @@ -147,16 +154,24 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { for (Decl* D : DS->decls()) { + QualType VDTy = cast(D)->getType(); - if(utils::isArrayOrPointerType(VDTy)){ + QualType innermost = VDTy; + while (innermost->isPointerType()) + innermost = innermost->getPointeeType(); + if (VDTy->isPointerType() && !innermost.isConstQualified()) { + copyVarToCurBlock(cast(D)); + continue; + } else if (VDTy->isArrayType()) { copyVarToCurBlock(cast(D)); continue; } + if (Expr* init = cast(D)->getInit()) { m_Varied = false; TraverseStmt(init); m_Marking = true; - if (m_Varied ) + if (m_Varied) copyVarToCurBlock(cast(D)); m_Marking = false; } @@ -165,12 +180,7 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { } bool VariedAnalyzer::VisitUnaryOperator(UnaryOperator* UnOp) { - const auto opCode = UnOp->getOpcode(); Expr* E = UnOp->getSubExpr(); - if (opCode == UO_AddrOf || opCode == UO_Deref) { - m_Varied = true; - m_Marking = true; - } TraverseStmt(E); m_Marking = false; return true; @@ -181,7 +191,7 @@ bool VariedAnalyzer::VisitDeclRefExpr(DeclRefExpr* DRE) { if (!VD) return true; - if (isVaried(VD)) + if (isVaried(VD) || VD->getType()->isArrayType()) m_Varied = true; if (m_Varied && m_Marking) diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 1f1fe761f..b9bdfcf39 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -630,20 +630,9 @@ namespace clad { return true; if (!m_ActivityRunInfo.HasAnalysisRun) { - ArrayRef FDparam = Function->parameters(); - std::vector derivedParam; - - for (auto* parameter : FDparam) { - QualType parType = parameter->getType(); - while (parType->isPointerType()) - parType = parType->getPointeeType(); - if (!parType.isConstQualified()) - derivedParam.push_back(parameter); - } - - std::copy(derivedParam.begin(), derivedParam.end(), - std::inserter(m_ActivityRunInfo.ToBeRecorded, - m_ActivityRunInfo.ToBeRecorded.end())); + if (Args) + for (const auto& dParam : DVI) + m_ActivityRunInfo.ToBeRecorded.insert(cast(dParam.param)); VariedAnalyzer analyzer(Function->getASTContext(), m_ActivityRunInfo.ToBeRecorded); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index a7d4dc6cb..03bd8ffd6 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1954,20 +1954,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // call has a different name than the function's signature parameter. pullbackRequest.CUDAGlobalArgsIndexes = globalCallArgs; - pullbackRequest.BaseFunctionName = - clad::utils::ComputeEffectiveFnName(FD); - pullbackRequest.Mode = DiffMode::experimental_pullback; - // Silence diag outputs in nested derivation process. - pullbackRequest.VerboseDiags = false; - pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; - pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; - bool isaMethod = isa(FD); - for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) - if (MD && isLambdaCallOperator(MD)) { - if (const auto* paramDecl = FD->getParamDecl(i)) - pullbackRequest.DVI.push_back(paramDecl); - } else if (DerivedCallOutputArgs[i + isaMethod]) - pullbackRequest.DVI.push_back(FD->getParamDecl(i)); + pullbackRequest.BaseFunctionName = + clad::utils::ComputeEffectiveFnName(FD); + pullbackRequest.Mode = DiffMode::experimental_pullback; + // Silence diag outputs in nested derivation process. + pullbackRequest.VerboseDiags = false; + pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; + pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; + pullbackRequest.setToBeRecorded(m_DiffReq.getToBeRecorded()); + bool isaMethod = isa(FD); + for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) + if (MD && isLambdaCallOperator(MD)) { + if (const auto* paramDecl = FD->getParamDecl(i)) + pullbackRequest.DVI.push_back(paramDecl); + } else if (DerivedCallOutputArgs[i + isaMethod]) + pullbackRequest.DVI.push_back(FD->getParamDecl(i)); FunctionDecl* pullbackFD = nullptr; if (m_ExternalSource) diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp index 4f5d3a2d6..0a93b93cf 100644 --- a/test/Analyses/ActivityReverse.cpp +++ b/test/Analyses/ActivityReverse.cpp @@ -6,6 +6,26 @@ #include "clad/Differentiator/Differentiator.h" +inline double interpolate1d(double low, double high, double val, unsigned int numBins, double const* vals) +{ + double binWidth = (high - low) / numBins; + int idx = val >= high ? numBins - 1 : std::abs((val - low) / binWidth); + + // interpolation + double central = low + (idx + 0.5) * binWidth; + if (val > low + 0.5 * binWidth && val < high - 0.5 * binWidth) { + double slope; + if (val < central) { + slope = vals[idx] - vals[idx - 1]; + } else { + slope = vals[idx + 1] - vals[idx]; + } + return vals[idx] + slope * (val - central) / binWidth; + } + + return vals[idx]; +} + double f1(double x){ double a = x*x; double b = 1; @@ -273,7 +293,7 @@ double f9(double x, double const *obs) return res; } -// CHECK: void f9_grad(double x, const double *obs, double *_d_x, double *_d_obs) { +// CHECK: void f9_grad_0(double x, const double *obs, double *_d_x) { // CHECK-NEXT: int loopIdx0 = 0; // CHECK-NEXT: clad::tape _t1 = {}; // CHECK-NEXT: double _d_res = 0.; @@ -364,7 +384,7 @@ int main(){ double arr[] = {1,2,3,4,5}; double darr[] = {0,0,0,0,0}; double result[3] = {}; - double dx; + double dx = 0; TEST(f1, 3);// CHECK-EXEC: {6.00} TEST(f2, 3);// CHECK-EXEC: {6.00} TEST(f3, 3);// CHECK-EXEC: {0.00} @@ -373,8 +393,8 @@ int main(){ TEST(f6, 3);// CHECK-EXEC: {0.00} TEST(f7, 3);// CHECK-EXEC: {1.00} TEST(f8, 3);// CHECK-EXEC: {1.00} - auto grad = clad::gradient(f9); - grad.execute(3, arr, &dx, darr); + auto grad9 = clad::gradient(f9, "x"); + grad9.execute(3, arr, &dx, darr); printf("%.2f\n", dx);// CHECK-EXEC: 2.00 TEST(f10, 3);// CHECK-EXEC: {1.00} TEST(f11, 3);// CHECK-EXEC: {1.00} From fac1aeef2a54116db9d6f6ab1c4ce7619d2f13ff Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Mon, 25 Nov 2024 18:28:25 +0100 Subject: [PATCH 4/5] Improve CallExpr analysis --- include/clad/Differentiator/DiffPlanner.h | 10 +- lib/Differentiator/ActivityAnalyzer.cpp | 39 ++- lib/Differentiator/DiffPlanner.cpp | 292 ++++++++++------------ lib/Differentiator/ReverseModeVisitor.cpp | 4 +- test/Analyses/ActivityReverse.cpp | 47 +++- 5 files changed, 199 insertions(+), 193 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index ccce6288c..a33e10f79 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -34,11 +34,12 @@ struct DiffRequest { } m_TbrRunInfo; mutable struct ActivityRunInfo { - std::set ToBeRecorded; bool HasAnalysisRun = false; } m_ActivityRunInfo; public: + /// All varied declarations. + static std::set AllVariedDecls; /// Function to be differentiated. const clang::FunctionDecl* Function = nullptr; /// Name of the base function to be differentiated. Can be different from @@ -144,13 +145,6 @@ struct DiffRequest { bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; - - void setToBeRecorded(std::set init) { - this->m_ActivityRunInfo.ToBeRecorded = init; - } - std::set getToBeRecorded() const { - return this->m_ActivityRunInfo.ToBeRecorded; - } }; using DiffInterval = std::vector; diff --git a/lib/Differentiator/ActivityAnalyzer.cpp b/lib/Differentiator/ActivityAnalyzer.cpp index 7085826b7..38d65f037 100644 --- a/lib/Differentiator/ActivityAnalyzer.cpp +++ b/lib/Differentiator/ActivityAnalyzer.cpp @@ -121,6 +121,8 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { FunctionDecl* FD = CE->getDirectCallee(); bool noHiddenParam = (CE->getNumArgs() == FD->getNumParams()); if (noHiddenParam) { + bool restoreMarking = m_Marking; + bool restoreVaried = m_Varied; MutableArrayRef FDparam = FD->parameters(); for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { clang::Expr* par = CE->getArg(i); @@ -130,25 +132,24 @@ bool VariedAnalyzer::VisitCallExpr(CallExpr* CE) { while (innermostType->isPointerType()) innermostType = innermostType->getPointeeType(); - if ((parType->isReferenceType() || - utils::isArrayOrPointerType(parType)) && - !innermostType.isConstQualified()) { - m_Marking = true; - m_Varied = true; - } - + m_Varied = false; + m_Marking = false; TraverseStmt(par); - if ((parType->isReferenceType() || - utils::isArrayOrPointerType(parType)) && - !innermostType.isConstQualified()) { - m_Marking = false; //? - m_Varied = false; - } - - if ((m_Varied || !innermostType.isConstQualified())) + if (m_Varied) m_VariedDecls.insert(FDparam[i]); + else if ((parType->isReferenceType() || + (utils::isArrayOrPointerType(parType) && + !innermostType.isConstQualified()))) { + m_Varied = true; + m_Marking = true; + TraverseStmt(par); + m_VariedDecls.insert(FDparam[i]); + } } + m_Varied = restoreVaried; + m_Marking = restoreMarking; } + return true; } @@ -159,12 +160,10 @@ bool VariedAnalyzer::VisitDeclStmt(DeclStmt* DS) { QualType innermost = VDTy; while (innermost->isPointerType()) innermost = innermost->getPointeeType(); - if (VDTy->isPointerType() && !innermost.isConstQualified()) { - copyVarToCurBlock(cast(D)); - continue; - } else if (VDTy->isArrayType()) { + if (VDTy->isArrayType() || + (VDTy->isPointerType() && !innermost.isConstQualified())) { copyVarToCurBlock(cast(D)); - continue; + m_Varied = true; } if (Expr* init = cast(D)->getInit()) { diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index b9bdfcf39..f404391a3 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -22,158 +22,146 @@ using namespace clang; namespace clad { - static SourceLocation noLoc; - - /// Returns `DeclRefExpr` node corresponding to the function, method or - /// functor argument which is to be differentiated. - /// - /// \param[in] call A clad differentiation function call expression - /// \param SemaRef Reference to Sema - DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { - struct Finder : - RecursiveASTVisitor { - Sema& m_SemaRef; - SourceLocation m_BeginLoc; - DeclRefExpr* m_FnDRE = nullptr; - Finder(Sema& SemaRef, SourceLocation beginLoc) - : m_SemaRef(SemaRef), m_BeginLoc(beginLoc) {} - - // Required for visiting lambda declarations. - bool shouldVisitImplicitCode() const { return true; } - - bool VisitDeclRefExpr(DeclRefExpr* DRE) { - if (auto VD = dyn_cast(DRE->getDecl())) { - auto varType = VD->getType().getTypePtr(); - // If variable is of class type, set `m_FnDRE` to - // `DeclRefExpr` of overloaded call operator method of - // the class type. - if (varType->isStructureOrClassType()) { - auto RD = varType->getAsCXXRecordDecl(); - TraverseDecl(RD); - } else { - TraverseStmt(VD->getInit()); - } - } else if (isa(DRE->getDecl())) - m_FnDRE = DRE; - return false; +std::set DiffRequest::AllVariedDecls; +static SourceLocation noLoc; + +/// Returns `DeclRefExpr` node corresponding to the function, method or +/// functor argument which is to be differentiated. +/// +/// \param[in] call A clad differentiation function call expression +/// \param SemaRef Reference to Sema +DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { + struct Finder : RecursiveASTVisitor { + Sema& m_SemaRef; + SourceLocation m_BeginLoc; + DeclRefExpr* m_FnDRE = nullptr; + Finder(Sema& SemaRef, SourceLocation beginLoc) + : m_SemaRef(SemaRef), m_BeginLoc(beginLoc) {} + + // Required for visiting lambda declarations. + bool shouldVisitImplicitCode() const { return true; } + + bool VisitDeclRefExpr(DeclRefExpr* DRE) { + if (auto* VD = dyn_cast(DRE->getDecl())) { + const auto* varType = VD->getType().getTypePtr(); + // If variable is of class type, set `m_FnDRE` to + // `DeclRefExpr` of overloaded call operator method of + // the class type. + if (varType->isStructureOrClassType()) { + auto* RD = varType->getAsCXXRecordDecl(); + TraverseDecl(RD); + } else { + TraverseStmt(VD->getInit()); } + } else if (isa(DRE->getDecl())) + m_FnDRE = DRE; + return false; + } - bool VisitCXXRecordDecl(CXXRecordDecl* RD) { - auto callOperatorDeclName = - m_SemaRef.getASTContext().DeclarationNames.getCXXOperatorName( - OverloadedOperatorKind::OO_Call); - LookupResult R(m_SemaRef, - callOperatorDeclName, - noLoc, - Sema::LookupNameKind::LookupMemberName); - // We do not want diagnostics that would fire because of this lookup. - R.suppressDiagnostics(); - m_SemaRef.LookupQualifiedName(R, RD); - - // Emit error diagnostics - if (R.empty()) { - const char diagFmt[] = "'%0' has no defined operator()"; - auto diagId = - m_SemaRef.Diags.getCustomDiagID(DiagnosticsEngine::Level::Error, - diagFmt); - m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); - return false; - } else if (!R.isSingleResult()) { - const char diagFmt[] = - "'%0' has multiple definitions of operator(). " - "Multiple definitions of call operators are not supported."; - auto diagId = - m_SemaRef.Diags.getCustomDiagID(DiagnosticsEngine::Level::Error, - diagFmt); - m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); - - // Emit diagnostics for candidate functions - for (auto oper = R.begin(), operEnd = R.end(); oper != operEnd; - ++oper) { - auto candidateFn = cast(oper.getDecl()); - m_SemaRef.NoteOverloadCandidate(candidateFn, - cast(candidateFn)); - } - return false; - } else if (R.isSingleResult() == 1 && - cast(R.getFoundDecl())->getAccess() != - AccessSpecifier::AS_public) { - const char diagFmt[] = - "'%0' contains %1 call operator. Differentiation of " - "private/protected call operator is not supported."; - - auto diagId = - m_SemaRef.Diags.getCustomDiagID(DiagnosticsEngine::Level::Error, - diagFmt); - // Compute access specifier name so that it can be used in - // diagnostic message. - const char* callOperatorAS = - (cast(R.getFoundDecl())->getAccess() == - AccessSpecifier::AS_private - ? "private" - : "protected"); - m_SemaRef.Diag(m_BeginLoc, diagId) - << RD->getName() << callOperatorAS; - auto callOperator = cast(R.getFoundDecl()); - - bool isImplicit = true; - - // compute if the corresponding access specifier of the found - // call operator is implicit or explicit. - for (auto decl : RD->decls()) { - if (decl == callOperator) - break; - if (isa(decl)) { - isImplicit = false; - break; - } - } - - // Emit diagnostics for the found call operator - m_SemaRef.Diag(callOperator->getBeginLoc(), - diag::note_access_natural) - << (unsigned)(callOperator->getAccess() == - AccessSpecifier::AS_protected) - << isImplicit; - - return false; + bool VisitCXXRecordDecl(CXXRecordDecl* RD) { + auto callOperatorDeclName = + m_SemaRef.getASTContext().DeclarationNames.getCXXOperatorName( + OverloadedOperatorKind::OO_Call); + LookupResult R(m_SemaRef, callOperatorDeclName, noLoc, + Sema::LookupNameKind::LookupMemberName); + // We do not want diagnostics that would fire because of this lookup. + R.suppressDiagnostics(); + m_SemaRef.LookupQualifiedName(R, RD); + + // Emit error diagnostics + if (R.empty()) { + const char diagFmt[] = "'%0' has no defined operator()"; + auto diagId = m_SemaRef.Diags.getCustomDiagID( + DiagnosticsEngine::Level::Error, diagFmt); + m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); + } else if (!R.isSingleResult()) { + const char diagFmt[] = + "'%0' has multiple definitions of operator(). " + "Multiple definitions of call operators are not supported."; + auto diagId = m_SemaRef.Diags.getCustomDiagID( + DiagnosticsEngine::Level::Error, diagFmt); + m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName(); + + // Emit diagnostics for candidate functions + for (auto oper = R.begin(), operEnd = R.end(); oper != operEnd; + ++oper) { + auto* candidateFn = cast(oper.getDecl()); + m_SemaRef.NoteOverloadCandidate(candidateFn, + cast(candidateFn)); + } + } else if (R.isSingleResult() == 1 && + cast(R.getFoundDecl())->getAccess() != + AccessSpecifier::AS_public) { + const char diagFmt[] = + "'%0' contains %1 call operator. Differentiation of " + "private/protected call operator is not supported."; + + auto diagId = m_SemaRef.Diags.getCustomDiagID( + DiagnosticsEngine::Level::Error, diagFmt); + // Compute access specifier name so that it can be used in + // diagnostic message. + const char* callOperatorAS = + (cast(R.getFoundDecl())->getAccess() == + AccessSpecifier::AS_private + ? "private" + : "protected"); + m_SemaRef.Diag(m_BeginLoc, diagId) << RD->getName() << callOperatorAS; + auto* callOperator = cast(R.getFoundDecl()); + + bool isImplicit = true; + + // compute if the corresponding access specifier of the found + // call operator is implicit or explicit. + for (auto* decl : RD->decls()) { + if (decl == callOperator) + break; + if (isa(decl)) { + isImplicit = false; + break; } - - assert(R.isSingleResult() && - "Multiple definitions of call operators are not supported"); - assert(R.isSingleResult() == 1 && - cast(R.getFoundDecl())->getAccess() == - AccessSpecifier::AS_public && - "Differentiation of private/protected call operators are " - "not supported"); - auto callOperator = cast(R.getFoundDecl()); - // Creating `DeclRefExpr` of the found overloaded call operator - // method, to maintain consistency with member function - // differentiation. - CXXScopeSpec CSS; - utils::BuildNNS(m_SemaRef, callOperator->getDeclContext(), CSS, - /*addGlobalNS=*/true); - - // `ExprValueKind::VK_RValue` is used because functions are - // decomposed to function pointers and thus a temporary is - // created for the function pointer. - auto newFnDRE = clad_compat::GetResult( - m_SemaRef.BuildDeclRefExpr(callOperator, - callOperator->getType(), - CLAD_COMPAT_ExprValueKind_R_or_PR_Value, - noLoc, - &CSS)); - m_FnDRE = cast(newFnDRE); - return false; } - } finder(SemaRef, call->getArg(0)->getBeginLoc()); - finder.TraverseStmt(call->getArg(0)); - assert(cast(call->getDirectCallee()->getDeclContext()) - ->getName() == "clad" && - "Should be called for clad:: special functions!"); - return finder.m_FnDRE; - } + // Emit diagnostics for the found call operator + m_SemaRef.Diag(callOperator->getBeginLoc(), diag::note_access_natural) + << (unsigned)(callOperator->getAccess() == + AccessSpecifier::AS_protected) + << isImplicit; + + } else { + assert(R.isSingleResult() && + "Multiple definitions of call operators are not supported"); + assert(R.isSingleResult() == 1 && + cast(R.getFoundDecl())->getAccess() == + AccessSpecifier::AS_public && + "Differentiation of private/protected call operators are " + "not supported"); + auto* callOperator = cast(R.getFoundDecl()); + // Creating `DeclRefExpr` of the found overloaded call operator + // method, to maintain consistency with member function + // differentiation. + CXXScopeSpec CSS; + utils::BuildNNS(m_SemaRef, callOperator->getDeclContext(), CSS, + /*addGlobalNS=*/true); + + // `ExprValueKind::VK_RValue` is used because functions are + // decomposed to function pointers and thus a temporary is + // created for the function pointer. + auto* newFnDRE = + clad_compat::GetResult(m_SemaRef.BuildDeclRefExpr( + callOperator, callOperator->getType(), + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noLoc, &CSS)); + m_FnDRE = cast(newFnDRE); + } + return false; + } + } finder(SemaRef, call->getArg(0)->getBeginLoc()); + finder.TraverseStmt(call->getArg(0)); + + assert(cast(call->getDirectCallee()->getDeclContext()) + ->getName() == "clad" && + "Should be called for clad:: special functions!"); + return finder.m_FnDRE; +} void DiffRequest::updateCall(FunctionDecl* FD, FunctionDecl* OverloadedFD, Sema& SemaRef) { @@ -632,15 +620,13 @@ namespace clad { if (!m_ActivityRunInfo.HasAnalysisRun) { if (Args) for (const auto& dParam : DVI) - m_ActivityRunInfo.ToBeRecorded.insert(cast(dParam.param)); - - VariedAnalyzer analyzer(Function->getASTContext(), - m_ActivityRunInfo.ToBeRecorded); + AllVariedDecls.insert(cast(dParam.param)); + VariedAnalyzer analyzer(Function->getASTContext(), AllVariedDecls); analyzer.Analyze(Function); m_ActivityRunInfo.HasAnalysisRun = true; } - auto found = m_ActivityRunInfo.ToBeRecorded.find(VD); - return found != m_ActivityRunInfo.ToBeRecorded.end(); + auto found = AllVariedDecls.find(VD); + return found != AllVariedDecls.end(); } bool DiffCollector::VisitCallExpr(CallExpr* E) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 03bd8ffd6..61532f656 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1672,8 +1672,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. - if (utils::IsReferenceOrPointerArg(arg) || - !m_DiffReq.shouldHaveAdjoint(PVD)) { + if (utils::IsReferenceOrPointerArg(arg)) { argDiff = Visit(arg); CallArgDx.push_back(argDiff.getExpr_dx()); } else { @@ -1961,7 +1960,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.VerboseDiags = false; pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; - pullbackRequest.setToBeRecorded(m_DiffReq.getToBeRecorded()); bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) if (MD && isLambdaCallOperator(MD)) { diff --git a/test/Analyses/ActivityReverse.cpp b/test/Analyses/ActivityReverse.cpp index 0a93b93cf..490bafede 100644 --- a/test/Analyses/ActivityReverse.cpp +++ b/test/Analyses/ActivityReverse.cpp @@ -373,6 +373,32 @@ double f11(double x){ // CHECK-NEXT: } // CHECK-NEXT: } +double f12_1(double y, const double* obs){ + double nopull = interpolate1d(1.0, 3.0, 5.0, 5, obs); + return nopull; +} +double f12(double x, const double* obs){ + double pull = f12_1(x, obs); + return pull*x; +} + +// CHECK: void f12_1_pullback(double y, const double *obs, double _d_y0, double *_d_y); +// CHECK-NEXT: void f12_grad_0(double x, const double *obs, double *_d_x) { +// CHECK-NEXT: double _d_pull = 0.; +// CHECK-NEXT: double pull = f12_1(x, obs); +// CHECK-NEXT: { +// CHECK-NEXT: _d_pull += 1 * x; +// CHECK-NEXT: *_d_x += pull * 1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0.; +// CHECK-NEXT: f12_1_pullback(x, obs, _d_pull, &_r0); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + + + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -384,7 +410,6 @@ int main(){ double arr[] = {1,2,3,4,5}; double darr[] = {0,0,0,0,0}; double result[3] = {}; - double dx = 0; TEST(f1, 3);// CHECK-EXEC: {6.00} TEST(f2, 3);// CHECK-EXEC: {6.00} TEST(f3, 3);// CHECK-EXEC: {0.00} @@ -393,24 +418,24 @@ int main(){ TEST(f6, 3);// CHECK-EXEC: {0.00} TEST(f7, 3);// CHECK-EXEC: {1.00} TEST(f8, 3);// CHECK-EXEC: {1.00} + double dx9 = 0; auto grad9 = clad::gradient(f9, "x"); - grad9.execute(3, arr, &dx, darr); - printf("%.2f\n", dx);// CHECK-EXEC: 2.00 + grad9.execute(3, arr, &dx9, darr); + printf("%.2f\n", dx9);// CHECK-EXEC: 2.00 TEST(f10, 3);// CHECK-EXEC: {1.00} TEST(f11, 3);// CHECK-EXEC: {1.00} + double dx12 = 0; + auto grad = clad::gradient(f12, "x"); + grad.execute(3, arr, &dx12, darr); + printf("%.2f\n", dx12);// CHECK-EXEC: 5.00 } // CHECK: void f4_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { -// CHECK-NEXT: double _d_k = 0.; // CHECK-NEXT: double k = 2 * u; // CHECK-NEXT: double _d_n = 0.; // CHECK-NEXT: double n = 2 * v; -// CHECK-NEXT: { -// CHECK-NEXT: _d_n += _d_y * k; -// CHECK-NEXT: _d_k += n * _d_y; -// CHECK-NEXT: } +// CHECK-NEXT: _d_n += _d_y * k; // CHECK-NEXT: *_d_v += 2 * _d_n; -// CHECK-NEXT: *_d_u += 2 * _d_k; // CHECK-NEXT: } // CHECK: void f8_1_pullback(double v, double u, double _d_y, double *_d_v, double *_d_u) { @@ -438,4 +463,8 @@ int main(){ // CHECK-NEXT: *_d_u = 0.; // CHECK-NEXT: *_d_v += _r_d0; // CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void f12_1_pullback(double y, const double *obs, double _d_y0, double *_d_y) { +// CHECK-NEXT: double nopull = interpolate1d(1., 3., 5., 5, obs); // CHECK-NEXT: } \ No newline at end of file From 7bf55f3a56597cbee522c586f599ce4be64c0bcb Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Thu, 12 Dec 2024 16:21:19 +0100 Subject: [PATCH 5/5] Remove static set from DiffReq --- .../clad/Differentiator/DerivativeBuilder.h | 2 ++ include/clad/Differentiator/DiffPlanner.h | 24 +++++++++++--- lib/Differentiator/DiffPlanner.cpp | 32 ++++++++++++------- lib/Differentiator/ReverseModeVisitor.cpp | 32 ++++++++++--------- 4 files changed, 59 insertions(+), 31 deletions(-) diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 5e9d54ac2..8a70fc6f3 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -13,6 +13,7 @@ #include "clang/Sema/Sema.h" #include "clad/Differentiator/DerivedFnCollector.h" #include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/DynamicGraph.h" #include #include @@ -72,6 +73,7 @@ namespace clad { class DerivativeBuilder { private: friend class VisitorBase; + friend class DiffRequest; friend class BaseForwardModeVisitor; friend class PushForwardModeVisitor; friend class VectorForwardModeVisitor; diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index a33e10f79..b025b5c8f 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -3,12 +3,13 @@ #include "clang/AST/RecursiveASTVisitor.h" #include "llvm/ADT/SmallSet.h" +#include +#include +#include "clad/Differentiator/DerivativeBuilder.h" #include "clad/Differentiator/DiffMode.h" #include "clad/Differentiator/DynamicGraph.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" -#include -#include namespace clang { class CallExpr; class CompilerInstance; @@ -21,6 +22,7 @@ class Type; } // namespace clang namespace clad { +class DerivativeBuilder; /// A struct containing information about request to differentiate a function. struct DiffRequest { @@ -34,12 +36,13 @@ struct DiffRequest { } m_TbrRunInfo; mutable struct ActivityRunInfo { + std::set VariedDecls; bool HasAnalysisRun = false; } m_ActivityRunInfo; public: - /// All varied declarations. - static std::set AllVariedDecls; + const DerivativeBuilder* Builder = nullptr; + // static std::set AllVariedDecls; /// Function to be differentiated. const clang::FunctionDecl* Function = nullptr; /// Name of the base function to be differentiated. Can be different from @@ -128,7 +131,8 @@ struct DiffRequest { Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && EnableVariedAnalysis == other.EnableVariedAnalysis && DVI == other.DVI && use_enzyme == other.use_enzyme && - DeclarationOnly == other.DeclarationOnly; + DeclarationOnly == other.DeclarationOnly && + getVariedDecls() == other.getVariedDecls(); } const clang::FunctionDecl* operator->() const { return Function; } @@ -145,6 +149,16 @@ struct DiffRequest { bool shouldBeRecorded(clang::Expr* E) const; bool shouldHaveAdjoint(const clang::VarDecl* VD) const; + + void setVariedDecls(std::set init) { + for (auto* vd : init) + this->m_ActivityRunInfo.VariedDecls.insert(vd); + } + std::set getVariedDecls() const { + return this->m_ActivityRunInfo.VariedDecls; + } + DiffRequest() {} + DiffRequest(DerivativeBuilder& builder) : Builder(&builder) {} }; using DiffInterval = std::vector; diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index f404391a3..ec5aaf6ed 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -22,8 +22,8 @@ using namespace clang; namespace clad { -std::set DiffRequest::AllVariedDecls; -static SourceLocation noLoc; +// std::set DiffRequest::AllVariedDecls; +static SourceLocation noloc; /// Returns `DeclRefExpr` node corresponding to the function, method or /// functor argument which is to be differentiated. @@ -62,7 +62,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { auto callOperatorDeclName = m_SemaRef.getASTContext().DeclarationNames.getCXXOperatorName( OverloadedOperatorKind::OO_Call); - LookupResult R(m_SemaRef, callOperatorDeclName, noLoc, + LookupResult R(m_SemaRef, callOperatorDeclName, noloc, Sema::LookupNameKind::LookupMemberName); // We do not want diagnostics that would fire because of this lookup. R.suppressDiagnostics(); @@ -149,7 +149,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { auto* newFnDRE = clad_compat::GetResult(m_SemaRef.BuildDeclRefExpr( callOperator, callOperator->getType(), - CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noLoc, &CSS)); + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, noloc, &CSS)); m_FnDRE = cast(newFnDRE); } return false; @@ -198,7 +198,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { auto kernelArgIdx = numArgs - 1; auto* cudaKernelFlag = SemaRef - .ActOnCXXBoolLiteral(noLoc, + .ActOnCXXBoolLiteral(noloc, replacementFD->hasAttr() ? tok::kw_true : tok::kw_false) @@ -209,7 +209,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { // Create ref to generated FD. DeclRefExpr* DRE = - DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noLoc, replacementFD, + DeclRefExpr::Create(C, oldDRE->getQualifierLoc(), noloc, replacementFD, /*RefersToEnclosingVariableOrCapture=*/false, replacementFD->getNameInfo(), replacementFD->getType(), oldDRE->getValueKind()); @@ -225,7 +225,7 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { // Add the "&" operator auto* newUnOp = SemaRef - .BuildUnaryOp(nullptr, noLoc, UnaryOperatorKind::UO_AddrOf, DRE) + .BuildUnaryOp(nullptr, noloc, UnaryOperatorKind::UO_AddrOf, DRE) .get(); call->setArg(derivedFnArgIdx, newUnOp); } @@ -618,15 +618,25 @@ DeclRefExpr* getArgFunction(CallExpr* call, Sema& SemaRef) { return true; if (!m_ActivityRunInfo.HasAnalysisRun) { + if (Builder) + for (auto diffreq : this->Builder->m_DiffRequestGraph.getNodes()) + for (auto vd : diffreq.getVariedDecls()) + m_ActivityRunInfo.VariedDecls.insert(vd); + if (Args) for (const auto& dParam : DVI) - AllVariedDecls.insert(cast(dParam.param)); - VariedAnalyzer analyzer(Function->getASTContext(), AllVariedDecls); + m_ActivityRunInfo.VariedDecls.insert(cast(dParam.param)); + VariedAnalyzer analyzer(Function->getASTContext(), + m_ActivityRunInfo.VariedDecls); analyzer.Analyze(Function); m_ActivityRunInfo.HasAnalysisRun = true; + if (Builder) + this->Builder->m_DiffRequestGraph.addNode(*this); } - auto found = AllVariedDecls.find(VD); - return found != AllVariedDecls.end(); + auto found = m_ActivityRunInfo.VariedDecls.find(VD); + return found != m_ActivityRunInfo.VariedDecls.end(); + + return false; } bool DiffCollector::VisitCallExpr(CallExpr* E) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 61532f656..7b8f9966c 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -216,6 +216,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource->ActAfterParsingDiffArgs(m_DiffReq, args); auto derivativeBaseName = m_DiffReq.BaseFunctionName; + // llvm::errs() << "\nBaseFunctionName: " << derivativeBaseName << "\n"; std::string gradientName = derivativeBaseName + funcPostfix(); // To be consistent with older tests, nothing is appended to 'f_grad' if // we differentiate w.r.t. all the parameters at once. @@ -1946,27 +1947,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Overloaded derivative was not found, request the CladPlugin to // derive the called function. - DiffRequest pullbackRequest{}; + DiffRequest pullbackRequest(m_Builder); pullbackRequest.Function = FD; // Mark the indexes of the global args. Necessary if the argument of the // call has a different name than the function's signature parameter. pullbackRequest.CUDAGlobalArgsIndexes = globalCallArgs; - pullbackRequest.BaseFunctionName = - clad::utils::ComputeEffectiveFnName(FD); - pullbackRequest.Mode = DiffMode::experimental_pullback; - // Silence diag outputs in nested derivation process. - pullbackRequest.VerboseDiags = false; - pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; - pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; - bool isaMethod = isa(FD); - for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) - if (MD && isLambdaCallOperator(MD)) { - if (const auto* paramDecl = FD->getParamDecl(i)) - pullbackRequest.DVI.push_back(paramDecl); - } else if (DerivedCallOutputArgs[i + isaMethod]) - pullbackRequest.DVI.push_back(FD->getParamDecl(i)); + pullbackRequest.BaseFunctionName = + clad::utils::ComputeEffectiveFnName(FD); + pullbackRequest.Mode = DiffMode::experimental_pullback; + // Silence diag outputs in nested derivation process. + pullbackRequest.VerboseDiags = false; + pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; + pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis; + pullbackRequest.setVariedDecls(m_DiffReq.getVariedDecls()); + bool isaMethod = isa(FD); + for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) + if (MD && isLambdaCallOperator(MD)) { + if (const auto* paramDecl = FD->getParamDecl(i)) + pullbackRequest.DVI.push_back(paramDecl); + } else if (DerivedCallOutputArgs[i + isaMethod]) + pullbackRequest.DVI.push_back(FD->getParamDecl(i)); FunctionDecl* pullbackFD = nullptr; if (m_ExternalSource)