Skip to content

Commit

Permalink
Fix noloc failure
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 10, 2024
1 parent 785ba1f commit e7d3b1f
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,

IdentifierInfo* II = &m_Context.Idents.get(
request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix);
SourceLocation loc{m_Function->getLocation()};
DeclarationNameInfo name(II, loc);
SourceLocation validLoc{m_Function->getLocation()};
DeclarationNameInfo name(II, validLoc);
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope());
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result =
m_Builder.cloneFunction(FD, *this, DC, loc, name, FD->getType());
m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType());
FunctionDecl* derivedFD = result.first;
m_Derivative = derivedFD;

Expand All @@ -190,8 +190,8 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
if (PVD->hasDefaultArg())
clonedPVDDefaultArg = Clone(PVD->getDefaultArg());

newPVD = ParmVarDecl::Create(m_Context, m_Sema.CurContext, noLoc, noLoc,
PVD->getIdentifier(), PVD->getType(),
newPVD = ParmVarDecl::Create(m_Context, m_Sema.CurContext, validLoc,
validLoc, PVD->getIdentifier(), PVD->getType(),
PVD->getTypeSourceInfo(),
PVD->getStorageClass(), clonedPVDDefaultArg);

Expand Down Expand Up @@ -322,13 +322,14 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
m_Context.IntTy, m_Context, dValue);
dArrVal.push_back(dValueLiteral);
}
dInitializer = m_Sema.ActOnInitList(noLoc, dArrVal, noLoc).get();
dInitializer =
m_Sema.ActOnInitList(validLoc, dArrVal, validLoc).get();
} else if (const auto* ptrType =
dyn_cast<PointerType>(fieldType.getTypePtr())) {
if (!ptrType->getPointeeType()->isRealType())
continue;
// Pointer member variables should be initialised by `nullptr`.
dInitializer = m_Sema.ActOnCXXNullPtrLiteral(noLoc).get();
dInitializer = m_Sema.ActOnCXXNullPtrLiteral(validLoc).get();
} else {
int dValue = (fieldDecl == m_IndependentVar);
dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
Expand Down Expand Up @@ -403,7 +404,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,

IdentifierInfo* derivedFnII = &m_Context.Idents.get(
originalFnEffectiveName + GetPushForwardFunctionSuffix());
DeclarationNameInfo derivedFnName(derivedFnII, noLoc);
DeclarationNameInfo derivedFnName(derivedFnII, m_Function->getLocation());
llvm::SmallVector<QualType, 16> paramTypes;
llvm::SmallVector<QualType, 16> derivedParamTypes;

Expand Down Expand Up @@ -984,6 +985,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
return StmtDiff(Clone(CE));
}

SourceLocation validLoc{CE->getBeginLoc()};

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
Expand All @@ -993,7 +996,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {

Expr* Call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
noLoc, ClonedArgs, noLoc)
validLoc, ClonedArgs, validLoc)
.get();
// Creating a zero derivative
auto* zero =
Expand All @@ -1008,7 +1011,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
if (m_DerivativeOrder == 1)
s = "";

SourceLocation noLoc;
llvm::SmallVector<Expr*, 4> CallArgs{};
llvm::SmallVector<Expr*, 4> diffArgs;

Expand Down Expand Up @@ -1108,7 +1110,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
callDiff =
m_Sema
.ActOnCallExpr(m_Sema.getScopeForContext(m_Sema.CurContext),
derivativeRef, noLoc, pushforwardFnArgs, noLoc)
derivativeRef, validLoc, pushforwardFnArgs, validLoc)
.get();
}

Expand All @@ -1132,8 +1134,9 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
if (allArgsAreConstantLiterals) {
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, llvm::MutableArrayRef<Expr*>(CallArgs),
validLoc)
.get();
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down Expand Up @@ -1173,11 +1176,11 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
Expr* execConfig = nullptr;
if (auto KCE = dyn_cast<CUDAKernelCallExpr>(CE))
execConfig = Clone(KCE->getConfig());
callDiff =
m_Sema
.ActOnCallExpr(getCurrentScope(), BuildDeclRef(pushforwardFD),
noLoc, pushforwardFnArgs, noLoc, execConfig)
.get();
callDiff = m_Sema
.ActOnCallExpr(getCurrentScope(),
BuildDeclRef(pushforwardFD), validLoc,
pushforwardFnArgs, validLoc, execConfig)
.get();
}
}
}
Expand All @@ -1188,8 +1191,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
Multiplier = diffArgs[0];
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), validLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), validLoc)
.get();
// FIXME: Extend this for multiarg support
// Check if the function is eligible for numerical differentiation.
Expand Down

0 comments on commit e7d3b1f

Please sign in to comment.