From 67190a884bf1413ce7a687be7e7b156eab6d4c83 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 1 Jul 2024 21:01:48 +0300 Subject: [PATCH] Remove DerivePullback from ReverseModeVisitor and generate pullbacks with Derive. --- .../clad/Differentiator/ReverseModeVisitor.h | 2 - lib/Differentiator/DerivativeBuilder.cpp | 16 +- lib/Differentiator/ReverseModeVisitor.cpp | 156 +----------------- 3 files changed, 14 insertions(+), 160 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 346ec8155..7ea61bbbc 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -381,8 +381,6 @@ namespace clad { /// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'. DerivativeAndOverload Derive(const clang::FunctionDecl* FD, const DiffRequest& request); - DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD, - const DiffRequest& request); StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); StmtDiff VisitCallExpr(const clang::CallExpr* CE); diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index d688a64ae..7c86d4fe6 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -419,17 +419,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { result = V.DerivePushforward(FD, request); } else if (request.Mode == DiffMode::reverse) { ReverseModeVisitor V(*this, request); - if (request.CallUpdateRequired) { - result = V.Derive(FD, request); - } else { - if (!m_ErrorEstHandler.empty()) { - InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request); - V.AddExternalSource(*m_ErrorEstHandler.back()); - } - result = V.DerivePullback(FD, request); - if (!m_ErrorEstHandler.empty()) - CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); + if (!request.CallUpdateRequired && !m_ErrorEstHandler.empty()) { + InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request); + V.AddExternalSource(*m_ErrorEstHandler.back()); } + result = V.Derive(FD, request); + if (!request.CallUpdateRequired && !m_ErrorEstHandler.empty()) + CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { ReverseModeForwPassVisitor V(*this, request); result = V.Derive(FD, request); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6fc76b9e2..0716107d8 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -283,15 +283,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, assert(m_DiffReq.Function && "Must not be null."); DiffParams args{}; - DiffInputVarsInfo DVI; - if (request.Args) { - DVI = request.DVI; - for (const auto& dParam : DVI) + if (!request.DVI.empty()) + for (const auto& dParam : request.DVI) args.push_back(dParam.param); - } else std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - if (args.empty()) + if (args.empty() && (!isa(FD) || utils::IsStaticMethod(FD))) return {}; if (m_ExternalSource) @@ -303,10 +300,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, outputArrayStr = m_DiffReq->getParamDecl(lastArgN)->getNameAsString(); } - auto derivativeBaseName = request.BaseFunctionName; - std::string gradientName = derivativeBaseName + funcPostfix(m_DiffReq); + std::string derivativeBaseName = request.BaseFunctionName; + std::string derivativeName = derivativeBaseName + funcPostfix(m_DiffReq); - IdentifierInfo* II = &m_Context.Idents.get(gradientName); + IdentifierInfo* II = &m_Context.Idents.get(derivativeName); DeclarationNameInfo name(II, noLoc); // If we are in error estimation mode, we have an extra `double&` @@ -323,7 +320,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If reverse mode differentiates only part of the arguments it needs to // generate an overload that can take in all the diff variables bool shouldCreateOverload = false; - if (request.Mode != DiffMode::jacobian) + if (request.Mode != DiffMode::jacobian && m_DiffReq.CallUpdateRequired) shouldCreateOverload = true; if (!request.DeclarationOnly && !request.DerivedFDPrototypes.empty()) // If the overload is already created, we don't need to create it again. @@ -347,7 +344,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) auto* DC = const_cast(m_DiffReq->getDeclContext()); if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( - gradientName, DC, gradientFunctionType)) { + derivativeName, DC, gradientFunctionType)) { // Set m_Derivative for creating the overload. m_Derivative = customDerivative; FunctionDecl* gradientOverloadFD = nullptr; @@ -459,143 +456,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return DerivativeAndOverload{result.first, gradientOverloadFD}; } - DerivativeAndOverload - ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD, - const DiffRequest& request) { - if (request.EnableTBRAnalysis) { - TBRAnalyzer analyzer(m_Context); - analyzer.Analyze(FD); - m_ToBeRecorded = analyzer.getResult(); - } - - // FIXME: Duplication of external source here is a workaround - // for the two 'Derive's being different functions. - if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDerive(); - silenceDiags = !request.VerboseDiags; - // FIXME: We should not use const_cast to get the decl request here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(m_DiffReq) = request; - assert(m_DiffReq.Function && "Must not be null."); - - DiffParams args{}; - if (!request.DVI.empty()) - for (const auto& dParam : request.DVI) - args.push_back(dParam.param); - else - std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); -#ifndef NDEBUG - bool isStaticMethod = utils::IsStaticMethod(FD); - assert((!args.empty() || !isStaticMethod) && - "Cannot generate pullback function of a function " - "with no differentiable arguments"); -#endif - - if (m_ExternalSource) - m_ExternalSource->ActAfterParsingDiffArgs(request, args); - - auto derivativeName = utils::ComputeEffectiveFnName(m_DiffReq.Function) + - funcPostfix(m_DiffReq); - auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); - - auto paramTypes = ComputeParamTypes(args); - const auto* originalFnType = - dyn_cast(m_DiffReq->getType()); - - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); - - QualType pullbackFnType = m_Context.getFunctionType( - m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo()); - - // Check if the function is already declared as a custom derivative. - // FIXME: We should not use const_cast to get the decl context here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - auto* DC = const_cast(m_DiffReq->getDeclContext()); - if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl( - derivativeName, DC, pullbackFnType)) - return DerivativeAndOverload{customDerivative, nullptr}; - - llvm::SaveAndRestore saveContext(m_Sema.CurContext); - llvm::SaveAndRestore saveScope(getCurrentScope(), - getEnclosingNamespaceOrTUScope()); - // FIXME: We should not use const_cast to get the decl context here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - m_Sema.CurContext = const_cast(m_DiffReq->getDeclContext()); - - SourceLocation validLoc{m_DiffReq->getLocation()}; - DeclWithContext fnBuildRes = - m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext, - validLoc, DNI, pullbackFnType); - m_Derivative = fnBuildRes.first; - - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnScope(); - - beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | - Scope::DeclScope); - m_Sema.PushFunctionScope(); - m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); - - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnScope(); - - auto params = BuildParams(args); - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnParams(params); - - m_Derivative->setParams(params); - m_Derivative->setBody(nullptr); - - if (!request.DeclarationOnly) { - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope(); - - beginScope(Scope::FnScope | Scope::DeclScope); - m_DerivativeFnScope = getCurrentScope(); - - beginBlock(); - if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDerivedFnBody(request); - - StmtDiff bodyDiff = Visit(m_DiffReq->getBody()); - Stmt* forward = bodyDiff.getStmt(); - Stmt* reverse = bodyDiff.getStmt_dx(); - - // Create the body of the function. - // Firstly, all "global" Stmts are put into fn's body. - for (Stmt* S : m_Globals) - addToCurrentBlock(S, direction::forward); - // Forward pass. - if (auto* CS = dyn_cast(forward)) - for (Stmt* S : CS->body()) - addToCurrentBlock(S, direction::forward); - - // Reverse pass. - if (auto* RCS = dyn_cast(reverse)) - for (Stmt* S : RCS->body()) - addToCurrentBlock(S, direction::forward); - - if (m_ExternalSource) - m_ExternalSource->ActOnEndOfDerivedFnBody(); - - Stmt* fnBody = endBlock(); - m_Derivative->setBody(fnBody); - endScope(); // Function body scope - - // Size >= current derivative order means that there exists a declaration - // or prototype for the currently derived function. - if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder) - m_Derivative->setPreviousDeclaration( - request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]); - } - m_Sema.PopFunctionScopeInfo(); - m_Sema.PopDeclContext(); - endScope(); // Function decl scope - - return DerivativeAndOverload{fnBuildRes.first, nullptr}; - } - void ReverseModeVisitor::DifferentiateWithClad() { if (m_DiffReq.EnableTBRAnalysis) { TBRAnalyzer analyzer(m_Context);