Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lambda support in the reverse mode #1126

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

gojakuch
Copy link
Collaborator

@gojakuch gojakuch commented Oct 29, 2024

Potentially fixes: #1054

@gojakuch gojakuch force-pushed the lambda-support-reverse branch from fac23c9 to 1ca617c Compare November 1, 2024 18:47
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

There were too many comments to post at once. Showing the first 10 out of 21. Check the log or trigger a new build to see more.

@@ -403,6 +403,16 @@ getConstantArrayType(const ASTContext& Ctx, QualType EltTy,
#define CLAD_COMPAT_CLANG15_Declarator_DeclarationAttrs_ExtraParam clang::ParsedAttributesView::none(),
#endif

#if CLANG_VERSION_MAJOR > 12
#define CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind( \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: function-like macro 'CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind' used; consider a 'constexpr' template function [cppcoreguidelines-macro-usage]

#define CLAD_COMPAT_CXXRecordDecl_CreateLambda_DependencyKind(                 \
        ^

ReverseModeVisitor::diffLambdaCXXRecordDecl(const CXXRecordDecl* Original) {
// Create a new Lambda CXXRecordDecl that is going to represent a pullback
CXXRecordDecl* Cloned = CXXRecordDecl::CreateLambda(
m_Context, const_cast<DeclContext*>(Original->getDeclContext()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use const_cast [cppcoreguidelines-pro-type-const-cast]

        m_Context, const_cast<DeclContext*>(Original->getDeclContext()),
                   ^


// Create operator() as a pullback
for (auto* Method : Original->methods()) {
if (CXXMethodDecl* OriginalOpCall = dyn_cast<CXXMethodDecl>(Method)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: use auto when initializing with a template cast to avoid duplicating the type name [modernize-use-auto]

Suggested change
if (CXXMethodDecl* OriginalOpCall = dyn_cast<CXXMethodDecl>(Method)) {
if (auto* OriginalOpCall = dyn_cast<CXXMethodDecl>(Method)) {

std::vector<Expr *> children_Exp;
std::vector<Expr *> children_Exp_dx;

for (auto children : children_iterator_range) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto children' can be declared as 'const auto *children' [llvm-qualified-auto]

Suggested change
for (auto children : children_iterator_range) {
for (const auto *children : children_iterator_range) {


for (auto children : children_iterator_range) {
// auto children_expr = const_cast<clang::Expr*>(dyn_cast<clang::Expr>(children));
auto children_expr = dyn_cast<clang::Expr>(children);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto children_expr' can be declared as 'const auto *children_expr' [llvm-qualified-auto]

Suggested change
auto children_expr = dyn_cast<clang::Expr>(children);
const auto *children_expr = dyn_cast<clang::Expr>(children);

// ============== CAP

// FIXME: ideally, we need to create a reverse_forw lambda and not copy the original one for the forward pass.
auto forwardLambdaClass = LE->getLambdaClass();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto forwardLambdaClass' can be declared as 'auto *forwardLambdaClass' [llvm-qualified-auto]

Suggested change
auto forwardLambdaClass = LE->getLambdaClass();
auto *forwardLambdaClass = LE->getLambdaClass();

LE->getCaptureDefaultLoc(),
LE->hasExplicitParameters(),
LE->hasExplicitResultType(),
true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: too many arguments to function call, expected 7, have 8 [clang-diagnostic-error]

                            true);
                            ^
Additional context

llvm/include/clang/Sema/Sema.h:7176: 'buildLambdaScope' declared here

  void buildLambdaScope(sema::LambdaScopeInfo *LSI, CXXMethodDecl *CallOperator,
       ^

LE->hasExplicitResultType(),
true);

auto forwardLE = LambdaExpr::Create(m_Context,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto forwardLE' can be declared as 'auto *forwardLE' [llvm-qualified-auto]

Suggested change
auto forwardLE = LambdaExpr::Create(m_Context,
auto *forwardLE = LambdaExpr::Create(m_Context,


std::vector<LambdaCapture> children_LC_Exp_dx;

for (auto children_expr : children_Exp_dx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto children_expr' can be declared as 'auto *children_expr' [llvm-qualified-auto]

Suggested change
for (auto children_expr : children_Exp_dx) {
for (auto *children_expr : children_Exp_dx) {

for (auto children_expr : children_Exp_dx) {
if(isa<CXXConstructExpr>(children_expr)) {

auto tmp = dyn_cast<CXXConstructExpr>(children_expr)->getArg(0)->IgnoreImpCasts();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto tmp' can be declared as 'auto *tmp' [llvm-qualified-auto]

Suggested change
auto tmp = dyn_cast<CXXConstructExpr>(children_expr)->getArg(0)->IgnoreImpCasts();
auto *tmp = dyn_cast<CXXConstructExpr>(children_expr)->getArg(0)->IgnoreImpCasts();

@gojakuch gojakuch force-pushed the lambda-support-reverse branch from 1ca617c to 5cb158c Compare November 9, 2024 15:45
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clang-tidy made some suggestions

There were too many comments to post at once. Showing the first 10 out of 11. Check the log or trigger a new build to see more.

auto tmp = dyn_cast<CXXConstructExpr>(children_expr)->getArg(0)->IgnoreImpCasts();

if (isa<DeclRefExpr>(tmp)) {
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(tmp)->getDecl());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto VD' can be declared as 'auto *VD' [llvm-qualified-auto]

Suggested change
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(tmp)->getDecl());
auto *VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(tmp)->getDecl());


if (isa<DeclRefExpr>(tmp)) {
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(tmp)->getDecl());
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: use emplace_back instead of push_back [modernize-use-emplace]

Suggested change
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
children_LC_Exp_dx.emplace_back(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD);

children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
}
if(isa<ParenExpr>(tmp)) {
auto PE = dyn_cast<ParenExpr>(tmp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto PE' can be declared as 'auto *PE' [llvm-qualified-auto]

Suggested change
auto PE = dyn_cast<ParenExpr>(tmp);
auto *PE = dyn_cast<ParenExpr>(tmp);

}
if(isa<ParenExpr>(tmp)) {
auto PE = dyn_cast<ParenExpr>(tmp);
auto OCE = dyn_cast<CXXOperatorCallExpr>(PE->getSubExpr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto OCE' can be declared as 'auto *OCE' [llvm-qualified-auto]

Suggested change
auto OCE = dyn_cast<CXXOperatorCallExpr>(PE->getSubExpr());
auto *OCE = dyn_cast<CXXOperatorCallExpr>(PE->getSubExpr());

auto PE = dyn_cast<ParenExpr>(tmp);
auto OCE = dyn_cast<CXXOperatorCallExpr>(PE->getSubExpr());

auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(OCE->getArg(0))->getDecl());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto VD' can be declared as 'auto *VD' [llvm-qualified-auto]

Suggested change
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(OCE->getArg(0))->getDecl());
auto *VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(OCE->getArg(0))->getDecl());

auto OCE = dyn_cast<CXXOperatorCallExpr>(PE->getSubExpr());

auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(OCE->getArg(0))->getDecl());
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: use emplace_back instead of push_back [modernize-use-emplace]

Suggested change
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
children_LC_Exp_dx.emplace_back(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD);

}
}
if (isa<DeclRefExpr>(children_expr)) {
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(children_expr)->getDecl());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto VD' can be declared as 'auto *VD' [llvm-qualified-auto]

Suggested change
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(children_expr)->getDecl());
auto *VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(children_expr)->getDecl());

}
if (isa<DeclRefExpr>(children_expr)) {
auto VD = dyn_cast<VarDecl>(dyn_cast<DeclRefExpr>(children_expr)->getDecl());
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: use emplace_back instead of push_back [modernize-use-emplace]

Suggested change
children_LC_Exp_dx.push_back(LambdaCapture(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD));
children_LC_Exp_dx.emplace_back(SourceLocation(), true, LambdaCaptureKind::LCK_ByRef, VD);

LE->getCaptureDefaultLoc(),
LE->hasExplicitParameters(),
LE->hasExplicitResultType(),
true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: too many arguments to function call, expected 7, have 8 [clang-diagnostic-error]

                              true);
                              ^
Additional context

llvm/include/clang/Sema/Sema.h:7176: 'buildLambdaScope' declared here

  void buildLambdaScope(sema::LambdaScopeInfo *LSI, CXXMethodDecl *CallOperator,
       ^

// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
// modified by the derived callee function.
if (utils::IsReferenceOrPointerArg(arg) ||
if (utils::IsReferenceOrPointerArg(arg)||
!m_DiffReq.shouldHaveAdjoint(PVD)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: repeated branch in conditional chain [bugprone-branch-clone]

          !m_DiffReq.shouldHaveAdjoint(PVD)) {
                                             ^
Additional context

lib/Differentiator/ReverseModeVisitor.cpp:2191: end of the original

      } else if (isArgLambda) {
       ^

lib/Differentiator/ReverseModeVisitor.cpp:2191: clone 1 starts here

      } else if (isArgLambda) {
                              ^

@gojakuch gojakuch force-pushed the lambda-support-reverse branch from 5cb158c to e870826 Compare December 9, 2024 17:53
@gojakuch
Copy link
Collaborator Author

gojakuch commented Dec 9, 2024

I rebased this PR (up to some most recent changes) some time ago and tried to continue pushing this through, but to no avail. I don't think this is quickly fixable in the matter weeks anymore, so I'm leaving all of the code for my attempts committed here, in case anyone decides to take this off at some point. it should be possible to implement this in general, but I've been beating around the bush here for too long and I've got no idea how exactly can this be implemented correctly. I could generate some correct code at some point, but Clad would crash after that because the AST was ill-formed. now it's all completely gone. maybe I'll get back to this later and remake this from scratch, once I've got more time on my hands

@vgvassilev
Copy link
Owner

I rebased this PR (up to some most recent changes) and tried to continue pushing this through, but to no avail. I don't think this is quickly fixable in the matter weeks anymore, so I'm leaving all of the code for my attempts committed here, in case anyone decides to take this off at some point. it should be possible to implement this in general, but I've been beating around the bush here for too long and I've got no idea how exactly can this be implemented correctly. I could generate some correct code at some point, but Clad would crash after that because the AST was ill-formed. now it's all completely gone. maybe I'll get back to this later and remake this from scratch, once I've got more time on my hands

Hi @gojakuch. Thanks for the efforts. It is really unfortunate outcome of the project but I hope we can continue from where you are leaving it. Best of luck and we are looking forward to having you back again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants