Skip to content

Commit

Permalink
Add support for simple lambda functions in forward mode
Browse files Browse the repository at this point in the history
This commit provides support for the simplest lambda functions,
that is, those with no captures in forward mode. The original
lambda function is copied into the derivative, but the
corresponding lambda class gets extended to also have a pushforward
method for the call operator overload.
  • Loading branch information
gojakuch committed Jun 17, 2024
1 parent de8a6f6 commit 3161ee1
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 21 deletions.
65 changes: 44 additions & 21 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "clad/Differentiator/ErrorEstimator.h"

#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTLambda.h"
#include "clang/AST/Expr.h"
#include "clang/AST/TemplateBase.h"
#include "clang/Sema/Lookup.h"
Expand Down Expand Up @@ -1039,6 +1040,9 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {

SourceLocation validLoc{CE->getBeginLoc()};

// Calls to lambda functions are processed differently
bool isLambda = isLambdaCallOperator(FD);

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
Expand Down Expand Up @@ -1070,12 +1074,24 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
// a direct or indirect (operator overload) call to member function.
StmtDiff baseDiff;
// Add derivative of the implicit `this` pointer to the `diffArgs`.
if (auto MD = dyn_cast<CXXMethodDecl>(FD)) {
if (isLambda) {
if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE)) {
QualType ptrType = m_Context.getPointerType(m_Context.getRecordType(
FD->getDeclContext()->getOuterLexicalRecordContext()));
// For now, only lambdas with no captures are supported, so we just pass
// a nullptr instead of the diff object.
baseDiff =
StmtDiff(Clone(OCE->getArg(0)),
new (m_Context) CXXNullPtrLiteralExpr(ptrType, validLoc));
diffArgs.push_back(baseDiff.getExpr_dx());
}
} else if (const auto* MD =
dyn_cast<CXXMethodDecl>(FD)) { // isLambda == false
if (MD->isInstance()) {
const Expr* baseOriginalE = nullptr;
if (auto MCE = dyn_cast<CXXMemberCallExpr>(CE))
if (const auto* MCE = dyn_cast<CXXMemberCallExpr>(CE))
baseOriginalE = MCE->getImplicitObjectArgument();
else if (auto OCE = dyn_cast<CXXOperatorCallExpr>(CE))
else if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE))
baseOriginalE = OCE->getArg(0);
baseDiff = Visit(baseOriginalE);
Expr* baseDerivative = baseDiff.getExpr_dx();
Expand Down Expand Up @@ -1137,34 +1153,36 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {

auto customDerivativeArgs = pushforwardFnArgs;

if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
if (Expr* baseE = baseDiff.getExpr()) {
if (!baseE->getType()->isPointerType())
baseE = BuildOp(UnaryOperatorKind::UO_AddrOf, baseE);
customDerivativeArgs.insert(customDerivativeArgs.begin(), baseE);
}

// Try to find a user-defined overloaded derivative.
Expr* callDiff = nullptr;
std::string customPushforward =
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));

// Check if it is a recursive call.
if (!callDiff && (FD == m_DiffReq.Function) &&
m_Mode == GetPushForwardMode()) {
// The differentiated function is called recursively.
Expr* derivativeRef =
m_Sema
.BuildDeclarationNameExpr(CXXScopeSpec(),
m_Derivative->getNameInfo(), m_Derivative)
.get();
callDiff =
m_Sema
.ActOnCallExpr(m_Sema.getScopeForContext(m_Sema.CurContext),
derivativeRef, validLoc, pushforwardFnArgs, validLoc)
.get();
if (!isLambda) {
// Check if it is a recursive call.
if (!callDiff && (FD == m_DiffReq.Function) &&
m_Mode == GetPushForwardMode()) {
// The differentiated function is called recursively.
Expr* derivativeRef =
m_Sema
.BuildDeclarationNameExpr(
CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative)
.get();
callDiff = m_Sema
.ActOnCallExpr(
m_Sema.getScopeForContext(m_Sema.CurContext),
derivativeRef, validLoc, pushforwardFnArgs, validLoc)
.get();
}
}

// If all arguments are constant literals, then this does not contribute to
Expand Down Expand Up @@ -1493,7 +1511,12 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
if (QT->isPointerType())
QT = QT->getPointeeType();
auto* typeDecl = QT->getAsCXXRecordDecl();
if (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl)) {
// For lambda functions, we should also simply copy the original lambda. The
// differentiation of lambdas is happening in the `VisitCallExpr`. For now,
// only the declarations with lambda expressions without captures are
// supported.
if (typeDecl && (clad::utils::hasNonDifferentiableAttribute(typeDecl) ||
typeDecl->isLambda())) {
for (auto* D : DS->decls()) {
if (auto* VD = dyn_cast<VarDecl>(D))
decls.push_back(VD);
Expand Down
29 changes: 29 additions & 0 deletions test/ForwardMode/Lambdas.C
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: %cladclang %s -I%S/../../include -oLambdas.out 2>&1 | %filecheck %s
// RUN: ./Lambdas.out | %filecheck_exec %s
// CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"

double fn0(double x) {
auto _f = [](double _x) {
return _x*_x;
};
return _f(x) + 1;
}

double fn1(double x, double y) {
auto _f = [](double _x, double _y) {
return _x + _y;
};
return _f(x*x, x+2) + y;
}

int main() {
auto fn0_dx = clad::differentiate(fn0, 0);
printf("Result is = %.2f\n", fn0_dx.execute(7)); // CHECK-EXEC: Result is = 14.00
printf("Result is = %.2f\n", fn0_dx.execute(-1)); // CHECK-EXEC: Result is = -2.00

auto fn1_dx = clad::differentiate(fn1, 0);
printf("Result is = %.2f\n", fn1_dx.execute(7, 1)); // CHECK-EXEC: Result is = 15.00
printf("Result is = %.2f\n", fn1_dx.execute(-1, 1)); // CHECK-EXEC: Result is = -1.00
}

0 comments on commit 3161ee1

Please sign in to comment.