Skip to content

Commit

Permalink
Fix failures in debug mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Apr 12, 2024
1 parent 91c0dec commit 4f8292c
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 94 deletions.
3 changes: 2 additions & 1 deletion include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ namespace clad {
clang::IdentifierInfo* II, clang::QualType T,
clang::StorageClass SC = clang::StorageClass::SC_None,
clang::Expr* defArg = nullptr,
clang::TypeSourceInfo* TSI = nullptr);
clang::TypeSourceInfo* TSI = nullptr,
clang::SourceLocation Loc = clang::SourceLocation());

/// If `T` represents an array or a pointer type then returns the
/// corresponding array element or the pointee type. If `T` is a reference
Expand Down
12 changes: 6 additions & 6 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ namespace clad {
clang::Expr*
BuildCallExprToMemFn(clang::Expr* Base, llvm::StringRef MemberFunctionName,
llvm::MutableArrayRef<clang::Expr*> ArgExprs,
clang::ValueDecl* memberDecl = nullptr);
clang::SourceLocation Loc = noLoc);

/// Build a call to member function through this pointer.
///
Expand All @@ -509,10 +509,9 @@ namespace clad {
/// \param[in] useRefQualifiedThisObj If true, then the `this` object is
/// perfectly forwarded while calling member functions.
/// \returns Built member function call expression
clang::Expr*
BuildCallExprToMemFn(clang::CXXMethodDecl* FD,
llvm::MutableArrayRef<clang::Expr*> argExprs,
bool useRefQualifiedThisObj = false);
clang::Expr* BuildCallExprToMemFn(
clang::CXXMethodDecl* FD, llvm::MutableArrayRef<clang::Expr*> argExprs,
bool useRefQualifiedThisObj = false, clang::SourceLocation Loc = noLoc);

/// Build a call to a free function or member function through
/// this pointer depending on whether the `FD` argument corresponds to a
Expand Down Expand Up @@ -577,7 +576,8 @@ namespace clad {
clang::ParmVarDecl* CloneParmVarDecl(const clang::ParmVarDecl* PVD,
clang::IdentifierInfo* II,
bool pushOnScopeChains = false,
bool cloneDefaultArg = true);
bool cloneDefaultArg = true,
clang::SourceLocation Loc = noLoc);
/// A function to get the single argument "forward_central_difference"
/// call expression for the given arguments.
///
Expand Down
45 changes: 24 additions & 21 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,

IdentifierInfo* II = &m_Context.Idents.get(
request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix);
SourceLocation loc{m_Function->getLocation()};
DeclarationNameInfo name(II, loc);
SourceLocation validLoc{m_Function->getLocation()};
DeclarationNameInfo name(II, validLoc);
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope());
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result =
m_Builder.cloneFunction(FD, *this, DC, loc, name, FD->getType());
m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType());
FunctionDecl* derivedFD = result.first;
m_Derivative = derivedFD;

Expand All @@ -190,8 +190,8 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
if (PVD->hasDefaultArg())
clonedPVDDefaultArg = Clone(PVD->getDefaultArg());

newPVD = ParmVarDecl::Create(m_Context, m_Sema.CurContext, noLoc, noLoc,
PVD->getIdentifier(), PVD->getType(),
newPVD = ParmVarDecl::Create(m_Context, m_Sema.CurContext, validLoc,
validLoc, PVD->getIdentifier(), PVD->getType(),
PVD->getTypeSourceInfo(),
PVD->getStorageClass(), clonedPVDDefaultArg);

Expand Down Expand Up @@ -322,13 +322,14 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
m_Context.IntTy, m_Context, dValue);
dArrVal.push_back(dValueLiteral);
}
dInitializer = m_Sema.ActOnInitList(noLoc, dArrVal, noLoc).get();
dInitializer =
m_Sema.ActOnInitList(validLoc, dArrVal, validLoc).get();
} else if (const auto* ptrType =
dyn_cast<PointerType>(fieldType.getTypePtr())) {
if (!ptrType->getPointeeType()->isRealType())
continue;
// Pointer member variables should be initialised by `nullptr`.
dInitializer = m_Sema.ActOnCXXNullPtrLiteral(noLoc).get();
dInitializer = m_Sema.ActOnCXXNullPtrLiteral(validLoc).get();
} else {
int dValue = (fieldDecl == m_IndependentVar);
dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
Expand Down Expand Up @@ -403,7 +404,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,

IdentifierInfo* derivedFnII = &m_Context.Idents.get(
originalFnEffectiveName + GetPushForwardFunctionSuffix());
DeclarationNameInfo derivedFnName(derivedFnII, noLoc);
DeclarationNameInfo derivedFnName(derivedFnII, m_Function->getLocation());
llvm::SmallVector<QualType, 16> paramTypes;
llvm::SmallVector<QualType, 16> derivedParamTypes;

Expand Down Expand Up @@ -984,6 +985,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
return StmtDiff(Clone(CE));
}

