diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index ef00863c6..9ac7165bf 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -208,6 +208,12 @@ namespace clad { /// is already derived or not. void AddEdgeToGraph(const DiffRequest& request, bool alreadyDerived = false); + + /// Handles processing of a diff request when an existing derivative is + /// being processed. + /// \param[in] Request The request to be processed. + /// \returns The derivative function if found, nullptr otherwise. + clang::FunctionDecl* HandleNestedDiffRequest(DiffRequest& request); }; } // end namespace clad diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 461155a44..075656c35 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1209,19 +1209,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // Check if request already derived in DerivedFunctions. FunctionDecl* pushforwardFD = - m_Builder.FindDerivedFunction(pushforwardFnRequest); - if (!pushforwardFD) { - // Derive declaration of the pushforward function. - pushforwardFnRequest.DeclarationOnly = true; - pushforwardFD = - plugin::ProcessDiffRequest(m_CladPlugin, pushforwardFnRequest); - - // Add the request to derive the definition of the pushforward function - // into the queue. - pushforwardFnRequest.DeclarationOnly = false; - pushforwardFnRequest.DerivedFDPrototype = pushforwardFD; - } - m_Builder.AddEdgeToGraph(pushforwardFnRequest); + m_Builder.HandleNestedDiffRequest(pushforwardFnRequest); if (pushforwardFD) { if (baseDiff.getExpr()) { diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 43fd1ccaa..24cd53b0d 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -295,6 +295,33 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { return OverloadedFn; } + clang::FunctionDecl* + DerivativeBuilder::HandleNestedDiffRequest(DiffRequest& request) { + // FIXME: Find a way to do this without accessing plugin namespace functions + bool alreadyDerived = true; + FunctionDecl* derivative = this->FindDerivedFunction(request); + if (!derivative) { + alreadyDerived = false; + // Derive declaration of the the forward mode derivative. + request.DeclarationOnly = true; + derivative = plugin::ProcessDiffRequest(m_CladPlugin, request); + + // It is possible that user has provided a custom derivative for the + // derivative function. In that case, we should not derive the definition + // again. + if (derivative && + (derivative->isDefined() || m_DFC.IsCustomDerivative(derivative))) + alreadyDerived = true; + + // Add the request to derive the definition of the forward mode derivative + // to the schedule. + request.DeclarationOnly = false; + request.DerivedFDPrototype = derivative; + } + this->AddEdgeToGraph(request, alreadyDerived); + return derivative; + } + void DerivativeBuilder::AddErrorEstimationModel( std::unique_ptr estModel) { m_EstModel.push_back(std::move(estModel)); @@ -423,9 +450,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { // FIXME: if the derivatives aren't registered in this order and the // derivative is a member function it goes into an infinite loop if (!m_DFC.IsCustomDerivative(result.derivative)) { - if (auto FD = result.derivative) + if (auto* FD = result.derivative) registerDerivative(FD, m_Sema); - if (auto OFD = result.overload) + if (auto* OFD = result.overload) registerDerivative(OFD, m_Sema); } diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index a442f6e88..6d4b2982f 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -59,27 +59,8 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( IndependentArgRequest.CallUpdateRequired = false; IndependentArgRequest.UpdateDiffParamsInfo(SemaRef); // FIXME: Find a way to do this without accessing plugin namespace functions - bool alreadyDerived = true; FunctionDecl* firstDerivative = - Builder.FindDerivedFunction(IndependentArgRequest); - if (!firstDerivative) { - alreadyDerived = false; - // Derive declaration of the the forward mode derivative. - IndependentArgRequest.DeclarationOnly = true; - firstDerivative = plugin::ProcessDiffRequest(CP, IndependentArgRequest); - - // It is possible that user has provided a custom derivative for the - // derivative function. In that case, we should not derive the definition - // again. - if (firstDerivative->isDefined() || DFC.IsCustomDerivative(firstDerivative)) - alreadyDerived = true; - - // Add the request to derive the definition of the forward mode derivative - // to the schedule. - IndependentArgRequest.DeclarationOnly = false; - IndependentArgRequest.DerivedFDPrototype = firstDerivative; - } - Builder.AddEdgeToGraph(IndependentArgRequest, alreadyDerived); + Builder.HandleNestedDiffRequest(IndependentArgRequest); // Further derives function w.r.t to ReverseModeArgs DiffRequest ReverseModeRequest{}; @@ -89,28 +70,8 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode( ReverseModeRequest.BaseFunctionName = firstDerivative->getNameAsString(); ReverseModeRequest.UpdateDiffParamsInfo(SemaRef); - alreadyDerived = true; FunctionDecl* secondDerivative = - Builder.FindDerivedFunction(ReverseModeRequest); - if (!secondDerivative) { - alreadyDerived = false; - // Derive declaration of the the reverse mode derivative. - ReverseModeRequest.DeclarationOnly = true; - secondDerivative = plugin::ProcessDiffRequest(CP, ReverseModeRequest); - - // It is possible that user has provided a custom derivative for the - // derivative function. In that case, we should not derive the definition - // again. - if (secondDerivative->isDefined() || - DFC.IsCustomDerivative(secondDerivative)) - alreadyDerived = true; - - // Add the request to derive the definition of the reverse mode - // derivative to the schedule. - ReverseModeRequest.DeclarationOnly = false; - ReverseModeRequest.DerivedFDPrototype = secondDerivative; - } - Builder.AddEdgeToGraph(ReverseModeRequest, alreadyDerived); + Builder.HandleNestedDiffRequest(ReverseModeRequest); return secondDerivative; } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 958ac1788..9a1a53477 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -364,8 +364,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // Check if the function is already declared as a custom derivative. auto* DC = const_cast(m_DiffReq->getDeclContext()); if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( - gradientName, DC, gradientFunctionType)) - return DerivativeAndOverload{customDerivative, nullptr}; + gradientName, DC, gradientFunctionType)) { + // Set m_Derivative for creating the overload. + m_Derivative = customDerivative; + FunctionDecl* gradientOverloadFD = nullptr; + if (shouldCreateOverload) + gradientOverloadFD = CreateGradientOverload(); + return DerivativeAndOverload{customDerivative, gradientOverloadFD}; + } // Create the gradient function declaration. llvm::SaveAndRestore SaveContext(m_Sema.CurContext); @@ -1682,29 +1688,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (DerivedCallOutputArgs[i + isaMethod]) pullbackRequest.DVI.push_back(FD->getParamDecl(i)); - FunctionDecl* pullbackFD = - m_Builder.FindDerivedFunction(pullbackRequest); - if (!pullbackFD) { - if (!m_ExternalSource) { - // Derive the declaration of the pullback function. - pullbackRequest.DeclarationOnly = true; - pullbackFD = - plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); - - // Add the request to derive the definition of the pullback - // function. - pullbackRequest.DeclarationOnly = false; - pullbackRequest.DerivedFDPrototype = pullbackFD; - } else { - // FIXME: Error estimation currently uses singleton objects - - // m_ErrorEstHandler and m_EstModel, which is cleared after each - // error_estimate request. This requires the pullback to be derived - // at the same time to access the singleton objects. - pullbackFD = - plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); - } - } - m_Builder.AddEdgeToGraph(pullbackRequest); + FunctionDecl* pullbackFD = nullptr; + if (m_ExternalSource) + // FIXME: Error estimation currently uses singleton objects - + // m_ErrorEstHandler and m_EstModel, which is cleared after each + // error_estimate request. This requires the pullback to be derived + // at the same time to access the singleton objects. + pullbackFD = + plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); + else + pullbackFD = m_Builder.HandleNestedDiffRequest(pullbackRequest); // Clad failed to derive it. // FIXME: Add support for reference arguments to the numerical diff. If @@ -1794,28 +1787,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, clad::utils::ComputeEffectiveFnName(FD); calleeFnForwPassReq.VerboseDiags = true; - bool alreadyDerived = true; FunctionDecl* calleeFnForwPassFD = - m_Builder.FindDerivedFunction(calleeFnForwPassReq); - if (!calleeFnForwPassFD) { - alreadyDerived = false; - // Derive declaration of the the forward pass function. - calleeFnForwPassReq.DeclarationOnly = true; - calleeFnForwPassFD = - plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); - - // It is possible that user has provided a custom derivative for the - // derivative function. In that case, we should not derive the - // definition again. - if (calleeFnForwPassFD->getDefinition()) - alreadyDerived = true; - - // Add the request to derive the definition of the forward pass - // function. - calleeFnForwPassReq.DeclarationOnly = false; - calleeFnForwPassReq.DerivedFDPrototype = calleeFnForwPassFD; - } - m_Builder.AddEdgeToGraph(calleeFnForwPassReq, alreadyDerived); + m_Builder.HandleNestedDiffRequest(calleeFnForwPassReq); assert(calleeFnForwPassFD && "Clad failed to generate callee function forward pass function");