diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 071a0e516..9c19969cb 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -158,6 +158,8 @@ namespace clad { clang::StringLiteral* CreateStringLiteral(clang::ASTContext& C, llvm::StringRef str); + bool isLambdaQType(clang::QualType QT); + /// Returns true if `QT` is Array or Pointer Type, otherwise returns false. bool isArrayOrPointerType(clang::QualType QT); diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index fc3bb0b02..8464f2be6 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -316,6 +316,13 @@ namespace clad { return false; } + bool isLambdaQType(QualType QT) { + if (const RecordType* RT = QT->getAs()) + if (const CXXRecordDecl* RD = dyn_cast(RT->getDecl())) + return RD->isLambda(); + return false; + } + bool IsReferenceOrPointerArg(const Expr* arg) { // The argument is passed by reference if it's passed as an L-value. // However, if arg is a MaterializeTemporaryExpr, then arg is a diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 100a25298..f51730dbc 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1485,7 +1485,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Context, const_cast(Original->getDeclContext()), Original->getLambdaTypeInfo(), Original->getBeginLoc(), CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind(Original), - Original->isGenericLambda(), Original->getLambdaCaptureDefault()); + Original->isGenericLambda(), LCD_ByRef); // Copy the fields if any (FieldDecl) for (auto* Field : Original->fields()) { @@ -1554,7 +1554,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ClonedOpCall->setAccess(OriginalOpCall->getAccess()); Cloned->addDecl(ClonedOpCall); - break; // we might get into an infinite loop otherwise + break; // we get into an infinite loop otherwise } } } @@ -1567,47 +1567,81 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitLambdaExpr(const clang::LambdaExpr* LE) { + // ============== CAP + auto children_iterator_range = LE->children(); - std::vector children_Exp; std::vector children_Exp_dx; for (auto children : children_iterator_range) { - auto children_expr = const_cast(dyn_cast(children)); + // auto children_expr = const_cast(dyn_cast(children)); + auto children_expr = dyn_cast(children); if (children_expr) { - children_Exp.push_back(children_expr); - - children_Exp_dx.push_back(children_expr); - - if(isa(children_expr)) { - std::string constructedTypeName = QualType::getAsString(dyn_cast(children_expr)->getType().split(), PrintingPolicy{ {} }); - // if (!utils::IsKokkosTeamPolicy(constructedTypeName) && !utils::IsKokkosRange(constructedTypeName) && !utils::IsKokkosMember(constructedTypeName)) { - auto children_exprV = Visit(children_expr); - auto children_expr_copy = dyn_cast(Clone(children_expr)); - children_expr_copy->setArg(0, children_exprV.getExpr_dx()); - children_Exp_dx.push_back(children_expr_copy); - // } - } - else if(isa(children_expr)) { - - } - else { - auto children_exprV = Visit(children_expr); - if (children_exprV.getExpr_dx()) { - children_Exp_dx.push_back(children_exprV.getExpr_dx()); - } - } + children_Exp.push_back(dyn_cast(Clone(children_expr))); + + // children_Exp_dx.push_back(children_expr); + + // if(isa(children_expr)) { + // std::string constructedTypeName = QualType::getAsString(dyn_cast(children_expr)->getType().split(), PrintingPolicy{ {} }); + // // if (!utils::IsKokkosTeamPolicy(constructedTypeName) && !utils::IsKokkosRange(constructedTypeName) && !utils::IsKokkosMember(constructedTypeName)) { + // auto children_exprV = Visit(children_expr); + // auto children_expr_copy = dyn_cast(Clone(children_expr)); + // children_expr_copy->setArg(0, children_exprV.getExpr_dx()); + // children_Exp_dx.push_back(children_expr_copy); + // // } + // } + // else if(isa(children_expr)) { + + // } + // else { + // auto children_exprV = Visit(children_expr); + // if (children_exprV.getExpr_dx()) { + // children_Exp_dx.push_back(children_exprV.getExpr_dx()); + // } + // } } } llvm::ArrayRef childrenRef_Exp = clad_compat::makeArrayRef(children_Exp.data(), children_Exp.size()); - llvm::ArrayRef childrenRef_Exp_dx = - clad_compat::makeArrayRef(children_Exp_dx.data(), children_Exp_dx.size()); + llvm::ArrayRef childrenRef_Exp_dx; // = + // clad_compat::makeArrayRef(children_Exp_dx.data(), children_Exp_dx.size()); + // ============== CAP + + // FIXME: ideally, we need to create a reverse_forw lambda and not copy the original one for the forward pass. auto forwardLambdaClass = LE->getLambdaClass(); + clang::LambdaIntroducer cloneIntro; + cloneIntro.Default = forwardLambdaClass->getLambdaCaptureDefault(); + cloneIntro.Range.setBegin(LE->getBeginLoc()); + cloneIntro.Range.setEnd(LE->getEndLoc()); + + clang::AttributeFactory cloneAttrFactory; + const clang::DeclSpec cloneDS(cloneAttrFactory); + clang::Declarator cloneD( + cloneDS, CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam + CLAD_COMPAT_CLANG12_Declarator_LambdaExpr); + clang::sema::LambdaScopeInfo* cloneLSI = m_Sema.PushLambdaScope(); + beginScope(clang::Scope::BlockScope | clang::Scope::FnScope | + clang::Scope::DeclScope); + m_Sema.ActOnStartOfLambdaDefinition( + cloneIntro, cloneD, + clad_compat::Sema_ActOnStartOfLambdaDefinition_ScopeOrDeclSpec( + getCurrentScope(), cloneDS)); + + cloneLSI->CallOperator = forwardLambdaClass->getLambdaCallOperator(); + + m_Sema.buildLambdaScope(cloneLSI, + cloneLSI->CallOperator, + LE->getIntroducerRange(), + LE->getCaptureDefault(), + LE->getCaptureDefaultLoc(), + LE->hasExplicitParameters(), + LE->hasExplicitResultType(), + true); + auto forwardLE = LambdaExpr::Create(m_Context, forwardLambdaClass, LE->getIntroducerRange(), @@ -1620,7 +1654,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, false); clang::LambdaExpr* reverseLE = nullptr; - auto* ClonedCXXRec = diffLambdaCXXRecordDecl(forwardLambdaClass); + CXXRecordDecl* diffedCXXRec = diffLambdaCXXRecordDecl(forwardLambdaClass); + + endScope(); clang::LambdaIntroducer Intro; Intro.Default = forwardLambdaClass->getLambdaCaptureDefault(); @@ -1640,12 +1676,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, clad_compat::Sema_ActOnStartOfLambdaDefinition_ScopeOrDeclSpec( getCurrentScope(), DS)); - for (auto* Method : ClonedCXXRec->methods()) { - if (CXXMethodDecl* cpb = dyn_cast(Method)) { - if (cpb->getOverloadedOperator() == OO_Call) - LSI->CallOperator = cpb; - } - } + LSI->CallOperator = diffedCXXRec->getLambdaCallOperator(); + + // ============== CAP std::vector children_LC_Exp_dx; @@ -1671,27 +1704,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD)); } } - assert(children_Exp_dx.size() == children_LC_Exp_dx.size() && "Wrong number of captures"); + // assert(children_Exp_dx.size() == children_LC_Exp_dx.size() && "Wrong number of captures"); + + llvm::ArrayRef childrenRef_LC_Exp_dx;// = + // clad_compat::makeArrayRef(children_LC_Exp_dx.data(), children_LC_Exp_dx.size()); - llvm::ArrayRef childrenRef_LC_Exp_dx = - clad_compat::makeArrayRef(children_LC_Exp_dx.data(), children_LC_Exp_dx.size()); + // diffedCXXRec->setCaptures(m_Context, childrenRef_LC_Exp_dx); - // Initialize and attach LambdaDefinitionData to mark this as a lambda. - ClonedCXXRec->setCaptures(m_Context, childrenRef_LC_Exp_dx); + // ============== CAP - m_Sema.buildLambdaScope(LSI, - //bodyV.getStmt_dx(), + m_Sema.buildLambdaScope(LSI, LSI->CallOperator, LE->getIntroducerRange(), - LE->getCaptureDefault(), + LCD_ByRef, LE->getCaptureDefaultLoc(), LE->hasExplicitParameters(), LE->hasExplicitResultType(), - LE->isMutable()); + true); reverseLE = LambdaExpr::Create( - m_Context, ClonedCXXRec, LE->getIntroducerRange(), - LE->getCaptureDefault(), LE->getCaptureDefaultLoc(), + m_Context, diffedCXXRec, LE->getIntroducerRange(), + LCD_ByRef, LE->getCaptureDefaultLoc(), LE->hasExplicitParameters(), LE->hasExplicitResultType(), childrenRef_Exp_dx, LE->getEndLoc(), false); @@ -1910,13 +1943,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const auto* PVD = FD->getParamDecl( i - static_cast(isMethodOperatorCall)); StmtDiff argDiff{}; + + bool isArgLambda = clad::utils::isLambdaQType(arg->getType()); // is this argument a lambda? + // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly // modified by the derived callee function. - if (utils::IsReferenceOrPointerArg(arg) || + if (utils::IsReferenceOrPointerArg(arg)|| !m_DiffReq.shouldHaveAdjoint(PVD)) { argDiff = Visit(arg); CallArgDx.push_back(argDiff.getExpr_dx()); + } else if (isArgLambda) { + // TODO: this block is now the same as the one above, but we might want to actually save the differentiated lambda into a declaration first here. This way we wouldn't create new lambdas for the derivative every time the user passes the same lambda as an argument. + argDiff = Visit(arg); + CallArgDx.push_back(argDiff.getExpr_dx()); } else { // Create temporary variables corresponding to derivative of each // argument, so that they can be referred to when arguments is visited. @@ -1925,7 +1965,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // same as the call expression as it is the type used to declare the // _gradX array QualType dArgTy = getNonConstType(arg->getType(), m_Context, m_Sema); - VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy), false, nullptr, clang::VarDecl::InitializationStyle::CInit, true); + VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy), false, nullptr, clang::VarDecl::InitializationStyle::CInit, isLambda); PreCallStmts.push_back(BuildDeclStmt(dArgDecl)); DeclRefExpr* dArgRef = BuildDeclRef(dArgDecl); if (isa(CE)) {