diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index f93daee7d..90dee815c 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -13,6 +13,7 @@ #include "clad/Differentiator/ErrorEstimator.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/ASTLambda.h" #include "clang/AST/Expr.h" #include "clang/AST/TemplateBase.h" #include "clang/Sema/Lookup.h" @@ -1036,9 +1037,12 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { "Differentiation of only direct calls is supported. Ignored"); return StmtDiff(Clone(CE)); } - + const auto* MD = dyn_cast(FD); SourceLocation validLoc{CE->getBeginLoc()}; + // Calls to lambda functions are processed differently + bool isLambda = MD && isLambdaCallOperator(MD); + // If the function is non_differentiable, return zero derivative. if (clad::utils::hasNonDifferentiableAttribute(CE)) { // Calling the function without computing derivatives @@ -1070,12 +1074,23 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // a direct or indirect (operator overload) call to member function. StmtDiff baseDiff; // Add derivative of the implicit `this` pointer to the `diffArgs`. - if (auto MD = dyn_cast(FD)) { + if (isLambda) { + if (const auto* OCE = dyn_cast(CE)) { + QualType ptrType = m_Context.getPointerType(m_Context.getRecordType( + FD->getDeclContext()->getOuterLexicalRecordContext())); + // For now, only lambdas with no captures are supported, so we just pass + // a nullptr instead of the diff object. + baseDiff = + StmtDiff(Clone(OCE->getArg(0)), + new (m_Context) CXXNullPtrLiteralExpr(ptrType, validLoc)); + diffArgs.push_back(baseDiff.getExpr_dx()); + } + } else if (MD) { // isLambda == false if (MD->isInstance()) { const Expr* baseOriginalE = nullptr; - if (auto MCE = dyn_cast(CE)) + if (const auto* MCE = dyn_cast(CE)) baseOriginalE = MCE->getImplicitObjectArgument(); - else if (auto OCE = dyn_cast(CE)) + else if (const auto* OCE = dyn_cast(CE)) baseOriginalE = OCE->getArg(0); baseDiff = Visit(baseOriginalE); Expr* baseDerivative = baseDiff.getExpr_dx(); @@ -1137,34 +1152,36 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { auto customDerivativeArgs = pushforwardFnArgs; - if (baseDiff.getExpr()) { - Expr* baseE = baseDiff.getExpr(); + if (Expr* baseE = baseDiff.getExpr()) { if (!baseE->getType()->isPointerType()) baseE = BuildOp(UnaryOperatorKind::UO_AddrOf, baseE); customDerivativeArgs.insert(customDerivativeArgs.begin(), baseE); } // Try to find a user-defined overloaded derivative. + Expr* callDiff = nullptr; std::string customPushforward = clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix(); - Expr* callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( + callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( customPushforward, customDerivativeArgs, getCurrentScope(), const_cast(FD->getDeclContext())); - // Check if it is a recursive call. - if (!callDiff && (FD == m_DiffReq.Function) && - m_Mode == GetPushForwardMode()) { - // The differentiated function is called recursively. - Expr* derivativeRef = - m_Sema - .BuildDeclarationNameExpr(CXXScopeSpec(), - m_Derivative->getNameInfo(), m_Derivative) - .get(); - callDiff = - m_Sema - .ActOnCallExpr(m_Sema.getScopeForContext(m_Sema.CurContext), - derivativeRef, validLoc, pushforwardFnArgs, validLoc) - .get(); + if (!isLambda) { + // Check if it is a recursive call. + if (!callDiff && (FD == m_DiffReq.Function) && + m_Mode == GetPushForwardMode()) { + // The differentiated function is called recursively. + Expr* derivativeRef = + m_Sema + .BuildDeclarationNameExpr( + CXXScopeSpec(), m_Derivative->getNameInfo(), m_Derivative) + .get(); + callDiff = m_Sema + .ActOnCallExpr( + m_Sema.getScopeForContext(m_Sema.CurContext), + derivativeRef, validLoc, pushforwardFnArgs, validLoc) + .get(); + } } // If all arguments are constant literals, then this does not contribute to @@ -1493,7 +1510,12 @@ StmtDiff BaseForwardModeVisitor::VisitDeclStmt(const DeclStmt* DS) { if (QT->isPointerType()) QT = QT->getPointeeType(); auto* typeDecl = QT->getAsCXXRecordDecl(); - if (typeDecl && clad::utils::hasNonDifferentiableAttribute(typeDecl)) { + // For lambda functions, we should also simply copy the original lambda. The + // differentiation of lambdas is happening in the `VisitCallExpr`. For now, + // only the declarations with lambda expressions without captures are + // supported. + if (typeDecl && (clad::utils::hasNonDifferentiableAttribute(typeDecl) || + typeDecl->isLambda())) { for (auto* D : DS->decls()) { if (auto* VD = dyn_cast(D)) decls.push_back(VD); diff --git a/test/ForwardMode/Lambdas.C b/test/ForwardMode/Lambdas.C new file mode 100644 index 000000000..f9c203a22 --- /dev/null +++ b/test/ForwardMode/Lambdas.C @@ -0,0 +1,29 @@ +// RUN: %cladclang %s -I%S/../../include -oLambdas.out 2>&1 | %filecheck %s +// RUN: ./Lambdas.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" + +double fn0(double x) { + auto _f = [](double _x) { + return _x*_x; + }; + return _f(x) + 1; +} + +double fn1(double x, double y) { + auto _f = [](double _x, double _y) { + return _x + _y; + }; + return _f(x*x, x+2) + y; +} + +int main() { + auto fn0_dx = clad::differentiate(fn0, 0); + printf("Result is = %.2f\n", fn0_dx.execute(7)); // CHECK-EXEC: Result is = 14.00 + printf("Result is = %.2f\n", fn0_dx.execute(-1)); // CHECK-EXEC: Result is = -2.00 + + auto fn1_dx = clad::differentiate(fn1, 0); + printf("Result is = %.2f\n", fn1_dx.execute(7, 1)); // CHECK-EXEC: Result is = 15.00 + printf("Result is = %.2f\n", fn1_dx.execute(-1, 1)); // CHECK-EXEC: Result is = -1.00 +}