Skip to content

Commit

Permalink
half-progress with actual captures and fixing crashes
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed Nov 9, 2024
1 parent ad6809b commit 5cb158c
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 53 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ namespace clad {
clang::StringLiteral* CreateStringLiteral(clang::ASTContext& C,
llvm::StringRef str);

bool isLambdaQType(clang::QualType QT);

/// Returns true if `QT` is Array or Pointer Type, otherwise returns false.
bool isArrayOrPointerType(clang::QualType QT);

Expand Down
7 changes: 7 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,13 @@ namespace clad {
return false;
}

bool isLambdaQType(QualType QT) {
if (const RecordType* RT = QT->getAs<RecordType>())
if (const CXXRecordDecl* RD = dyn_cast<CXXRecordDecl>(RT->getDecl()))
return RD->isLambda();
return false;
}

bool IsReferenceOrPointerArg(const Expr* arg) {
// The argument is passed by reference if it's passed as an L-value.
// However, if arg is a MaterializeTemporaryExpr, then arg is a
Expand Down
143 changes: 90 additions & 53 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1702,7 +1702,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_Context, const_cast<DeclContext*>(Original->getDeclContext()),
Original->getLambdaTypeInfo(), Original->getBeginLoc(),
CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind(Original),
Original->isGenericLambda(), Original->getLambdaCaptureDefault());
Original->isGenericLambda(), LCD_ByRef);

