diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6f1df389a..6394ee9dd 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -17,6 +17,7 @@ #include "clad/Differentiator/StmtClone.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/ASTLambda.h" #include "clang/AST/Expr.h" #include "clang/AST/Stmt.h" #include "clang/AST/TemplateBase.h" @@ -1596,13 +1597,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff baseDiff; // If it has more args or f_darg0 was not found, we look for its pullback // function. + const auto* MD = dyn_cast(FD); if (!OverloadedDerivedFn) { size_t idx = 0; /// Add base derivative expression in the derived call output args list if /// `CE` is a call to an instance member function. - if (const auto* MD = dyn_cast(FD)) { - if (MD->isInstance()) { + if (MD) { + if (isLambdaCallOperator(MD)) { + QualType ptrType = m_Context.getPointerType(m_Context.getRecordType( + FD->getDeclContext()->getOuterLexicalRecordContext())); + baseDiff = + StmtDiff(Clone(dyn_cast(CE)->getArg(0)), + new (m_Context) CXXNullPtrLiteralExpr(ptrType, Loc)); + } else if (MD->isInstance()) { const Expr* baseOriginalE = nullptr; if (const auto* MCE = dyn_cast(CE)) baseOriginalE = MCE->getImplicitObjectArgument(); @@ -1700,7 +1708,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.EnableTBRAnalysis = m_DiffReq.EnableTBRAnalysis; bool isaMethod = isa(FD); for (size_t i = 0, e = FD->getNumParams(); i < e; ++i) - if (DerivedCallOutputArgs[i + isaMethod]) + if (MD && isLambdaCallOperator(MD)) { + if (const auto* paramDecl = FD->getParamDecl(i)) + pullbackRequest.DVI.push_back(paramDecl); + } else if (DerivedCallOutputArgs[i + isaMethod]) pullbackRequest.DVI.push_back(FD->getParamDecl(i)); FunctionDecl* pullbackFD = nullptr; @@ -2735,6 +2746,31 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bool promoteToFnScope = !getCurrentScope()->isFunctionScope() && m_DiffReq.Mode != DiffMode::reverse_mode_forward_pass; + + // If the DeclStmt is not empty, check the first declaration in case it is a + // lambda function. This case it is treated separately for now and we don't + // create a variable for its derivative. + bool isLambda = false; + const auto* declsBegin = DS->decls().begin(); + if (declsBegin != DS->decls().end() && isa(*declsBegin)) { + auto* VD = dyn_cast(*declsBegin); + QualType QT = VD->getType(); + if (!QT->isPointerType()) { + auto* typeDecl = QT->getAsCXXRecordDecl(); + // We should also simply copy the original lambda. The differentiation + // of lambdas is happening in the `VisitCallExpr`. For now, only the + // declarations with lambda expressions without captures are supported. + isLambda = typeDecl && typeDecl->isLambda(); + if (isLambda) { + for (auto* D : DS->decls()) + if (auto* VD = dyn_cast(D)) + decls.push_back(VD); + Stmt* DSClone = BuildDeclStmt(decls); + return StmtDiff(DSClone, nullptr); + } + } + } + // For each variable declaration v, create another declaration _d_v to // store derivatives for potential reassignments. E.g. // double y = x; @@ -2742,7 +2778,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // double _d_y = _d_x; double y = x; for (auto* D : DS->decls()) { if (auto* VD = dyn_cast(D)) { - DeclDiff VDDiff = DifferentiateVarDecl(VD); + DeclDiff VDDiff; + if (!isLambda) + VDDiff = DifferentiateVarDecl(VD); // Check if decl's name is the same as before. The name may be changed // if decl name collides with something in the derivative body. @@ -2762,8 +2800,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // double _d_y = x; // copied from original function, collides with // _d_y // } - if (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || - VD->getType() != VDDiff.getDecl()->getType()) + if (!isLambda && + (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || + VD->getType() != VDDiff.getDecl()->getType())) m_DeclReplacements[VD] = VDDiff.getDecl(); // Here, we move the declaration to the function global scope. diff --git a/test/Gradient/Lambdas.C b/test/Gradient/Lambdas.C new file mode 100644 index 000000000..98b5e5536 --- /dev/null +++ b/test/Gradient/Lambdas.C @@ -0,0 +1,79 @@ +// RUN: %cladclang %s -I%S/../../include -oLambdas.out 2>&1 | %filecheck %s +// RUN: ./Lambdas.out | %filecheck_exec %s +// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oLambdas.out +// RUN: ./Lambdas.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" + +double f1(double i, double j) { + auto _f = [] (double t) { + return t*t + 1.0; + }; + return i + _f(j); +} + +// CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const; +// CHECK-NEXT: void f1_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK-NEXT: auto _f = []{{ ?}}(double t) { +// CHECK-NEXT: return t * t + 1.; +// CHECK-NEXT: }{{;?}} +// CHECK: { +// CHECK-NEXT: *_d_i += 1; +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: _f.operator_call_pullback(j, 1, &_r0); +// CHECK-NEXT: *_d_j += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + +double f2(double i, double j) { + auto _f = [] (double t, double k) { + return t + k; + }; + double x = _f(i + j, i); + return x; +} + +// CHECK: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const; +// CHECK-NEXT: void f2_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK-NEXT: double _d_x = 0; +// CHECK-NEXT: auto _f = []{{ ?}}(double t, double k) { +// CHECK-NEXT: return t + k; +// CHECK-NEXT: }{{;?}} +// CHECK: double x = operator()(i + j, i); +// CHECK-NEXT: _d_x += 1; +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 0; +// CHECK-NEXT: double _r1 = 0; +// CHECK-NEXT: _f.operator_call_pullback(i + j, i, _d_x, &_r0, &_r1); +// CHECK-NEXT: *_d_i += _r0; +// CHECK-NEXT: *_d_j += _r0; +// CHECK-NEXT: *_d_i += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } + + +int main() { + auto df1 = clad::gradient(f1); + double di = 0, dj = 0; + df1.execute(3, 4, &di, &dj); + printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 1.00 8.00 + + auto df2 = clad::gradient(f2); + di = 0, dj = 0; + df2.execute(3, 4, &di, &dj); + printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 2.00 1.00 +} + +// CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const { +// CHECK-NEXT: { +// CHECK-NEXT: *_d_t += _d_y * t; +// CHECK-NEXT: *_d_t += t * _d_y; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const { +// CHECK-NEXT: { +// CHECK-NEXT: *_d_t += _d_y; +// CHECK-NEXT: *_d_k += _d_y; +// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file