From 8bf965f7edca986a7a27a138192339f392f350bd Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Thu, 4 Apr 2024 20:55:55 +0200 Subject: [PATCH] Modify ordering of pullbacks in reverse mode --- .../ReverseModeForwPassVisitor.cpp | 33 ++++---- lib/Differentiator/ReverseModeVisitor.cpp | 83 +++++++++++++------ 2 files changed, 75 insertions(+), 41 deletions(-) diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index cce53a594..1d51bb314 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -57,25 +57,30 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, m_Derivative->setParams(params); m_Derivative->setBody(nullptr); - beginScope(Scope::FnScope | Scope::DeclScope); - m_DerivativeFnScope = getCurrentScope(); + if (!request.DeclarationOnly) { + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); - beginBlock(); - beginBlock(direction::reverse); - - StmtDiff bodyDiff = Visit(m_Function->getBody()); - Stmt* forward = bodyDiff.getStmt(); + beginBlock(); + beginBlock(direction::reverse); - for (Stmt* S : ReverseModeVisitor::m_Globals) - addToCurrentBlock(S); + StmtDiff bodyDiff = Visit(m_Function->getBody()); + Stmt* forward = bodyDiff.getStmt(); - if (auto* CS = dyn_cast(forward)) - for (Stmt* S : CS->body()) + for (Stmt* S : ReverseModeVisitor::m_Globals) addToCurrentBlock(S); - Stmt* fnBody = endBlock(); - m_Derivative->setBody(fnBody); - endScope(); + if (auto* CS = dyn_cast(forward)) + for (Stmt* S : CS->body()) + addToCurrentBlock(S); + + Stmt* fnBody = endBlock(); + m_Derivative->setBody(fnBody); + endScope(); + + if (request.DerivedFDPrototype) + m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype); + } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); endScope(); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d9283efa3..c9f8e97f3 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -535,40 +535,45 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Derivative->setParams(params); m_Derivative->setBody(nullptr); - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope(); + if (!request.DeclarationOnly) { + if (m_ExternalSource) + m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope(); - beginScope(Scope::FnScope | Scope::DeclScope); - m_DerivativeFnScope = getCurrentScope(); + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); - beginBlock(); - if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDerivedFnBody(request); + beginBlock(); + if (m_ExternalSource) + m_ExternalSource->ActOnStartOfDerivedFnBody(request); - StmtDiff bodyDiff = Visit(m_Function->getBody()); - Stmt* forward = bodyDiff.getStmt(); - Stmt* reverse = bodyDiff.getStmt_dx(); + StmtDiff bodyDiff = Visit(m_Function->getBody()); + Stmt* forward = bodyDiff.getStmt(); + Stmt* reverse = bodyDiff.getStmt_dx(); - // Create the body of the function. - // Firstly, all "global" Stmts are put into fn's body. - for (Stmt* S : m_Globals) - addToCurrentBlock(S, direction::forward); - // Forward pass. - if (auto* CS = dyn_cast(forward)) - for (Stmt* S : CS->body()) + // Create the body of the function. + // Firstly, all "global" Stmts are put into fn's body. + for (Stmt* S : m_Globals) addToCurrentBlock(S, direction::forward); + // Forward pass. + if (auto* CS = dyn_cast(forward)) + for (Stmt* S : CS->body()) + addToCurrentBlock(S, direction::forward); - // Reverse pass. - if (auto* RCS = dyn_cast(reverse)) - for (Stmt* S : RCS->body()) - addToCurrentBlock(S, direction::forward); + // Reverse pass. + if (auto* RCS = dyn_cast(reverse)) + for (Stmt* S : RCS->body()) + addToCurrentBlock(S, direction::forward); - if (m_ExternalSource) - m_ExternalSource->ActOnEndOfDerivedFnBody(); + if (m_ExternalSource) + m_ExternalSource->ActOnEndOfDerivedFnBody(); - Stmt* fnBody = endBlock(); - m_Derivative->setBody(fnBody); - endScope(); // Function body scope + Stmt* fnBody = endBlock(); + m_Derivative->setBody(fnBody); + endScope(); // Function body scope + + if (request.DerivedFDPrototype) + m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype); + } m_Sema.PopFunctionScopeInfo(); m_Sema.PopDeclContext(); endScope(); // Function decl scope @@ -1794,7 +1799,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) if (DerivedCallOutputArgs[i + isaMethod]) pullbackRequest.DVI.push_back(FD->getParamDecl(i)); - FunctionDecl* pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); + + FunctionDecl* pullbackFD = nullptr; + if (!m_ExternalSource) { + // Derive the declaration of the pullback function. + pullbackRequest.DeclarationOnly = true; + pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); + + // Add the request to derive the definition of the pullback function. + pullbackRequest.DeclarationOnly = false; + pullbackRequest.DerivedFDPrototype = pullbackFD; + plugin::AddRequestToSchedule(m_CladPlugin, pullbackRequest); + } else { + // FIXME: Error estimation currently uses singleton objects - m_ErrorEstHandler and m_EstModel, which is cleared after each error_estimate request. This requires the pullback to be derived at the same time to access the singleton objects. + pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); + } + + // Clad failed to derive it. // FIXME: Add support for reference arguments to the numerical diff. If // it already correctly support reference arguments then confirm the @@ -1882,9 +1903,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, calleeFnForwPassReq.BaseFunctionName = clad::utils::ComputeEffectiveFnName(FD); calleeFnForwPassReq.VerboseDiags = true; + + // Derive declaration of the the forward pass function. + calleeFnForwPassReq.DeclarationOnly = true; FunctionDecl* calleeFnForwPassFD = plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); + // Add the request to derive the definition of the forward pass function. + calleeFnForwPassReq.DeclarationOnly = false; + calleeFnForwPassReq.DerivedFDPrototype = calleeFnForwPassFD; + plugin::AddRequestToSchedule(m_CladPlugin, calleeFnForwPassReq); + assert(calleeFnForwPassFD && "Clad failed to generate callee function forward pass function");