From e2e5128487032b7daea84d0c61e3b62791676bbf Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Fri, 17 May 2024 23:14:51 +0200 Subject: [PATCH] Add custom derivative to DerivativeSet This is required if the definition of the custom derivative is not found in the current translation unit and is linked in from another. Adding it to the set of derivatives ensures that the custom derivative is not differentiated again using numerical differentiation due to an unavailable definition. --- include/clad/Differentiator/DerivativeBuilder.h | 4 ++-- include/clad/Differentiator/DerivedFnCollector.h | 3 +++ lib/Differentiator/DerivativeBuilder.cpp | 11 ++++++++++- lib/Differentiator/DerivedFnCollector.cpp | 6 +++++- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 202856c7d..f2981438f 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -83,7 +83,7 @@ namespace clad { clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; - const DerivedFnCollector& m_DFC; + DerivedFnCollector& m_DFC; clad::DynamicGraph& m_DiffRequestGraph; std::unique_ptr m_NodeCloner; clang::NamespaceDecl* m_BuiltinDerivativesNSD; @@ -135,7 +135,7 @@ namespace clad { public: DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC, + DerivedFnCollector& DFC, clad::DynamicGraph& DRG); ~DerivativeBuilder(); /// Reset the model use for error estimation (if any). diff --git a/include/clad/Differentiator/DerivedFnCollector.h b/include/clad/Differentiator/DerivedFnCollector.h index b20160a44..909285e99 100644 --- a/include/clad/Differentiator/DerivedFnCollector.h +++ b/include/clad/Differentiator/DerivedFnCollector.h @@ -27,6 +27,9 @@ class DerivedFnCollector { /// Adds a derived function to the collection. void Add(const DerivedFnInfo& DFI); + /// Adds a function to derivative set. + void AddToDerivativeSet(const clang::FunctionDecl* FD); + /// Finds a `DerivedFnInfo` object in the collection that satisfies the /// given differentiation request. DerivedFnInfo Find(const DiffRequest& request) const; diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 861f1dc60..2500de15f 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -37,7 +37,7 @@ using namespace clang; namespace clad { DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC, + DerivedFnCollector& DFC, clad::DynamicGraph& G) : m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC), m_DiffRequestGraph(G), @@ -244,6 +244,15 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { Expr* UnresolvedLookup = m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get(); + // Add the custom derivative to the set of derivatives. + // This is required in case the definition of the custom derivative + // is not found in the current translation unit and is linked in + // from another translation unit. + // Adding it to the set of derivatives ensures that the custom + // derivative is not differentiated again using numerical + // differentiation due to unavailable definition. + m_DFC.AddToDerivativeSet(R.getAsSingle()); + auto MARargs = llvm::MutableArrayRef(CallArgs); SourceLocation Loc; diff --git a/lib/Differentiator/DerivedFnCollector.cpp b/lib/Differentiator/DerivedFnCollector.cpp index f32883689..71bc22258 100644 --- a/lib/Differentiator/DerivedFnCollector.cpp +++ b/lib/Differentiator/DerivedFnCollector.cpp @@ -8,7 +8,11 @@ void DerivedFnCollector::Add(const DerivedFnInfo& DFI) { "`DerivedFnCollector::Add` more than once for the same derivative " ". Ideally, we shouldn't do either."); m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI); - m_DerivativeSet.insert(DFI.DerivedFn()); + AddToDerivativeSet(DFI.DerivedFn()); +} + +void DerivedFnCollector::AddToDerivativeSet(const clang::FunctionDecl* FD) { + m_DerivativeSet.insert(FD); } bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const {