diff --git a/include/clad/Differentiator/DerivedFnCollector.h b/include/clad/Differentiator/DerivedFnCollector.h index b0d297176..b20160a44 100644 --- a/include/clad/Differentiator/DerivedFnCollector.h +++ b/include/clad/Differentiator/DerivedFnCollector.h @@ -6,6 +6,7 @@ #include "clang/AST/Decl.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" namespace clad { @@ -14,10 +15,13 @@ namespace clad { /// making it possible to reuse previously computed derivatives. class DerivedFnCollector { using DerivedFns = llvm::SmallVector; + using DerivativeSet = llvm::SmallSet; /// Mapping to efficiently find out information about all the derivatives of /// a function. llvm::DenseMap m_DerivedFnInfoCollection; + /// Set to keep track of all the functions that are derivatives. + DerivativeSet m_DerivativeSet; public: /// Adds a derived function to the collection. diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index c7418eb6f..b079f554d 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -43,14 +43,6 @@ inline const char* DiffModeToString(DiffMode mode) { return "unknown"; } } - -/// Returns true if the given mode is a pullback/pushforward mode. -inline bool IsPullbackOrPushforwardMode(DiffMode mode) { - return mode == DiffMode::experimental_pushforward || - mode == DiffMode::experimental_pullback || - mode == DiffMode::experimental_vector_pushforward || - mode == DiffMode::reverse_mode_forward_pass; -} } #endif diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 35a252c60..70079ade9 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -294,12 +294,9 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { assert(FD && "Must not be null."); // If FD is only a declaration, try to find its definition. if (!FD->getDefinition()) { - // If only declaration is requested, allow this for non - // pullback/pushforward modes. For ex, this is required for Hessian - - // where we have forward mode followed by reverse mode, but we only need - // the declaration of the forward mode initially. - if (!request.DeclarationOnly || - IsPullbackOrPushforwardMode(request.Mode)) { + // If only declaration is requested, allow this for clad-generated + // functions. + if (!request.DeclarationOnly || !m_DFC.IsDerivative(FD)) { if (request.VerboseDiags) diag(DiagnosticsEngine::Error, request.CallContext ? request.CallContext->getBeginLoc() : noLoc, diff --git a/lib/Differentiator/DerivedFnCollector.cpp b/lib/Differentiator/DerivedFnCollector.cpp index 1e9c9a837..f32883689 100644 --- a/lib/Differentiator/DerivedFnCollector.cpp +++ b/lib/Differentiator/DerivedFnCollector.cpp @@ -8,6 +8,7 @@ void DerivedFnCollector::Add(const DerivedFnInfo& DFI) { "`DerivedFnCollector::Add` more than once for the same derivative " ". Ideally, we shouldn't do either."); m_DerivedFnInfoCollection[DFI.OriginalFn()].push_back(DFI); + m_DerivativeSet.insert(DFI.DerivedFn()); } bool DerivedFnCollector::AlreadyExists(const DerivedFnInfo& DFI) const { @@ -36,4 +37,8 @@ DerivedFnInfo DerivedFnCollector::Find(const DiffRequest& request) const { return DerivedFnInfo(); return *it; } + +bool DerivedFnCollector::IsDerivative(const clang::FunctionDecl* FD) const { + return m_DerivativeSet.count(FD); +} } // namespace clad \ No newline at end of file diff --git a/test/Hessian/BuiltinDerivatives.C b/test/Hessian/BuiltinDerivatives.C index 9d9106350..de3a37adc 100644 --- a/test/Hessian/BuiltinDerivatives.C +++ b/test/Hessian/BuiltinDerivatives.C @@ -13,12 +13,7 @@ float f1(float x) { return sin(x) + cos(x); } -// CHECK: float f1_darg0(float x) { -// CHECK-NEXT: float _d_x = 1; -// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, _d_x); -// CHECK-NEXT: ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, _d_x); -// CHECK-NEXT: return _t0.pushforward + _t1.pushforward; -// CHECK-NEXT: } +// CHECK: float f1_darg0(float x); // CHECK: void f1_darg0_grad(float x, float *_d_x); @@ -30,11 +25,7 @@ float f2(float x) { return exp(x); } -// CHECK: float f2_darg0(float x) { -// CHECK-NEXT: float _d_x = 1; -// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x); -// CHECK-NEXT: return _t0.pushforward; -// CHECK-NEXT: } +// CHECK: float f2_darg0(float x); // CHECK: void f2_darg0_grad(float x, float *_d_x); @@ -47,11 +38,7 @@ float f3(float x) { return log(x); } -// CHECK: float f3_darg0(float x) { -// CHECK-NEXT: float _d_x = 1; -// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::log_pushforward(x, _d_x); -// CHECK-NEXT: return _t0.pushforward; -// CHECK-NEXT: } +// CHECK: float f3_darg0(float x); // CHECK: void f3_darg0_grad(float x, float *_d_x); @@ -64,11 +51,7 @@ float f4(float x) { return pow(x, 4.0F); } -// CHECK: float f4_darg0(float x) { -// CHECK-NEXT: float _d_x = 1; -// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x, 0.F); -// CHECK-NEXT: return _t0.pushforward; -// CHECK-NEXT: } +// CHECK: float f4_darg0(float x); // CHECK: void f4_darg0_grad(float x, float *_d_x); @@ -81,11 +64,7 @@ float f5(float x) { return pow(2.0F, x); } -// CHECK: float f5_darg0(float x) { -// CHECK-NEXT: float _d_x = 1; -// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x); -// CHECK-NEXT: return _t0.pushforward; -// CHECK-NEXT: } +// CHECK: float f5_darg0(float x); // CHECK: void f5_darg0_grad(float x, float *_d_x); @@ -98,21 +77,11 @@ float f6(float x, float y) { return pow(x, y); } -// CHECK: float f6_darg0(float x, float y) { -// CHECK-NEXT: float _d_x = 1; -// CHECK-NEXT: float _d_y = 0; -// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y); -// CHECK-NEXT: return _t0.pushforward; -// CHECK-NEXT: } +// CHECK: float f6_darg0(float x, float y); // CHECK: void f6_darg0_grad(float x, float y, float *_d_x, float *_d_y); -// CHECK: float f6_darg1(float x, float y) { -// CHECK-NEXT: float _d_x = 0; -// CHECK-NEXT: float _d_y = 1; -// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y); -// CHECK-NEXT: return _t0.pushforward; -// CHECK-NEXT: } +// CHECK: float f6_darg1(float x, float y); // CHECK: void f6_darg1_grad(float x, float y, float *_d_x, float *_d_y); @@ -147,6 +116,13 @@ int main() { TEST1(f5, 3); // CHECK-EXEC: Result is = {3.84} TEST2(f6, 3, 4); // CHECK-EXEC: Result is = {108.00, 145.65, 145.65, 97.76} +// CHECK: float f1_darg0(float x) { +// CHECK-NEXT: float _d_x = 1; +// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::sin_pushforward(x, _d_x); +// CHECK-NEXT: ValueAndPushforward _t1 = clad::custom_derivatives{{(::std)?}}::cos_pushforward(x, _d_x); +// CHECK-NEXT: return _t0.pushforward + _t1.pushforward; +// CHECK-NEXT: } + // CHECK: void sin_pushforward_pullback(float x, float d_x, ValueAndPushforward _d_y, float *_d_x, float *_d_d_x); // CHECK: void cos_pushforward_pullback(float x, float d_x, ValueAndPushforward _d_y, float *_d_x, float *_d_d_x); @@ -180,6 +156,12 @@ int main() { // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK: float f2_darg0(float x) { +// CHECK-NEXT: float _d_x = 1; +// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::exp_pushforward(x, _d_x); +// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: } + // CHECK: void exp_pushforward_pullback(float x, float d_x, ValueAndPushforward _d_y, float *_d_x, float *_d_d_x); // CHECK: void f2_darg0_grad(float x, float *_d_x) { @@ -199,6 +181,12 @@ int main() { // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK: float f3_darg0(float x) { +// CHECK-NEXT: float _d_x = 1; +// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::log_pushforward(x, _d_x); +// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: } + // CHECK: void log_pushforward_pullback(float x, float d_x, ValueAndPushforward _d_y, float *_d_x, float *_d_d_x); // CHECK: void f3_darg0_grad(float x, float *_d_x) { @@ -218,6 +206,12 @@ int main() { // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK: float f4_darg0(float x) { +// CHECK-NEXT: float _d_x = 1; +// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, 4.F, _d_x, 0.F); +// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: } + // CHECK: void pow_pushforward_pullback(float x, float exponent, float d_x, float d_exponent, ValueAndPushforward _d_y, float *_d_x, float *_d_exponent, float *_d_d_x, float *_d_d_exponent); // CHECK: void f4_darg0_grad(float x, float *_d_x) { @@ -239,6 +233,12 @@ int main() { // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK: float f5_darg0(float x) { +// CHECK-NEXT: float _d_x = 1; +// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(2.F, x, 0.F, _d_x); +// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: } + // CHECK: void f5_darg0_grad(float x, float *_d_x) { // CHECK-NEXT: float _d__d_x = 0; // CHECK-NEXT: ValueAndPushforward _d__t0 = {}; @@ -258,6 +258,13 @@ int main() { // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK: float f6_darg0(float x, float y) { +// CHECK-NEXT: float _d_x = 1; +// CHECK-NEXT: float _d_y = 0; +// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y); +// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: } + // CHECK: void f6_darg0_grad(float x, float y, float *_d_x, float *_d_y) { // CHECK-NEXT: float _d__d_x = 0; // CHECK-NEXT: float _d__d_y = 0; @@ -281,6 +288,13 @@ int main() { // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK: float f6_darg1(float x, float y) { +// CHECK-NEXT: float _d_x = 0; +// CHECK-NEXT: float _d_y = 1; +// CHECK-NEXT: ValueAndPushforward _t0 = clad::custom_derivatives{{(::std)?}}::pow_pushforward(x, y, _d_x, _d_y); +// CHECK-NEXT: return _t0.pushforward; +// CHECK-NEXT: } + // CHECK: void f6_darg1_grad(float x, float y, float *_d_x, float *_d_y) { // CHECK-NEXT: float _d__d_x = 0; // CHECK-NEXT: float _d__d_y = 0; diff --git a/test/Hessian/Hessians.C b/test/Hessian/Hessians.C index a854bb9de..485c06e28 100644 --- a/test/Hessian/Hessians.C +++ b/test/Hessian/Hessians.C @@ -10,13 +10,7 @@ __attribute__((always_inline)) double f_cubed_add1(double a, double b) { return a * a * a + b * b * b; } -//CHECK:{{[__attribute__((always_inline)) ]*}}double f_cubed_add1_darg0(double a, double b){{[ __attribute__((always_inline))]*}} { -//CHECK-NEXT: double _d_a = 1; -//CHECK-NEXT: double _d_b = 0; -//CHECK-NEXT: double _t0 = a * a; -//CHECK-NEXT: double _t1 = b * b; -//CHECK-NEXT: return (_d_a * a + a * _d_a) * a + _t0 * _d_a + (_d_b * b + b * _d_b) * b + _t1 * _d_b; -//CHECK-NEXT:} +//CHECK:{{[__attribute__((always_inline)) ]*}}double f_cubed_add1_darg0(double a, double b){{[ __attribute__((always_inline))]*}}; void f_cubed_add1_darg0_grad(double a, double b, double *_d_a, double *_d_b); //CHECK:{{[__attribute__((always_inline)) ]*}}void f_cubed_add1_darg0_grad(double a, double b, double *_d_a, double *_d_b){{[ __attribute__((always_inline))]*}}; @@ -212,6 +206,14 @@ int main() { TEST3(&Widget::memFn_2, W, 7, 9); // CHECK-EXEC: Result is = {5400.00, 4200.00, 4200.00, 0.00} TEST2(fn_def_arg, 3, 5); // CHECK-EXEC: Result is = {0.00, 2.00, 2.00, 0.00} +//CHECK:{{[__attribute__((always_inline)) ]*}}double f_cubed_add1_darg0(double a, double b){{[ __attribute__((always_inline))]*}} { +//CHECK-NEXT: double _d_a = 1; +//CHECK-NEXT: double _d_b = 0; +//CHECK-NEXT: double _t0 = a * a; +//CHECK-NEXT: double _t1 = b * b; +//CHECK-NEXT: return (_d_a * a + a * _d_a) * a + _t0 * _d_a + (_d_b * b + b * _d_b) * b + _t1 * _d_b; +//CHECK-NEXT:} + //CHECK:{{[__attribute__((always_inline)) ]*}}void f_cubed_add1_darg0_grad(double a, double b, double *_d_a, double *_d_b){{[ __attribute__((always_inline))]*}} { //CHECK-NEXT: double _d__d_a = 0; //CHECK-NEXT: double _d__d_b = 0; diff --git a/test/Hessian/NestedFunctionCalls.C b/test/Hessian/NestedFunctionCalls.C index 65021b12c..e23077ec7 100644 --- a/test/Hessian/NestedFunctionCalls.C +++ b/test/Hessian/NestedFunctionCalls.C @@ -17,30 +17,9 @@ double f2(double x, double y){ return ans; } -// CHECK: clad::ValueAndPushforward f_pushforward(double x, double y, double _d_x, double _d_y); - -// CHECK: double f2_darg0(double x, double y) { -// CHECK-NEXT: double _d_x = 1; -// CHECK-NEXT: double _d_y = 0; -// CHECK-NEXT: clad::ValueAndPushforward _t0 = f_pushforward(x, y, _d_x, _d_y); -// CHECK-NEXT: double _d_ans = _t0.pushforward; -// CHECK-NEXT: double ans = _t0.value; -// CHECK-NEXT: return _d_ans; -// CHECK-NEXT: } - - +// CHECK: double f2_darg0(double x, double y); // CHECK: void f2_darg0_grad(double x, double y, double *_d_x, double *_d_y); - - -// CHECK: double f2_darg1(double x, double y) { -// CHECK-NEXT: double _d_x = 0; -// CHECK-NEXT: double _d_y = 1; -// CHECK-NEXT: clad::ValueAndPushforward _t0 = f_pushforward(x, y, _d_x, _d_y); -// CHECK-NEXT: double _d_ans = _t0.pushforward; -// CHECK-NEXT: double ans = _t0.value; -// CHECK-NEXT: return _d_ans; -// CHECK-NEXT: } - +// CHECK: double f2_darg1(double x, double y); // CHECK: void f2_darg1_grad(double x, double y, double *_d_x, double *_d_y); // CHECK: void f2_hessian(double x, double y, double *hessianMatrix) { @@ -48,12 +27,17 @@ double f2(double x, double y){ // CHECK-NEXT: f2_darg1_grad(x, y, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}}); // CHECK-NEXT: } +// CHECK: clad::ValueAndPushforward f_pushforward(double x, double y, double _d_x, double _d_y); -// CHECK: clad::ValueAndPushforward f_pushforward(double x, double y, double _d_x, double _d_y) { -// CHECK-NEXT: return {x * x + y * y, _d_x * x + x * _d_x + _d_y * y + y * _d_y}; +// CHECK: double f2_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: clad::ValueAndPushforward _t0 = f_pushforward(x, y, _d_x, _d_y); +// CHECK-NEXT: double _d_ans = _t0.pushforward; +// CHECK-NEXT: double ans = _t0.value; +// CHECK-NEXT: return _d_ans; // CHECK-NEXT: } - // CHECK: void f_pushforward_pullback(double x, double y, double _d_x, double _d_y, clad::ValueAndPushforward _d_y0, double *_d_x, double *_d_y, double *_d__d_x, double *_d__d_y); // CHECK: void f2_darg0_grad(double x, double y, double *_d_x, double *_d_y) { @@ -85,6 +69,14 @@ double f2(double x, double y){ // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK: double f2_darg1(double x, double y) { +// CHECK-NEXT: double _d_x = 0; +// CHECK-NEXT: double _d_y = 1; +// CHECK-NEXT: clad::ValueAndPushforward _t0 = f_pushforward(x, y, _d_x, _d_y); +// CHECK-NEXT: double _d_ans = _t0.pushforward; +// CHECK-NEXT: double ans = _t0.value; +// CHECK-NEXT: return _d_ans; +// CHECK-NEXT: } // CHECK: void f2_darg1_grad(double x, double y, double *_d_x, double *_d_y) { // CHECK-NEXT: double _d__d_x = 0; @@ -115,6 +107,10 @@ double f2(double x, double y){ // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK: clad::ValueAndPushforward f_pushforward(double x, double y, double _d_x, double _d_y) { +// CHECK-NEXT: return {x * x + y * y, _d_x * x + x * _d_x + _d_y * y + y * _d_y}; +// CHECK-NEXT: } + // CHECK: void f_pushforward_pullback(double x, double y, double _d_x, double _d_y, clad::ValueAndPushforward _d_y0, double *_d_x, double *_d_y, double *_d__d_x, double *_d__d_y) { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: diff --git a/test/Hessian/Pointers.C b/test/Hessian/Pointers.C index 06ffe91c8..9cc959baf 100644 --- a/test/Hessian/Pointers.C +++ b/test/Hessian/Pointers.C @@ -10,20 +10,9 @@ double nonMemFn(double i, double j) { return i*j; } -// CHECK: double nonMemFn_darg0(double i, double j) { -// CHECK-NEXT: double _d_i = 1; -// CHECK-NEXT: double _d_j = 0; -// CHECK-NEXT: return _d_i * j + i * _d_j; -// CHECK-NEXT: } - +// CHECK: double nonMemFn_darg0(double i, double j); // CHECK: void nonMemFn_darg0_grad(double i, double j, double *_d_i, double *_d_j); - -// CHECK: double nonMemFn_darg1(double i, double j) { -// CHECK-NEXT: double _d_i = 0; -// CHECK-NEXT: double _d_j = 1; -// CHECK-NEXT: return _d_i * j + i * _d_j; -// CHECK-NEXT: } - +// CHECK: double nonMemFn_darg1(double i, double j); // CHECK: void nonMemFn_darg1_grad(double i, double j, double *_d_i, double *_d_j); // CHECK: void nonMemFn_hessian(double i, double j, double *hessianMatrix) { @@ -31,6 +20,12 @@ double nonMemFn(double i, double j) { // CHECK-NEXT: nonMemFn_darg1_grad(i, j, hessianMatrix + {{2U|2UL}}, hessianMatrix + {{3U|3UL}}); // CHECK-NEXT: } +// CHECK: double nonMemFn_darg0(double i, double j) { +// CHECK-NEXT: double _d_i = 1; +// CHECK-NEXT: double _d_j = 0; +// CHECK-NEXT: return _d_i * j + i * _d_j; +// CHECK-NEXT: } + // CHECK: void nonMemFn_darg0_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: double _d__d_i = 0; // CHECK-NEXT: double _d__d_j = 0; @@ -46,6 +41,12 @@ double nonMemFn(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK: double nonMemFn_darg1(double i, double j) { +// CHECK-NEXT: double _d_i = 0; +// CHECK-NEXT: double _d_j = 1; +// CHECK-NEXT: return _d_i * j + i * _d_j; +// CHECK-NEXT: } + // CHECK: void nonMemFn_darg1_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: double _d__d_i = 0; // CHECK-NEXT: double _d__d_j = 0;