Skip to content

Commit

Permalink
Fix errors due to recursive calling of HandleTopLevelDecl
Browse files Browse the repository at this point in the history
- Add custom derivative to DerivativeSet: This is required if the definition of the custom derivative is not found in the current translation unit and is linked in from another.
Adding it to the set of derivatives ensures that the custom derivative is not differentiated again using numerical differentiation due to an unavailable definition.

- Fix recursive processing of DiffRequests: There can be cases where `m_Multiplexer` is not provided. Hence, we don't delay HandleTranslationUnit at the end and it is called repeatedly. This resulted in HandleTopLevelDecl being called recursively (from PerformPendingInstantiations). This commit adds conditional checks to ensure this doesn't perturb the execution of the differentiation plan.
  • Loading branch information
vaithak committed May 20, 2024
1 parent d1fec23 commit bf2f64a
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 14 deletions.
4 changes: 2 additions & 2 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ namespace clad {
clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
const DerivedFnCollector& m_DFC;
DerivedFnCollector& m_DFC;
clad::DynamicGraph<DiffRequest>& m_DiffRequestGraph;
std::unique_ptr<utils::StmtClone> m_NodeCloner;
clang::NamespaceDecl* m_BuiltinDerivativesNSD;
Expand Down Expand Up @@ -135,7 +135,7 @@ namespace clad {

public:
DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P,
const DerivedFnCollector& DFC,
DerivedFnCollector& DFC,
clad::DynamicGraph<DiffRequest>& DRG);
~DerivativeBuilder();
/// Reset the model use for error estimation (if any).
Expand Down
3 changes: 3 additions & 0 deletions include/clad/Differentiator/DerivedFnCollector.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class DerivedFnCollector {
/// Adds a derived function to the collection.
void Add(const DerivedFnInfo& DFI);

/// Adds a function to derivative set.
void AddToDerivativeSet(const clang::FunctionDecl* FD);

/// Finds a `DerivedFnInfo` object in the collection that satisfies the
/// given differentiation request.
DerivedFnInfo Find(const DiffRequest& request) const;
Expand Down
4 changes: 4 additions & 0 deletions include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ template <typename T> class DynamicGraph {
m_currentId = -1;
}

/// Check if currently processing a node.
/// \returns True if currently processing a node, false otherwise.
bool isProcessingNode() { return m_currentId != -1; }

/// Get the nodes in the graph.
std::vector<T> getNodes() { return m_nodes; }

Expand Down
13 changes: 12 additions & 1 deletion lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ using namespace clang;
namespace clad {

DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P,
const DerivedFnCollector& DFC,
DerivedFnCollector& DFC,
clad::DynamicGraph<DiffRequest>& G)
: m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC),
m_DiffRequestGraph(G),
Expand Down Expand Up @@ -253,6 +253,17 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get();

// Add the custom derivative to the set of derivatives.
// This is required in case the definition of the custom derivative
// is not found in the current translation unit and is linked in
// from another translation unit.
// Adding it to the set of derivatives ensures that the custom
// derivative is not differentiated again using numerical
// differentiation due to unavailable definition.
if (auto* CE = dyn_cast<CallExpr>(OverloadedFn))
if (FunctionDecl* FD = CE->getDirectCallee())
m_DFC.AddToDerivativeSet(FD);
}
return OverloadedFn;
}
Expand Down
6 changes: 5 additions & 1 deletion lib/Differentiator/DerivedFnCollector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ void DerivedFnCollector::Add(const DerivedFnInfo& DFI) {
"`DerivedFnCollector::Add` more than once for the same derivative "
". Ideally, we shouldn't do either.");
m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI);
m_DerivativeSet.insert(DFI.DerivedFn());
AddToDerivativeSet(DFI.DerivedFn());
}

void DerivedFnCollector::AddToDerivativeSet(const clang::FunctionDecl* FD) {
m_DerivativeSet.insert(FD);
}

bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const {
Expand Down
17 changes: 11 additions & 6 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,12 +407,17 @@ namespace clad {
Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled);
Sema::LocalEagerInstantiationScope LocalInstantiations(S);

DiffRequest request = m_DiffRequestGraph.getNextToProcessNode();
while (request.Function != nullptr) {
m_DiffRequestGraph.setCurrentProcessingNode(request);
ProcessDiffRequest(request);
m_DiffRequestGraph.markCurrentNodeProcessed();
request = m_DiffRequestGraph.getNextToProcessNode();
if (!m_DiffRequestGraph.isProcessingNode()) {
// This check is to avoid recursive processing of the graph, as
// HandleTopLevelDecl can be called recursively in non-standard
// setup for code generation.
DiffRequest request = m_DiffRequestGraph.getNextToProcessNode();
while (request.Function) {
m_DiffRequestGraph.setCurrentProcessingNode(request);
ProcessDiffRequest(request);
m_DiffRequestGraph.markCurrentNodeProcessed();
request = m_DiffRequestGraph.getNextToProcessNode();
}
}

// Put the TUScope in a consistent state after clad is done.
Expand Down
9 changes: 5 additions & 4 deletions tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,12 @@ class CladTimerGroup {
bool HandleTopLevelDecl(clang::DeclGroupRef D) override {
if (D.isSingleDecl())
if (auto* FD = llvm::dyn_cast<clang::FunctionDecl>(D.getSingleDecl()))
if (m_DFC.IsDerivative(FD)) {
assert(!m_Multiplexer &&
"Must happen only if we failed to rearrange the consumers");
// If we build the derivative in a non-standard (with no Multiplexer)
// setup, we exit early to give control to the non-standard setup for
// code generation.
// FIXME: This should go away if Cling starts using the clang driver.
if (!m_Multiplexer && m_DFC.IsDerivative(FD))
return true;
}

HandleTopLevelDeclForClad(D);
AppendDelayed({CallKind::HandleTopLevelDecl, D});
Expand Down
10 changes: 10 additions & 0 deletions unittests/Misc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
add_clad_unittest(MiscTests
main.cpp
CallDeclOnly.cpp
Defs.cpp
DynamicGraph.cpp
)

# Create a library from Defs.cpp
add_library(Defs SHARED Defs.cpp)
enable_clad_for_executable(Defs)

# Link the library to the test
target_link_libraries(MiscTests PRIVATE Defs)

69 changes: 69 additions & 0 deletions unittests/Misc/CallDeclOnly.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#include "clad/Differentiator/Differentiator.h"

#include <iostream>
#include <regex>
#include <string>

#include "gtest/gtest.h"

double foo(double x, double alpha, double theta, double x0 = 0);

double wrapper1(double* params) {
const double ix = 1 + params[0];
return foo(10., ix, 1.0);
}

TEST(CallDeclOnly, CheckNumDiff) {
auto grad = clad::gradient(wrapper1, "params");
// Collect output of grad.dump() into a string as it ouputs using llvm::outs()
std::string actual;
testing::internal::CaptureStdout();
grad.dump();
actual = testing::internal::GetCapturedStdout();

// Check the generated code from grad.dump()
std::string expected = R"(The code is:
void wrapper1_grad(double *params, double *_d_params) {
double _d_ix = 0;
const double ix = 1 + params[0];
goto _label0;
_label0:
{
double _r0 = 0;
double _r1 = 0;
double _r2 = 0;
double _r3 = 0;
double _grad0[4] = {0};
numerical_diff::central_difference(foo, _grad0, 0, 10., ix, 1., 0);
_r0 += 1 * _grad0[0];
_r1 += 1 * _grad0[1];
_r2 += 1 * _grad0[2];
_r3 += 1 * _grad0[3];
_d_ix += _r1;
}
_d_params[0] += _d_ix;
}
)";
EXPECT_EQ(actual, expected);
}

namespace clad {
namespace custom_derivatives {
// Custom pushforward for the square function but definition will be linked from
// another file.
clad::ValueAndPushforward<double, double> sq_pushforward(double x, double _d_x);
} // namespace custom_derivatives
} // namespace clad

double sq(double x) { return x * x; }

double wrapper2(double* params) { return sq(params[0]); }

TEST(CallDeclOnly, CheckCustomDiff) {
auto grad = clad::hessian(wrapper2, "params[0]");
double x = 4.0;
double dx = 0.0;
grad.execute(&x, &dx);
EXPECT_DOUBLE_EQ(dx, 2.0);
}
25 changes: 25 additions & 0 deletions unittests/Misc/Defs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "clad/Differentiator/Differentiator.h"

double foo(double x, double alpha, double theta, double x0 = 0) {
return x * alpha * theta * x0;
}

namespace clad {
namespace custom_derivatives {
clad::ValueAndPushforward<double, double> sq_pushforward(double x,
double _d_x) {
return {x * x, 2 * x};
}

void sq_pushforward_pullback(double x, double _dx,
clad::ValueAndPushforward<double, double> _d_y,
double* _d_x, double* _d__d_x) {
goto _label0;
_label0: {
*_d_x += _d_y.value * x;
*_d_x += x * _d_y.value;
*_d_x += 2 * _d_y.pushforward;
}
}
} // namespace custom_derivatives
} // namespace clad

0 comments on commit bf2f64a

Please sign in to comment.