From 22b2590c0df1b322af81c100857dd7dc935bca4c Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Fri, 21 Jun 2024 23:01:21 +0200 Subject: [PATCH] Add support for simple lambda expressions in reverse mode This commit provides support for primitive lambda expressions with no captures in reverse mode in the same way they are currently supported in the forward mode (#937). That is, the lambda expressions are not visited yet. Instead, the lambda functions are treated as a special case of functors. Fixes: #789 --- lib/Differentiator/ReverseModeVisitor.cpp | 51 +++++++++++++-- test/Gradient/Lambdas.C | 79 +++++++++++++++++++++++ 2 files changed, 124 insertions(+), 6 deletions(-) create mode 100644 test/Gradient/Lambdas.C diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6f1df389a..6394ee9dd 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -17,6 +17,7 @@ #include "clad/Differentiator/StmtClone.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/ASTLambda.h" #include "clang/AST/Expr.h" #include "clang/AST/Stmt.h" #include "clang/AST/TemplateBase.h" @@ -1596,13 +1597,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff baseDiff; // If it has more args or f_darg0 was not found, we look for its pullback // function. + const auto* MD = dyn_cast(FD); if (!OverloadedDerivedFn) { size_t idx = 0; /// Add base derivative expression in the derived call output args list if /// `CE` is a call to an instance member function. - if (const auto* MD = dyn_cast(FD)) { - if (MD->isInstance()) { + if (MD) { + if (isLambdaCallOperator(MD)) { + QualType ptrType = m_Context.getPointerType(m_Context.getRecordType( + FD->getDeclContext()->getOuterLexicalRecordContext())); + baseDiff = + StmtDiff(Clone(dyn_cast(CE)->getArg(0)), + new (m_Context) CXXNullPtrLiteralExpr(ptrType, Loc)); + } else if (MD->isInstance()) { const Expr* baseOriginalE = nullptr; if (const auto* MCE = dyn_cast(CE)) baseOriginalE = MCE->getImplicitObjectArgument(); @@ -1700,7 +1708,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) - if (DerivedCallOutputArgs[i + isaMethod]) + if (MD && isLambdaCallOperator(MD)) { + if (const auto* paramDecl = FD->getParamDecl(i)) + pullbackRequest.DVI.push_back(paramDecl); + } else if (DerivedCallOutputArgs[i + isaMethod]) pullbackRequest.DVI.push_back(FD->getParamDecl(i)); FunctionDecl* pullbackFD = nullptr; @@ -2735,6 +2746,31 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bool promoteToFnScope = !getCurrentScope()->isFunctionScope() && m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass; + + // If the DeclStmt is not empty, check the first declaration in case it is a + // lambda function. This case it is treated separately for now and we don't + // create a variable for its derivative. + bool isLambda = false; + const auto* declsBegin = DS->decls().begin(); + if (declsBegin != DS->decls().end() && isa(*declsBegin)) { + auto* VD = dyn_cast(*declsBegin); + QualType QT = VD->getType(); + if (!QT->isPointerType()) { + auto* typeDecl = QT->getAsCXXRecordDecl(); + // We should also simply copy the original lambda. The differentiation + // of lambdas is happening in the `VisitCallExpr`. For now, only the + // declarations with lambda expressions without captures are supported. + isLambda = typeDecl && typeDecl->isLambda(); + if (isLambda) { + for (auto* D : DS->decls()) + if (auto* VD = dyn_cast(D)) + decls.push_back(VD); + Stmt* DSClone = BuildDeclStmt(decls); + return StmtDiff(DSClone, nullptr); + } + } + } + // For each variable declaration v, create another declaration _d_v to // store derivatives for potential reassignments. E.g. // double y = x; @@ -2742,7 +2778,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // double _d_y = _d_x; double y = x; for (auto* D : DS->decls()) { if (auto* VD = dyn_cast(D)) { - DeclDiff VDDiff = DifferentiateVarDecl(VD); + DeclDiff VDDiff; + if (!isLambda) + VDDiff = DifferentiateVarDecl(VD); // Check if decl's name is the same as before. The name may be changed // if decl name collides with something in the derivative body. @@ -2762,8 +2800,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // double _d_y = x; // copied from original function, collides with // _d_y // } - if (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || - VD->getType() != VDDiff.getDecl()->getType()) + if (!isLambda && + (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || + VD->getType() != VDDiff.getDecl()->getType())) m_DeclReplacements[VD] = VDDiff.getDecl(); // Here, we move the declaration to the function global scope. diff --git a/test/Gradient/Lambdas.C b/test/Gradient/Lambdas.C new file mode 100644 index 000000000..98b5e5536 --- /dev/null +++ b/test/Gradient/Lambdas.C @@ -0,0 +1,79 @@ +// RUN: %cladclang %s -I%S/../../include -oLambdas.out 2>&1 | %filecheck %s +// RUN: ./Lambdas.out | %filecheck_exec %s +// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oLambdas.out +// RUN: ./Lambdas.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" + +double f1(double i, double j) { + auto _f = [] (double t) { + return t*t + 1.0; + }; + return i + _f(j); +} + +// CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const; +// CHECK-NEXT: void f1_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK-NEXT: auto _f = []{{ ?}}(double t) { +// CHECK-NEXT: return t * t + 1.; +// CHECK-NEXT: }{{;?}} +// CHECK: { +// CHECK-NEXT: *_d_i += 1; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: _f.operator_call_pullback(j, 1, &_r0); +// CHECK-NEXT: *_d_j += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double f2(double i, double j) { + auto _f = [] (double t, double k) { + return t + k; + }; + double x = _f(i + j, i); + return x; +} + +// CHECK: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const; +// CHECK-NEXT: void f2_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK-NEXT: double _d_x = 0; +// CHECK-NEXT: auto _f = []{{ ?}}(double t, double k) { +// CHECK-NEXT: return t + k; +// CHECK-NEXT: }{{;?}} +// CHECK: double x = operator()(i + j, i); +// CHECK-NEXT: _d_x += 1; +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: _f.operator_call_pullback(i + j, i, _d_x, &_r0, &_r1); +// CHECK-NEXT: *_d_i += _r0; +// CHECK-NEXT: *_d_j += _r0; +// CHECK-NEXT: *_d_i += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } + + +int main() { + auto df1 = clad::gradient(f1); + double di = 0, dj = 0; + df1.execute(3, 4, &di, &dj); + printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 1.00 8.00 + + auto df2 = clad::gradient(f2); + di = 0, dj = 0; + df2.execute(3, 4, &di, &dj); + printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 2.00 1.00 +} + +// CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const { +// CHECK-NEXT: { +// CHECK-NEXT: *_d_t += _d_y * t; +// CHECK-NEXT: *_d_t += t * _d_y; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const { +// CHECK-NEXT: { +// CHECK-NEXT: *_d_t += _d_y; +// CHECK-NEXT: *_d_k += _d_y; +// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file