diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 202856c7d..f2981438f 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -83,7 +83,7 @@ namespace clad { clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; - const DerivedFnCollector& m_DFC; + DerivedFnCollector& m_DFC; clad::DynamicGraph& m_DiffRequestGraph; std::unique_ptr m_NodeCloner; clang::NamespaceDecl* m_BuiltinDerivativesNSD; @@ -135,7 +135,7 @@ namespace clad { public: DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC, + DerivedFnCollector& DFC, clad::DynamicGraph& DRG); ~DerivativeBuilder(); /// Reset the model use for error estimation (if any). diff --git a/include/clad/Differentiator/DerivedFnCollector.h b/include/clad/Differentiator/DerivedFnCollector.h index b20160a44..909285e99 100644 --- a/include/clad/Differentiator/DerivedFnCollector.h +++ b/include/clad/Differentiator/DerivedFnCollector.h @@ -27,6 +27,9 @@ class DerivedFnCollector { /// Adds a derived function to the collection. void Add(const DerivedFnInfo& DFI); + /// Adds a function to derivative set. + void AddToDerivativeSet(const clang::FunctionDecl* FD); + /// Finds a `DerivedFnInfo` object in the collection that satisfies the /// given differentiation request. DerivedFnInfo Find(const DiffRequest& request) const; diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 861f1dc60..9fe899e7b 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -37,7 +37,7 @@ using namespace clang; namespace clad { DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC, + DerivedFnCollector& DFC, clad::DynamicGraph& G) : m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC), m_DiffRequestGraph(G), @@ -253,6 +253,17 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { OverloadedFn = m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get(); + + // Add the custom derivative to the set of derivatives. + // This is required in case the definition of the custom derivative + // is not found in the current translation unit and is linked in + // from another translation unit. + // Adding it to the set of derivatives ensures that the custom + // derivative is not differentiated again using numerical + // differentiation due to unavailable definition. + if (auto* CE = dyn_cast(OverloadedFn)) + if (auto* FD = CE->getDirectCallee()) + m_DFC.AddToDerivativeSet(FD); } return OverloadedFn; } diff --git a/lib/Differentiator/DerivedFnCollector.cpp b/lib/Differentiator/DerivedFnCollector.cpp index f32883689..71bc22258 100644 --- a/lib/Differentiator/DerivedFnCollector.cpp +++ b/lib/Differentiator/DerivedFnCollector.cpp @@ -8,7 +8,11 @@ void DerivedFnCollector::Add(const DerivedFnInfo& DFI) { "`DerivedFnCollector::Add` more than once for the same derivative " ". Ideally, we shouldn't do either."); m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI); - m_DerivativeSet.insert(DFI.DerivedFn()); + AddToDerivativeSet(DFI.DerivedFn()); +} + +void DerivedFnCollector::AddToDerivativeSet(const clang::FunctionDecl* FD) { + m_DerivativeSet.insert(FD); } bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const { diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 3ec1f9be6..3ba226f31 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -170,11 +170,12 @@ class CladTimerGroup { bool HandleTopLevelDecl(clang::DeclGroupRef D) override { if (D.isSingleDecl()) if (auto* FD = llvm::dyn_cast(D.getSingleDecl())) - if (m_DFC.IsDerivative(FD)) { - assert(!m_Multiplexer && - "Must happen only if we failed to rearrange the consumers"); + // If we build the derivative in a non-standard (with no Multiplexer) + // setup, we exit early to give control to the non-standard setup for + // code generation. + // FIXME: This should go away if Cling starts using the clang driver. + if (!m_Multiplexer && m_DFC.IsDerivative(FD)) return true; - } HandleTopLevelDeclForClad(D); AppendDelayed({CallKind::HandleTopLevelDecl, D});