diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index ffc90be12..ae2d26a99 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -216,7 +216,8 @@ namespace clad { clang::IdentifierInfo* II, clang::QualType T, clang::StorageClass SC = clang::StorageClass::SC_None, clang::Expr* defArg = nullptr, - clang::TypeSourceInfo* TSI = nullptr); + clang::TypeSourceInfo* TSI = nullptr, + clang::SourceLocation Loc = clang::SourceLocation()); /// If `T` represents an array or a pointer type then returns the /// corresponding array element or the pointee type. If `T` is a reference diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 0bbc13290..8fa556134 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -500,7 +500,8 @@ namespace clad { clang::Expr* BuildCallExprToMemFn(clang::Expr* Base, llvm::StringRef MemberFunctionName, llvm::MutableArrayRef ArgExprs, - clang::ValueDecl* memberDecl = nullptr); + clang::ValueDecl* memberDecl = nullptr, + clang::SourceLocation Loc = noLoc); /// Build a call to member function through this pointer. /// @@ -509,10 +510,9 @@ namespace clad { /// \param[in] useRefQualifiedThisObj If true, then the `this` object is /// perfectly forwarded while calling member functions. /// \returns Built member function call expression - clang::Expr* - BuildCallExprToMemFn(clang::CXXMethodDecl* FD, - llvm::MutableArrayRef argExprs, - bool useRefQualifiedThisObj = false); + clang::Expr* BuildCallExprToMemFn( + clang::CXXMethodDecl* FD, llvm::MutableArrayRef argExprs, + bool useRefQualifiedThisObj = false, clang::SourceLocation Loc = noLoc); /// Build a call to a free function or member function through /// this pointer depending on whether the `FD` argument corresponds to a @@ -577,7 +577,8 @@ namespace clad { clang::ParmVarDecl* CloneParmVarDecl(const clang::ParmVarDecl* PVD, clang::IdentifierInfo* II, bool pushOnScopeChains = false, - bool cloneDefaultArg = true); + bool cloneDefaultArg = true, + clang::SourceLocation Loc = noLoc); /// A function to get the single argument "forward_central_difference" /// call expression for the given arguments. /// diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 13ba31176..471c3083d 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -162,14 +162,14 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, IdentifierInfo* II = &m_Context.Idents.get( request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix); - SourceLocation loc{m_Function->getLocation()}; - DeclarationNameInfo name(II, loc); + SourceLocation validLoc{m_Function->getLocation()}; + DeclarationNameInfo name(II, validLoc); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope()); DeclContext* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = - m_Builder.cloneFunction(FD, *this, DC, loc, name, FD->getType()); + m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType()); FunctionDecl* derivedFD = result.first; m_Derivative = derivedFD; @@ -190,8 +190,8 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, if (PVD->hasDefaultArg()) clonedPVDDefaultArg = Clone(PVD->getDefaultArg()); - newPVD = ParmVarDecl::Create(m_Context, m_Sema.CurContext, noLoc, noLoc, - PVD->getIdentifier(), PVD->getType(), + newPVD = ParmVarDecl::Create(m_Context, m_Sema.CurContext, validLoc, + validLoc, PVD->getIdentifier(), PVD->getType(), PVD->getTypeSourceInfo(), PVD->getStorageClass(), clonedPVDDefaultArg); @@ -322,13 +322,14 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, m_Context.IntTy, m_Context, dValue); dArrVal.push_back(dValueLiteral); } - dInitializer = m_Sema.ActOnInitList(noLoc, dArrVal, noLoc).get(); + dInitializer = + m_Sema.ActOnInitList(validLoc, dArrVal, validLoc).get(); } else if (const auto* ptrType = dyn_cast(fieldType.getTypePtr())) { if (!ptrType->getPointeeType()->isRealType()) continue; // Pointer member variables should be initialised by `nullptr`. - dInitializer = m_Sema.ActOnCXXNullPtrLiteral(noLoc).get(); + dInitializer = m_Sema.ActOnCXXNullPtrLiteral(validLoc).get(); } else { int dValue = (fieldDecl == m_IndependentVar); dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy, @@ -403,7 +404,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, IdentifierInfo* derivedFnII = &m_Context.Idents.get( originalFnEffectiveName + GetPushForwardFunctionSuffix()); - DeclarationNameInfo derivedFnName(derivedFnII, noLoc); + DeclarationNameInfo derivedFnName(derivedFnII, m_Function->getLocation()); llvm::SmallVector paramTypes; llvm::SmallVector derivedParamTypes; @@ -984,6 +985,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { return StmtDiff(Clone(CE)); } + SourceLocation validLoc{CE->getBeginLoc()}; + // If the function is non_differentiable, return zero derivative. if (clad::utils::hasNonDifferentiableAttribute(CE)) { // Calling the function without computing derivatives @@ -993,7 +996,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { Expr* Call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), - noLoc, ClonedArgs, noLoc) + validLoc, ClonedArgs, validLoc) .get(); // Creating a zero derivative auto* zero = @@ -1008,7 +1011,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { if (m_DerivativeOrder == 1) s = ""; - SourceLocation noLoc; llvm::SmallVector CallArgs{}; llvm::SmallVector diffArgs; @@ -1108,7 +1110,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { callDiff = m_Sema .ActOnCallExpr(m_Sema.getScopeForContext(m_Sema.CurContext), - derivativeRef, noLoc, pushforwardFnArgs, noLoc) + derivativeRef, validLoc, pushforwardFnArgs, validLoc) .get(); } @@ -1132,8 +1134,9 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { if (allArgsAreConstantLiterals) { Expr* call = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, - llvm::MutableArrayRef(CallArgs), noLoc) + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), + validLoc, llvm::MutableArrayRef(CallArgs), + validLoc) .get(); auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); @@ -1173,11 +1176,11 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { Expr* execConfig = nullptr; if (auto KCE = dyn_cast(CE)) execConfig = Clone(KCE->getConfig()); - callDiff = - m_Sema - .ActOnCallExpr(getCurrentScope(), BuildDeclRef(pushforwardFD), - noLoc, pushforwardFnArgs, noLoc, execConfig) - .get(); + callDiff = m_Sema + .ActOnCallExpr(getCurrentScope(), + BuildDeclRef(pushforwardFD), validLoc, + pushforwardFnArgs, validLoc, execConfig) + .get(); } } } @@ -1188,8 +1191,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { Multiplier = diffArgs[0]; Expr* call = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, - llvm::MutableArrayRef(CallArgs), noLoc) + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), validLoc, + llvm::MutableArrayRef(CallArgs), validLoc) .get(); // FIXME: Extend this for multiarg support // Check if the function is eligible for numerical differentiation. diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index adaf5d29b..c180d0b1a 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -346,12 +346,14 @@ namespace clad { BuildParmVarDecl(clang::Sema& semaRef, clang::DeclContext* DC, clang::IdentifierInfo* II, clang::QualType T, clang::StorageClass SC, clang::Expr* defArg, - clang::TypeSourceInfo* TSI) { + clang::TypeSourceInfo* TSI, clang::SourceLocation Loc) { ASTContext& C = semaRef.getASTContext(); if (!TSI) TSI = C.getTrivialTypeSourceInfo(T, noLoc); + if (Loc.isInvalid()) + Loc = utils::GetValidSLoc(semaRef); ParmVarDecl* PVD = - ParmVarDecl::Create(C, DC, noLoc, noLoc, II, T, TSI, SC, defArg); + ParmVarDecl::Create(C, DC, Loc, Loc, II, T, TSI, SC, defArg); return PVD; } diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index a482c29ab..92e394dff 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -34,7 +34,11 @@ using namespace clang; namespace clad { clang::CompoundStmt* VisitorBase::MakeCompoundStmt(const Stmts& Stmts) { auto Stmts_ref = clad_compat::makeArrayRef(Stmts.data(), Stmts.size()); - return clad_compat::CompoundStmt_Create(m_Context, Stmts_ref /**/ CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2(FPOptionsOverride()), noLoc, noLoc); + return clad_compat::CompoundStmt_Create( + m_Context, + Stmts_ref /**/ CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2( + FPOptionsOverride()), + utils::GetValidSLoc(m_Sema), utils::GetValidSLoc(m_Sema)); } bool VisitorBase::isUnusedResult(const Expr* E) { @@ -524,7 +528,11 @@ namespace clad { Expr* VisitorBase::BuildCallExprToMemFn(Expr* Base, StringRef MemberFunctionName, - MutableArrayRef ArgExprs, ValueDecl* memberDecl) { + MutableArrayRef ArgExprs, + ValueDecl* memberDecl, + SourceLocation Loc /*=noLoc*/) { + if (Loc.isInvalid()) + Loc = m_Function->getLocation(); UnqualifiedId Member; Member.setIdentifier(&m_Context.Idents.get(MemberFunctionName), noLoc); CXXScopeSpec SS; @@ -549,7 +557,7 @@ namespace clad { // explicitly should assign the correct member function whenever we can. if (memberDecl) ME->setMemberDecl(memberDecl); - return m_Sema.ActOnCallExpr(getCurrentScope(), ME, noLoc, ArgExprs, noLoc) + return m_Sema.ActOnCallExpr(getCurrentScope(), ME, Loc, ArgExprs, Loc) .get(); } @@ -568,9 +576,11 @@ namespace clad { Expr* VisitorBase::BuildCallExprToMemFn( clang::CXXMethodDecl* FD, llvm::MutableArrayRef argExprs, - bool useRefQualifiedThisObj) { + bool useRefQualifiedThisObj, SourceLocation Loc /*=noLoc*/) { Expr* thisExpr = clad_compat::Sema_BuildCXXThisExpr(m_Sema, FD); bool isArrow = true; + if (Loc.isInvalid()) + Loc = m_Function->getLocation(); // C++ does not support perfect forwarding of `*this` object inside // a member function. @@ -604,8 +614,8 @@ namespace clad { ExprObjectKind::OK_Ordinary CLAD_COMPAT_CLANG9_MemberExpr_ExtraParams(NOUR_None)); return m_Sema - .BuildCallToMemberFunction(getCurrentScope(), memberExpr, noLoc, - argExprs, noLoc) + .BuildCallToMemberFunction(getCurrentScope(), memberExpr, Loc, argExprs, + Loc) .get(); } @@ -769,13 +779,16 @@ namespace clad { ParmVarDecl* VisitorBase::CloneParmVarDecl(const ParmVarDecl* PVD, IdentifierInfo* II, bool pushOnScopeChains, - bool cloneDefaultArg) { + bool cloneDefaultArg, + SourceLocation Loc) { Expr* newPVDDefaultArg = nullptr; if (PVD->hasDefaultArg() && cloneDefaultArg) { newPVDDefaultArg = Clone(PVD->getDefaultArg()); } + if (Loc.isInvalid()) + Loc = PVD->getLocation(); auto newPVD = ParmVarDecl::Create( - m_Context, m_Sema.CurContext, noLoc, noLoc, II, PVD->getType(), + m_Context, m_Sema.CurContext, Loc, Loc, II, PVD->getType(), PVD->getTypeSourceInfo(), PVD->getStorageClass(), newPVDDefaultArg); if (pushOnScopeChains && newPVD->getIdentifier()) { m_Sema.PushOnScopeChains(newPVD, getCurrentScope(),