Skip to content

Commit

Permalink
Specify true args in BuildCallToCustomDerivativeOrNumericalDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Oct 28, 2024
1 parent 0d23492 commit 5f8b54b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1952,7 +1952,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, pushforwardCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()), true, true,
const_cast<DeclContext*>(FD->getDeclContext()),
/*forCustomDerv=*/true, /*namespaceShouldExist=*/true,
CUDAExecConfig);
if (OverloadedDerivedFn)
asGrad = false;
Expand Down Expand Up @@ -2054,7 +2055,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()), true, true,
const_cast<DeclContext*>(FD->getDeclContext()),
/*forCustomDerv=*/true, /*namespaceShouldExist=*/true,
CUDAExecConfig);
if (baseDiff.getExpr())
pullbackCallArgs.erase(pullbackCallArgs.begin());
Expand Down

0 comments on commit 5f8b54b

Please sign in to comment.