Skip to content

Commit

Permalink
Refactor handling of nested differentiation requests
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jun 13, 2024
1 parent eac0b6e commit c399190
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 102 deletions.
6 changes: 6 additions & 0 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ namespace clad {
/// is already derived or not.
void AddEdgeToGraph(const DiffRequest& request,
bool alreadyDerived = false);

/// Handles processing of a diff request when an existing derivative is
/// being processed.
/// \param[in] Request The request to be processed.
/// \returns The derivative function if found, nullptr otherwise.
clang::FunctionDecl* HandleNestedDiffRequest(DiffRequest& request);
};

} // end namespace clad
Expand Down
14 changes: 1 addition & 13 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,19 +1209,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {

// Check if request already derived in DerivedFunctions.
FunctionDecl* pushforwardFD =
m_Builder.FindDerivedFunction(pushforwardFnRequest);
if (!pushforwardFD) {
// Derive declaration of the pushforward function.
pushforwardFnRequest.DeclarationOnly = true;
pushforwardFD =
plugin::ProcessDiffRequest(m_CladPlugin, pushforwardFnRequest);

// Add the request to derive the definition of the pushforward function
// into the queue.
pushforwardFnRequest.DeclarationOnly = false;
pushforwardFnRequest.DerivedFDPrototype = pushforwardFD;
}
m_Builder.AddEdgeToGraph(pushforwardFnRequest);
m_Builder.HandleNestedDiffRequest(pushforwardFnRequest);

if (pushforwardFD) {
if (baseDiff.getExpr()) {
Expand Down
31 changes: 29 additions & 2 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,33 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
return OverloadedFn;
}

clang::FunctionDecl*
DerivativeBuilder::HandleNestedDiffRequest(DiffRequest& request) {
// FIXME: Find a way to do this without accessing plugin namespace functions
bool alreadyDerived = true;
FunctionDecl* derivative = this->FindDerivedFunction(request);
if (!derivative) {
alreadyDerived = false;
// Derive declaration of the the forward mode derivative.
request.DeclarationOnly = true;
derivative = plugin::ProcessDiffRequest(m_CladPlugin, request);

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the definition
// again.
if (derivative &&
(derivative->isDefined() || m_DFC.IsCustomDerivative(derivative)))
alreadyDerived = true;

// Add the request to derive the definition of the forward mode derivative
// to the schedule.
request.DeclarationOnly = false;
request.DerivedFDPrototype = derivative;
}
this->AddEdgeToGraph(request, alreadyDerived);
return derivative;
}

void DerivativeBuilder::AddErrorEstimationModel(
std::unique_ptr<FPErrorEstimationModel> estModel) {
m_EstModel.push_back(std::move(estModel));
Expand Down Expand Up @@ -423,9 +450,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
// FIXME: if the derivatives aren't registered in this order and the
// derivative is a member function it goes into an infinite loop
if (!m_DFC.IsCustomDerivative(result.derivative)) {
if (auto FD = result.derivative)
if (auto* FD = result.derivative)
registerDerivative(FD, m_Sema);
if (auto OFD = result.overload)
if (auto* OFD = result.overload)
registerDerivative(OFD, m_Sema);
}

Expand Down
43 changes: 2 additions & 41 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,8 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
IndependentArgRequest.CallUpdateRequired = false;
IndependentArgRequest.UpdateDiffParamsInfo(SemaRef);
// FIXME: Find a way to do this without accessing plugin namespace functions
bool alreadyDerived = true;
FunctionDecl* firstDerivative =
Builder.FindDerivedFunction(IndependentArgRequest);
if (!firstDerivative) {
alreadyDerived = false;
// Derive declaration of the the forward mode derivative.
IndependentArgRequest.DeclarationOnly = true;
firstDerivative = plugin::ProcessDiffRequest(CP, IndependentArgRequest);

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the definition
// again.
if (firstDerivative->isDefined() || DFC.IsCustomDerivative(firstDerivative))
alreadyDerived = true;

// Add the request to derive the definition of the forward mode derivative
// to the schedule.
IndependentArgRequest.DeclarationOnly = false;
IndependentArgRequest.DerivedFDPrototype = firstDerivative;
}
Builder.AddEdgeToGraph(IndependentArgRequest, alreadyDerived);
Builder.HandleNestedDiffRequest(IndependentArgRequest);

// Further derives function w.r.t to ReverseModeArgs
DiffRequest ReverseModeRequest{};
Expand All @@ -89,28 +70,8 @@ static FunctionDecl* DeriveUsingForwardAndReverseMode(
ReverseModeRequest.BaseFunctionName = firstDerivative->getNameAsString();
ReverseModeRequest.UpdateDiffParamsInfo(SemaRef);

alreadyDerived = true;
FunctionDecl* secondDerivative =
Builder.FindDerivedFunction(ReverseModeRequest);
if (!secondDerivative) {
alreadyDerived = false;
// Derive declaration of the the reverse mode derivative.
ReverseModeRequest.DeclarationOnly = true;
secondDerivative = plugin::ProcessDiffRequest(CP, ReverseModeRequest);

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the definition
// again.
if (secondDerivative->isDefined() ||
DFC.IsCustomDerivative(secondDerivative))
alreadyDerived = true;

// Add the request to derive the definition of the reverse mode
// derivative to the schedule.
ReverseModeRequest.DeclarationOnly = false;
ReverseModeRequest.DerivedFDPrototype = secondDerivative;
}
Builder.AddEdgeToGraph(ReverseModeRequest, alreadyDerived);
Builder.HandleNestedDiffRequest(ReverseModeRequest);
return secondDerivative;
}

Expand Down
65 changes: 19 additions & 46 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Check if the function is already declared as a custom derivative.
auto* DC = const_cast<DeclContext*>(m_DiffReq->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
gradientName, DC, gradientFunctionType))
return DerivativeAndOverload{customDerivative, nullptr};
gradientName, DC, gradientFunctionType)) {
// Set m_Derivative for creating the overload.
m_Derivative = customDerivative;
FunctionDecl* gradientOverloadFD = nullptr;
if (shouldCreateOverload)
gradientOverloadFD = CreateGradientOverload();
return DerivativeAndOverload{customDerivative, gradientOverloadFD};
}

