diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 202856c7d..f2981438f 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -83,7 +83,7 @@ namespace clad { clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; - const DerivedFnCollector& m_DFC; + DerivedFnCollector& m_DFC; clad::DynamicGraph& m_DiffRequestGraph; std::unique_ptr m_NodeCloner; clang::NamespaceDecl* m_BuiltinDerivativesNSD; @@ -135,7 +135,7 @@ namespace clad { public: DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC, + DerivedFnCollector& DFC, clad::DynamicGraph& DRG); ~DerivativeBuilder(); /// Reset the model use for error estimation (if any). diff --git a/include/clad/Differentiator/DerivedFnCollector.h b/include/clad/Differentiator/DerivedFnCollector.h index b20160a44..909285e99 100644 --- a/include/clad/Differentiator/DerivedFnCollector.h +++ b/include/clad/Differentiator/DerivedFnCollector.h @@ -27,6 +27,9 @@ class DerivedFnCollector { /// Adds a derived function to the collection. void Add(const DerivedFnInfo& DFI); + /// Adds a function to derivative set. + void AddToDerivativeSet(const clang::FunctionDecl* FD); + /// Finds a `DerivedFnInfo` object in the collection that satisfies the /// given differentiation request. DerivedFnInfo Find(const DiffRequest& request) const; diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index ae84e0b73..55e5ad350 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -98,6 +98,10 @@ template class DynamicGraph { m_currentId = -1; } + /// Check if currently processing a node. + /// \returns True if currently processing a node, false otherwise. + bool isProcessingNode() { return m_currentId != -1; } + /// Get the nodes in the graph. std::vector getNodes() { return m_nodes; } diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 861f1dc60..59adfd7e2 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -37,7 +37,7 @@ using namespace clang; namespace clad { DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC, + DerivedFnCollector& DFC, clad::DynamicGraph& G) : m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC), m_DiffRequestGraph(G), @@ -253,6 +253,17 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { OverloadedFn = m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get(); + + // Add the custom derivative to the set of derivatives. + // This is required in case the definition of the custom derivative + // is not found in the current translation unit and is linked in + // from another translation unit. + // Adding it to the set of derivatives ensures that the custom + // derivative is not differentiated again using numerical + // differentiation due to unavailable definition. + if (auto* CE = dyn_cast(OverloadedFn)) + if (FunctionDecl* FD = CE->getDirectCallee()) + m_DFC.AddToDerivativeSet(FD); } return OverloadedFn; } diff --git a/lib/Differentiator/DerivedFnCollector.cpp b/lib/Differentiator/DerivedFnCollector.cpp index f32883689..71bc22258 100644 --- a/lib/Differentiator/DerivedFnCollector.cpp +++ b/lib/Differentiator/DerivedFnCollector.cpp @@ -8,7 +8,11 @@ void DerivedFnCollector::Add(const DerivedFnInfo& DFI) { "`DerivedFnCollector::Add` more than once for the same derivative " ". Ideally, we shouldn't do either."); m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI); - m_DerivativeSet.insert(DFI.DerivedFn()); + AddToDerivativeSet(DFI.DerivedFn()); +} + +void DerivedFnCollector::AddToDerivativeSet(const clang::FunctionDecl* FD) { + m_DerivativeSet.insert(FD); } bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const { diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index c667ddd2c..c4aa93d4b 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -407,12 +407,17 @@ namespace clad { Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled); Sema::LocalEagerInstantiationScope LocalInstantiations(S); - DiffRequest request = m_DiffRequestGraph.getNextToProcessNode(); - while (request.Function != nullptr) { - m_DiffRequestGraph.setCurrentProcessingNode(request); - ProcessDiffRequest(request); - m_DiffRequestGraph.markCurrentNodeProcessed(); - request = m_DiffRequestGraph.getNextToProcessNode(); + if (!m_DiffRequestGraph.isProcessingNode()) { + // This check is to avoid recursive processing of the graph, as + // HandleTopLevelDecl can be called recursively in non-standard + // setup for code generation. + DiffRequest request = m_DiffRequestGraph.getNextToProcessNode(); + while (request.Function) { + m_DiffRequestGraph.setCurrentProcessingNode(request); + ProcessDiffRequest(request); + m_DiffRequestGraph.markCurrentNodeProcessed(); + request = m_DiffRequestGraph.getNextToProcessNode(); + } } // Put the TUScope in a consistent state after clad is done. diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 3ec1f9be6..3ba226f31 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -170,11 +170,12 @@ class CladTimerGroup { bool HandleTopLevelDecl(clang::DeclGroupRef D) override { if (D.isSingleDecl()) if (auto* FD = llvm::dyn_cast(D.getSingleDecl())) - if (m_DFC.IsDerivative(FD)) { - assert(!m_Multiplexer && - "Must happen only if we failed to rearrange the consumers"); + // If we build the derivative in a non-standard (with no Multiplexer) + // setup, we exit early to give control to the non-standard setup for + // code generation. + // FIXME: This should go away if Cling starts using the clang driver. + if (!m_Multiplexer && m_DFC.IsDerivative(FD)) return true; - } HandleTopLevelDeclForClad(D); AppendDelayed({CallKind::HandleTopLevelDecl, D}); diff --git a/unittests/Misc/CMakeLists.txt b/unittests/Misc/CMakeLists.txt index dc13b2417..4aaddf9d2 100644 --- a/unittests/Misc/CMakeLists.txt +++ b/unittests/Misc/CMakeLists.txt @@ -1,4 +1,14 @@ add_clad_unittest(MiscTests main.cpp + CallDeclOnly.cpp + Defs.cpp DynamicGraph.cpp ) + +# Create a library from Defs.cpp +add_library(Defs SHARED Defs.cpp) +target_include_directories(Defs PUBLIC ${CMAKE_SOURCE_DIR}/include) +target_link_libraries(Defs PUBLIC stdc++ pthread m) + +# Link the library to the test +target_link_libraries(MiscTests PRIVATE Defs) diff --git a/unittests/Misc/CallDeclOnly.cpp b/unittests/Misc/CallDeclOnly.cpp new file mode 100644 index 000000000..4ad44d29c --- /dev/null +++ b/unittests/Misc/CallDeclOnly.cpp @@ -0,0 +1,69 @@ +#include "clad/Differentiator/Differentiator.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +double foo(double x, double alpha, double theta, double x0 = 0); + +double wrapper1(double* params) { + const double ix = 1 + params[0]; + return foo(10., ix, 1.0); +} + +TEST(CallDeclOnly, CheckNumDiff) { + auto grad = clad::gradient(wrapper1, "params"); + // Collect output of grad.dump() into a string as it ouputs using llvm::outs() + std::string actual; + testing::internal::CaptureStdout(); + grad.dump(); + actual = testing::internal::GetCapturedStdout(); + + // Check the generated code from grad.dump() + std::string expected = R"(The code is: +void wrapper1_grad(double *params, double *_d_params) { + double _d_ix = 0; + const double ix = 1 + params[0]; + goto _label0; + _label0: + { + double _r0 = 0; + double _r1 = 0; + double _r2 = 0; + double _r3 = 0; + double _grad0[4] = {0}; + numerical_diff::central_difference(foo, _grad0, 0, 10., ix, 1., 0); + _r0 += 1 * _grad0[0]; + _r1 += 1 * _grad0[1]; + _r2 += 1 * _grad0[2]; + _r3 += 1 * _grad0[3]; + _d_ix += _r1; + } + _d_params[0] += _d_ix; +} + +)"; + EXPECT_EQ(actual, expected); +} + +namespace clad { +namespace custom_derivatives { +// Custom pushforward for the square function but definition will be linked from +// another file. +clad::ValueAndPushforward sq_pushforward(double x, double _d_x); +} // namespace custom_derivatives +} // namespace clad + +double sq(double x) { return x * x; } + +double wrapper2(double* params) { return sq(params[0]); } + +TEST(CallDeclOnly, CheckCustomDiff) { + auto grad = clad::hessian(wrapper2, "params[0]"); + double x = 4.0; + double dx = 0.0; + grad.execute(&x, &dx); + EXPECT_DOUBLE_EQ(dx, 2.0); +} \ No newline at end of file diff --git a/unittests/Misc/Defs.cpp b/unittests/Misc/Defs.cpp new file mode 100644 index 000000000..bf0a2aa62 --- /dev/null +++ b/unittests/Misc/Defs.cpp @@ -0,0 +1,25 @@ +#include "clad/Differentiator/Differentiator.h" + +double foo(double x, double alpha, double theta, double x0 = 0) { + return x * alpha * theta * x0; +} + +namespace clad { +namespace custom_derivatives { +clad::ValueAndPushforward sq_pushforward(double x, + double _d_x) { + return {x * x, 2 * x}; +} + +void sq_pushforward_pullback(double x, double _dx, + clad::ValueAndPushforward _d_y, + double* _d_x, double* _d__d_x) { + goto _label0; +_label0: { + *_d_x += _d_y.value * x; + *_d_x += x * _d_y.value; + *_d_x += 2 * _d_y.pushforward; +} +} +} // namespace custom_derivatives +} // namespace clad \ No newline at end of file