Skip to content

Commit

Permalink
Fix DVI info for pullback methods
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 3, 2024
1 parent 8788574 commit 9eba23e
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,25 @@ namespace clad {
void DiffRequest::UpdateDiffParamsInfo(Sema& semaRef) {
// Diff info for pullbacks is generated automatically,
// its parameters are not provided by the user.
if (Mode == DiffMode::experimental_pullback)
if (Mode == DiffMode::experimental_pullback) {
if (!Function->getPreviousDecl())
return;
const FunctionDecl* FD = Function->getPreviousDecl();
// Might need to update DVI args, as they may be pointing to the
// declaration parameters, not the definition parameters.
for (size_t i = 0, e = DVI.size(), paramIdx = 0;
i < e && paramIdx < FD->getNumParams(); ++i) {
auto* param = DVI[i].param;
while (paramIdx < FD->getNumParams() &&
FD->getParamDecl(paramIdx) != param) {
++paramIdx;
}
if (paramIdx != FD->getNumParams())
// Update the parameter to point to the definition parameter.
DVI[i].param = Function->getParamDecl(paramIdx);
}
return;
}
DVI.clear();
auto& C = semaRef.getASTContext();
const Expr* diffArgs = Args;
Expand Down

0 comments on commit 9eba23e

Please sign in to comment.