// Create the gradient function declaration.
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
Expand Down Expand Up @@ -1682,29 +1688,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (DerivedCallOutputArgs[i + isaMethod])
pullbackRequest.DVI.push_back(FD->getParamDecl(i));

FunctionDecl* pullbackFD =
m_Builder.FindDerivedFunction(pullbackRequest);
if (!pullbackFD) {
if (!m_ExternalSource) {
// Derive the declaration of the pullback function.
pullbackRequest.DeclarationOnly = true;
pullbackFD =
plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);

// Add the request to derive the definition of the pullback
// function.
pullbackRequest.DeclarationOnly = false;
pullbackRequest.DerivedFDPrototype = pullbackFD;
} else {
// FIXME: Error estimation currently uses singleton objects -
// m_ErrorEstHandler and m_EstModel, which is cleared after each
// error_estimate request. This requires the pullback to be derived
// at the same time to access the singleton objects.
pullbackFD =
plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
}
}
m_Builder.AddEdgeToGraph(pullbackRequest);
FunctionDecl* pullbackFD = nullptr;
if (m_ExternalSource)
// FIXME: Error estimation currently uses singleton objects -
// m_ErrorEstHandler and m_EstModel, which is cleared after each
// error_estimate request. This requires the pullback to be derived
// at the same time to access the singleton objects.
pullbackFD =
plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
else
pullbackFD = m_Builder.HandleNestedDiffRequest(pullbackRequest);

// Clad failed to derive it.
// FIXME: Add support for reference arguments to the numerical diff. If
Expand Down Expand Up @@ -1794,28 +1787,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
clad::utils::ComputeEffectiveFnName(FD);
calleeFnForwPassReq.VerboseDiags = true;

bool alreadyDerived = true;
FunctionDecl* calleeFnForwPassFD =
m_Builder.FindDerivedFunction(calleeFnForwPassReq);
if (!calleeFnForwPassFD) {
alreadyDerived = false;
// Derive declaration of the the forward pass function.
calleeFnForwPassReq.DeclarationOnly = true;
calleeFnForwPassFD =
plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq);

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the
// definition again.
if (calleeFnForwPassFD->getDefinition())
alreadyDerived = true;

// Add the request to derive the definition of the forward pass
// function.
calleeFnForwPassReq.DeclarationOnly = false;
calleeFnForwPassReq.DerivedFDPrototype = calleeFnForwPassFD;
}
m_Builder.AddEdgeToGraph(calleeFnForwPassReq, alreadyDerived);
m_Builder.HandleNestedDiffRequest(calleeFnForwPassReq);

assert(calleeFnForwPassFD &&
"Clad failed to generate callee function forward pass function");
Expand Down

0 comments on commit c399190

Please sign in to comment.