Skip to content

Commit

Permalink
Add support for custom derivatives for top level derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vaithak committed Jun 11, 2024
1 parent 6efcd08 commit 10e85f4
Show file tree
Hide file tree
Showing 11 changed files with 270 additions and 39 deletions.
31 changes: 30 additions & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,32 @@ namespace clad {
stream << arg;
}

/// Lookup the result of finding a custom derivative or numerical
/// differentiation function.
///
/// \param[in] Name The name of the function to look up.
/// \param[in] originalFnDC The original function's DeclContext.
/// \param[in] SS The CXXScopeSpec to extend with the namespace of the
/// function.
/// \param[in] forCustomDerv A flag to keep track of which
/// namespace we should look in for the overloads.
/// \param[in] namespaceShouldExist A flag to enforce assertion failure
/// if the overload function namespace was not found. If false and
/// the function containing namespace was not found,
clang::LookupResult LookupCustomDerivativeOrNumericalDiff(
const std::string& Name, clang::DeclContext* originalFnDC,
clang::CXXScopeSpec& SS, bool forCustomDerv = true,
bool namespaceShouldExist = true);

/// Looks up if the user has defined a custom derivative for the given
/// derivative function.
/// \param[in] D
/// \returns The custom derivative function if found, nullptr otherwise.
clang::FunctionDecl*
LookupCustomDerivativeDecl(const std::string& Name,
clang::DeclContext* originalFnDC,
clang::QualType functionType);

public:
DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P,
DerivedFnCollector& DFC,
Expand Down Expand Up @@ -175,7 +201,10 @@ namespace clad {
/// graph.
///
/// \param[in] request The request to add the edge to.
void AddEdgeToGraph(const DiffRequest& request);
/// \param[in] alreadyDerived A flag to keep track of whether the request
/// is already derived or not.
void AddEdgeToGraph(const DiffRequest& request,
bool alreadyDerived = false);
};

} // end namespace clad
Expand Down
5 changes: 4 additions & 1 deletion include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,12 @@ template <typename T> class DynamicGraph {
/// Add an edge from the current node being processed to the
/// destination node.
/// \param dest
void addEdgeToCurrentNode(const T& dest) {
/// \param alreadyProcessed If the destination node is already processed.
void addEdgeToCurrentNode(const T& dest, bool alreadyProcessed = false) {
if (m_currentId != -1)
addEdge(m_nodes[m_currentId], dest);
if (alreadyProcessed)
m_nodeMap[dest].first = true;
}

/// Set the current node being processed.
Expand Down
13 changes: 10 additions & 3 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,20 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
for (auto field : diffVarInfo.fields)
argInfo += "_" + field;

IdentifierInfo* II = &m_Context.Idents.get(
request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix);
// Check if the function is already declared as a custom derivative.
std::string gradientName =
request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix;
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
if (FunctionDecl* customDerivative =
m_Builder.LookupCustomDerivativeDecl(gradientName, DC, FD->getType()))
return DerivativeAndOverload{customDerivative, nullptr};

IdentifierInfo* II = &m_Context.Idents.get(gradientName);
SourceLocation validLoc{m_Function->getLocation()};
DeclarationNameInfo name(II, validLoc);
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope());
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());

m_Sema.CurContext = DC;
DeclWithContext result =
m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType());
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,9 @@ namespace clad {
DeclContext* DC = DC1;
for (int i = contexts.size() - 1; i >= 0; --i) {
NamespaceDecl* ND = cast<NamespaceDecl>(contexts[i]);
DC = LookupNSD(semaRef, ND->getIdentifier()->getName(),
/*shouldExist=*/false, DC1);
if (ND->getIdentifier())
DC = LookupNSD(semaRef, ND->getIdentifier()->getName(),
/*shouldExist=*/false, DC1);
if (!DC)
return nullptr;
DC1 = DC;
Expand Down
56 changes: 43 additions & 13 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,16 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
return false;
}

Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {
LookupResult DerivativeBuilder::LookupCustomDerivativeOrNumericalDiff(
const std::string& Name, clang::DeclContext* originalFnDC,
CXXScopeSpec& SS, bool forCustomDerv /*=true*/,
bool namespaceShouldExist /*=true*/) {

IdentifierInfo* II = &m_Context.Idents.get(Name);
DeclarationName name(II);
DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema));
LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName);

