Skip to content

Commit

Permalink
Add test for calling functions with no defn
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed May 20, 2024
1 parent 12fd325 commit 07d2f4f
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
10 changes: 10 additions & 0 deletions unittests/Misc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
add_clad_unittest(MiscTests
main.cpp
CallDeclOnly.cpp
Defs.cpp
DynamicGraph.cpp
)

# Create a library from Defs.cpp
add_library(Defs SHARED Defs.cpp)
enable_clad_for_executable(Defs)

# Link the library to the test
target_link_libraries(MiscTests PRIVATE Defs)

78 changes: 78 additions & 0 deletions unittests/Misc/CallDeclOnly.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include "clad/Differentiator/Differentiator.h"

#include <iostream>
#include <string>

#include "gtest/gtest.h"

double foo(double x, double alpha, double theta, double x0 = 0);

double wrapper1(double* params) {
const double ix = 1 + params[0];
return foo(10., ix, 1.0);
}

TEST(CallDeclOnly, CheckNumDiff) {
auto grad = clad::gradient(wrapper1, "params");
// Collect output of grad.dump() into a string as it ouputs using llvm::outs()
std::string actual;
testing::internal::CaptureStdout();
grad.dump();
actual = testing::internal::GetCapturedStdout();

// Check the generated code from grad.dump()
std::string expected = R"(The code is:
void wrapper1_grad(double *params, double *_d_params) {
double _d_ix = 0;
const double ix = 1 + params[0];
goto _label0;
_label0:
{
double _r0 = 0;
double _r1 = 0;
double _r2 = 0;
double _r3 = 0;
double _grad0[4] = {0};
numerical_diff::central_difference(foo, _grad0, 0, 10., ix, 1., 0);
_r0 += 1 * _grad0[0];
_r1 += 1 * _grad0[1];
_r2 += 1 * _grad0[2];
_r3 += 1 * _grad0[3];
_d_ix += _r1;
}
_d_params[0] += _d_ix;
}
)";
EXPECT_EQ(actual, expected);
}

namespace clad {
namespace custom_derivatives {
// Custom pushforward for the square function but definition will be linked from
// another file.
clad::ValueAndPushforward<double, double> sq_pushforward(double x, double _d_x);
} // namespace custom_derivatives
} // namespace clad

double sq(double x) { return x * x; }

double wrapper2(double* params) { return sq(params[0]); }

TEST(CallDeclOnly, CheckCustomDiff) {
auto grad = clad::hessian(wrapper2, "params[0]");
// Collect output of grad.dump() into a string as it ouputs using llvm::outs()
std::string actual;
testing::internal::CaptureStdout();
grad.dump();
actual = testing::internal::GetCapturedStdout();

// Check the generated code from grad.dump()
std::string expected = R"(The code is:
void wrapper2_hessian(double *params, double *hessianMatrix) {
wrapper2_darg0_0_grad(params, hessianMatrix + 0UL);
}
)";
EXPECT_EQ(actual, expected);
}
25 changes: 25 additions & 0 deletions unittests/Misc/Defs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "clad/Differentiator/Differentiator.h"

double foo(double x, double alpha, double theta, double x0 = 0) {
return x * alpha * theta * x0;
}

namespace clad {
namespace custom_derivatives {
clad::ValueAndPushforward<double, double> sq_pushforward(double x,
double _d_x) {
return {x * x, 2 * x};
}

void sq_pushforward_pullback(double x, double _dx,
clad::ValueAndPushforward<double, double> _d_y,
double* _d_x, double* _d__d_x) {
goto _label0;
_label0 : {
*_d_x += _d_y.value * x;
*_d_x += x * _d_y.value;
*_d_x += 2 * _d_y.pushforward;
}
}
} // namespace custom_derivatives
} // namespace clad

0 comments on commit 07d2f4f

Please sign in to comment.