diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 202856c7d..f2981438f 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -83,7 +83,7 @@ namespace clad { clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; - const DerivedFnCollector& m_DFC; + DerivedFnCollector& m_DFC; clad::DynamicGraph& m_DiffRequestGraph; std::unique_ptr m_NodeCloner; clang::NamespaceDecl* m_BuiltinDerivativesNSD; @@ -135,7 +135,7 @@ namespace clad { public: DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC, + DerivedFnCollector& DFC, clad::DynamicGraph& DRG); ~DerivativeBuilder(); /// Reset the model use for error estimation (if any). diff --git a/include/clad/Differentiator/DerivedFnCollector.h b/include/clad/Differentiator/DerivedFnCollector.h index b20160a44..909285e99 100644 --- a/include/clad/Differentiator/DerivedFnCollector.h +++ b/include/clad/Differentiator/DerivedFnCollector.h @@ -27,6 +27,9 @@ class DerivedFnCollector { /// Adds a derived function to the collection. void Add(const DerivedFnInfo& DFI); + /// Adds a function to derivative set. + void AddToDerivativeSet(const clang::FunctionDecl* FD); + /// Finds a `DerivedFnInfo` object in the collection that satisfies the /// given differentiation request. DerivedFnInfo Find(const DiffRequest& request) const; diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 861f1dc60..2500de15f 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -37,7 +37,7 @@ using namespace clang; namespace clad { DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC, + DerivedFnCollector& DFC, clad::DynamicGraph& G) : m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC), m_DiffRequestGraph(G), @@ -244,6 +244,15 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { Expr* UnresolvedLookup = m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get(); + // Add the custom derivative to the set of derivatives. + // This is required in case the definition of the custom derivative + // is not found in the current translation unit and is linked in + // from another translation unit. + // Adding it to the set of derivatives ensures that the custom + // derivative is not differentiated again using numerical + // differentiation due to unavailable definition. + m_DFC.AddToDerivativeSet(R.getAsSingle()); + auto MARargs = llvm::MutableArrayRef(CallArgs); SourceLocation Loc; diff --git a/lib/Differentiator/DerivedFnCollector.cpp b/lib/Differentiator/DerivedFnCollector.cpp index f32883689..71bc22258 100644 --- a/lib/Differentiator/DerivedFnCollector.cpp +++ b/lib/Differentiator/DerivedFnCollector.cpp @@ -8,7 +8,11 @@ void DerivedFnCollector::Add(const DerivedFnInfo& DFI) { "`DerivedFnCollector::Add` more than once for the same derivative " ". Ideally, we shouldn't do either."); m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI); - m_DerivativeSet.insert(DFI.DerivedFn()); + AddToDerivativeSet(DFI.DerivedFn()); +} + +void DerivedFnCollector::AddToDerivativeSet(const clang::FunctionDecl* FD) { + m_DerivativeSet.insert(FD); } bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const {