NamespaceDecl* NSD = nullptr;
std::string namespaceID;
if (forCustomDerv) {
Expand All @@ -201,10 +207,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
"flag, this means that every try to numerically differentiate a "
"function will fail! Remove the flag to revert to default "
"behaviour.");
return nullptr;
return R;
}
}
CXXScopeSpec SS;
DeclContext* DC = NSD;

// FIXME: Here `if` branch should be removed once we update
Expand All @@ -223,13 +228,37 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
} else {
SS.Extend(m_Context, NSD, noLoc, noLoc);
}
IdentifierInfo* II = &m_Context.Idents.get(Name);
DeclarationName name(II);
DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema));

LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName);
if (DC)
m_Sema.LookupQualifiedName(R, DC);
return R;
}

FunctionDecl* DerivativeBuilder::LookupCustomDerivativeDecl(
const std::string& Name, clang::DeclContext* originalFnDC,
QualType functionType) {
CXXScopeSpec SS;
LookupResult R =
LookupCustomDerivativeOrNumericalDiff(Name, originalFnDC, SS);

for (NamedDecl* ND : R)
if (auto* FD = dyn_cast<FunctionDecl>(ND))
// Check if FD and functionType have the same signature.
if (utils::SameCanonicalType(FD->getType(), functionType))
if (FD->isDefined() || !m_DFC.IsDerivative(FD))
return FD;

return nullptr;
}

Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {

CXXScopeSpec SS;
LookupResult R = LookupCustomDerivativeOrNumericalDiff(
Name, originalFnDC, SS, forCustomDerv, namespaceShouldExist);

Expr* OverloadedFn = nullptr;
if (!R.empty()) {
// FIXME: We should find a way to specify nested name specifier
Expand Down Expand Up @@ -402,7 +431,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
return nullptr;
}

void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& request) {
m_DiffRequestGraph.addEdgeToCurrentNode(request);
void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& request,
bool alreadyDerived /*=false*/) {
m_DiffRequestGraph.addEdgeToCurrentNode(request, alreadyDerived);
}
}// end namespace clad
32 changes: 26 additions & 6 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,27 @@ namespace clad {
IndependentArgRequest.CallUpdateRequired = false;
IndependentArgRequest.UpdateDiffParamsInfo(SemaRef);
// FIXME: Find a way to do this without accessing plugin namespace functions
bool alreadyDerived = true;
FunctionDecl* firstDerivative =
Builder.FindDerivedFunction(IndependentArgRequest);
if (!firstDerivative) {
alreadyDerived = false;
// Derive declaration of the the forward mode derivative.
IndependentArgRequest.DeclarationOnly = true;
firstDerivative = plugin::ProcessDiffRequest(CP, IndependentArgRequest);

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the definition
// again.
if (firstDerivative->isDefined())
alreadyDerived = true;

// Add the request to derive the definition of the forward mode derivative
// to the schedule.
IndependentArgRequest.DeclarationOnly = false;
IndependentArgRequest.DerivedFDPrototype = firstDerivative;
}
Builder.AddEdgeToGraph(IndependentArgRequest);
Builder.AddEdgeToGraph(IndependentArgRequest, alreadyDerived);

// Further derives function w.r.t to ReverseModeArgs
DiffRequest ReverseModeRequest{};
Expand All @@ -81,20 +89,27 @@ namespace clad {
ReverseModeRequest.BaseFunctionName = firstDerivative->getNameAsString();
ReverseModeRequest.UpdateDiffParamsInfo(SemaRef);

alreadyDerived = true;
FunctionDecl* secondDerivative =
Builder.FindDerivedFunction(ReverseModeRequest);
if (!secondDerivative) {
alreadyDerived = false;
// Derive declaration of the the reverse mode derivative.
ReverseModeRequest.DeclarationOnly = true;
secondDerivative = plugin::ProcessDiffRequest(CP, ReverseModeRequest);

// Add the request to derive the definition of the reverse mode derivative
// to the schedule.
// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the definition
// again.
if (secondDerivative->isDefined())
alreadyDerived = true;

// Add the request to derive the definition of the reverse mode
// derivative to the schedule.
ReverseModeRequest.DeclarationOnly = false;
ReverseModeRequest.DerivedFDPrototype = secondDerivative;
}
Builder.AddEdgeToGraph(ReverseModeRequest);

Builder.AddEdgeToGraph(ReverseModeRequest, alreadyDerived);
return secondDerivative;
}

Expand Down Expand Up @@ -249,8 +264,13 @@ namespace clad {
// Cast to function pointer.
originalFnProtoType->getExtProtoInfo());

// Create the gradient function declaration.
// Check if the function is already declared as a custom derivative.
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
hessianFuncName, DC, hessianFunctionType))
return DerivativeAndOverload{customDerivative, nullptr};

