Skip to content

Commit

Permalink
Remove DerivePullback from ReverseModeVisitor and generate pullbacks …
Browse files Browse the repository at this point in the history
…with Derive.
  • Loading branch information
PetroZarytskyi committed Jul 11, 2024
1 parent c2da01f commit 23f39bc
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 160 deletions.
2 changes: 0 additions & 2 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,6 @@ namespace clad {
/// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'.
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request);
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
Expand Down
16 changes: 6 additions & 10 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,17 +419,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
result = V.DerivePushforward(FD, request);
} else if (request.Mode == DiffMode::reverse) {
ReverseModeVisitor V(*this, request);
if (request.CallUpdateRequired) {
result = V.Derive(FD, request);
} else {
if (!m_ErrorEstHandler.empty()) {
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
V.AddExternalSource(*m_ErrorEstHandler.back());
}
result = V.DerivePullback(FD, request);
if (!m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
if (!request.CallUpdateRequired && !m_ErrorEstHandler.empty()) {
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
V.AddExternalSource(*m_ErrorEstHandler.back());
}
result = V.Derive(FD, request);
if (!request.CallUpdateRequired && !m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this, request);
result = V.Derive(FD, request);
Expand Down
156 changes: 8 additions & 148 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
assert(m_DiffReq.Function && "Must not be null.");

DiffParams args{};
DiffInputVarsInfo DVI;
if (request.Args) {
DVI = request.DVI;
for (const auto& dParam : DVI)
if (!request.DVI.empty())
for (const auto& dParam : request.DVI)
args.push_back(dParam.param);
}
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));
if (args.empty())
if (args.empty() && (!isa<CXXMethodDecl>(FD) || utils::IsStaticMethod(FD)))
return {};

if (m_ExternalSource)
Expand All @@ -303,10 +300,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
outputArrayStr = m_DiffReq->getParamDecl(lastArgN)->getNameAsString();
}

auto derivativeBaseName = request.BaseFunctionName;
std::string gradientName = derivativeBaseName + funcPostfix(m_DiffReq);
std::string derivativeBaseName = request.BaseFunctionName;
std::string derivativeName = derivativeBaseName + funcPostfix(m_DiffReq);

IdentifierInfo* II = &m_Context.Idents.get(gradientName);
IdentifierInfo* II = &m_Context.Idents.get(derivativeName);
DeclarationNameInfo name(II, noLoc);

// If we are in error estimation mode, we have an extra `double&`
Expand All @@ -323,7 +320,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If reverse mode differentiates only part of the arguments it needs to
// generate an overload that can take in all the diff variables
bool shouldCreateOverload = false;
if (request.Mode != DiffMode::jacobian)
if (request.Mode != DiffMode::jacobian && m_DiffReq.CallUpdateRequired)
shouldCreateOverload = true;
if (!request.DeclarationOnly && !request.DerivedFDPrototypes.empty())
// If the overload is already created, we don't need to create it again.
Expand All @@ -347,7 +344,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
gradientName, DC, gradientFunctionType)) {
derivativeName, DC, gradientFunctionType)) {
// Set m_Derivative for creating the overload.
m_Derivative = customDerivative;
FunctionDecl* gradientOverloadFD = nullptr;
Expand Down Expand Up @@ -459,143 +456,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return DerivativeAndOverload{result.first, gradientOverloadFD};
}

DerivativeAndOverload
ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request) {
if (request.EnableTBRAnalysis) {
TBRAnalyzer analyzer(m_Context);
analyzer.Analyze(FD);
m_ToBeRecorded = analyzer.getResult();
}

// FIXME: Duplication of external source here is a workaround
// for the two 'Derive's being different functions.
if (m_ExternalSource)
m_ExternalSource->ActOnStartOfDerive();
silenceDiags = !request.VerboseDiags;
// FIXME: We should not use const_cast to get the decl request here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<DiffRequest&>(m_DiffReq) = request;
assert(m_DiffReq.Function && "Must not be null.");

DiffParams args{};
if (!request.DVI.empty())
for (const auto& dParam : request.DVI)
args.push_back(dParam.param);
else
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));
#ifndef NDEBUG
bool isStaticMethod = utils::IsStaticMethod(FD);
assert((!args.empty() || !isStaticMethod) &&
"Cannot generate pullback function of a function "
"with no differentiable arguments");
#endif

if (m_ExternalSource)
m_ExternalSource->ActAfterParsingDiffArgs(request, args);

auto derivativeName = utils::ComputeEffectiveFnName(m_DiffReq.Function) +
funcPostfix(m_DiffReq);
auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName);

auto paramTypes = ComputeParamTypes(args);
const auto* originalFnType =
dyn_cast<FunctionProtoType>(m_DiffReq->getType());

if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes);

QualType pullbackFnType = m_Context.getFunctionType(
m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo());

// Check if the function is already declared as a custom derivative.
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
derivativeName, DC, pullbackFnType))
return DerivativeAndOverload{customDerivative, nullptr};

llvm::SaveAndRestore<DeclContext*> saveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
m_Sema.CurContext = const_cast<DeclContext*>(m_DiffReq->getDeclContext());

SourceLocation validLoc{m_DiffReq->getLocation()};
DeclWithContext fnBuildRes =
m_Builder.cloneFunction(m_DiffReq.Function, *this, m_Sema.CurContext,
validLoc, DNI, pullbackFnType);
m_Derivative = fnBuildRes.first;

if (m_ExternalSource)
m_ExternalSource->ActBeforeCreatingDerivedFnScope();

beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), m_Derivative);

if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnScope();

auto params = BuildParams(args);
if (m_ExternalSource)
m_ExternalSource->ActAfterCreatingDerivedFnParams(params);

m_Derivative->setParams(params);
m_Derivative->setBody(nullptr);

if (!request.DeclarationOnly) {
if (m_ExternalSource)
m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope();

beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();

beginBlock();
if (m_ExternalSource)
m_ExternalSource->ActOnStartOfDerivedFnBody(request);

StmtDiff bodyDiff = Visit(m_DiffReq->getBody());
Stmt* forward = bodyDiff.getStmt();
Stmt* reverse = bodyDiff.getStmt_dx();

// Create the body of the function.
// Firstly, all "global" Stmts are put into fn's body.
for (Stmt* S : m_Globals)
addToCurrentBlock(S, direction::forward);
// Forward pass.
if (auto* CS = dyn_cast<CompoundStmt>(forward))
for (Stmt* S : CS->body())
addToCurrentBlock(S, direction::forward);

// Reverse pass.
if (auto* RCS = dyn_cast<CompoundStmt>(reverse))
for (Stmt* S : RCS->body())
addToCurrentBlock(S, direction::forward);

if (m_ExternalSource)
m_ExternalSource->ActOnEndOfDerivedFnBody();

Stmt* fnBody = endBlock();
m_Derivative->setBody(fnBody);
endScope(); // Function body scope

// Size >= current derivative order means that there exists a declaration
// or prototype for the currently derived function.
if (request.DerivedFDPrototypes.size() >= request.CurrentDerivativeOrder)
m_Derivative->setPreviousDeclaration(
request.DerivedFDPrototypes[request.CurrentDerivativeOrder - 1]);
}
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope

return DerivativeAndOverload{fnBuildRes.first, nullptr};
}

void ReverseModeVisitor::DifferentiateWithClad() {
if (m_DiffReq.EnableTBRAnalysis) {
TBRAnalyzer analyzer(m_Context);
Expand Down

0 comments on commit 23f39bc

Please sign in to comment.