Skip to content

Commit

Permalink
Add support for simple lambda functions in forward mode
Browse files Browse the repository at this point in the history
This commit provides support for the simplest lambda functions,
that is, those with no captures in forward mode. The original
lambda function is copied into the derivative, but the
corresponding lambda class gets extended to also have a pushforward
method for the call operator overload.
  • Loading branch information
gojakuch committed Jun 15, 2024
1 parent 4ebd1af commit b25d6d1
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 30 deletions.
85 changes: 55 additions & 30 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,10 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
"Differentiation of only direct calls is supported. Ignored");
return StmtDiff(Clone(CE));
}
// Calls to lambda functions are processed differently
bool isLambda =
(FD->getDeclContext()->isRecord() &&
FD->getDeclContext()->getOuterLexicalRecordContext()->isLambda());

SourceLocation validLoc{CE->getBeginLoc()};

Expand Down Expand Up @@ -1064,18 +1068,32 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// a direct or indirect (operator overload) call to member function.
StmtDiff baseDiff;
// Add derivative of the implicit `this` pointer to the `diffArgs`.
if (auto MD = dyn_cast<CXXMethodDecl>(FD)) {
if (MD->isInstance()) {
const Expr* baseOriginalE = nullptr;
if (auto MCE = dyn_cast<CXXMemberCallExpr>(CE))
baseOriginalE = MCE->getImplicitObjectArgument();
else if (auto OCE = dyn_cast<CXXOperatorCallExpr>(CE))
baseOriginalE = OCE->getArg(0);
baseDiff = Visit(baseOriginalE);
Expr* baseDerivative = baseDiff.getExpr_dx();
if (!baseDerivative->getType()->isPointerType())
baseDerivative = BuildOp(UnaryOperatorKind::UO_AddrOf, baseDerivative);
diffArgs.push_back(baseDerivative);
if (const auto* MD = dyn_cast<CXXMethodDecl>(FD)) {
if (isLambda) {
if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE)) {
QualType ptrType = m_Context.getPointerType(m_Context.getRecordType(
FD->getDeclContext()->getOuterLexicalRecordContext()));
// For now, only lambdas with no captures are supported, so we just pass
// a nullptr instead of the diff object.
baseDiff =
StmtDiff(Clone(OCE->getArg(0)),
new (m_Context) CXXNullPtrLiteralExpr(ptrType, validLoc));
diffArgs.push_back(baseDiff.getExpr_dx());
}
} else { // isLambda == false
if (MD->isInstance()) {
const Expr* baseOriginalE = nullptr;
if (const auto* MCE = dyn_cast<CXXMemberCallExpr>(CE))
baseOriginalE = MCE->getImplicitObjectArgument();
else if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE))
baseOriginalE = OCE->getArg(0);
baseDiff = Visit(baseOriginalE);
Expr* baseDerivative = baseDiff.getExpr_dx();
if (!baseDerivative->getType()->isPointerType())
baseDerivative =
BuildOp(UnaryOperatorKind::UO_AddrOf, baseDerivative);
diffArgs.push_back(baseDerivative);
}
}
}

Expand Down Expand Up @@ -1131,34 +1149,36 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {

auto customDerivativeArgs = pushforwardFnArgs;

if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
if (Expr* baseE = baseDiff.getExpr()) {
if (!baseE->getType()->isPointerType())
baseE = BuildOp(UnaryOperatorKind::UO_AddrOf, baseE);
customDerivativeArgs.insert(customDerivativeArgs.begin(), baseE);
}

// Try to find a user-defined overloaded derivative.
Expr* callDiff = nullptr;
std::string customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));

// Check if it is a recursive call.
if (!callDiff && (FD == m_DiffReq.Function) &&
m_Mode == GetPushForwardMode()) {
// The differentiated function is called recursively.
Expr* derivativeRef =
m_Sema
.BuildDeclarationNameExpr(CXXScopeSpec(),
m_Derivative->getNameInfo(), m_Derivative)
.get();
callDiff =
m_Sema
.ActOnCallExpr(m_Sema.getScopeForContext(m_Sema.CurContext),
derivativeRef, validLoc, pushforwardFnArgs, validLoc)
.get();
if (!isLambda) {
// Check if it is a recursive call.
if (!callDiff && (FD == m_DiffReq.Function) &&
m_Mode == GetPushForwardMode()) {
// The differentiated function is called recursively.
Expr* derivativeRef =
m_Sema
.BuildDeclarationNameExpr(
CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative)
.get();
callDiff = m_Sema
.ActOnCallExpr(
m_Sema.getScopeForContext(m_Sema.CurContext),
derivativeRef, validLoc, pushforwardFnArgs, validLoc)
.get();
}
}

// If all arguments are constant literals, then this does not contribute to
Expand Down Expand Up @@ -1499,7 +1519,12 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
if (QT->isPointerType())
QT = QT->getPointeeType();
auto* typeDecl = QT->getAsCXXRecordDecl();
if (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl)) {
// For lambda functions, we should also simply copy the original lambda. The
// differentiation of lambdas is happening in the `VisitCallExpr`. For now,
// only the declarations with lambda expressions without captures are
// supported.
if (typeDecl && (clad::utils::hasNonDifferentiableAttribute(typeDecl) ||
typeDecl->isLambda())) {
for (auto* D : DS->decls()) {
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
Expand Down
29 changes: 29 additions & 0 deletions test/ForwardMode/Lambdas.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: %cladclang %s -I%S/../../include -oLambdas.out 2>&1 | %filecheck %s
// RUN: ./Lambdas.out | %filecheck_exec %s
// CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"

double fn0(double x) {
auto _f = [](double _x) {
return _x*_x;
};
return _f(x) + 1;
}

double fn1(double x, double y) {
auto _f = [](double _x, double _y) {
return _x + _y;
};
return _f(x*x, x+2) + y;
}

int main() {
auto fn0_dx = clad::differentiate(fn0, 0);
printf("Result is = %.2f\n", fn0_dx.execute(7)); // CHECK-EXEC: Result is = 14.00
printf("Result is = %.2f\n", fn0_dx.execute(-1)); // CHECK-EXEC: Result is = -2.00

auto fn1_dx = clad::differentiate(fn1, 0);
printf("Result is = %.2f\n", fn1_dx.execute(7, 1)); // CHECK-EXEC: Result is = 15.00
printf("Result is = %.2f\n", fn1_dx.execute(-1, 1)); // CHECK-EXEC: Result is = -1.00
}

0 comments on commit b25d6d1

Please sign in to comment.