// Copy the fields if any (FieldDecl)
for (auto* Field : Original->fields()) {
Expand Down Expand Up @@ -1771,7 +1771,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ClonedOpCall->setAccess(OriginalOpCall->getAccess());
Cloned->addDecl(ClonedOpCall);

break; // we might get into an infinite loop otherwise
break; // we get into an infinite loop otherwise
}
}
}
Expand All @@ -1784,47 +1784,81 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitLambdaExpr(const clang::LambdaExpr* LE) {
// ============== CAP

auto children_iterator_range = LE->children();

std::vector<Expr *> children_Exp;
std::vector<Expr *> children_Exp_dx;

for (auto children : children_iterator_range) {
auto children_expr = const_cast<clang::Expr*>(dyn_cast<clang::Expr>(children));
// auto children_expr = const_cast<clang::Expr*>(dyn_cast<clang::Expr>(children));
auto children_expr = dyn_cast<clang::Expr>(children);
if (children_expr) {
children_Exp.push_back(children_expr);

children_Exp_dx.push_back(children_expr);

if(isa<CXXConstructExpr>(children_expr)) {
std::string constructedTypeName = QualType::getAsString(dyn_cast<CXXConstructExpr>(children_expr)->getType().split(), PrintingPolicy{ {} });
// if (!utils::IsKokkosTeamPolicy(constructedTypeName) && !utils::IsKokkosRange(constructedTypeName) && !utils::IsKokkosMember(constructedTypeName)) {
auto children_exprV = Visit(children_expr);
auto children_expr_copy = dyn_cast<CXXConstructExpr>(Clone(children_expr));
children_expr_copy->setArg(0, children_exprV.getExpr_dx());
children_Exp_dx.push_back(children_expr_copy);
// }
}
else if(isa<DeclRefExpr>(children_expr)) {

}
else {
auto children_exprV = Visit(children_expr);
if (children_exprV.getExpr_dx()) {
children_Exp_dx.push_back(children_exprV.getExpr_dx());
}
}
children_Exp.push_back(dyn_cast<clang::Expr>(Clone(children_expr)));

// children_Exp_dx.push_back(children_expr);

// if(isa<CXXConstructExpr>(children_expr)) {
// std::string constructedTypeName = QualType::getAsString(dyn_cast<CXXConstructExpr>(children_expr)->getType().split(), PrintingPolicy{ {} });
// // if (!utils::IsKokkosTeamPolicy(constructedTypeName) && !utils::IsKokkosRange(constructedTypeName) && !utils::IsKokkosMember(constructedTypeName)) {
// auto children_exprV = Visit(children_expr);
// auto children_expr_copy = dyn_cast<CXXConstructExpr>(Clone(children_expr));
// children_expr_copy->setArg(0, children_exprV.getExpr_dx());
// children_Exp_dx.push_back(children_expr_copy);
// // }
// }
// else if(isa<DeclRefExpr>(children_expr)) {

// }
// else {
// auto children_exprV = Visit(children_expr);
// if (children_exprV.getExpr_dx()) {
// children_Exp_dx.push_back(children_exprV.getExpr_dx());
// }
// }
}
}

llvm::ArrayRef<Expr*> childrenRef_Exp =
clad_compat::makeArrayRef(children_Exp.data(), children_Exp.size());

llvm::ArrayRef<Expr*> childrenRef_Exp_dx =
clad_compat::makeArrayRef(children_Exp_dx.data(), children_Exp_dx.size());
llvm::ArrayRef<Expr*> childrenRef_Exp_dx; // =
// clad_compat::makeArrayRef(children_Exp_dx.data(), children_Exp_dx.size());

// ============== CAP

// FIXME: ideally, we need to create a reverse_forw lambda and not copy the original one for the forward pass.
auto forwardLambdaClass = LE->getLambdaClass();

clang::LambdaIntroducer cloneIntro;
cloneIntro.Default = forwardLambdaClass->getLambdaCaptureDefault();
cloneIntro.Range.setBegin(LE->getBeginLoc());
cloneIntro.Range.setEnd(LE->getEndLoc());

clang::AttributeFactory cloneAttrFactory;
const clang::DeclSpec cloneDS(cloneAttrFactory);
clang::Declarator cloneD(
cloneDS, CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam
CLAD_COMPAT_CLANG12_Declarator_LambdaExpr);
clang::sema::LambdaScopeInfo* cloneLSI = m_Sema.PushLambdaScope();
beginScope(clang::Scope::BlockScope | clang::Scope::FnScope |
clang::Scope::DeclScope);
m_Sema.ActOnStartOfLambdaDefinition(
cloneIntro, cloneD,
clad_compat::Sema_ActOnStartOfLambdaDefinition_ScopeOrDeclSpec(
getCurrentScope(), cloneDS));

cloneLSI->CallOperator = forwardLambdaClass->getLambdaCallOperator();

m_Sema.buildLambdaScope(cloneLSI,
cloneLSI->CallOperator,
LE->getIntroducerRange(),
LE->getCaptureDefault(),
LE->getCaptureDefaultLoc(),
LE->hasExplicitParameters(),
LE->hasExplicitResultType(),
true);

auto forwardLE = LambdaExpr::Create(m_Context,
forwardLambdaClass,
LE->getIntroducerRange(),
Expand All @@ -1837,7 +1871,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
false);

clang::LambdaExpr* reverseLE = nullptr;
auto* ClonedCXXRec = diffLambdaCXXRecordDecl(forwardLambdaClass);
CXXRecordDecl* diffedCXXRec = diffLambdaCXXRecordDecl(forwardLambdaClass);

endScope();

clang::LambdaIntroducer Intro;
Intro.Default = forwardLambdaClass->getLambdaCaptureDefault();
Expand All @@ -1857,12 +1893,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
clad_compat::Sema_ActOnStartOfLambdaDefinition_ScopeOrDeclSpec(
getCurrentScope(), DS));

for (auto* Method : ClonedCXXRec->methods()) {
if (CXXMethodDecl* cpb = dyn_cast<CXXMethodDecl>(Method)) {
if (cpb->getOverloadedOperator() == OO_Call)
LSI->CallOperator = cpb;
}
}
LSI->CallOperator = diffedCXXRec->getLambdaCallOperator();

// ============== CAP

std::vector<LambdaCapture> children_LC_Exp_dx;

Expand All @@ -1888,27 +1921,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
}
}
assert(children_Exp_dx.size() == children_LC_Exp_dx.size() && "Wrong number of captures");
// assert(children_Exp_dx.size() == children_LC_Exp_dx.size() && "Wrong number of captures");

llvm::ArrayRef<LambdaCapture> childrenRef_LC_Exp_dx;// =
// clad_compat::makeArrayRef(children_LC_Exp_dx.data(), children_LC_Exp_dx.size());

llvm::ArrayRef<LambdaCapture> childrenRef_LC_Exp_dx =
clad_compat::makeArrayRef(children_LC_Exp_dx.data(), children_LC_Exp_dx.size());
// diffedCXXRec->setCaptures(m_Context, childrenRef_LC_Exp_dx);

