Skip to content

Commit

Permalink
Bring back DerivativeSet and use it to fix Hessian ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 30, 2024
1 parent 3c248d1 commit b2ed85b
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 116 deletions.
5 changes: 0 additions & 5 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,6 @@ namespace clad {
///
/// \param[in] request The request to add the edge to.
void AddEdgeToGraph(const DiffRequest& request);
/// Add edge between two requests in the DiffRequest graph.
///
/// \param[in] from The source request.
/// \param[in] to The destination request.
void AddEdgeToGraph(const DiffRequest& from, const DiffRequest& to);
};

} // end namespace clad
Expand Down
4 changes: 4 additions & 0 deletions include/clad/Differentiator/DerivedFnCollector.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "clang/AST/Decl.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"

namespace clad {
Expand All @@ -14,10 +15,13 @@ namespace clad {
/// making it possible to reuse previously computed derivatives.
class DerivedFnCollector {
using DerivedFns = llvm::SmallVector<DerivedFnInfo, 16>;
using DerivativeSet = llvm::SmallSet<const clang::FunctionDecl*, 16>;
/// Mapping to efficiently find out information about all the derivatives of
/// a function.
llvm::DenseMap<const clang::FunctionDecl*, DerivedFns>
m_DerivedFnInfoCollection;
/// Set to keep track of all the functions that are derivatives.
DerivativeSet m_DerivativeSet;

public:
/// Adds a derived function to the collection.
Expand Down
8 changes: 0 additions & 8 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,6 @@ inline const char* DiffModeToString(DiffMode mode) {
return "unknown";
}
}

/// Returns true if the given mode is a pullback/pushforward mode.
inline bool IsPullbackOrPushforwardMode(DiffMode mode) {
return mode == DiffMode::experimental_pushforward ||
mode == DiffMode::experimental_pullback ||
mode == DiffMode::experimental_vector_pushforward ||
mode == DiffMode::reverse_mode_forward_pass;
}
}

#endif
4 changes: 1 addition & 3 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ struct DiffRequest {
/// Define the == operator for DiffRequest.
bool operator==(const DiffRequest& other) const {
// either function match or previous declaration match
return (Function == other.Function ||
Function->getPreviousDecl() == other.Function ||
Function == other.Function->getPreviousDecl()) &&
return Function == other.Function &&
BaseFunctionName == other.BaseFunctionName &&
CurrentDerivativeOrder == other.CurrentDerivativeOrder &&
RequestedDerivativeOrder == other.RequestedDerivativeOrder &&
Expand Down
5 changes: 2 additions & 3 deletions include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ template <typename T> class DynamicGraph {

// Adds the edge from the current node to the destination node.
void addEdgeToCurrentNode(const T& dest) {
if (m_currentId == -1)
return;
addEdge(m_nodes[m_currentId], dest);
if (m_currentId != -1)
addEdge(m_nodes[m_currentId], dest);
}

// Set the current node to the node with the given id.
Expand Down
14 changes: 3 additions & 11 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
assert(FD && "Must not be null.");
// If FD is only a declaration, try to find its definition.
if (!FD->getDefinition()) {
// If only declaration is requested, allow this for non
// pullback/pushforward modes. For ex, this is required for Hessian -
// where we have forward mode followed by reverse mode, but we only need
// the declaration of the forward mode initially.
if (!request.DeclarationOnly ||
IsPullbackOrPushforwardMode(request.Mode)) {
// If only declaration is requested, allow this for clad-generated
// functions.
if (!request.DeclarationOnly || !m_DFC.IsDerivative(FD)) {
if (request.VerboseDiags)
diag(DiagnosticsEngine::Error,
request.CallContext ? request.CallContext->getBeginLoc() : noLoc,
Expand Down Expand Up @@ -404,9 +401,4 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& request) {
m_DiffRequestGraph.addEdgeToCurrentNode(request);
}

void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& from,
const DiffRequest& to) {
m_DiffRequestGraph.addEdge(from, to);
}
}// end namespace clad
5 changes: 5 additions & 0 deletions lib/Differentiator/DerivedFnCollector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ void DerivedFnCollector::Add(const DerivedFnInfo& DFI) {
"`DerivedFnCollector::Add` more than once for the same derivative "
". Ideally, we shouldn't do either.");
m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI);
m_DerivativeSet.insert(DFI.DerivedFn());
}

bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const {
Expand Down Expand Up @@ -36,4 +37,8 @@ DerivedFnInfo DerivedFnCollector::Find(const DiffRequest& request) const {
return DerivedFnInfo();
return *it;
}

bool DerivedFnCollector::IsDerivative(const clang::FunctionDecl* FD) const {
return m_DerivativeSet.count(FD);
}
} // namespace clad
90 changes: 52 additions & 38 deletions test/Hessian/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@ float f1(float x) {
return sin(x) + cos(x);
}

// CHECK: float f1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, _d_x);
// CHECK-NEXT: ValueAndPushforward<float, float> _t1 = clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward + _t1.pushforward;
// CHECK-NEXT: }
// CHECK: float f1_darg0(float x);

// CHECK: void f1_darg0_grad(float x, float *_d_x);

Expand All @@ -30,11 +25,7 @@ float f2(float x) {
return exp(x);
}

// CHECK: float f2_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f2_darg0(float x);

// CHECK: void f2_darg0_grad(float x, float *_d_x);

Expand All @@ -47,11 +38,7 @@ float f3(float x) {
return log(x);
}

// CHECK: float f3_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::log_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f3_darg0(float x);

// CHECK: void f3_darg0_grad(float x, float *_d_x);

Expand All @@ -64,11 +51,7 @@ float f4(float x) {
return pow(x, 4.0F);
}

// CHECK: float f4_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x, 0.F);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f4_darg0(float x);

// CHECK: void f4_darg0_grad(float x, float *_d_x);

Expand All @@ -81,11 +64,7 @@ float f5(float x) {
return pow(2.0F, x);
}

// CHECK: float f5_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f5_darg0(float x);

// CHECK: void f5_darg0_grad(float x, float *_d_x);

Expand All @@ -98,21 +77,11 @@ float f6(float x, float y) {
return pow(x, y);
}

// CHECK: float f6_darg0(float x, float y) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f6_darg0(float x, float y);

// CHECK: void f6_darg0_grad(float x, float y, float *_d_x, float *_d_y);

// CHECK: float f6_darg1(float x, float y) {
// CHECK-NEXT: float _d_x = 0;
// CHECK-NEXT: float _d_y = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }
// CHECK: float f6_darg1(float x, float y);

// CHECK: void f6_darg1_grad(float x, float y, float *_d_x, float *_d_y);

Expand Down Expand Up @@ -147,6 +116,13 @@ int main() {
TEST1(f5, 3); // CHECK-EXEC: Result is = {3.84}
TEST2(f6, 3, 4); // CHECK-EXEC: Result is = {108.00, 145.65, 145.65, 97.76}

// CHECK: float f1_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, _d_x);
// CHECK-NEXT: ValueAndPushforward<float, float> _t1 = clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward + _t1.pushforward;
// CHECK-NEXT: }

// CHECK: void sin_pushforward_pullback(float x, float d_x, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_d_x);

// CHECK: void cos_pushforward_pullback(float x, float d_x, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_d_x);
Expand Down Expand Up @@ -180,6 +156,12 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f2_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void exp_pushforward_pullback(float x, float d_x, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_d_x);

// CHECK: void f2_darg0_grad(float x, float *_d_x) {
Expand All @@ -199,6 +181,12 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f3_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<float, float> _t0 = clad::custom_derivatives{{(::std)?}}::log_pushforward(x, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void log_pushforward_pullback(float x, float d_x, ValueAndPushforward<float, float> _d_y, float *_d_x, float *_d_d_x);

// CHECK: void f3_darg0_grad(float x, float *_d_x) {
Expand All @@ -218,6 +206,12 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f4_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x, 0.F);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void pow_pushforward_pullback(float x, float exponent, float d_x, float d_exponent, ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d_y, float *_d_x, float *_d_exponent, float *_d_d_x, float *_d_d_exponent);

// CHECK: void f4_darg0_grad(float x, float *_d_x) {
Expand All @@ -239,6 +233,12 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f5_darg0(float x) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void f5_darg0_grad(float x, float *_d_x) {
// CHECK-NEXT: float _d__d_x = 0;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _d__t0 = {};
Expand All @@ -258,6 +258,13 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f6_darg0(float x, float y) {
// CHECK-NEXT: float _d_x = 1;
// CHECK-NEXT: float _d_y = 0;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void f6_darg0_grad(float x, float y, float *_d_x, float *_d_y) {
// CHECK-NEXT: float _d__d_x = 0;
// CHECK-NEXT: float _d__d_y = 0;
Expand All @@ -281,6 +288,13 @@ int main() {
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: float f6_darg1(float x, float y) {
// CHECK-NEXT: float _d_x = 0;
// CHECK-NEXT: float _d_y = 1;
// CHECK-NEXT: ValueAndPushforward<decltype(::std::pow(float(), float())), decltype(::std::pow(float(), float()))> _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y);
// CHECK-NEXT: return _t0.pushforward;
// CHECK-NEXT: }

// CHECK: void f6_darg1_grad(float x, float y, float *_d_x, float *_d_y) {
// CHECK-NEXT: float _d__d_x = 0;
// CHECK-NEXT: float _d__d_y = 0;
Expand Down
16 changes: 9 additions & 7 deletions test/Hessian/Hessians.C
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@
__attribute__((always_inline)) double f_cubed_add1(double a, double b) {
return a * a * a + b * b * b;
}
//CHECK:{{[__attribute__((always_inline)) ]*}}double f_cubed_add1_darg0(double a, double b){{[ __attribute__((always_inline))]*}} {
//CHECK-NEXT: double _d_a = 1;
//CHECK-NEXT: double _d_b = 0;
//CHECK-NEXT: double _t0 = a * a;
//CHECK-NEXT: double _t1 = b * b;
//CHECK-NEXT: return (_d_a * a + a * _d_a) * a + _t0 * _d_a + (_d_b * b + b * _d_b) * b + _t1 * _d_b;
//CHECK-NEXT:}
//CHECK:{{[__attribute__((always_inline)) ]*}}double f_cubed_add1_darg0(double a, double b){{[ __attribute__((always_inline))]*}};

void f_cubed_add1_darg0_grad(double a, double b, double *_d_a, double *_d_b);
//CHECK:{{[__attribute__((always_inline)) ]*}}void f_cubed_add1_darg0_grad(double a, double b, double *_d_a, double *_d_b){{[ __attribute__((always_inline))]*}};
Expand Down Expand Up @@ -212,6 +206,14 @@ int main() {
TEST3(&Widget::memFn_2, W, 7, 9); // CHECK-EXEC: Result is = {5400.00, 4200.00, 4200.00, 0.00}
TEST2(fn_def_arg, 3, 5); // CHECK-EXEC: Result is = {0.00, 2.00, 2.00, 0.00}

//CHECK:{{[__attribute__((always_inline)) ]*}}double f_cubed_add1_darg0(double a, double b){{[ __attribute__((always_inline))]*}} {
//CHECK-NEXT: double _d_a = 1;
//CHECK-NEXT: double _d_b = 0;
//CHECK-NEXT: double _t0 = a * a;
//CHECK-NEXT: double _t1 = b * b;
//CHECK-NEXT: return (_d_a * a + a * _d_a) * a + _t0 * _d_a + (_d_b * b + b * _d_b) * b + _t1 * _d_b;
//CHECK-NEXT:}

//CHECK:{{[__attribute__((always_inline)) ]*}}void f_cubed_add1_darg0_grad(double a, double b, double *_d_a, double *_d_b){{[ __attribute__((always_inline))]*}} {
//CHECK-NEXT: double _d__d_a = 0;
//CHECK-NEXT: double _d__d_b = 0;
Expand Down
Loading

0 comments on commit b2ed85b

Please sign in to comment.