From 4959e67afe283edb3d521e9281b276a957a0de84 Mon Sep 17 00:00:00 2001 From: kchristin Date: Mon, 28 Oct 2024 12:16:38 +0200 Subject: [PATCH] Specify true args in BuildCallToCustomDerivativeOrNumericalDiff --- lib/Differentiator/ReverseModeVisitor.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ef41f42b2..04f626286 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1952,7 +1952,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforward, pushforwardCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext()), true, true, + const_cast(FD->getDeclContext()), + /*forCustomDerv=*/true, /*namespaceShouldExist=*/true, CUDAExecConfig); if (OverloadedDerivedFn) asGrad = false; @@ -2054,7 +2055,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPullback, pullbackCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext()), true, true, + const_cast(FD->getDeclContext()), + /*forCustomDerv=*/true, /*namespaceShouldExist=*/true, CUDAExecConfig); if (baseDiff.getExpr()) pullbackCallArgs.erase(pullbackCallArgs.begin());