diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index fa7c09ad6..6e5deeea9 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -47,6 +47,8 @@ class BaseForwardModeVisitor virtual void ExecuteInsidePushforwardFunctionBlock(); + virtual void DifferentiateCallOperatorIfFunctor(clang::QualType QT); + static bool IsDifferentiableType(clang::QualType T); virtual StmtDiff diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 7a70f3bcb..d9c5bb504 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1029,6 +1029,8 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) { } StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { + DifferentiateCallOperatorIfFunctor(DRE->getType()); + DeclRefExpr* clonedDRE = nullptr; // Check if referenced Decl was "replaced" with another identifier inside // the derivative @@ -1594,6 +1596,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { // If the DeclStmt is not empty, check the first declaration. if (declsBegin != declsEnd && isa(*declsBegin)) { auto* VD = dyn_cast(*declsBegin); + DifferentiateCallOperatorIfFunctor(VD->getType()); // Check for non-differentiable types. QualType QT = VD->getType(); if (QT->isPointerType()) @@ -2057,8 +2060,65 @@ StmtDiff BaseForwardModeVisitor::VisitBreakStmt(const BreakStmt* stmt) { return StmtDiff(Clone(stmt)); } +void BaseForwardModeVisitor::DifferentiateCallOperatorIfFunctor( + clang::QualType QT) { + // Identify if the constructed type is a functor. For functors, we need to + // differentiate their call operator once an object has been constructed, to + // allow user calls to pushforwards inside user-provided custom derivatives. + // FIXME: A much more scalable solution would be to create pushforwards once + // they're called from user-provided custom derivatives. This could then be + // applied to other operators besides operator() to avoid compilation errors + // in such cases. + if (auto* RD = QT->getAsCXXRecordDecl()) { + CXXRecordDecl* constructedType = RD->getDefinition(); + bool isFunctor = constructedType && !constructedType->isLambda(); + std::vector callMethods; + if (isFunctor) { + for (const auto* method : constructedType->methods()) { + if (const auto* cxxMethod = dyn_cast(method)) { + if (cxxMethod->isOverloadedOperator() && + cxxMethod->getOverloadedOperator() == OO_Call) { + callMethods.push_back(cxxMethod); + } + } + } + isFunctor = isFunctor && !callMethods.empty(); + } + + if (isFunctor) { + for (const auto* FD : callMethods) { + CXXScopeSpec SS; + bool hasCustomDerivative = + !m_Builder + .LookupCustomDerivativeOrNumericalDiff( + clad::utils::ComputeEffectiveFnName(FD) + + GetPushForwardFunctionSuffix(), + const_cast(FD->getDeclContext()), SS) + .empty(); + + if (!hasCustomDerivative) { + // Request Clad to diff it. + DiffRequest pushforwardFnRequest; + pushforwardFnRequest.Function = FD; + pushforwardFnRequest.Mode = GetPushForwardMode(); + pushforwardFnRequest.BaseFunctionName = + utils::ComputeEffectiveFnName(FD); + // Silence diag outputs in nested derivation process. + pushforwardFnRequest.VerboseDiags = false; + + // Check if request already derived in DerivedFunctions. + m_Builder.HandleNestedDiffRequest(pushforwardFnRequest); + } + } + } + } +} + StmtDiff BaseForwardModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) { + DifferentiateCallOperatorIfFunctor(CE->getType()); + + // Now continue differentiating the constructor itself: llvm::SmallVector clonedArgs, derivedArgs; for (auto arg : CE->arguments()) { auto argDiff = Visit(arg); diff --git a/test/ForwardMode/Functors.C b/test/ForwardMode/Functors.C index fdfcab92b..05d45e7f2 100644 --- a/test/ForwardMode/Functors.C +++ b/test/ForwardMode/Functors.C @@ -386,6 +386,36 @@ struct WidgetPointer { } }; +namespace clad { +namespace custom_derivatives { + template + void use_functor_pushforward(double x, F& f, double d_x, F& d_f) { + f.operator_call_pushforward(x, &d_f, d_x); + } +} +} +template +void use_functor(double x, F& f) { + f(x); +} + +struct Foo { + double &y; + Foo(double &y): y(y) {} + + double operator()(double x) { + y = 2*x; + + return x; + } +}; + +double fn0(double x) { + Foo func = Foo{x}; + use_functor(x, func); + return x; +} + #define INIT(E, ARG)\ auto d_##E = clad::differentiate(&E, ARG);\ auto d_##E##Ref = clad::differentiate(E, ARG); @@ -504,4 +534,21 @@ int main() { TEST_2(W_Arr_5, 6, 5); // CHECK-EXEC: 6.00 6.00 TEST_2(W_Pointer_3, 6, 5); // CHECK-EXEC: 37.00 37.00 TEST_2(W_Pointer_5, 6, 5); // CHECK-EXEC: 51.00 51.00 + + auto dfn0 = clad::differentiate(fn0, "x"); + printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2 } + +// CHECK: clad::ValueAndPushforward operator_call_pushforward(double x, Foo *_d_this, double _d_x); +// CHECK: double fn0_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: Foo _d_func = Foo{{[{]*_d_x[}]*}}; +// CHECK-NEXT: Foo func = Foo{{[{]*x[}]*}}; +// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func); +// CHECK-NEXT: return _d_x; +// CHECK-NEXT:} +// CHECK: clad::ValueAndPushforward operator_call_pushforward(double x, Foo *_d_this, double _d_x) { +// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x; +// CHECK-NEXT: this->y = 2 * x; +// CHECK-NEXT: return {x, _d_x}; +// CHECK-NEXT:} \ No newline at end of file diff --git a/test/ForwardMode/ReferenceArguments.C b/test/ForwardMode/ReferenceArguments.C index 0800bdacb..4d2be654c 100644 --- a/test/ForwardMode/ReferenceArguments.C +++ b/test/ForwardMode/ReferenceArguments.C @@ -4,6 +4,31 @@ #include "clad/Differentiator/Differentiator.h" +namespace clad { +namespace custom_derivatives { + template + void use_functor_pushforward(double &x, F& f, double &d_x, F& d_f) { + f.operator_call_pushforward(x, &d_f, d_x); + } +} +} +template +void use_functor(double &x, F& f) { + f(x); +} + +struct Foo { + double operator()(double& x) { + x = 2*x*x; + return x; + } +}; + +double fn0(double x, Foo& func) { + use_functor(x, func); + return x; +} + double fn1(double& i, double& j) { double res = i * i * j; return res; @@ -21,12 +46,14 @@ double fn1(double& i, double& j) { #define INIT(fn, ...) auto d_##fn = clad::differentiate(fn, __VA_ARGS__); #define TEST(fn, ...) \ - auto res = d_##fn.execute(__VA_ARGS__); \ - printf("{%.2f}\n", res) + printf("{%.2f}\n", d_##fn.execute(__VA_ARGS__)) int main() { + INIT(fn0, "x"); INIT(fn1, "i"); double i = 3, j = 5; TEST(fn1, i, j); // CHECK-EXEC: {30.00} + Foo fff; + TEST(fn0, i, fff); // CHECK-EXEC: {12.00} } diff --git a/test/ForwardMode/UserDefinedTypes.C b/test/ForwardMode/UserDefinedTypes.C index 82ef5f54f..70340577d 100644 --- a/test/ForwardMode/UserDefinedTypes.C +++ b/test/ForwardMode/UserDefinedTypes.C @@ -420,6 +420,8 @@ Tensor fn5(double i, double j) { return T; } +// CHECK: void operator_call_pushforward(double val, Tensor *_d_this, double _d_val); + // CHECK: Tensor fn5_darg0(double i, double j) { // CHECK-NEXT: double _d_i = 1; // CHECK-NEXT: double _d_j = 0; @@ -593,8 +595,6 @@ TensorD5 fn11(double i, double j) { return res1; } -// CHECK: void operator_call_pushforward(double val, Tensor *_d_this, double _d_val); - // CHECK: clad::ValueAndPushforward operator_subscript_pushforward(std::size_t idx, Tensor *_d_this, std::size_t _d_idx); // CHECK: clad::ValueAndPushforward, Tensor > operator_plus_pushforward(const Tensor &a, const Tensor &b, const Tensor &_d_a, const Tensor &_d_b); @@ -965,6 +965,16 @@ double fn18(double i, double j) { // CHECK-NEXT: return _d_v[0].mem; // CHECK-NEXT: } +// CHECK: void operator_call_pushforward(double val, Tensor *_d_this, double _d_val) { +// CHECK-NEXT: { +// CHECK-NEXT: unsigned int _d_i = 0; +// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) { +// CHECK-NEXT: _d_this->data[i] = _d_val; +// CHECK-NEXT: this->data[i] = val; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + template void print(const Tensor& t) { for (int i=0; i[[_M_value:[a-zA-Z_]+]],{{( __imag)?}} _d_this->[[_M_value:[a-zA-Z_]+]]}; // CHECK-NEXT: } -// CHECK: void operator_call_pushforward(double val, Tensor *_d_this, double _d_val) { -// CHECK-NEXT: { -// CHECK-NEXT: unsigned int _d_i = 0; -// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) { -// CHECK-NEXT: _d_this->data[i] = _d_val; -// CHECK-NEXT: this->data[i] = val; -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } - // CHECK: clad::ValueAndPushforward operator_subscript_pushforward(std::size_t idx, Tensor *_d_this, std::size_t _d_idx) { // CHECK-NEXT: return {(double &)this->data[idx], (double &)_d_this->data[idx]}; // CHECK-NEXT: } diff --git a/test/Functors/Simple.C b/test/Functors/Simple.C index d24dd7d17..21e455d31 100644 --- a/test/Functors/Simple.C +++ b/test/Functors/Simple.C @@ -58,6 +58,50 @@ float f(float x) { return x; } +namespace clad { +namespace custom_derivatives { + template + void use_functor_pushforward(double x, F& f, double d_x, F& d_f) { + f.operator_call_pushforward(x, &d_f, d_x); + } +} +} +template +void use_functor(double x, F& f) { + f(x); +} + +struct Foo { + double &y; + Foo(double &y): y(y) {} + + double operator()(double x) { + y = 2*x; + + return x; + } +}; + +double fn0(double x) { + Foo func = Foo({x}); + use_functor(x, func); + return x; +} + +// CHECK: clad::ValueAndPushforward operator_call_pushforward(double x, Foo *_d_this, double _d_x); +// CHECK: double fn0_darg0(double x) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: Foo _d_func = Foo({_d_x}); +// CHECK-NEXT: Foo func = Foo({x}); +// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func); +// CHECK-NEXT: return _d_x; +// CHECK-NEXT:} +// CHECK: clad::ValueAndPushforward operator_call_pushforward(double x, Foo *_d_this, double _d_x) { +// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x; +// CHECK-NEXT: this->y = 2 * x; +// CHECK-NEXT: return {x, _d_x}; +// CHECK-NEXT:} + int main() { AFunctor doubler; int x = doubler(5); @@ -73,5 +117,8 @@ int main() { auto f1_darg1 = clad::differentiate(&SimpleExpression::operator(), 1); printf("Result is = %f\n", f1_darg1.execute(expr, 3.5, 4.5)); // CHECK-EXEC: Result is = 9 + auto dfn0 = clad::differentiate(fn0, "x"); + printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2 + return 0; }