// Initialize and attach LambdaDefinitionData to mark this as a lambda.
ClonedCXXRec->setCaptures(m_Context, childrenRef_LC_Exp_dx);
// ============== CAP

m_Sema.buildLambdaScope(LSI,
//bodyV.getStmt_dx(),
m_Sema.buildLambdaScope(LSI,
LSI->CallOperator,
LE->getIntroducerRange(),
LE->getCaptureDefault(),
LCD_ByRef,
LE->getCaptureDefaultLoc(),
LE->hasExplicitParameters(),
LE->hasExplicitResultType(),
LE->isMutable());
true);

reverseLE = LambdaExpr::Create(
m_Context, ClonedCXXRec, LE->getIntroducerRange(),
LE->getCaptureDefault(), LE->getCaptureDefaultLoc(),
m_Context, diffedCXXRec, LE->getIntroducerRange(),
LCD_ByRef, LE->getCaptureDefaultLoc(),
LE->hasExplicitParameters(), LE->hasExplicitResultType(),
childrenRef_Exp_dx, LE->getEndLoc(), false);

Expand Down Expand Up @@ -2146,13 +2179,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
const auto* PVD = FD->getParamDecl(
i - static_cast<unsigned long>(isMethodOperatorCall));
StmtDiff argDiff{};

bool isArgLambda = clad::utils::isLambdaQType(arg->getType()); // is this argument a lambda?

// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
// modified by the derived callee function.
if (utils::IsReferenceOrPointerArg(arg) ||
if (utils::IsReferenceOrPointerArg(arg)||
!m_DiffReq.shouldHaveAdjoint(PVD)) {
argDiff = Visit(arg);
CallArgDx.push_back(argDiff.getExpr_dx());
} else if (isArgLambda) {
// TODO: this block is now the same as the one above, but we might want to actually save the differentiated lambda into a declaration first here. This way we wouldn't create new lambdas for the derivative every time the user passes the same lambda as an argument.
argDiff = Visit(arg);
CallArgDx.push_back(argDiff.getExpr_dx());
} else {
// Create temporary variables corresponding to derivative of each
// argument, so that they can be referred to when arguments is visited.
Expand All @@ -2161,7 +2201,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// same as the call expression as it is the type used to declare the
// _gradX array
QualType dArgTy = getNonConstType(arg->getType(), m_Context, m_Sema);
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy), false, nullptr, clang::VarDecl::InitializationStyle::CInit, true);
VarDecl* dArgDecl = BuildVarDecl(dArgTy, "_r", getZeroInit(dArgTy), false, nullptr, clang::VarDecl::InitializationStyle::CInit, isLambda);
PreCallStmts.push_back(BuildDeclStmt(dArgDecl));
CallArgDx.push_back(BuildDeclRef(dArgDecl));
// Visit using uninitialized reference.
Expand Down Expand Up @@ -2298,7 +2338,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::errs() << "i: " << idx << '\n';
QualType paramTy = FD->getParamDecl(idx)->getType();
if (!argDerivative || utils::isArrayOrPointerType(paramTy) ||
isCladArrayType(argDerivative->getType()))
isCladArrayType(argDerivative->getType()) || clad::utils::isLambdaQType(paramTy))
gradArgExpr = argDerivative;
else
gradArgExpr =
Expand Down Expand Up @@ -2393,10 +2433,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackRequest.EnableVariedAnalysis = m_DiffReq.EnableVariedAnalysis;
bool isaMethod = isa<CXXMethodDecl>(FD);
for (size_t i = 0, e = FD->getNumParams(); i < e; ++i)
if (MD && isLambdaCallOperator(MD)) {
if (const auto* paramDecl = FD->getParamDecl(i))
pullbackRequest.DVI.push_back(paramDecl);
} else if (DerivedCallOutputArgs[i + isaMethod])
if (DerivedCallOutputArgs[i + isaMethod])
pullbackRequest.DVI.push_back(FD->getParamDecl(i));

FunctionDecl* pullbackFD = nullptr;
Expand Down

0 comments on commit 5cb158c

Please sign in to comment.