From e7d3b1fe3cd83328f71440d627466bb174c1e344 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Wed, 10 Apr 2024 13:15:03 +0200 Subject: [PATCH] Fix noloc failure --- lib/Differentiator/BaseForwardModeVisitor.cpp | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 13ba31176..471c3083d 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -162,14 +162,14 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, IdentifierInfo* II = &m_Context.Idents.get( request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix); - SourceLocation loc{m_Function->getLocation()}; - DeclarationNameInfo name(II, loc); + SourceLocation validLoc{m_Function->getLocation()}; + DeclarationNameInfo name(II, validLoc); llvm::SaveAndRestore SaveContext(m_Sema.CurContext); llvm::SaveAndRestore SaveScope(getCurrentScope()); DeclContext* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = - m_Builder.cloneFunction(FD, *this, DC, loc, name, FD->getType()); + m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType()); FunctionDecl* derivedFD = result.first; m_Derivative = derivedFD; @@ -190,8 +190,8 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, if (PVD->hasDefaultArg()) clonedPVDDefaultArg = Clone(PVD->getDefaultArg()); - newPVD = ParmVarDecl::Create(m_Context, m_Sema.CurContext, noLoc, noLoc, - PVD->getIdentifier(), PVD->getType(), + newPVD = ParmVarDecl::Create(m_Context, m_Sema.CurContext, validLoc, + validLoc, PVD->getIdentifier(), PVD->getType(), PVD->getTypeSourceInfo(), PVD->getStorageClass(), clonedPVDDefaultArg); @@ -322,13 +322,14 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, m_Context.IntTy, m_Context, dValue); dArrVal.push_back(dValueLiteral); } - dInitializer = m_Sema.ActOnInitList(noLoc, dArrVal, noLoc).get(); + dInitializer = + m_Sema.ActOnInitList(validLoc, dArrVal, validLoc).get(); } else if (const auto* ptrType = dyn_cast(fieldType.getTypePtr())) { if (!ptrType->getPointeeType()->isRealType()) continue; // Pointer member variables should be initialised by `nullptr`. - dInitializer = m_Sema.ActOnCXXNullPtrLiteral(noLoc).get(); + dInitializer = m_Sema.ActOnCXXNullPtrLiteral(validLoc).get(); } else { int dValue = (fieldDecl == m_IndependentVar); dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy, @@ -403,7 +404,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD, IdentifierInfo* derivedFnII = &m_Context.Idents.get( originalFnEffectiveName + GetPushForwardFunctionSuffix()); - DeclarationNameInfo derivedFnName(derivedFnII, noLoc); + DeclarationNameInfo derivedFnName(derivedFnII, m_Function->getLocation()); llvm::SmallVector paramTypes; llvm::SmallVector derivedParamTypes; @@ -984,6 +985,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { return StmtDiff(Clone(CE)); } + SourceLocation validLoc{CE->getBeginLoc()}; + // If the function is non_differentiable, return zero derivative. if (clad::utils::hasNonDifferentiableAttribute(CE)) { // Calling the function without computing derivatives @@ -993,7 +996,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { Expr* Call = m_Sema .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), - noLoc, ClonedArgs, noLoc) + validLoc, ClonedArgs, validLoc) .get(); // Creating a zero derivative auto* zero = @@ -1008,7 +1011,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { if (m_DerivativeOrder == 1) s = ""; - SourceLocation noLoc; llvm::SmallVector CallArgs{}; llvm::SmallVector diffArgs; @@ -1108,7 +1110,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { callDiff = m_Sema .ActOnCallExpr(m_Sema.getScopeForContext(m_Sema.CurContext), - derivativeRef, noLoc, pushforwardFnArgs, noLoc) + derivativeRef, validLoc, pushforwardFnArgs, validLoc) .get(); } @@ -1132,8 +1134,9 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { if (allArgsAreConstantLiterals) { Expr* call = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, - llvm::MutableArrayRef(CallArgs), noLoc) + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), + validLoc, llvm::MutableArrayRef(CallArgs), + validLoc) .get(); auto* zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); @@ -1173,11 +1176,11 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { Expr* execConfig = nullptr; if (auto KCE = dyn_cast(CE)) execConfig = Clone(KCE->getConfig()); - callDiff = - m_Sema - .ActOnCallExpr(getCurrentScope(), BuildDeclRef(pushforwardFD), - noLoc, pushforwardFnArgs, noLoc, execConfig) - .get(); + callDiff = m_Sema + .ActOnCallExpr(getCurrentScope(), + BuildDeclRef(pushforwardFD), validLoc, + pushforwardFnArgs, validLoc, execConfig) + .get(); } } } @@ -1188,8 +1191,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { Multiplier = diffArgs[0]; Expr* call = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, - llvm::MutableArrayRef(CallArgs), noLoc) + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), validLoc, + llvm::MutableArrayRef(CallArgs), validLoc) .get(); // FIXME: Extend this for multiarg support // Check if the function is eligible for numerical differentiation.