diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 917088c42..dba1540a2 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -362,8 +362,17 @@ namespace clad { /// \param[in] D The declaration to build a DeclRefExpr for. /// \param[in] SS The scope specifier for the declaration. /// \returns the DeclRefExpr for the given declaration. - clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, - const clang::CXXScopeSpec* SS = nullptr); + clang::DeclRefExpr* + BuildDeclRef(clang::DeclaratorDecl* D, + const clang::CXXScopeSpec* SS = nullptr, + clang::ExprValueKind VK = clang::VK_LValue); + /// Builds a DeclRefExpr to a given Decl, adding proper nested name + /// qualifiers. + /// \param[in] D The declaration to build a DeclRefExpr for. + /// \param[in] NNS The nested name specifier to use. + clang::DeclRefExpr* + BuildDeclRef(clang::DeclaratorDecl* D, clang::NestedNameSpecifier* NNS, + clang::ExprValueKind VK = clang::VK_LValue); /// Stores the result of an expression in a temporary variable (of the same /// type as is the result of the expression) and returns a reference to it. diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index fe14227b7..8015b8fdb 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1036,8 +1036,9 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { // Sema::BuildDeclRefExpr is responsible for adding captured fields // to the underlying struct of a lambda. if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { - auto referencedDecl = cast(clonedDRE->getDecl()); - clonedDRE = cast(BuildDeclRef(referencedDecl)); + NestedNameSpecifier* NNS = DRE->getQualifier(); + auto* referencedDecl = cast(clonedDRE->getDecl()); + clonedDRE = BuildDeclRef(referencedDecl, NNS); } } else clonedDRE = cast(Clone(DRE)); @@ -1052,7 +1053,7 @@ StmtDiff BaseForwardModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { if (auto dVarDRE = dyn_cast(dExpr)) { auto dVar = cast(dVarDRE->getDecl()); if (dVar->getDeclContext() != m_Sema.CurContext) - dExpr = BuildDeclRef(dVar); + dExpr = BuildDeclRef(dVar, DRE->getQualifier()); } return StmtDiff(clonedDRE, dExpr); } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index e90036593..2a7a16aa4 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1544,8 +1544,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. // Sema::BuildDeclRefExpr is responsible for adding captured fields // to the underlying struct of a lambda. - if (VD->getDeclContext() != m_Sema.CurContext) - clonedDRE = cast(BuildDeclRef(VD)); + if (VD->getDeclContext() != m_Sema.CurContext) { + auto* ccDRE = dyn_cast(clonedDRE); + NestedNameSpecifier* NNS = DRE->getQualifier(); + auto* referencedDecl = cast(ccDRE->getDecl()); + clonedDRE = BuildDeclRef(referencedDecl, NNS, DRE->getValueKind()); + } // This case happens when ref-type variables have to become function // global. Ref-type declarations cannot be moved to the function global // scope because they can't be separated from their inits. @@ -1852,9 +1856,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } Expr* OverloadedDerivedFn = nullptr; - // If the function has a single arg and does not returns a reference or take + // If the function has a single arg and does not return a reference or take // arg by reference, we look for a derivative w.r.t. to this arg using the - // forward mode(it is unlikely that we need gradient of a one-dimensional' + // forward mode(it is unlikely that we need gradient of a one-dimensional // function). bool asGrad = true; @@ -2149,8 +2153,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff argDiff = Visit(arg); CallArgs.push_back(argDiff.getExpr_dx()); } - if (baseDiff.getExpr()) { - Expr* baseE = baseDiff.getExpr(); + if (Expr* baseE = baseDiff.getExpr()) { call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), CallArgs, Loc); } else { @@ -2167,6 +2170,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); return StmtDiff(resValue, resAdjoint, resAdjoint); } // Recreate the original call expression. + + if (const auto* OCE = dyn_cast(CE)) { + auto* FD = const_cast( + dyn_cast(OCE->getCalleeDecl())); + + NestedNameSpecifierLoc NNS(FD->getQualifier(), + /*Data=*/nullptr); + auto DAP = DeclAccessPair::make(FD, FD->getAccess()); + auto* memberExpr = MemberExpr::Create( + m_Context, Clone(OCE->getArg(0)), /*isArrow=*/false, Loc, NNS, noLoc, + FD, DAP, FD->getNameInfo(), + /*TemplateArgs=*/nullptr, m_Context.BoundMemberTy, + CLAD_COMPAT_ExprValueKind_R_or_PR_Value, + ExprObjectKind::OK_Ordinary CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams( + NOUR_None)); + call = m_Sema + .BuildCallToMemberFunction(getCurrentScope(), memberExpr, Loc, + CallArgs, Loc) + .get(); + return StmtDiff(call); + } + call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc, CallArgs, Loc) diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index f739abd2f..bb3dbd615 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -236,11 +236,51 @@ namespace clad { } DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D, - const CXXScopeSpec* SS /*=nullptr*/) { + const CXXScopeSpec* SS /*=nullptr*/, + ExprValueKind VK /*=VK_LValue*/) { QualType T = D->getType(); T = T.getNonReferenceType(); return cast(clad_compat::GetResult( - m_Sema.BuildDeclRefExpr(D, T, VK_LValue, D->getBeginLoc(), SS))); + m_Sema.BuildDeclRefExpr(D, T, VK, D->getBeginLoc(), SS))); + } + + DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D, + NestedNameSpecifier* NNS, + ExprValueKind VK /*=VK_LValue*/) { + std::vector NNChain; + CXXScopeSpec CSS; + while (NNS) { + NNChain.push_back(NNS); + NNS = NNS->getPrefix(); + } + + std::reverse(NNChain.begin(), NNChain.end()); + + for (size_t i = 0; i < NNChain.size(); ++i) { + NNS = NNChain[i]; + // FIXME: this needs to be extended to support more NNS kinds. An + // inspiration can be take from getFullyQualifiedNestedNameSpecifier in + // llvm-project/clang/lib/AST/QualTypeNames.cpp + if (NNS->getKind() == NestedNameSpecifier::Namespace) { + NamespaceDecl* NS = NNS->getAsNamespace(); + CSS.Extend(m_Context, NS, noLoc, noLoc); + } else if (NNS->getKind() == NestedNameSpecifier::TypeSpec) { + const Type* T = NNS->getAsType(); + if (auto* RT = const_cast(T->getAs())) { + RecordDecl* RD = RT->getDecl(); + // FIXME: currently only works for non-templated types + bool isTemplated = false; + if (const auto* CXXDecl = dyn_cast(RD)) + isTemplated = CXXDecl->getDescribedClassTemplate() || + (CXXDecl->getTemplateSpecializationKind() != + clang::TSK_Undeclared); + if (!isTemplated) + CSS.Extend(m_Context, RD->getIdentifier(), noLoc, noLoc); + } + } + } + + return BuildDeclRef(D, &CSS, VK); } IdentifierInfo* diff --git a/test/Gradient/Lambdas.C b/test/Gradient/Lambdas.C index f9b06aeeb..35776e2d6 100644 --- a/test/Gradient/Lambdas.C +++ b/test/Gradient/Lambdas.C @@ -13,7 +13,7 @@ double f1(double i, double j) { } // CHECK: inline void operator_call_pullback(double t, double _d_y, double *_d_t) const; -// CHECK-NEXT: void f1_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK: void f1_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: auto _f = []{{ ?}}(double t) { // CHECK-NEXT: return t * t + 1.; // CHECK-NEXT: }{{;?}} @@ -34,12 +34,12 @@ double f2(double i, double j) { } // CHECK: inline void operator_call_pullback(double t, double k, double _d_y, double *_d_t, double *_d_k) const; -// CHECK-NEXT: void f2_grad(double i, double j, double *_d_i, double *_d_j) { +// CHECK: void f2_grad(double i, double j, double *_d_i, double *_d_j) { // CHECK-NEXT: auto _f = []{{ ?}}(double t, double k) { // CHECK-NEXT: return t + k; // CHECK-NEXT: }{{;?}} // CHECK: double _d_x = 0.; -// CHECK-NEXT: double x = operator()(i + j, i); +// CHECK-NEXT: double x = _f.operator()(i + j, i); // CHECK-NEXT: _d_x += 1; // CHECK-NEXT: { // CHECK-NEXT: double _r0 = 0.; diff --git a/test/ValidCodeGen/ValidCodeGen.C b/test/ValidCodeGen/ValidCodeGen.C new file mode 100644 index 000000000..90d4e96c3 --- /dev/null +++ b/test/ValidCodeGen/ValidCodeGen.C @@ -0,0 +1,79 @@ +// RUN: %cladclang -std=c++14 %s -I%S/../../include -oValidCodeGen.out 2>&1 | %filecheck %s +// RUN: ./ValidCodeGen.out | %filecheck_exec %s +// RUN: %cladclang -std=c++14 -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oValidCodeGenWithTBR.out +// RUN: ./ValidCodeGenWithTBR.out | %filecheck_exec %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Differentiator.h" +#include "clad/Differentiator/STLBuiltins.h" +#include "../TestUtils.h" +#include "../PrintOverloads.h" + +namespace TN { + struct Test { + static int multiplier; + }; + int Test::multiplier = 3; + + template + struct Test2 { + T operator[](T x) { + return 4*x; + } + }; +} + +namespace clad { +namespace custom_derivatives { +namespace class_functions { + template + void operator_subscript_pullback(::TN::Test2* obj, T x, T d_u, ::TN::Test2* d_obj, T* d_x) { + (*d_x) += 4*d_u; + } +}}} + +double fn(double x) { + // fwd and rvs mode test + return x*TN::Test::multiplier; // in this test, it's important that this nested name is copied into the generated code properly in both modes +} + +double fn2(double x, double y) { + // rvs mode test + TN::Test2 t; + auto q = t[x]; // in this test, it's important that this operator call is copied into the generated code properly and that the pullback function is called with all the needed namespace prefixes + return q; +} + +int main() { + double dx, dy; + INIT_DIFFERENTIATE(fn, "x"); + INIT_GRADIENT(fn); + INIT_GRADIENT(fn2); + + TEST_GRADIENT(fn, /*numOfDerivativeArgs=*/1, 3, &dx); // CHECK-EXEC: {3.00} + TEST_GRADIENT(fn2, /*numOfDerivativeArgs=*/2, 3, 4, &dx, &dy); // CHECK-EXEC: {4.00, 0.00} + TEST_DIFFERENTIATE(fn, 3) // CHECK-EXEC: {3.00} +} + +//CHECK: double fn_darg0(double x) { +//CHECK-NEXT: double _d_x = 1; +//CHECK-NEXT: return _d_x * TN::Test::multiplier + x * 0; +//CHECK-NEXT: } + +//CHECK: void fn_grad(double x, double *_d_x) { +//CHECK-NEXT: *_d_x += 1 * TN::Test::multiplier; +//CHECK-NEXT: } + +//CHECK: void fn2_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: TN::Test2 _d_t({}); +//CHECK-NEXT: TN::Test2 t; +//CHECK-NEXT: TN::Test2 _t0 = t; +//CHECK-NEXT: double _d_q = 0.; +//CHECK-NEXT: double q = t.operator[](x); +//CHECK-NEXT: _d_q += 1; +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: clad::custom_derivatives::class_functions::operator_subscript_pullback(&_t0, x, _d_q, &_d_t, &_r0); +//CHECK-NEXT: *_d_x += _r0; +//CHECK-NEXT: } +//CHECK-NEXT: }