From 42c45ede8a46be0d4093c9b9915e96a8fde0cb26 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 12 Jun 2024 21:43:30 +0200 Subject: [PATCH] Add support for simple lambda functions in forward mode This commit provides support for the simplest lambda functions, that is, those with no captures in forward mode. The original lambda function is copied into the derivative, but the corresponding lambda class gets extended to also have a pushforward method for the call operator overload. --- lib/Differentiator/BaseForwardModeVisitor.cpp | 66 ++++++++++++------- test/ForwardMode/Lambdas.C | 29 ++++++++ 2 files changed, 73 insertions(+), 22 deletions(-) create mode 100644 test/ForwardMode/Lambdas.C diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index f93daee7d..90dee815c 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -13,6 +13,7 @@ #include "clad/Differentiator/ErrorEstimator.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/ASTLambda.h" #include "clang/AST/Expr.h" #include "clang/AST/TemplateBase.h" #include "clang/Sema/Lookup.h" @@ -1036,9 +1037,12 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { "Differentiation of only direct calls is supported. Ignored"); return StmtDiff(Clone(CE)); } - + const auto* MD = dyn_cast(FD); SourceLocation validLoc{CE->getBeginLoc()}; + // Calls to lambda functions are processed differently + bool isLambda = MD && isLambdaCallOperator(MD); + // If the function is non_differentiable, return zero derivative. if (clad::utils::hasNonDifferentiableAttribute(CE)) { // Calling the function without computing derivatives @@ -1070,12 +1074,23 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // a direct or indirect (operator overload) call to member function. StmtDiff baseDiff; // Add derivative of the implicit `this` pointer to the `diffArgs`. - if (auto MD = dyn_cast(FD)) { + if (isLambda) { + if (const auto* OCE = dyn_cast(CE)) { + QualType ptrType = m_Context.getPointerType(m_Context.getRecordType( + FD->getDeclContext()->getOuterLexicalRecordContext())); + // For now, only lambdas with no captures are supported, so we just pass + // a nullptr instead of the diff object. + baseDiff = + StmtDiff(Clone(OCE->getArg(0)), + new (m_Context) CXXNullPtrLiteralExpr(ptrType, validLoc)); + diffArgs.push_back(baseDiff.getExpr_dx()); + } + } else if (MD) { // isLambda == false if (MD->isInstance()) { const Expr* baseOriginalE = nullptr; - if (auto MCE = dyn_cast(CE)) + if (const auto* MCE = dyn_cast(CE)) baseOriginalE = MCE->getImplicitObjectArgument(); - else if (auto OCE = dyn_cast(CE)) + else if (const auto* OCE = dyn_cast(CE)) baseOriginalE = OCE->getArg(0); baseDiff = Visit(baseOriginalE); Expr* baseDerivative = baseDiff.getExpr_dx(); @@ -1137,34 +1152,36 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { auto customDerivativeArgs = pushforwardFnArgs; - if (baseDiff.getExpr()) { - Expr* baseE = baseDiff.getExpr(); + if (Expr* baseE = baseDiff.getExpr()) { if (!baseE->getType()->isPointerType()) baseE = BuildOp(UnaryOperatorKind::UO_AddrOf, baseE); customDerivativeArgs.insert(customDerivativeArgs.begin(), baseE); } // Try to find a user-defined overloaded derivative. + Expr* callDiff = nullptr; std::string customPushforward = clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix(); - Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( + callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforward, customDerivativeArgs, getCurrentScope(), const_cast(FD->getDeclContext())); - // Check if it is a recursive call. - if (!callDiff && (FD == m_DiffReq.Function) && - m_Mode == GetPushForwardMode()) { - // The differentiated function is called recursively. - Expr* derivativeRef = - m_Sema - .BuildDeclarationNameExpr(CXXScopeSpec(), - m_Derivative->getNameInfo(), m_Derivative) - .get(); - callDiff = - m_Sema - .ActOnCallExpr(m_Sema.getScopeForContext(m_Sema.CurContext), - derivativeRef, validLoc, pushforwardFnArgs, validLoc) - .get(); + if (!isLambda) { + // Check if it is a recursive call. + if (!callDiff && (FD == m_DiffReq.Function) && + m_Mode == GetPushForwardMode()) { + // The differentiated function is called recursively. + Expr* derivativeRef = + m_Sema + .BuildDeclarationNameExpr( + CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative) + .get(); + callDiff = m_Sema + .ActOnCallExpr( + m_Sema.getScopeForContext(m_Sema.CurContext), + derivativeRef, validLoc, pushforwardFnArgs, validLoc) + .get(); + } } // If all arguments are constant literals, then this does not contribute to @@ -1493,7 +1510,12 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { if (QT->isPointerType()) QT = QT->getPointeeType(); auto* typeDecl = QT->getAsCXXRecordDecl(); - if (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl)) { + // For lambda functions, we should also simply copy the original lambda. The + // differentiation of lambdas is happening in the `VisitCallExpr`. For now, + // only the declarations with lambda expressions without captures are + // supported. + if (typeDecl && (clad::utils::hasNonDifferentiableAttribute(typeDecl) || + typeDecl->isLambda())) { for (auto* D : DS->decls()) { if (auto* VD = dyn_cast(D)) decls.push_back(VD); diff --git a/test/ForwardMode/Lambdas.C b/test/ForwardMode/Lambdas.C new file mode 100644 index 000000000..f9c203a22 --- /dev/null +++ b/test/ForwardMode/Lambdas.C @@ -0,0 +1,29 @@ +// RUN: %cladclang %s -I%S/../../include -oLambdas.out 2>&1 | %filecheck %s +// RUN: ./Lambdas.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" + +double fn0(double x) { + auto _f = [](double _x) { + return _x*_x; + }; + return _f(x) + 1; +} + +double fn1(double x, double y) { + auto _f = [](double _x, double _y) { + return _x + _y; + }; + return _f(x*x, x+2) + y; +} + +int main() { + auto fn0_dx = clad::differentiate(fn0, 0); + printf("Result is = %.2f\n", fn0_dx.execute(7)); // CHECK-EXEC: Result is = 14.00 + printf("Result is = %.2f\n", fn0_dx.execute(-1)); // CHECK-EXEC: Result is = -2.00 + + auto fn1_dx = clad::differentiate(fn1, 0); + printf("Result is = %.2f\n", fn1_dx.execute(7, 1)); // CHECK-EXEC: Result is = 15.00 + printf("Result is = %.2f\n", fn1_dx.execute(-1, 1)); // CHECK-EXEC: Result is = -1.00 +}