From bf2f64ad00ffed4c3a8b58be54f5a55acb94d6ab Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Sun, 19 May 2024 09:34:44 +0200 Subject: [PATCH] Fix errors due to recursive calling of HandleTopLevelDecl - Add custom derivative to DerivativeSet: This is required if the definition of the custom derivative is not found in the current translation unit and is linked in from another. Adding it to the set of derivatives ensures that the custom derivative is not differentiated again using numerical differentiation due to an unavailable definition. - Fix recursive processing of DiffRequests: There can be cases where `m_Multiplexer` is not provided. Hence, we don't delay HandleTranslationUnit at the end and it is called repeatedly. This resulted in HandleTopLevelDecl being called recursively (from PerformPendingInstantiations). This commit adds conditional checks to ensure this doesn't perturb the execution of the differentiation plan. --- .../clad/Differentiator/DerivativeBuilder.h | 4 +- .../clad/Differentiator/DerivedFnCollector.h | 3 + include/clad/Differentiator/DynamicGraph.h | 4 ++ lib/Differentiator/DerivativeBuilder.cpp | 13 +++- lib/Differentiator/DerivedFnCollector.cpp | 6 +- tools/ClangPlugin.cpp | 17 +++-- tools/ClangPlugin.h | 9 +-- unittests/Misc/CMakeLists.txt | 10 +++ unittests/Misc/CallDeclOnly.cpp | 69 +++++++++++++++++++ unittests/Misc/Defs.cpp | 25 +++++++ 10 files changed, 146 insertions(+), 14 deletions(-) create mode 100644 unittests/Misc/CallDeclOnly.cpp create mode 100644 unittests/Misc/Defs.cpp diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 202856c7d..f2981438f 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -83,7 +83,7 @@ namespace clad { clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; - const DerivedFnCollector& m_DFC; + DerivedFnCollector& m_DFC; clad::DynamicGraph& m_DiffRequestGraph; std::unique_ptr m_NodeCloner; clang::NamespaceDecl* m_BuiltinDerivativesNSD; @@ -135,7 +135,7 @@ namespace clad { public: DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC, + DerivedFnCollector& DFC, clad::DynamicGraph& DRG); ~DerivativeBuilder(); /// Reset the model use for error estimation (if any). diff --git a/include/clad/Differentiator/DerivedFnCollector.h b/include/clad/Differentiator/DerivedFnCollector.h index b20160a44..909285e99 100644 --- a/include/clad/Differentiator/DerivedFnCollector.h +++ b/include/clad/Differentiator/DerivedFnCollector.h @@ -27,6 +27,9 @@ class DerivedFnCollector { /// Adds a derived function to the collection. void Add(const DerivedFnInfo& DFI); + /// Adds a function to derivative set. + void AddToDerivativeSet(const clang::FunctionDecl* FD); + /// Finds a `DerivedFnInfo` object in the collection that satisfies the /// given differentiation request. DerivedFnInfo Find(const DiffRequest& request) const; diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index ae84e0b73..55e5ad350 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -98,6 +98,10 @@ template class DynamicGraph { m_currentId = -1; } + /// Check if currently processing a node. + /// \returns True if currently processing a node, false otherwise. + bool isProcessingNode() { return m_currentId != -1; } + /// Get the nodes in the graph. std::vector getNodes() { return m_nodes; } diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 861f1dc60..59adfd7e2 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -37,7 +37,7 @@ using namespace clang; namespace clad { DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC, + DerivedFnCollector& DFC, clad::DynamicGraph& G) : m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC), m_DiffRequestGraph(G), @@ -253,6 +253,17 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { OverloadedFn = m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get(); + + // Add the custom derivative to the set of derivatives. + // This is required in case the definition of the custom derivative + // is not found in the current translation unit and is linked in + // from another translation unit. + // Adding it to the set of derivatives ensures that the custom + // derivative is not differentiated again using numerical + // differentiation due to unavailable definition. + if (auto* CE = dyn_cast(OverloadedFn)) + if (FunctionDecl* FD = CE->getDirectCallee()) + m_DFC.AddToDerivativeSet(FD); } return OverloadedFn; } diff --git a/lib/Differentiator/DerivedFnCollector.cpp b/lib/Differentiator/DerivedFnCollector.cpp index f32883689..71bc22258 100644 --- a/lib/Differentiator/DerivedFnCollector.cpp +++ b/lib/Differentiator/DerivedFnCollector.cpp @@ -8,7 +8,11 @@ void DerivedFnCollector::Add(const DerivedFnInfo& DFI) { "`DerivedFnCollector::Add` more than once for the same derivative " ". Ideally, we shouldn't do either."); m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI); - m_DerivativeSet.insert(DFI.DerivedFn()); + AddToDerivativeSet(DFI.DerivedFn()); +} + +void DerivedFnCollector::AddToDerivativeSet(const clang::FunctionDecl* FD) { + m_DerivativeSet.insert(FD); } bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const { diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index c667ddd2c..c4aa93d4b 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -407,12 +407,17 @@ namespace clad { Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled); Sema::LocalEagerInstantiationScope LocalInstantiations(S); - DiffRequest request = m_DiffRequestGraph.getNextToProcessNode(); - while (request.Function != nullptr) { - m_DiffRequestGraph.setCurrentProcessingNode(request); - ProcessDiffRequest(request); - m_DiffRequestGraph.markCurrentNodeProcessed(); - request = m_DiffRequestGraph.getNextToProcessNode(); + if (!m_DiffRequestGraph.isProcessingNode()) { + // This check is to avoid recursive processing of the graph, as + // HandleTopLevelDecl can be called recursively in non-standard + // setup for code generation. + DiffRequest request = m_DiffRequestGraph.getNextToProcessNode(); + while (request.Function) { + m_DiffRequestGraph.setCurrentProcessingNode(request); + ProcessDiffRequest(request); + m_DiffRequestGraph.markCurrentNodeProcessed(); + request = m_DiffRequestGraph.getNextToProcessNode(); + } } // Put the TUScope in a consistent state after clad is done. diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 3ec1f9be6..3ba226f31 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -170,11 +170,12 @@ class CladTimerGroup { bool HandleTopLevelDecl(clang::DeclGroupRef D) override { if (D.isSingleDecl()) if (auto* FD = llvm::dyn_cast(D.getSingleDecl())) - if (m_DFC.IsDerivative(FD)) { - assert(!m_Multiplexer && - "Must happen only if we failed to rearrange the consumers"); + // If we build the derivative in a non-standard (with no Multiplexer) + // setup, we exit early to give control to the non-standard setup for + // code generation. + // FIXME: This should go away if Cling starts using the clang driver. + if (!m_Multiplexer && m_DFC.IsDerivative(FD)) return true; - } HandleTopLevelDeclForClad(D); AppendDelayed({CallKind::HandleTopLevelDecl, D}); diff --git a/unittests/Misc/CMakeLists.txt b/unittests/Misc/CMakeLists.txt index dc13b2417..e951dedfd 100644 --- a/unittests/Misc/CMakeLists.txt +++ b/unittests/Misc/CMakeLists.txt @@ -1,4 +1,14 @@ add_clad_unittest(MiscTests main.cpp + CallDeclOnly.cpp + Defs.cpp DynamicGraph.cpp ) + +# Create a library from Defs.cpp +add_library(Defs SHARED Defs.cpp) +enable_clad_for_executable(Defs) + +# Link the library to the test +target_link_libraries(MiscTests PRIVATE Defs) + diff --git a/unittests/Misc/CallDeclOnly.cpp b/unittests/Misc/CallDeclOnly.cpp new file mode 100644 index 000000000..4ad44d29c --- /dev/null +++ b/unittests/Misc/CallDeclOnly.cpp @@ -0,0 +1,69 @@ +#include "clad/Differentiator/Differentiator.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +double foo(double x, double alpha, double theta, double x0 = 0); + +double wrapper1(double* params) { + const double ix = 1 + params[0]; + return foo(10., ix, 1.0); +} + +TEST(CallDeclOnly, CheckNumDiff) { + auto grad = clad::gradient(wrapper1, "params"); + // Collect output of grad.dump() into a string as it ouputs using llvm::outs() + std::string actual; + testing::internal::CaptureStdout(); + grad.dump(); + actual = testing::internal::GetCapturedStdout(); + + // Check the generated code from grad.dump() + std::string expected = R"(The code is: +void wrapper1_grad(double *params, double *_d_params) { + double _d_ix = 0; + const double ix = 1 + params[0]; + goto _label0; + _label0: + { + double _r0 = 0; + double _r1 = 0; + double _r2 = 0; + double _r3 = 0; + double _grad0[4] = {0}; + numerical_diff::central_difference(foo, _grad0, 0, 10., ix, 1., 0); + _r0 += 1 * _grad0[0]; + _r1 += 1 * _grad0[1]; + _r2 += 1 * _grad0[2]; + _r3 += 1 * _grad0[3]; + _d_ix += _r1; + } + _d_params[0] += _d_ix; +} + +)"; + EXPECT_EQ(actual, expected); +} + +namespace clad { +namespace custom_derivatives { +// Custom pushforward for the square function but definition will be linked from +// another file. +clad::ValueAndPushforward sq_pushforward(double x, double _d_x); +} // namespace custom_derivatives +} // namespace clad + +double sq(double x) { return x * x; } + +double wrapper2(double* params) { return sq(params[0]); } + +TEST(CallDeclOnly, CheckCustomDiff) { + auto grad = clad::hessian(wrapper2, "params[0]"); + double x = 4.0; + double dx = 0.0; + grad.execute(&x, &dx); + EXPECT_DOUBLE_EQ(dx, 2.0); +} \ No newline at end of file diff --git a/unittests/Misc/Defs.cpp b/unittests/Misc/Defs.cpp new file mode 100644 index 000000000..bf0a2aa62 --- /dev/null +++ b/unittests/Misc/Defs.cpp @@ -0,0 +1,25 @@ +#include "clad/Differentiator/Differentiator.h" + +double foo(double x, double alpha, double theta, double x0 = 0) { + return x * alpha * theta * x0; +} + +namespace clad { +namespace custom_derivatives { +clad::ValueAndPushforward sq_pushforward(double x, + double _d_x) { + return {x * x, 2 * x}; +} + +void sq_pushforward_pullback(double x, double _dx, + clad::ValueAndPushforward _d_y, + double* _d_x, double* _d__d_x) { + goto _label0; +_label0: { + *_d_x += _d_y.value * x; + *_d_x += x * _d_y.value; + *_d_x += 2 * _d_y.pushforward; +} +} +} // namespace custom_derivatives +} // namespace clad \ No newline at end of file