From c30acaa950fb52ea3a042cd69bb0619961031fc7 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Tue, 8 Oct 2024 11:15:53 +0300 Subject: [PATCH 1/2] Remove excessive FD and request parameters from DeriveVectorMode --- .../clad/Differentiator/VectorForwardModeVisitor.h | 5 +---- lib/Differentiator/DerivativeBuilder.cpp | 2 +- lib/Differentiator/VectorForwardModeVisitor.cpp | 12 ++++-------- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/include/clad/Differentiator/VectorForwardModeVisitor.h b/include/clad/Differentiator/VectorForwardModeVisitor.h index c5bccdeda..f906c6c14 100644 --- a/include/clad/Differentiator/VectorForwardModeVisitor.h +++ b/include/clad/Differentiator/VectorForwardModeVisitor.h @@ -30,13 +30,10 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor { ///\brief Produces the first derivative of a given function with /// respect to multiple parameters. /// - ///\param[in] FD - the function that will be differentiated. - /// ///\returns The differentiated and potentially created enclosing /// context. /// - DerivativeAndOverload DeriveVectorMode(const clang::FunctionDecl* FD, - const DiffRequest& request); + DerivativeAndOverload DeriveVectorMode(); /// Builds an overload for the vector mode function that has derived params /// for all the arguments of the requested function and it calls the original diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index ada7153c6..946a68719 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -408,7 +408,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { result = V.DerivePushforward(); } else if (request.Mode == DiffMode::vector_forward_mode) { VectorForwardModeVisitor V(*this, request); - result = V.DeriveVectorMode(FD, request); + result = V.DeriveVectorMode(); } else if (request.Mode == DiffMode::experimental_vector_pushforward) { VectorPushForwardModeVisitor V(*this, request); result = V.DerivePushforward(); diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index c3756619a..7bce4dd4b 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -53,20 +53,16 @@ void VectorForwardModeVisitor::SetIndependentVarsExpr(Expr* IndVarCountExpr) { m_IndVarCountExpr = IndVarCountExpr; } -DerivativeAndOverload -VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, - const DiffRequest& request) { - assert(m_DiffReq == request); +DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() { + const FunctionDecl* FD = m_DiffReq.Function; assert(m_DiffReq.Mode == DiffMode::vector_forward_mode); DiffParams args{}; - DiffInputVarsInfo DVI; - DVI = request.DVI; - for (auto dParam : DVI) + for (auto dParam : m_DiffReq.DVI) args.push_back(dParam.param); // Generate name for the derivative function. - std::string derivedFnName = request.BaseFunctionName + "_dvec"; + std::string derivedFnName = m_DiffReq.BaseFunctionName + "_dvec"; if (args.size() != FD->getNumParams()) { for (auto arg : args) { auto it = std::find(FD->param_begin(), FD->param_end(), arg); From 75894b5cffd3afbed70df70e0223b73586d98a25 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Wed, 13 Nov 2024 08:09:16 +0200 Subject: [PATCH 2/2] Update lib/Differentiator/VectorForwardModeVisitor.cpp Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- lib/Differentiator/VectorForwardModeVisitor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index 7bce4dd4b..c471c517d 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -58,7 +58,7 @@ DerivativeAndOverload VectorForwardModeVisitor::DeriveVectorMode() { assert(m_DiffReq.Mode == DiffMode::vector_forward_mode); DiffParams args{}; - for (auto dParam : m_DiffReq.DVI) + for (const auto& dParam : m_DiffReq.DVI) args.push_back(dParam.param); // Generate name for the derivative function.