Skip to content

Commit

Permalink
Add custom derivative to DerivativeSet
Browse files Browse the repository at this point in the history
This is required if the definition of the custom derivative is not found in the current translation unit and is linked in from another.
Adding it to the set of derivatives ensures that the custom derivative is not differentiated again using numerical differentiation due to an unavailable definition.
  • Loading branch information
vaithak committed May 19, 2024
1 parent d1fec23 commit 13e4c13
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 8 deletions.
4 changes: 2 additions & 2 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ namespace clad {
clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
const DerivedFnCollector& m_DFC;
DerivedFnCollector& m_DFC;
clad::DynamicGraph<DiffRequest>& m_DiffRequestGraph;
std::unique_ptr<utils::StmtClone> m_NodeCloner;
clang::NamespaceDecl* m_BuiltinDerivativesNSD;
Expand Down Expand Up @@ -135,7 +135,7 @@ namespace clad {

public:
DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P,
const DerivedFnCollector& DFC,
DerivedFnCollector& DFC,
clad::DynamicGraph<DiffRequest>& DRG);
~DerivativeBuilder();
/// Reset the model use for error estimation (if any).
Expand Down
3 changes: 3 additions & 0 deletions include/clad/Differentiator/DerivedFnCollector.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class DerivedFnCollector {
/// Adds a derived function to the collection.
void Add(const DerivedFnInfo& DFI);

/// Adds a function to derivative set.
void AddToDerivativeSet(const clang::FunctionDecl* FD);

/// Finds a `DerivedFnInfo` object in the collection that satisfies the
/// given differentiation request.
DerivedFnInfo Find(const DiffRequest& request) const;
Expand Down
13 changes: 12 additions & 1 deletion lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ using namespace clang;
namespace clad {

DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P,
const DerivedFnCollector& DFC,
DerivedFnCollector& DFC,
clad::DynamicGraph<DiffRequest>& G)
: m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC),
m_DiffRequestGraph(G),
Expand Down Expand Up @@ -253,6 +253,17 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get();

// Add the custom derivative to the set of derivatives.
// This is required in case the definition of the custom derivative
// is not found in the current translation unit and is linked in
// from another translation unit.
// Adding it to the set of derivatives ensures that the custom
// derivative is not differentiated again using numerical
// differentiation due to unavailable definition.
if (auto* CE = dyn_cast<CallExpr>(OverloadedFn))
if (auto* FD = CE->getDirectCallee())
m_DFC.AddToDerivativeSet(FD);
}
return OverloadedFn;
}
Expand Down
6 changes: 5 additions & 1 deletion lib/Differentiator/DerivedFnCollector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ void DerivedFnCollector::Add(const DerivedFnInfo& DFI) {
"`DerivedFnCollector::Add` more than once for the same derivative "
". Ideally, we shouldn't do either.");
m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI);
m_DerivativeSet.insert(DFI.DerivedFn());
AddToDerivativeSet(DFI.DerivedFn());
}

void DerivedFnCollector::AddToDerivativeSet(const clang::FunctionDecl* FD) {
m_DerivativeSet.insert(FD);
}

bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const {
Expand Down
9 changes: 5 additions & 4 deletions tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,12 @@ class CladTimerGroup {
bool HandleTopLevelDecl(clang::DeclGroupRef D) override {
if (D.isSingleDecl())
if (auto* FD = llvm::dyn_cast<clang::FunctionDecl>(D.getSingleDecl()))
if (m_DFC.IsDerivative(FD)) {
assert(!m_Multiplexer &&
"Must happen only if we failed to rearrange the consumers");
// If we build the derivative in a non-standard (with no Multiplexer)
// setup, we exit early to give control to the non-standard setup for
// code generation.
// FIXME: This should go away if Cling starts using the clang driver.
if (!m_Multiplexer && m_DFC.IsDerivative(FD))
return true;
}

HandleTopLevelDeclForClad(D);
AppendDelayed({CallKind::HandleTopLevelDecl, D});
Expand Down

0 comments on commit 13e4c13

Please sign in to comment.