From df5a48aad613f241b792582ae048b6c857dc6446 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Tue, 10 Sep 2024 21:08:29 +0200 Subject: [PATCH 1/9] Fix the generation of invalid code in some common cases This commit fixes the way Clad generates code. Specifically, it addresses the way operators appear in the generated code in the reverse mode and the way nested name qualifiers are built in both modes. Fixes: 1050, 1087 --- include/clad/Differentiator/VisitorBase.h | 11 +++ lib/Differentiator/BaseForwardModeVisitor.cpp | 6 +- lib/Differentiator/ReverseModeVisitor.cpp | 25 ++++-- lib/Differentiator/VisitorBase.cpp | 37 +++++++++ test/Gradient/Lambdas.C | 2 +- test/ValidCodeGen/ValidCodeGen.C | 80 +++++++++++++++++++ 6 files changed, 152 insertions(+), 9 deletions(-) create mode 100644 test/ValidCodeGen/ValidCodeGen.C diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 917088c42..3bfd65564 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -364,6 +364,17 @@ namespace clad { /// \returns the DeclRefExpr for the given declaration. clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, const clang::CXXScopeSpec* SS = nullptr); + /// Builds a DeclRefExpr to a given Decl, adding proper nested name + /// qualifiers. \param[in] D The declaration to build a DeclRefExpr for. + /// \param[in] NNS The nested name specifier to use + /// \param[in] FoundD Found decl that can later be accessed from the + /// DeclRefExpr with the getFoundDecl() method \param[in] TemplateArgs + /// Template arguments, can be left nullptr \returns the DeclRefExpr for the + /// given declaration. + clang::DeclRefExpr* + BuildDeclRef(clang::DeclaratorDecl* D, clang::NestedNameSpecifier* NNS, + clang::NamedDecl* FoundD, + const clang::TemplateArgumentListInfo* TemplateArgs = nullptr); /// Stores the result of an expression in a temporary variable (of the same /// type as is the result of the expression) and returns a reference to it. diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index fe14227b7..b55bb82d3 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1036,8 +1036,9 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { // Sema::BuildDeclRefExpr is responsible for adding captured fields // to the underlying struct of a lambda. if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { + NestedNameSpecifier* NNS = DRE->getQualifier(); auto referencedDecl = cast(clonedDRE->getDecl()); - clonedDRE = cast(BuildDeclRef(referencedDecl)); + clonedDRE = BuildDeclRef(referencedDecl, NNS, clonedDRE->getFoundDecl()); } } else clonedDRE = cast(Clone(DRE)); @@ -1052,7 +1053,8 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { if (auto dVarDRE = dyn_cast(dExpr)) { auto dVar = cast(dVarDRE->getDecl()); if (dVar->getDeclContext() != m_Sema.CurContext) - dExpr = BuildDeclRef(dVar); + dExpr = + BuildDeclRef(dVar, DRE->getQualifier(), dVarDRE->getFoundDecl()); } return StmtDiff(clonedDRE, dExpr); } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index e90036593..139f2dac7 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1544,8 +1544,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. // Sema::BuildDeclRefExpr is responsible for adding captured fields // to the underlying struct of a lambda. - if (VD->getDeclContext() != m_Sema.CurContext) - clonedDRE = cast(BuildDeclRef(VD)); + if (VD->getDeclContext() != m_Sema.CurContext) { + auto* ccDRE = dyn_cast(clonedDRE); + NestedNameSpecifier* NNS = DRE->getQualifier(); + auto referencedDecl = cast(ccDRE->getDecl()); + clonedDRE = BuildDeclRef(referencedDecl, NNS, ccDRE->getFoundDecl()); + } // This case happens when ref-type variables have to become function // global. Ref-type declarations cannot be moved to the function global // scope because they can't be separated from their inits. @@ -1852,9 +1856,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } Expr* OverloadedDerivedFn = nullptr; - // If the function has a single arg and does not returns a reference or take + // If the function has a single arg and does not return a reference or take // arg by reference, we look for a derivative w.r.t. to this arg using the - // forward mode(it is unlikely that we need gradient of a one-dimensional' + // forward mode(it is unlikely that we need gradient of a one-dimensional // function). bool asGrad = true; @@ -2149,8 +2153,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff argDiff = Visit(arg); CallArgs.push_back(argDiff.getExpr_dx()); } - if (baseDiff.getExpr()) { - Expr* baseE = baseDiff.getExpr(); + if (Expr* baseE = baseDiff.getExpr()) { call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), CallArgs, Loc); } else { @@ -2167,6 +2170,16 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); return StmtDiff(resValue, resAdjoint, resAdjoint); } // Recreate the original call expression. + + if (const auto* OCE = dyn_cast(CE)) { + CallArgs.insert(CallArgs.begin(), Clone(OCE->getArg(0))); + call = CXXOperatorCallExpr::Create( + m_Context, OCE->getOperator(), Clone(CE->getCallee()), CallArgs, + FD->getCallResultType(), VK_LValue, Loc, + CE->getFPFeaturesInEffect(LangOptions())); + return StmtDiff(call); + } + call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, CallArgs, Loc) diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index f739abd2f..16318e1ea 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -243,6 +243,43 @@ namespace clad { m_Sema.BuildDeclRefExpr(D, T, VK_LValue, D->getBeginLoc(), SS))); } + DeclRefExpr* VisitorBase::BuildDeclRef( + DeclaratorDecl* D, NestedNameSpecifier* NNS, NamedDecl* FoundD, + const TemplateArgumentListInfo* TemplateArgs /*=nullptr*/) { + QualType T = D->getType(); + T = T.getNonReferenceType(); + + std::vector NNChain; + CXXScopeSpec CSS; + while (NNS) { + NNChain.push_back(NNS); + NNS = NNS->getPrefix(); + } + + std::reverse(NNChain.begin(), NNChain.end()); + + for (size_t i = 0; i < NNChain.size(); ++i) { + NNS = NNChain[i]; + if (NNS->getKind() == NestedNameSpecifier::Namespace) { + NamespaceDecl* NS = NNS->getAsNamespace(); + CSS.Extend(m_Context, NS, noLoc, noLoc); + } else if (NNS->getKind() == NestedNameSpecifier::TypeSpec) { + const Type* T = NNS->getAsType(); + if (auto* RT = const_cast(T->getAs())) { + RecordDecl* RD = RT->getDecl(); + CSS.Extend(m_Context, RD->getIdentifier(), noLoc, noLoc); + } + } + } + + DeclarationNameInfo NameInfo(D->getDeclName(), D->getBeginLoc()); + auto NNLoc = (CSS.isNotEmpty() && CSS.isValid()) + ? CSS.getWithLocInContext(m_Context) + : NestedNameSpecifierLoc(); + return cast(clad_compat::GetResult( + m_Sema.BuildDeclRefExpr(D, T, VK_LValue, NameInfo, NNLoc, FoundD))); + } + IdentifierInfo* VisitorBase::CreateUniqueIdentifier(llvm::StringRef nameBase) { // For intermediate variables, use numbered names (_t0), for everything diff --git a/test/Gradient/Lambdas.C b/test/Gradient/Lambdas.C index f9b06aeeb..a1e85b8b3 100644 --- a/test/Gradient/Lambdas.C +++ b/test/Gradient/Lambdas.C @@ -39,7 +39,7 @@ double f2(double i, double j) { // CHECK-NEXT: return t + k; // CHECK-NEXT: }{{;?}} // CHECK: double _d_x = 0.; -// CHECK-NEXT: double x = operator()(i + j, i); +// CHECK-NEXT: double x = _f(i + j, i); // CHECK-NEXT: _d_x += 1; // CHECK-NEXT: { // CHECK-NEXT: double _r0 = 0.; diff --git a/test/ValidCodeGen/ValidCodeGen.C b/test/ValidCodeGen/ValidCodeGen.C new file mode 100644 index 000000000..b789a57a5 --- /dev/null +++ b/test/ValidCodeGen/ValidCodeGen.C @@ -0,0 +1,80 @@ +// XFAIL: asserts +// RUN: %cladclang -std=c++14 %s -I%S/../../include -oValidCodeGen.out 2>&1 | %filecheck %s +// RUN: ./ValidCodeGen.out | %filecheck_exec %s +// RUN: %cladclang -std=c++14 -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oValidCodeGenWithTBR.out +// RUN: ./ValidCodeGenWithTBR.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/STLBuiltins.h" +#include "../TestUtils.h" +#include "../PrintOverloads.h" + +namespace TN { + struct Test { + static int multiplier; + }; + int Test::multiplier = 3; + + template + struct Test2 { + T operator[](T x) { + return 4*x; + } + }; +} + +namespace clad { +namespace custom_derivatives { +namespace class_functions { + template + void operator_subscript_pullback(::TN::Test2* obj, T x, T d_u, ::TN::Test2* d_obj, T* d_x) { + (*d_x) += 4*d_u; + } +}}} + +double fn(double x) { + // fwd and rvs mode test + return x*TN::Test::multiplier; // in this test, it's important that this nested name is copied into the generated code properly in both modes +} + +double fn2(double x, double y) { + // rvs mode test + TN::Test2 t; + auto q = t[x]; // in this test, it's important that this operator call is copied into the generated code properly and that the pullback function is called with all the needed namespace prefixes + return q; +} + +int main() { + double dx, dy; + INIT_DIFFERENTIATE(fn, "x"); + INIT_GRADIENT(fn); + INIT_GRADIENT(fn2); + + TEST_GRADIENT(fn, /*numOfDerivativeArgs=*/1, 3, &dx); // CHECK-EXEC: {3.00} + TEST_GRADIENT(fn2, /*numOfDerivativeArgs=*/2, 3, 4, &dx, &dy); // CHECK-EXEC: {4.00, 0.00} + TEST_DIFFERENTIATE(fn, 3) // CHECK-EXEC: {3.00} +} + +//CHECK: double fn_darg0(double x) { +//CHECK-NEXT: double _d_x = 1; +//CHECK-NEXT: return _d_x * TN::Test::multiplier + x * 0; +//CHECK-NEXT: } + +//CHECK: void fn_grad(double x, double *_d_x) { +//CHECK-NEXT: *_d_x += 1 * TN::Test::multiplier; +//CHECK-NEXT: } + +//CHECK: void fn2_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: TN::Test2 _d_t({}); +//CHECK-NEXT: TN::Test2 t; +//CHECK-NEXT: TN::Test2 _t0 = t; +//CHECK-NEXT: double _d_q = 0.; +//CHECK-NEXT: double q = t[x]; +//CHECK-NEXT: _d_q += 1; +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&_t0, x, _d_q, &_d_t, &_r0); +//CHECK-NEXT: *_d_x += _r0; +//CHECK-NEXT: } +//CHECK-NEXT: } From afec0843b64178ac7e917078254b19ea5b0193c5 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 11 Sep 2024 11:51:44 +0200 Subject: [PATCH 2/9] compatibility update --- lib/Differentiator/BaseForwardModeVisitor.cpp | 2 +- lib/Differentiator/ReverseModeVisitor.cpp | 5 ++--- lib/Differentiator/VisitorBase.cpp | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index b55bb82d3..4a9062f43 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1037,7 +1037,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { // to the underlying struct of a lambda. if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { NestedNameSpecifier* NNS = DRE->getQualifier(); - auto referencedDecl = cast(clonedDRE->getDecl()); + auto* referencedDecl = cast(clonedDRE->getDecl()); clonedDRE = BuildDeclRef(referencedDecl, NNS, clonedDRE->getFoundDecl()); } } else diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 139f2dac7..ceb6303e9 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1547,7 +1547,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (VD->getDeclContext() != m_Sema.CurContext) { auto* ccDRE = dyn_cast(clonedDRE); NestedNameSpecifier* NNS = DRE->getQualifier(); - auto referencedDecl = cast(ccDRE->getDecl()); + auto* referencedDecl = cast(ccDRE->getDecl()); clonedDRE = BuildDeclRef(referencedDecl, NNS, ccDRE->getFoundDecl()); } // This case happens when ref-type variables have to become function @@ -2175,8 +2175,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CallArgs.insert(CallArgs.begin(), Clone(OCE->getArg(0))); call = CXXOperatorCallExpr::Create( m_Context, OCE->getOperator(), Clone(CE->getCallee()), CallArgs, - FD->getCallResultType(), VK_LValue, Loc, - CE->getFPFeaturesInEffect(LangOptions())); + FD->getCallResultType(), VK_LValue, Loc, FPOptions()); return StmtDiff(call); } diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 16318e1ea..47bdfd28b 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -258,8 +258,8 @@ namespace clad { std::reverse(NNChain.begin(), NNChain.end()); - for (size_t i = 0; i < NNChain.size(); ++i) { - NNS = NNChain[i]; + for (auto& n : NNChain) { + NNS = n; if (NNS->getKind() == NestedNameSpecifier::Namespace) { NamespaceDecl* NS = NNS->getAsNamespace(); CSS.Extend(m_Context, NS, noLoc, noLoc); From bc9f5a3e176d59bdb1ce0977d9a53a1a6855094a Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 11 Sep 2024 13:27:05 +0200 Subject: [PATCH 3/9] update compatibility 2 --- lib/Differentiator/ReverseModeVisitor.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index ceb6303e9..e4dd26ff1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2175,7 +2175,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CallArgs.insert(CallArgs.begin(), Clone(OCE->getArg(0))); call = CXXOperatorCallExpr::Create( m_Context, OCE->getOperator(), Clone(CE->getCallee()), CallArgs, - FD->getCallResultType(), VK_LValue, Loc, FPOptions()); + FD->getCallResultType(), OCE->getValueKind(), Loc, + CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsOverride()); return StmtDiff(call); } From d9701624dae6498f04ca80a6088b992bfb166fe3 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 11 Sep 2024 18:04:56 +0200 Subject: [PATCH 4/9] update compatibility 3 --- include/clad/Differentiator/VisitorBase.h | 15 +++++---------- lib/Differentiator/BaseForwardModeVisitor.cpp | 5 ++--- lib/Differentiator/ReverseModeVisitor.cpp | 2 +- lib/Differentiator/VisitorBase.cpp | 15 +++------------ 4 files changed, 11 insertions(+), 26 deletions(-) diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 3bfd65564..bacd8a02c 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -365,16 +365,11 @@ namespace clad { clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, const clang::CXXScopeSpec* SS = nullptr); /// Builds a DeclRefExpr to a given Decl, adding proper nested name - /// qualifiers. \param[in] D The declaration to build a DeclRefExpr for. - /// \param[in] NNS The nested name specifier to use - /// \param[in] FoundD Found decl that can later be accessed from the - /// DeclRefExpr with the getFoundDecl() method \param[in] TemplateArgs - /// Template arguments, can be left nullptr \returns the DeclRefExpr for the - /// given declaration. - clang::DeclRefExpr* - BuildDeclRef(clang::DeclaratorDecl* D, clang::NestedNameSpecifier* NNS, - clang::NamedDecl* FoundD, - const clang::TemplateArgumentListInfo* TemplateArgs = nullptr); + /// qualifiers. + /// \param[in] D The declaration to build a DeclRefExpr for. + /// \param[in] NNS The nested name specifier to use. + clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, + clang::NestedNameSpecifier* NNS); /// Stores the result of an expression in a temporary variable (of the same /// type as is the result of the expression) and returns a reference to it. diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 4a9062f43..8015b8fdb 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1038,7 +1038,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { NestedNameSpecifier* NNS = DRE->getQualifier(); auto* referencedDecl = cast(clonedDRE->getDecl()); - clonedDRE = BuildDeclRef(referencedDecl, NNS, clonedDRE->getFoundDecl()); + clonedDRE = BuildDeclRef(referencedDecl, NNS); } } else clonedDRE = cast(Clone(DRE)); @@ -1053,8 +1053,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { if (auto dVarDRE = dyn_cast(dExpr)) { auto dVar = cast(dVarDRE->getDecl()); if (dVar->getDeclContext() != m_Sema.CurContext) - dExpr = - BuildDeclRef(dVar, DRE->getQualifier(), dVarDRE->getFoundDecl()); + dExpr = BuildDeclRef(dVar, DRE->getQualifier()); } return StmtDiff(clonedDRE, dExpr); } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index e4dd26ff1..b03f49c49 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1548,7 +1548,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto* ccDRE = dyn_cast(clonedDRE); NestedNameSpecifier* NNS = DRE->getQualifier(); auto* referencedDecl = cast(ccDRE->getDecl()); - clonedDRE = BuildDeclRef(referencedDecl, NNS, ccDRE->getFoundDecl()); + clonedDRE = BuildDeclRef(referencedDecl, NNS); } // This case happens when ref-type variables have to become function // global. Ref-type declarations cannot be moved to the function global diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 47bdfd28b..f6b1cab0f 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -243,12 +243,8 @@ namespace clad { m_Sema.BuildDeclRefExpr(D, T, VK_LValue, D->getBeginLoc(), SS))); } - DeclRefExpr* VisitorBase::BuildDeclRef( - DeclaratorDecl* D, NestedNameSpecifier* NNS, NamedDecl* FoundD, - const TemplateArgumentListInfo* TemplateArgs /*=nullptr*/) { - QualType T = D->getType(); - T = T.getNonReferenceType(); - + DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D, + NestedNameSpecifier* NNS) { std::vector NNChain; CXXScopeSpec CSS; while (NNS) { @@ -272,12 +268,7 @@ namespace clad { } } - DeclarationNameInfo NameInfo(D->getDeclName(), D->getBeginLoc()); - auto NNLoc = (CSS.isNotEmpty() && CSS.isValid()) - ? CSS.getWithLocInContext(m_Context) - : NestedNameSpecifierLoc(); - return cast(clad_compat::GetResult( - m_Sema.BuildDeclRefExpr(D, T, VK_LValue, NameInfo, NNLoc, FoundD))); + return BuildDeclRef(D, &CSS); } IdentifierInfo* From f8ba731bd52246600ff9009490eee0c8fc573f91 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 11 Sep 2024 18:32:02 +0200 Subject: [PATCH 5/9] change the generated value kind for operator calls --- lib/Differentiator/ReverseModeVisitor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index b03f49c49..e3fb19361 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2175,7 +2175,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CallArgs.insert(CallArgs.begin(), Clone(OCE->getArg(0))); call = CXXOperatorCallExpr::Create( m_Context, OCE->getOperator(), Clone(CE->getCallee()), CallArgs, - FD->getCallResultType(), OCE->getValueKind(), Loc, + FD->getCallResultType(), VK_LValue, Loc, CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsOverride()); return StmtDiff(call); } From 53bf73d1099e508a58c8df5545ee366bf3ab0d64 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Tue, 1 Oct 2024 22:23:15 +0200 Subject: [PATCH 6/9] minor adjustments --- include/clad/Differentiator/VisitorBase.h | 6 ++++-- lib/Differentiator/ReverseModeVisitor.cpp | 5 +++-- lib/Differentiator/VisitorBase.cpp | 18 ++++++++++++------ test/Gradient/Lambdas.C | 4 ++-- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index bacd8a02c..84bf42342 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -363,13 +363,15 @@ namespace clad { /// \param[in] SS The scope specifier for the declaration. /// \returns the DeclRefExpr for the given declaration. clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, - const clang::CXXScopeSpec* SS = nullptr); + const clang::CXXScopeSpec* SS = nullptr, + clang::ExprValueKind VK = clang::VK_LValue); /// Builds a DeclRefExpr to a given Decl, adding proper nested name /// qualifiers. /// \param[in] D The declaration to build a DeclRefExpr for. /// \param[in] NNS The nested name specifier to use. clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, - clang::NestedNameSpecifier* NNS); + clang::NestedNameSpecifier* NNS, + clang::ExprValueKind VK = clang::VK_LValue); /// Stores the result of an expression in a temporary variable (of the same /// type as is the result of the expression) and returns a reference to it. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index e3fb19361..620351c61 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1548,7 +1548,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto* ccDRE = dyn_cast(clonedDRE); NestedNameSpecifier* NNS = DRE->getQualifier(); auto* referencedDecl = cast(ccDRE->getDecl()); - clonedDRE = BuildDeclRef(referencedDecl, NNS); + clonedDRE = BuildDeclRef(referencedDecl, NNS, DRE->getValueKind()); } // This case happens when ref-type variables have to become function // global. Ref-type declarations cannot be moved to the function global @@ -2173,9 +2173,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (const auto* OCE = dyn_cast(CE)) { CallArgs.insert(CallArgs.begin(), Clone(OCE->getArg(0))); + // OCE->getArg(0)->dump(); call = CXXOperatorCallExpr::Create( m_Context, OCE->getOperator(), Clone(CE->getCallee()), CallArgs, - FD->getCallResultType(), VK_LValue, Loc, + FD->getCallResultType(), OCE->getValueKind(), Loc, CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsOverride()); return StmtDiff(call); } diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index f6b1cab0f..105d62780 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -236,26 +236,32 @@ namespace clad { } DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D, - const CXXScopeSpec* SS /*=nullptr*/) { + const CXXScopeSpec* SS /*=nullptr*/, + ExprValueKind VK /*=VK_LValue*/) { QualType T = D->getType(); T = T.getNonReferenceType(); return cast(clad_compat::GetResult( - m_Sema.BuildDeclRefExpr(D, T, VK_LValue, D->getBeginLoc(), SS))); + m_Sema.BuildDeclRefExpr(D, T, VK, D->getBeginLoc(), SS))); } DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D, - NestedNameSpecifier* NNS) { + NestedNameSpecifier* NNS, + ExprValueKind VK /*=VK_LValue*/) { std::vector NNChain; CXXScopeSpec CSS; while (NNS) { + // FIXME: proper support for dependent NNS needs to be added. + //if (!NNS->isDependent()) return BuildDeclRef(D); + NNChain.push_back(NNS); NNS = NNS->getPrefix(); } std::reverse(NNChain.begin(), NNChain.end()); - for (auto& n : NNChain) { - NNS = n; + for (size_t i = 0; i < NNChain.size(); ++i) { + NNS = NNChain[i]; + // FIXME: this needs to be extended to support more NNS kinds. An inspiration can be take from getFullyQualifiedNestedNameSpecifier in llvm-project/clang/lib/AST/QualTypeNames.cpp if (NNS->getKind() == NestedNameSpecifier::Namespace) { NamespaceDecl* NS = NNS->getAsNamespace(); CSS.Extend(m_Context, NS, noLoc, noLoc); @@ -268,7 +274,7 @@ namespace clad { } } - return BuildDeclRef(D, &CSS); + return BuildDeclRef(D, &CSS, VK); } IdentifierInfo* diff --git a/test/Gradient/Lambdas.C b/test/Gradient/Lambdas.C index a1e85b8b3..9fef12d83 100644 --- a/test/Gradient/Lambdas.C +++ b/test/Gradient/Lambdas.C @@ -13,7 +13,7 @@ double f1(double i, double j) { } // CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const; -// CHECK-NEXT: void f1_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK: void f1_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: auto _f = []{{ ?}}(double t) { // CHECK-NEXT: return t * t + 1.; // CHECK-NEXT: }{{;?}} @@ -34,7 +34,7 @@ double f2(double i, double j) { } // CHECK: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const; -// CHECK-NEXT: void f2_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK: void f2_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: auto _f = []{{ ?}}(double t, double k) { // CHECK-NEXT: return t + k; // CHECK-NEXT: }{{;?}} From 0b0e0869fdbd4176e08130698990b552cb838813 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Tue, 1 Oct 2024 22:37:31 +0200 Subject: [PATCH 7/9] format --- include/clad/Differentiator/VisitorBase.h | 13 +++++++------ lib/Differentiator/VisitorBase.cpp | 10 ++++++---- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 84bf42342..dba1540a2 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -362,16 +362,17 @@ namespace clad { /// \param[in] D The declaration to build a DeclRefExpr for. /// \param[in] SS The scope specifier for the declaration. /// \returns the DeclRefExpr for the given declaration. - clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, - const clang::CXXScopeSpec* SS = nullptr, - clang::ExprValueKind VK = clang::VK_LValue); + clang::DeclRefExpr* + BuildDeclRef(clang::DeclaratorDecl* D, + const clang::CXXScopeSpec* SS = nullptr, + clang::ExprValueKind VK = clang::VK_LValue); /// Builds a DeclRefExpr to a given Decl, adding proper nested name /// qualifiers. /// \param[in] D The declaration to build a DeclRefExpr for. /// \param[in] NNS The nested name specifier to use. - clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, - clang::NestedNameSpecifier* NNS, - clang::ExprValueKind VK = clang::VK_LValue); + clang::DeclRefExpr* + BuildDeclRef(clang::DeclaratorDecl* D, clang::NestedNameSpecifier* NNS, + clang::ExprValueKind VK = clang::VK_LValue); /// Stores the result of an expression in a temporary variable (of the same /// type as is the result of the expression) and returns a reference to it. diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 105d62780..56f966392 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -236,7 +236,7 @@ namespace clad { } DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D, - const CXXScopeSpec* SS /*=nullptr*/, + const CXXScopeSpec* SS /*=nullptr*/, ExprValueKind VK /*=VK_LValue*/) { QualType T = D->getType(); T = T.getNonReferenceType(); @@ -245,13 +245,13 @@ namespace clad { } DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D, - NestedNameSpecifier* NNS, + NestedNameSpecifier* NNS, ExprValueKind VK /*=VK_LValue*/) { std::vector NNChain; CXXScopeSpec CSS; while (NNS) { // FIXME: proper support for dependent NNS needs to be added. - //if (!NNS->isDependent()) return BuildDeclRef(D); + // if (!NNS->isDependent()) return BuildDeclRef(D); NNChain.push_back(NNS); NNS = NNS->getPrefix(); @@ -261,7 +261,9 @@ namespace clad { for (size_t i = 0; i < NNChain.size(); ++i) { NNS = NNChain[i]; - // FIXME: this needs to be extended to support more NNS kinds. An inspiration can be take from getFullyQualifiedNestedNameSpecifier in llvm-project/clang/lib/AST/QualTypeNames.cpp + // FIXME: this needs to be extended to support more NNS kinds. An + // inspiration can be take from getFullyQualifiedNestedNameSpecifier in + // llvm-project/clang/lib/AST/QualTypeNames.cpp if (NNS->getKind() == NestedNameSpecifier::Namespace) { NamespaceDecl* NS = NNS->getAsNamespace(); CSS.Extend(m_Context, NS, noLoc, noLoc); From 238286ffdcb17cbf93251f4e0aaafac1e2039869 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 2 Oct 2024 12:27:26 +0200 Subject: [PATCH 8/9] call operators as methods --- lib/Differentiator/ReverseModeVisitor.cpp | 25 ++++++++++++++++------- test/Gradient/Lambdas.C | 2 +- test/ValidCodeGen/ValidCodeGen.C | 2 +- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 620351c61..402aacdce 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2171,13 +2171,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(resValue, resAdjoint, resAdjoint); } // Recreate the original call expression. - if (const auto* OCE = dyn_cast(CE)) { - CallArgs.insert(CallArgs.begin(), Clone(OCE->getArg(0))); - // OCE->getArg(0)->dump(); - call = CXXOperatorCallExpr::Create( - m_Context, OCE->getOperator(), Clone(CE->getCallee()), CallArgs, - FD->getCallResultType(), OCE->getValueKind(), Loc, - CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsOverride()); + if (auto* OCE = dyn_cast(CE)) { + CXXMethodDecl* FD = const_cast( + dyn_cast(OCE->getCalleeDecl())); + + NestedNameSpecifierLoc NNS(FD->getQualifier(), + /*Data=*/nullptr); + auto DAP = DeclAccessPair::make(FD, FD->getAccess()); + auto* memberExpr = MemberExpr::Create( + m_Context, Clone(OCE->getArg(0)), /*isArrow=*/false, Loc, NNS, noLoc, + FD, DAP, FD->getNameInfo(), + /*TemplateArgs=*/nullptr, m_Context.BoundMemberTy, + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + ExprObjectKind::OK_Ordinary CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams( + NOUR_None)); + call = m_Sema + .BuildCallToMemberFunction(getCurrentScope(), memberExpr, Loc, + CallArgs, Loc) + .get(); return StmtDiff(call); } diff --git a/test/Gradient/Lambdas.C b/test/Gradient/Lambdas.C index 9fef12d83..35776e2d6 100644 --- a/test/Gradient/Lambdas.C +++ b/test/Gradient/Lambdas.C @@ -39,7 +39,7 @@ double f2(double i, double j) { // CHECK-NEXT: return t + k; // CHECK-NEXT: }{{;?}} // CHECK: double _d_x = 0.; -// CHECK-NEXT: double x = _f(i + j, i); +// CHECK-NEXT: double x = _f.operator()(i + j, i); // CHECK-NEXT: _d_x += 1; // CHECK-NEXT: { // CHECK-NEXT: double _r0 = 0.; diff --git a/test/ValidCodeGen/ValidCodeGen.C b/test/ValidCodeGen/ValidCodeGen.C index b789a57a5..1e638be1c 100644 --- a/test/ValidCodeGen/ValidCodeGen.C +++ b/test/ValidCodeGen/ValidCodeGen.C @@ -70,7 +70,7 @@ int main() { //CHECK-NEXT: TN::Test2 t; //CHECK-NEXT: TN::Test2 _t0 = t; //CHECK-NEXT: double _d_q = 0.; -//CHECK-NEXT: double q = t[x]; +//CHECK-NEXT: double q = t.operator[](x); //CHECK-NEXT: _d_q += 1; //CHECK-NEXT: { //CHECK-NEXT: double _r0 = 0.; From cd3fc604834c08ec7b9e6c0a4c2facb38b6d91f3 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Sun, 6 Oct 2024 22:00:54 +0200 Subject: [PATCH 9/9] correct CXXScopeSpec::Extend usage (no templated types) --- lib/Differentiator/VisitorBase.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 56f966392..bb3dbd615 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -250,9 +250,6 @@ namespace clad { std::vector NNChain; CXXScopeSpec CSS; while (NNS) { - // FIXME: proper support for dependent NNS needs to be added. - // if (!NNS->isDependent()) return BuildDeclRef(D); - NNChain.push_back(NNS); NNS = NNS->getPrefix(); } @@ -271,7 +268,14 @@ namespace clad { const Type* T = NNS->getAsType(); if (auto* RT = const_cast(T->getAs())) { RecordDecl* RD = RT->getDecl(); - CSS.Extend(m_Context, RD->getIdentifier(), noLoc, noLoc); + // FIXME: currently only works for non-templated types + bool isTemplated = false; + if (const auto* CXXDecl = dyn_cast(RD)) + isTemplated = CXXDecl->getDescribedClassTemplate() || + (CXXDecl->getTemplateSpecializationKind() != + clang::TSK_Undeclared); + if (!isTemplated) + CSS.Extend(m_Context, RD->getIdentifier(), noLoc, noLoc); } } }