SourceLocation validLoc{CE->getBeginLoc()};

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
Expand All @@ -993,7 +996,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {

Expr* Call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
noLoc, ClonedArgs, noLoc)
validLoc, ClonedArgs, validLoc)
.get();
// Creating a zero derivative
auto* zero =
Expand All @@ -1008,7 +1011,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
if (m_DerivativeOrder == 1)
s = "";

SourceLocation noLoc;
llvm::SmallVector<Expr*, 4> CallArgs{};
llvm::SmallVector<Expr*, 4> diffArgs;

Expand Down Expand Up @@ -1108,7 +1110,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
callDiff =
m_Sema
.ActOnCallExpr(m_Sema.getScopeForContext(m_Sema.CurContext),
derivativeRef, noLoc, pushforwardFnArgs, noLoc)
derivativeRef, validLoc, pushforwardFnArgs, validLoc)
.get();
}

Expand All @@ -1132,8 +1134,9 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
if (allArgsAreConstantLiterals) {
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, llvm::MutableArrayRef<Expr*>(CallArgs),
validLoc)
.get();
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down Expand Up @@ -1168,16 +1171,16 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
if (baseDiff.getExpr()) {
callDiff =
BuildCallExprToMemFn(baseDiff.getExpr(), pushforwardFD->getName(),
pushforwardFnArgs, pushforwardFD);
pushforwardFnArgs, CE->getBeginLoc());
} else {
Expr* execConfig = nullptr;
if (auto KCE = dyn_cast<CUDAKernelCallExpr>(CE))
execConfig = Clone(KCE->getConfig());
callDiff =
m_Sema
.ActOnCallExpr(getCurrentScope(), BuildDeclRef(pushforwardFD),
noLoc, pushforwardFnArgs, noLoc, execConfig)
.get();
callDiff = m_Sema
.ActOnCallExpr(getCurrentScope(),
BuildDeclRef(pushforwardFD), validLoc,
pushforwardFnArgs, validLoc, execConfig)
.get();
}
}
}
Expand All @@ -1188,8 +1191,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
Multiplier = diffArgs[0];
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), validLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), validLoc)
.get();
// FIXME: Extend this for multiarg support
// Check if the function is eligible for numerical differentiation.
Expand Down
6 changes: 4 additions & 2 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,14 @@ namespace clad {
BuildParmVarDecl(clang::Sema& semaRef, clang::DeclContext* DC,
clang::IdentifierInfo* II, clang::QualType T,
clang::StorageClass SC, clang::Expr* defArg,
clang::TypeSourceInfo* TSI) {
clang::TypeSourceInfo* TSI, clang::SourceLocation Loc) {
ASTContext& C = semaRef.getASTContext();
if (!TSI)
TSI = C.getTrivialTypeSourceInfo(T, noLoc);
if (Loc.isInvalid())
Loc = utils::GetValidSLoc(semaRef);
ParmVarDecl* PVD =
ParmVarDecl::Create(C, DC, noLoc, noLoc, II, T, TSI, SC, defArg);
ParmVarDecl::Create(C, DC, Loc, Loc, II, T, TSI, SC, defArg);
return PVD;
}

Expand Down
9 changes: 6 additions & 3 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
m_Sema.CurContext = const_cast<DeclContext*>(m_Function->getDeclContext());

SourceLocation validLoc{m_Function->getLocation()};
DeclWithContext fnBuildRes = m_Builder.cloneFunction(
m_Function, *this, m_Sema.CurContext, noLoc, fnDNI, fnType);
m_Function, *this, m_Sema.CurContext, validLoc, fnDNI, fnType);
m_Derivative = fnBuildRes.first;

beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Expand Down Expand Up @@ -240,8 +241,10 @@ ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
auto returnDiff = Visit(value);
llvm::SmallVector<Expr*, 2> returnArgs = {returnDiff.getExpr(),
returnDiff.getExpr_dx()};
Expr* returnInitList = m_Sema.ActOnInitList(noLoc, returnArgs, noLoc).get();
Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get();
SourceLocation validLoc{RS->getBeginLoc()};
Expr* returnInitList =
m_Sema.ActOnInitList(validLoc, returnArgs, validLoc).get();
Stmt* newRS = m_Sema.BuildReturnStmt(validLoc, returnInitList).get();
return {newRS};
}

Expand Down
57 changes: 28 additions & 29 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
getEnclosingNamespaceOrTUScope());
m_Sema.CurContext = const_cast<DeclContext*>(m_Function->getDeclContext());

SourceLocation validLoc{m_Function->getLocation()};
DeclWithContext fnBuildRes = m_Builder.cloneFunction(
m_Function, *this, m_Sema.CurContext, noLoc, DNI, pullbackFnType);
m_Function, *this, m_Sema.CurContext, validLoc, DNI, pullbackFnType);
m_Derivative = fnBuildRes.first;

if (m_ExternalSource)
Expand Down Expand Up @@ -1436,6 +1437,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Clone(CE), Clone(CE));
}

SourceLocation Loc = CE->getExprLoc();

// Stores the call arguments for the function to be derived
llvm::SmallVector<Expr*, 16> CallArgs{};
// Stores the dx of the call arguments for the function to be derived
Expand All @@ -1460,14 +1463,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc)
.get();
Expr* call_dx =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs),
noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
return StmtDiff(call, call_dx);
}
Expand All @@ -1483,14 +1485,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc)
.get();
Expr* call_dx =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs),
noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
m_DeallocExprs.push_back(call);
m_DeallocExprs.push_back(call_dx);
Expand All @@ -1510,13 +1511,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff ArgDiff = Visit(Arg, dfdx());
CallArgs.push_back(ArgDiff.getExpr());
}
Expr* call = m_Sema
.ActOnCallExpr(getCurrentScope(),
Clone(CE->getCallee()),
noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs),
noLoc)
.get();
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(CallArgs), Loc)
.get();
return call;
}

Expand Down Expand Up @@ -1779,7 +1778,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

OverloadedDerivedFn = m_Sema
.ActOnCallExpr(getCurrentScope(), selfRef,
noLoc, pullbackCallArgs, noLoc)
Loc, pullbackCallArgs, Loc)
.get();
} else {
if (m_ExternalSource)
Expand Down Expand Up @@ -1852,12 +1851,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
OverloadedDerivedFn = BuildCallExprToMemFn(
baseE, pullbackFD->getName(), pullbackCallArgs, pullbackFD);
baseE, pullbackFD->getName(), pullbackCallArgs, Loc);
} else {
OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD),
noLoc, pullbackCallArgs, noLoc)
Loc, pullbackCallArgs, Loc)
.get();
}
}
Expand Down Expand Up @@ -1939,7 +1938,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// else
// Currently derivedBase `*d_this` can never be CladArrayType
CallArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc));
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, Loc));
}

for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
Expand All @@ -1958,19 +1957,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// CallArgs.push_back(derivedArg);
// else
CallArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedArg, noLoc));
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedArg, Loc));
} else
CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get());
CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(Loc).get());
}
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(),
CallArgs, calleeFnForwPassFD);
CallArgs, Loc);
} else {
call = m_Sema
.ActOnCallExpr(getCurrentScope(),
BuildDeclRef(calleeFnForwPassFD), noLoc,
CallArgs, noLoc)
BuildDeclRef(calleeFnForwPassFD), Loc,
CallArgs, Loc)
.get();
}
auto* callRes = StoreAndRef(call);
Expand All @@ -1981,8 +1980,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(resValue, nullptr, resAdjoint);
} // Recreate the original call expression.
call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
CallArgs, noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
CallArgs, Loc)
.get();
return StmtDiff(call);

Expand Down
Loading

0 comments on commit 4f8292c

Please sign in to comment.