Skip to content

Commit

Permalink
Fix DVI info for pullback methods
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Apr 12, 2024
1 parent ecda4d1 commit 33634e2
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
3 changes: 2 additions & 1 deletion lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,9 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;

SourceLocation loc{m_Function->getLocation()};
DeclWithContext cloneFunctionResult = m_Builder.cloneFunction(
m_Function, *this, DC, noLoc, derivedFnName, derivedFnType);
m_Function, *this, DC, loc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

llvm::SmallVector<ParmVarDecl*, 16> params;
Expand Down
19 changes: 18 additions & 1 deletion lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,25 @@ namespace clad {
void DiffRequest::UpdateDiffParamsInfo(Sema& semaRef) {
// Diff info for pullbacks is generated automatically,
// its parameters are not provided by the user.
if (Mode == DiffMode::experimental_pullback)
if (Mode == DiffMode::experimental_pullback) {
if (!Function->getPreviousDecl())
return;
const FunctionDecl* FD = Function->getPreviousDecl();
// Might need to update DVI args, as they may be pointing to the
// declaration parameters, not the definition parameters.
for (size_t i = 0, e = DVI.size(), paramIdx = 0;
i < e && paramIdx < FD->getNumParams(); ++i) {
const auto* param = DVI[i].param;
while (paramIdx < FD->getNumParams() &&
FD->getParamDecl(paramIdx) != param) {
++paramIdx;
}
if (paramIdx != FD->getNumParams())
// Update the parameter to point to the definition parameter.
DVI[i].param = Function->getParamDecl(paramIdx);
}
return;
}
DVI.clear();
auto& C = semaRef.getASTContext();
const Expr* diffArgs = Args;
Expand Down
35 changes: 35 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,34 @@ double fn17 (double x, double* y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double sq_defined_later(double x);

// CHECK: void sq_defined_later_pullback(double x, double _d_y, double *_d_x) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: *_d_x += _d_y * x;
// CHECK-NEXT: *_d_x += x * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn18(double x, double y) {
return sq_defined_later(x) + sq_defined_later(y);
}

// CHECK: void fn18_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 0;
// CHECK-NEXT: sq_defined_later_pullback(x, 1, &_r0);
// CHECK-NEXT: *_d_x += _r0;
// CHECK-NEXT: double _r1 = 0;
// CHECK-NEXT: sq_defined_later_pullback(y, 1, &_r1);
// CHECK-NEXT: *_d_y += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -902,4 +930,11 @@ int main() {
double y[] = {3.0, 2.0}, dx = 0;
fn17_grad_0.execute(5, y, &dx);
printf("{%.2f}\n", dx); // CHECK-EXEC: {2.00}

INIT(fn18);
TEST2(fn18, 3, 5); // CHECK-EXEC: {6.00, 10.00}
}

double sq_defined_later(double x) {
return x*x;
}

0 comments on commit 33634e2

Please sign in to comment.