// Create the gradient function declaration.
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
Expand Down
8 changes: 7 additions & 1 deletion lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
llvm::SaveAndRestore<Scope*> saveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
m_Sema.CurContext = const_cast<DeclContext*>(m_Function->getDeclContext());
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());

// Check if the function is already declared as a custom derivative.
if (FunctionDecl* customDerivative =
m_Builder.LookupCustomDerivativeDecl(fnName, DC, fnType))
return DerivativeAndOverload{customDerivative, nullptr};

m_Sema.CurContext = DC;
SourceLocation validLoc{m_Function->getLocation()};
DeclWithContext fnBuildRes = m_Builder.cloneFunction(
m_Function, *this, m_Sema.CurContext, validLoc, fnDNI, fnType);
Expand Down
17 changes: 15 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// Cast to function pointer.
originalFnType->getExtProtoInfo());

// Check if the function is already declared as a custom derivative.
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
if (FunctionDecl* customDerivative = m_Builder.LookupCustomDerivativeDecl(
gradientName, DC, gradientFunctionType))
return DerivativeAndOverload{customDerivative, nullptr};

// Create the gradient function declaration.
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope(),
getEnclosingNamespaceOrTUScope());
auto* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result = m_Builder.cloneFunction(
m_Function, *this, DC, noLoc, name, gradientFunctionType);
Expand Down Expand Up @@ -1775,20 +1780,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
clad::utils::ComputeEffectiveFnName(FD);
calleeFnForwPassReq.VerboseDiags = true;

bool alreadyDerived = true;
FunctionDecl* calleeFnForwPassFD =
m_Builder.FindDerivedFunction(calleeFnForwPassReq);
if (!calleeFnForwPassFD) {
alreadyDerived = false;
// Derive declaration of the the forward pass function.
calleeFnForwPassReq.DeclarationOnly = true;
calleeFnForwPassFD =
plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq);

// It is possible that user has provided a custom derivative for the
// derivative function. In that case, we should not derive the
// definition again.
if (calleeFnForwPassFD->getDefinition())
alreadyDerived = true;

// Add the request to derive the definition of the forward pass
// function.
calleeFnForwPassReq.DeclarationOnly = false;
calleeFnForwPassReq.DerivedFDPrototype = calleeFnForwPassFD;
}
m_Builder.AddEdgeToGraph(calleeFnForwPassReq);
m_Builder.AddEdgeToGraph(calleeFnForwPassReq, alreadyDerived);

assert(calleeFnForwPassFD &&
"Clad failed to generate callee function forward pass function");
Expand Down
12 changes: 9 additions & 3 deletions test/FirstDerivative/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
#include "../TestUtils.h"
extern "C" int printf(const char* fmt, ...);

namespace clad{
namespace custom_derivatives{
float f1_darg0(float x) {
return cos(x);
}
}
}

float f1(float x) {
return sin(x);
}

// CHECK: float f1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: {{(clad::)?}}ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: return cos(x);
// CHECK-NEXT: }

float f2(float x) {
Expand Down
Loading

0 comments on commit 10e85f4

Please sign in to comment.