From 7bfd4aa78509bb9efaacd253042d442542207aaf Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Wed, 3 Apr 2024 19:50:19 +0200 Subject: [PATCH] Fix DVI info for pullback methods --- lib/Differentiator/BaseForwardModeVisitor.cpp | 3 +- lib/Differentiator/DiffPlanner.cpp | 19 +++++++++- test/Gradient/FunctionCalls.C | 35 +++++++++++++++++++ tools/ClangPlugin.h | 3 +- 4 files changed, 57 insertions(+), 3 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index c27c1dbf5..13ba31176 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -438,8 +438,9 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, auto* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; + SourceLocation loc{m_Function->getLocation()}; DeclWithContext cloneFunctionResult = m_Builder.cloneFunction( - m_Function, *this, DC, noLoc, derivedFnName, derivedFnType); + m_Function, *this, DC, loc, derivedFnName, derivedFnType); m_Derivative = cloneFunctionResult.first; llvm::SmallVector params; diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index a86576d8e..7d80c3e39 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -275,8 +275,25 @@ namespace clad { void DiffRequest::UpdateDiffParamsInfo(Sema& semaRef) { // Diff info for pullbacks is generated automatically, // its parameters are not provided by the user. - if (Mode == DiffMode::experimental_pullback) + if (Mode == DiffMode::experimental_pullback) { + if (!Function->getPreviousDecl()) + return; + const FunctionDecl* FD = Function->getPreviousDecl(); + // Might need to update DVI args, as they may be pointing to the + // declaration parameters, not the definition parameters. + for (size_t i = 0, e = DVI.size(), paramIdx = 0; + i < e && paramIdx < FD->getNumParams(); ++i) { + const auto* param = DVI[i].param; + while (paramIdx < FD->getNumParams() && + FD->getParamDecl(paramIdx) != param) { + ++paramIdx; + } + if (paramIdx != FD->getNumParams()) + // Update the parameter to point to the definition parameter. + DVI[i].param = Function->getParamDecl(paramIdx); + } return; + } DVI.clear(); auto& C = semaRef.getASTContext(); const Expr* diffArgs = Args; diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index e522d4227..93a7cb89a 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -811,6 +811,34 @@ double fn17 (double x, double* y) { //CHECK-NEXT: } //CHECK-NEXT: } +double sq_defined_later(double x); + +// CHECK: void sq_defined_later_pullback(double x, double _d_y, double *_d_x) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: *_d_x += _d_y * x; +// CHECK-NEXT: *_d_x += x * _d_y; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double fn18(double x, double y) { + return sq_defined_later(x) + sq_defined_later(y); +} + +// CHECK: void fn18_grad(double x, double y, double *_d_x, double *_d_y) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: sq_defined_later_pullback(x, 1, &_r0); +// CHECK-NEXT: *_d_x += _r0; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: sq_defined_later_pullback(y, 1, &_r1); +// CHECK-NEXT: *_d_y += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } + template void reset(T* arr, int n) { for (int i=0; i