Skip to content

Commit

Permalink
Bring back DerivativeSet and use it to fix Hessian ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 30, 2024
1 parent 3c248d1 commit 153c3a1
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 98 deletions.
4 changes: 4 additions & 0 deletions include/clad/Differentiator/DerivedFnCollector.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "clang/AST/Decl.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"

namespace clad {
Expand All @@ -14,10 +15,13 @@ namespace clad {
/// making it possible to reuse previously computed derivatives.
class DerivedFnCollector {
using DerivedFns = llvm::SmallVector<DerivedFnInfo, 16>;
using DerivativeSet = llvm::SmallSet<const clang::FunctionDecl*, 16>;
/// Mapping to efficiently find out information about all the derivatives of
/// a function.
llvm::DenseMap<const clang::FunctionDecl*, DerivedFns>
m_DerivedFnInfoCollection;
/// Set to keep track of all the functions that are derivatives.
DerivativeSet m_DerivativeSet;

public:
/// Adds a derived function to the collection.
Expand Down
8 changes: 0 additions & 8 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,6 @@ inline const char* DiffModeToString(DiffMode mode) {
return "unknown";
}
}

/// Returns true if the given mode is a pullback/pushforward mode.
inline bool IsPullbackOrPushforwardMode(DiffMode mode) {
return mode == DiffMode::experimental_pushforward ||
mode == DiffMode::experimental_pullback ||
mode == DiffMode::experimental_vector_pushforward ||
mode == DiffMode::reverse_mode_forward_pass;
}
}

#endif
9 changes: 3 additions & 6 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
assert(FD && "Must not be null.");
// If FD is only a declaration, try to find its definition.
if (!FD->getDefinition()) {
// If only declaration is requested, allow this for non
// pullback/pushforward modes. For ex, this is required for Hessian -
// where we have forward mode followed by reverse mode, but we only need
// the declaration of the forward mode initially.
if (!request.DeclarationOnly ||
IsPullbackOrPushforwardMode(request.Mode)) {
// If only declaration is requested, allow this for clad-generated
// functions.
if (!request.DeclarationOnly || !m_DFC.IsDerivative(FD)) {
if (request.VerboseDiags)
diag(DiagnosticsEngine::Error,
request.CallContext ? request.CallContext->getBeginLoc() : noLoc,
Expand Down
5 changes: 5 additions & 0 deletions lib/Differentiator/DerivedFnCollector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ void DerivedFnCollector::Add(const DerivedFnInfo& DFI) {
"`DerivedFnCollector::Add` more than once for the same derivative "
". Ideally, we shouldn't do either.");
m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI);
m_DerivativeSet.insert(DFI.DerivedFn());
}

bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const {
Expand Down Expand Up @@ -36,4 +37,8 @@ DerivedFnInfo DerivedFnCollector::Find(const DiffRequest& request) const {
return DerivedFnInfo();
return *it;
}

bool DerivedFnCollector::IsDerivative(const clang::FunctionDecl* FD) const {
return m_DerivativeSet.count(FD);
}
} // namespace clad
90 changes: 52 additions & 38 deletions test/Hessian/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@ float f1(float x) {
return sin(x) + cos(x);
}

// CHECK: float f1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, _d_x);
// CHECK-NEXT: ValueAndPushforward<float, float> _t1 = clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward + _t1.pushforward;
// CHECK-NEXT: }
// CHECK: float f1_darg0(float x);

// CHECK: void f1_darg0_grad(float x, float *_d_x);

Expand All @@ -30,11 +25,7 @@ float f2(float x) {
return exp(x);
}

// CHECK: float f2_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f2_darg0(float x);

// CHECK: void f2_darg0_grad(float x, float *_d_x);

Expand All @@ -47,11 +38,7 @@ float f3(float x) {
return log(x);
}

// CHECK: float f3_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::log_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f3_darg0(float x);

// CHECK: void f3_darg0_grad(float x, float *_d_x);

Expand All @@ -64,11 +51,7 @@ float f4(float x) {
return pow(x, 4.0F);
}

// CHECK: float f4_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x, 0.F);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f4_darg0(float x);

// CHECK: void f4_darg0_grad(float x, float *_d_x);

Expand All @@ -81,11 +64,7 @@ float f5(float x) {
return pow(2.0F, x);
}

// CHECK: float f5_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f5_darg0(float x);

// CHECK: void f5_darg0_grad(float x, float *_d_x);

Expand All @@ -98,21 +77,11 @@ float f6(float x, float y) {
return pow(x, y);
}

// CHECK: float f6_darg0(float x, float y) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f6_darg0(float x, float y);

// CHECK: void f6_darg0_grad(float x, float y, float *_d_x, float *_d_y);

// CHECK: float f6_darg1(float x, float y) {
// CHECK-NEXT: float _d_x = 0;
// CHECK-NEXT: float _d_y = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f6_darg1(float x, float y);

// CHECK: void f6_darg1_grad(float x, float y, float *_d_x, float *_d_y);

Expand Down Expand Up @@ -147,6 +116,13 @@ int main() {
TEST1(f5, 3); // CHECK-EXEC: Result is = {3.84}
TEST2(f6, 3, 4); // CHECK-EXEC: Result is = {108.00, 145.65, 145.65, 97.76}

// CHECK: float f1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, _d_x);
// CHECK-NEXT: ValueAndPushforward<float, float> _t1 = clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward + _t1.pushforward;
// CHECK-NEXT: }

// CHECK: void sin_pushforward_pullback(float x, float d_x, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_d_x);

// CHECK: void cos_pushforward_pullback(float x, float d_x, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_d_x);
Expand Down Expand Up @@ -180,6 +156,12 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f2_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void exp_pushforward_pullback(float x, float d_x, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_d_x);

// CHECK: void f2_darg0_grad(float x, float *_d_x) {
Expand All @@ -199,6 +181,12 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f3_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::log_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void log_pushforward_pullback(float x, float d_x, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_d_x);

// CHECK: void f3_darg0_grad(float x, float *_d_x) {
Expand All @@ -218,6 +206,12 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f4_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x, 0.F);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void pow_pushforward_pullback(float x, float exponent, float d_x, float d_exponent, ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d_y, float *_d_x, float *_d_exponent, float *_d_d_x, float *_d_d_exponent);

// CHECK: void f4_darg0_grad(float x, float *_d_x) {
Expand All @@ -239,6 +233,12 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f5_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void f5_darg0_grad(float x, float *_d_x) {
// CHECK-NEXT: float _d__d_x = 0;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d__t0 = {};
Expand All @@ -258,6 +258,13 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f6_darg0(float x, float y) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void f6_darg0_grad(float x, float y, float *_d_x, float *_d_y) {
// CHECK-NEXT: float _d__d_x = 0;
// CHECK-NEXT: float _d__d_y = 0;
Expand All @@ -281,6 +288,13 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f6_darg1(float x, float y) {
// CHECK-NEXT: float _d_x = 0;
// CHECK-NEXT: float _d_y = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void f6_darg1_grad(float x, float y, float *_d_x, float *_d_y) {
// CHECK-NEXT: float _d__d_x = 0;
// CHECK-NEXT: float _d__d_y = 0;
Expand Down
16 changes: 9 additions & 7 deletions test/Hessian/Hessians.C
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@
__attribute__((always_inline)) double f_cubed_add1(double a, double b) {
return a * a * a + b * b * b;
}
//CHECK:{{[__attribute__((always_inline)) ]*}}double f_cubed_add1_darg0(double a, double b){{[ __attribute__((always_inline))]*}} {
//CHECK-NEXT: double _d_a = 1;
//CHECK-NEXT: double _d_b = 0;
//CHECK-NEXT: double _t0 = a * a;
//CHECK-NEXT: double _t1 = b * b;
//CHECK-NEXT: return (_d_a * a + a * _d_a) * a + _t0 * _d_a + (_d_b * b + b * _d_b) * b + _t1 * _d_b;
//CHECK-NEXT:}
//CHECK:{{[__attribute__((always_inline)) ]*}}double f_cubed_add1_darg0(double a, double b){{[ __attribute__((always_inline))]*}};

void f_cubed_add1_darg0_grad(double a, double b, double *_d_a, double *_d_b);
//CHECK:{{[__attribute__((always_inline)) ]*}}void f_cubed_add1_darg0_grad(double a, double b, double *_d_a, double *_d_b){{[ __attribute__((always_inline))]*}};
Expand Down Expand Up @@ -212,6 +206,14 @@ int main() {
TEST3(&Widget::memFn_2, W, 7, 9); // CHECK-EXEC: Result is = {5400.00, 4200.00, 4200.00, 0.00}
TEST2(fn_def_arg, 3, 5); // CHECK-EXEC: Result is = {0.00, 2.00, 2.00, 0.00}

//CHECK:{{[__attribute__((always_inline)) ]*}}double f_cubed_add1_darg0(double a, double b){{[ __attribute__((always_inline))]*}} {
//CHECK-NEXT: double _d_a = 1;
//CHECK-NEXT: double _d_b = 0;
//CHECK-NEXT: double _t0 = a * a;
//CHECK-NEXT: double _t1 = b * b;
//CHECK-NEXT: return (_d_a * a + a * _d_a) * a + _t0 * _d_a + (_d_b * b + b * _d_b) * b + _t1 * _d_b;
//CHECK-NEXT:}

//CHECK:{{[__attribute__((always_inline)) ]*}}void f_cubed_add1_darg0_grad(double a, double b, double *_d_a, double *_d_b){{[ __attribute__((always_inline))]*}} {
//CHECK-NEXT: double _d__d_a = 0;
//CHECK-NEXT: double _d__d_b = 0;
Expand Down
48 changes: 22 additions & 26 deletions test/Hessian/NestedFunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,27 @@ double f2(double x, double y){
return ans;
}

// CHECK: clad::ValueAndPushforward<double, double> f_pushforward(double x, double y, double _d_x, double _d_y);

// CHECK: double f2_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = f_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: double _d_ans = _t0.pushforward;
// CHECK-NEXT: double ans = _t0.value;
// CHECK-NEXT: return _d_ans;
// CHECK-NEXT: }


// CHECK: double f2_darg0(double x, double y);
// CHECK: void f2_darg0_grad(double x, double y, double *_d_x, double *_d_y);


// CHECK: double f2_darg1(double x, double y) {
// CHECK-NEXT: double _d_x = 0;
// CHECK-NEXT: double _d_y = 1;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = f_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: double _d_ans = _t0.pushforward;
// CHECK-NEXT: double ans = _t0.value;
// CHECK-NEXT: return _d_ans;
// CHECK-NEXT: }

// CHECK: double f2_darg1(double x, double y);
// CHECK: void f2_darg1_grad(double x, double y, double *_d_x, double *_d_y);

// CHECK: void f2_hessian(double x, double y, double *hessianMatrix) {
// CHECK-NEXT: f2_darg0_grad(x, y, hessianMatrix + {{0U|0UL}}, hessianMatrix + {{1U|1UL}});
// CHECK-NEXT: f2_darg1_grad(x, y, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}});
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double, double> f_pushforward(double x, double y, double _d_x, double _d_y);

// CHECK: clad::ValueAndPushforward<double, double> f_pushforward(double x, double y, double _d_x, double _d_y) {
// CHECK-NEXT: return {x * x + y * y, _d_x * x + x * _d_x + _d_y * y + y * _d_y};
// CHECK: double f2_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = f_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: double _d_ans = _t0.pushforward;
// CHECK-NEXT: double ans = _t0.value;
// CHECK-NEXT: return _d_ans;
// CHECK-NEXT: }


// CHECK: void f_pushforward_pullback(double x, double y, double _d_x, double _d_y, clad::ValueAndPushforward<double, double> _d_y0, double *_d_x, double *_d_y, double *_d__d_x, double *_d__d_y);

// CHECK: void f2_darg0_grad(double x, double y, double *_d_x, double *_d_y) {
Expand Down Expand Up @@ -85,6 +69,14 @@ double f2(double x, double y){
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: double f2_darg1(double x, double y) {
// CHECK-NEXT: double _d_x = 0;
// CHECK-NEXT: double _d_y = 1;
// CHECK-NEXT: clad::ValueAndPushforward<double, double> _t0 = f_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: double _d_ans = _t0.pushforward;
// CHECK-NEXT: double ans = _t0.value;
// CHECK-NEXT: return _d_ans;
// CHECK-NEXT: }

// CHECK: void f2_darg1_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: double _d__d_x = 0;
Expand Down Expand Up @@ -115,6 +107,10 @@ double f2(double x, double y){
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double, double> f_pushforward(double x, double y, double _d_x, double _d_y) {
// CHECK-NEXT: return {x * x + y * y, _d_x * x + x * _d_x + _d_y * y + y * _d_y};
// CHECK-NEXT: }

// CHECK: void f_pushforward_pullback(double x, double y, double _d_x, double _d_y, clad::ValueAndPushforward<double, double> _d_y0, double *_d_x, double *_d_y, double *_d__d_x, double *_d__d_y) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
Expand Down
Loading

0 comments on commit 153c3a1

Please sign in to comment.