diff --git a/include/clad/Differentiator/VectorForwardModeVisitor.h b/include/clad/Differentiator/VectorForwardModeVisitor.h index c5bccdeda..f906c6c14 100644 --- a/include/clad/Differentiator/VectorForwardModeVisitor.h +++ b/include/clad/Differentiator/VectorForwardModeVisitor.h @@ -30,13 +30,10 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor { ///\brief Produces the first derivative of a given function with /// respect to multiple parameters. /// - ///\param[in] FD - the function that will be differentiated. - /// ///\returns The differentiated and potentially created enclosing /// context. /// - DerivativeAndOverload DeriveVectorMode(const clang::FunctionDecl* FD, - const DiffRequest& request); + DerivativeAndOverload DeriveVectorMode(); /// Builds an overload for the vector mode function that has derived params /// for all the arguments of the requested function and it calls the original diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index ada7153c6..946a68719 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -408,7 +408,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { result = V.DerivePushforward(); } else if (request.Mode == DiffMode::vector_forward_mode) { VectorForwardModeVisitor V(*this, request); - result = V.DeriveVectorMode(FD, request); + result = V.DeriveVectorMode(); } else if (request.Mode == DiffMode::experimental_vector_pushforward) { VectorPushForwardModeVisitor V(*this, request); result = V.DerivePushforward(); diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index c3756619a..c471c517d 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -53,20 +53,16 @@ void VectorForwardModeVisitor::SetIndependentVarsExpr(Expr* IndVarCountExpr) { m_IndVarCountExpr = IndVarCountExpr; } -DerivativeAndOverload -VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, - const DiffRequest& request) { - assert(m_DiffReq == request); +DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() { + const FunctionDecl* FD = m_DiffReq.Function; assert(m_DiffReq.Mode == DiffMode::vector_forward_mode); DiffParams args{}; - DiffInputVarsInfo DVI; - DVI = request.DVI; - for (auto dParam : DVI) + for (const auto& dParam : m_DiffReq.DVI) args.push_back(dParam.param); // Generate name for the derivative function. - std::string derivedFnName = request.BaseFunctionName + "_dvec"; + std::string derivedFnName = m_DiffReq.BaseFunctionName + "_dvec"; if (args.size() != FD->getNumParams()) { for (auto arg : args) { auto it = std::find(FD->param_begin(), FD->param_end(), arg);