Skip to content

Commit

Permalink
Enable some cases of functor calls in custom pushforwards
Browse files Browse the repository at this point in the history
Previously, if a user wanted to provide a custom pushforward for a
function that uses functors in it, it was impossible to use generated
pushforwards for that functors' call operators. This commit aims to fix
this for basic functors that don't have multiple call operator overloads.

Fixes: vgvassilev#1023
  • Loading branch information
gojakuch committed Aug 13, 2024
1 parent 6cc83ee commit bba8991
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 14 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class BaseForwardModeVisitor

virtual void ExecuteInsidePushforwardFunctionBlock();

virtual void DifferentiateCallOperatorIfFunctor(clang::QualType QT);

static bool IsDifferentiableType(clang::QualType T);

virtual StmtDiff
Expand Down
60 changes: 60 additions & 0 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,8 @@ BaseForwardModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) {
}

StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
DifferentiateCallOperatorIfFunctor(DRE->getType());

DeclRefExpr* clonedDRE = nullptr;
// Check if referenced Decl was "replaced" with another identifier inside
// the derivative
Expand Down Expand Up @@ -1594,6 +1596,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
// If the DeclStmt is not empty, check the first declaration.
if (declsBegin != declsEnd && isa<VarDecl>(*declsBegin)) {
auto* VD = dyn_cast<VarDecl>(*declsBegin);
DifferentiateCallOperatorIfFunctor(VD->getType());
// Check for non-differentiable types.
QualType QT = VD->getType();
if (QT->isPointerType())
Expand Down Expand Up @@ -2057,8 +2060,65 @@ StmtDiff BaseForwardModeVisitor::VisitBreakStmt(const BreakStmt* stmt) {
return StmtDiff(Clone(stmt));
}

void BaseForwardModeVisitor::DifferentiateCallOperatorIfFunctor(
clang::QualType QT) {
// Identify if the constructed type is a functor. For functors, we need to
// differentiate their call operator once an object has been constructed, to
// allow user calls to pushforwards inside user-provided custom derivatives.
// FIXME: A much more scalable solution would be to create pushforwards once
// they're called from user-provided custom derivatives. This could then be
// applied to other operators besides operator() to avoid compilation errors
// in such cases.
if (auto* RD = QT->getAsCXXRecordDecl()) {
CXXRecordDecl* constructedType = RD->getDefinition();
bool isFunctor = constructedType && !constructedType->isLambda();
std::vector<const CXXMethodDecl*> callMethods;
if (isFunctor) {
for (const auto* method : constructedType->methods()) {
if (const auto* cxxMethod = dyn_cast<CXXMethodDecl>(method)) {
if (cxxMethod->isOverloadedOperator() &&
cxxMethod->getOverloadedOperator() == OO_Call) {
callMethods.push_back(cxxMethod);
}
}
}
isFunctor = isFunctor && !callMethods.empty();
}

if (isFunctor) {
for (const auto* FD : callMethods) {
CXXScopeSpec SS;
bool hasCustomDerivative =
!m_Builder
.LookupCustomDerivativeOrNumericalDiff(
clad::utils::ComputeEffectiveFnName(FD) +
GetPushForwardFunctionSuffix(),
const_cast<DeclContext*>(FD->getDeclContext()), SS)
.empty();

if (!hasCustomDerivative) {
// Request Clad to diff it.
DiffRequest pushforwardFnRequest;
pushforwardFnRequest.Function = FD;
pushforwardFnRequest.Mode = GetPushForwardMode();
pushforwardFnRequest.BaseFunctionName =
utils::ComputeEffectiveFnName(FD);
// Silence diag outputs in nested derivation process.
pushforwardFnRequest.VerboseDiags = false;

// Check if request already derived in DerivedFunctions.
m_Builder.HandleNestedDiffRequest(pushforwardFnRequest);
}
}
}
}
}

StmtDiff
BaseForwardModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) {
DifferentiateCallOperatorIfFunctor(CE->getType());

// Now continue differentiating the constructor itself:
llvm::SmallVector<Expr*, 4> clonedArgs, derivedArgs;
for (auto arg : CE->arguments()) {
auto argDiff = Visit(arg);
Expand Down
47 changes: 47 additions & 0 deletions test/ForwardMode/Functors.C
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,36 @@ struct WidgetPointer {
}
};

namespace clad {
namespace custom_derivatives {
template <typename F>
void use_functor_pushforward(double x, F& f, double d_x, F& d_f) {
f.operator_call_pushforward(x, &d_f, d_x);
}
}
}
template <typename F>
void use_functor(double x, F& f) {
f(x);
}

struct Foo {
double &y;
Foo(double &y): y(y) {}

double operator()(double x) {
y = 2*x;

return x;
}
};

double fn0(double x) {
Foo func = Foo{x};
use_functor(x, func);
return x;
}

#define INIT(E, ARG)\
auto d_##E = clad::differentiate(&E, ARG);\
auto d_##E##Ref = clad::differentiate(E, ARG);
Expand Down Expand Up @@ -504,4 +534,21 @@ int main() {
TEST_2(W_Arr_5, 6, 5); // CHECK-EXEC: 6.00 6.00
TEST_2(W_Pointer_3, 6, 5); // CHECK-EXEC: 37.00 37.00
TEST_2(W_Pointer_5, 6, 5); // CHECK-EXEC: 51.00 51.00

auto dfn0 = clad::differentiate(fn0, "x");
printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2
}

// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x);
// CHECK: double fn0_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: Foo _d_func = Foo{{[{]*_d_x[}]*}};
// CHECK-NEXT: Foo func = Foo{{[{]*x[}]*}};
// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT:}
// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x) {
// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x;
// CHECK-NEXT: this->y = 2 * x;
// CHECK-NEXT: return {x, _d_x};
// CHECK-NEXT:}
31 changes: 29 additions & 2 deletions test/ForwardMode/ReferenceArguments.C
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,31 @@

#include "clad/Differentiator/Differentiator.h"

namespace clad {
namespace custom_derivatives {
template <typename F>
void use_functor_pushforward(double &x, F& f, double &d_x, F& d_f) {
f.operator_call_pushforward(x, &d_f, d_x);
}
}
}
template <typename F>
void use_functor(double &x, F& f) {
f(x);
}

struct Foo {
double operator()(double& x) {
x = 2*x*x;
return x;
}
};

double fn0(double x, Foo& func) {
use_functor(x, func);
return x;
}

double fn1(double& i, double& j) {
double res = i * i * j;
return res;
Expand All @@ -21,12 +46,14 @@ double fn1(double& i, double& j) {
#define INIT(fn, ...) auto d_##fn = clad::differentiate(fn, __VA_ARGS__);

#define TEST(fn, ...) \
auto res = d_##fn.execute(__VA_ARGS__); \
printf("{%.2f}\n", res)
printf("{%.2f}\n", d_##fn.execute(__VA_ARGS__))

int main() {
INIT(fn0, "x");
INIT(fn1, "i");

double i = 3, j = 5;
TEST(fn1, i, j); // CHECK-EXEC: {30.00}
Foo fff;
TEST(fn0, i, fff); // CHECK-EXEC: {12.00}
}
24 changes: 12 additions & 12 deletions test/ForwardMode/UserDefinedTypes.C
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ Tensor<double, 5> fn5(double i, double j) {
return T;
}

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val);

// CHECK: Tensor<double, 5> fn5_darg0(double i, double j) {
// CHECK-NEXT: double _d_i = 1;
// CHECK-NEXT: double _d_j = 0;
Expand Down Expand Up @@ -593,8 +595,6 @@ TensorD5 fn11(double i, double j) {
return res1;
}

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val);

// CHECK: clad::ValueAndPushforward<double &, double &> operator_subscript_pushforward(std::size_t idx, Tensor<double, 5> *_d_this, std::size_t _d_idx);

// CHECK: clad::ValueAndPushforward<Tensor<double, 5U>, Tensor<double, 5U> > operator_plus_pushforward(const Tensor<double, 5U> &a, const Tensor<double, 5U> &b, const Tensor<double, 5U> &_d_a, const Tensor<double, 5U> &_d_b);
Expand Down Expand Up @@ -965,6 +965,16 @@ double fn18(double i, double j) {
// CHECK-NEXT: return _d_v[0].mem;
// CHECK-NEXT: }

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val) {
// CHECK-NEXT: {
// CHECK-NEXT: unsigned int _d_i = 0;
// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) {
// CHECK-NEXT: _d_this->data[i] = _d_val;
// CHECK-NEXT: this->data[i] = val;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

template<unsigned N>
void print(const Tensor<double, N>& t) {
for (int i=0; i<N; ++i) {
Expand Down Expand Up @@ -1071,16 +1081,6 @@ int main() {
// CHECK-NEXT: return {{[{](__imag )?}}this->[[_M_value:[a-zA-Z_]+]],{{( __imag)?}} _d_this->[[_M_value:[a-zA-Z_]+]]};
// CHECK-NEXT: }

// CHECK: void operator_call_pushforward(double val, Tensor<double, 5> *_d_this, double _d_val) {
// CHECK-NEXT: {
// CHECK-NEXT: unsigned int _d_i = 0;
// CHECK-NEXT: for (unsigned int i = 0; i < 5U; ++i) {
// CHECK-NEXT: _d_this->data[i] = _d_val;
// CHECK-NEXT: this->data[i] = val;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: clad::ValueAndPushforward<double &, double &> operator_subscript_pushforward(std::size_t idx, Tensor<double, 5> *_d_this, std::size_t _d_idx) {
// CHECK-NEXT: return {(double &)this->data[idx], (double &)_d_this->data[idx]};
// CHECK-NEXT: }
Expand Down
47 changes: 47 additions & 0 deletions test/Functors/Simple.C
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,50 @@ float f(float x) {
return x;
}

namespace clad {
namespace custom_derivatives {
template <typename F>
void use_functor_pushforward(double x, F& f, double d_x, F& d_f) {
f.operator_call_pushforward(x, &d_f, d_x);
}
}
}
template <typename F>
void use_functor(double x, F& f) {
f(x);
}

struct Foo {
double &y;
Foo(double &y): y(y) {}

double operator()(double x) {
y = 2*x;

return x;
}
};

double fn0(double x) {
Foo func = Foo({x});
use_functor(x, func);
return x;
}

// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x);
// CHECK: double fn0_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: Foo _d_func = Foo({_d_x});
// CHECK-NEXT: Foo func = Foo({x});
// CHECK-NEXT: clad::custom_derivatives::use_functor_pushforward(x, func, _d_x, _d_func);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT:}
// CHECK: clad::ValueAndPushforward<double, double> operator_call_pushforward(double x, Foo *_d_this, double _d_x) {
// CHECK-NEXT: _d_this->y = 0 * x + 2 * _d_x;
// CHECK-NEXT: this->y = 2 * x;
// CHECK-NEXT: return {x, _d_x};
// CHECK-NEXT:}

int main() {
AFunctor doubler;
int x = doubler(5);
Expand All @@ -73,5 +117,8 @@ int main() {
auto f1_darg1 = clad::differentiate(&SimpleExpression::operator(), 1);
printf("Result is = %f\n", f1_darg1.execute(expr, 3.5, 4.5)); // CHECK-EXEC: Result is = 9

auto dfn0 = clad::differentiate(fn0, "x");
printf("RES: %f\n", dfn0.execute(3.0)); // CHECK-EXEC: RES: 2

return 0;
}

0 comments on commit bba8991

Please sign in to comment.