Skip to content

Commit

Permalink
Modify ordering of pullbacks in reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Apr 12, 2024
1 parent 33634e2 commit 8bf965f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 41 deletions.
33 changes: 19 additions & 14 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,30 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
m_Derivative->setParams(params);
m_Derivative->setBody(nullptr);

beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
if (!request.DeclarationOnly) {
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();

beginBlock();
beginBlock(direction::reverse);

StmtDiff bodyDiff = Visit(m_Function->getBody());
Stmt* forward = bodyDiff.getStmt();
beginBlock();
beginBlock(direction::reverse);

for (Stmt* S : ReverseModeVisitor::m_Globals)
addToCurrentBlock(S);
StmtDiff bodyDiff = Visit(m_Function->getBody());
Stmt* forward = bodyDiff.getStmt();

if (auto* CS = dyn_cast<CompoundStmt>(forward))
for (Stmt* S : CS->body())
for (Stmt* S : ReverseModeVisitor::m_Globals)
addToCurrentBlock(S);

Stmt* fnBody = endBlock();
m_Derivative->setBody(fnBody);
endScope();
if (auto* CS = dyn_cast<CompoundStmt>(forward))
for (Stmt* S : CS->body())
addToCurrentBlock(S);

Stmt* fnBody = endBlock();
m_Derivative->setBody(fnBody);
endScope();

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope();
Expand Down
83 changes: 56 additions & 27 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,40 +535,45 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Derivative->setParams(params);
m_Derivative->setBody(nullptr);

if (m_ExternalSource)
m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope();
if (!request.DeclarationOnly) {
if (m_ExternalSource)
m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope();

beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();

beginBlock();
if (m_ExternalSource)
m_ExternalSource->ActOnStartOfDerivedFnBody(request);
beginBlock();
if (m_ExternalSource)
m_ExternalSource->ActOnStartOfDerivedFnBody(request);

StmtDiff bodyDiff = Visit(m_Function->getBody());
Stmt* forward = bodyDiff.getStmt();
Stmt* reverse = bodyDiff.getStmt_dx();
StmtDiff bodyDiff = Visit(m_Function->getBody());
Stmt* forward = bodyDiff.getStmt();
Stmt* reverse = bodyDiff.getStmt_dx();

// Create the body of the function.
// Firstly, all "global" Stmts are put into fn's body.
for (Stmt* S : m_Globals)
addToCurrentBlock(S, direction::forward);
// Forward pass.
if (auto* CS = dyn_cast<CompoundStmt>(forward))
for (Stmt* S : CS->body())
// Create the body of the function.
// Firstly, all "global" Stmts are put into fn's body.
for (Stmt* S : m_Globals)
addToCurrentBlock(S, direction::forward);
// Forward pass.
if (auto* CS = dyn_cast<CompoundStmt>(forward))
for (Stmt* S : CS->body())
addToCurrentBlock(S, direction::forward);

// Reverse pass.
if (auto* RCS = dyn_cast<CompoundStmt>(reverse))
for (Stmt* S : RCS->body())
addToCurrentBlock(S, direction::forward);
// Reverse pass.
if (auto* RCS = dyn_cast<CompoundStmt>(reverse))
for (Stmt* S : RCS->body())
addToCurrentBlock(S, direction::forward);

if (m_ExternalSource)
m_ExternalSource->ActOnEndOfDerivedFnBody();
if (m_ExternalSource)
m_ExternalSource->ActOnEndOfDerivedFnBody();

Stmt* fnBody = endBlock();
m_Derivative->setBody(fnBody);
endScope(); // Function body scope
Stmt* fnBody = endBlock();
m_Derivative->setBody(fnBody);
endScope(); // Function body scope

if (request.DerivedFDPrototype)
m_Derivative->setPreviousDeclaration(request.DerivedFDPrototype);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope
Expand Down Expand Up @@ -1794,7 +1799,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (DerivedCallOutputArgs[i + isaMethod])
pullbackRequest.DVI.push_back(FD->getParamDecl(i));
FunctionDecl* pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);

FunctionDecl* pullbackFD = nullptr;
if (!m_ExternalSource) {
// Derive the declaration of the pullback function.
pullbackRequest.DeclarationOnly = true;
pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);

// Add the request to derive the definition of the pullback function.
pullbackRequest.DeclarationOnly = false;
pullbackRequest.DerivedFDPrototype = pullbackFD;
plugin::AddRequestToSchedule(m_CladPlugin, pullbackRequest);
} else {
// FIXME: Error estimation currently uses singleton objects - m_ErrorEstHandler and m_EstModel, which is cleared after each error_estimate request. This requires the pullback to be derived at the same time to access the singleton objects.
pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
}


// Clad failed to derive it.
// FIXME: Add support for reference arguments to the numerical diff. If
// it already correctly support reference arguments then confirm the
Expand Down Expand Up @@ -1882,9 +1903,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
calleeFnForwPassReq.BaseFunctionName =
clad::utils::ComputeEffectiveFnName(FD);
calleeFnForwPassReq.VerboseDiags = true;

// Derive declaration of the the forward pass function.
calleeFnForwPassReq.DeclarationOnly = true;
FunctionDecl* calleeFnForwPassFD =
plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq);

// Add the request to derive the definition of the forward pass function.
calleeFnForwPassReq.DeclarationOnly = false;
calleeFnForwPassReq.DerivedFDPrototype = calleeFnForwPassFD;
plugin::AddRequestToSchedule(m_CladPlugin, calleeFnForwPassReq);

assert(calleeFnForwPassFD &&
"Clad failed to generate callee function forward pass function");

Expand Down

0 comments on commit 8bf965f

Please sign in to comment.