Skip to content

Commit

Permalink
Fix noloc failures in debug mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Apr 10, 2024
1 parent 785ba1f commit 509ec63
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ jobs:
cat obj/CMakeCache.txt
cat obj/CMakeFiles/*.log
- name: Setup tmate session
if: ${{ failure() && runner.debug }}
if: ${{ failure() && matrix.debug_build == true }}
uses: mxschmitt/action-tmate@v3
# When debugging increase to a suitable value!
timeout-minutes: 30
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ namespace clad {
clang::IdentifierInfo* II, clang::QualType T,
clang::StorageClass SC = clang::StorageClass::SC_None,
clang::Expr* defArg = nullptr,
clang::TypeSourceInfo* TSI = nullptr);
clang::TypeSourceInfo* TSI = nullptr,
clang::SourceLocation Loc = clang::SourceLocation());

/// If `T` represents an array or a pointer type then returns the
/// corresponding array element or the pointee type. If `T` is a reference
Expand Down
3 changes: 2 additions & 1 deletion include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,8 @@ namespace clad {
clang::ParmVarDecl* CloneParmVarDecl(const clang::ParmVarDecl* PVD,
clang::IdentifierInfo* II,
bool pushOnScopeChains = false,
bool cloneDefaultArg = true);
bool cloneDefaultArg = true,
clang::SourceLocation Loc = noLoc);
/// A function to get the single argument "forward_central_difference"
/// call expression for the given arguments.
///
Expand Down
43 changes: 23 additions & 20 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,14 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,

IdentifierInfo* II = &m_Context.Idents.get(
request.BaseFunctionName + "_d" + s + "arg" + argInfo + derivativeSuffix);
SourceLocation loc{m_Function->getLocation()};
DeclarationNameInfo name(II, loc);
SourceLocation validLoc{m_Function->getLocation()};
DeclarationNameInfo name(II, validLoc);
llvm::SaveAndRestore<DeclContext*> SaveContext(m_Sema.CurContext);
llvm::SaveAndRestore<Scope*> SaveScope(getCurrentScope());
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result =
m_Builder.cloneFunction(FD, *this, DC, loc, name, FD->getType());
m_Builder.cloneFunction(FD, *this, DC, validLoc, name, FD->getType());
FunctionDecl* derivedFD = result.first;
m_Derivative = derivedFD;

Expand All @@ -190,8 +190,8 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
if (PVD->hasDefaultArg())
clonedPVDDefaultArg = Clone(PVD->getDefaultArg());

newPVD = ParmVarDecl::Create(m_Context, m_Sema.CurContext, noLoc, noLoc,
PVD->getIdentifier(), PVD->getType(),
newPVD = ParmVarDecl::Create(m_Context, m_Sema.CurContext, validLoc,
validLoc, PVD->getIdentifier(), PVD->getType(),
PVD->getTypeSourceInfo(),
PVD->getStorageClass(), clonedPVDDefaultArg);

Expand Down Expand Up @@ -322,13 +322,14 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
m_Context.IntTy, m_Context, dValue);
dArrVal.push_back(dValueLiteral);
}
dInitializer = m_Sema.ActOnInitList(noLoc, dArrVal, noLoc).get();
dInitializer =
m_Sema.ActOnInitList(validLoc, dArrVal, validLoc).get();
} else if (const auto* ptrType =
dyn_cast<PointerType>(fieldType.getTypePtr())) {
if (!ptrType->getPointeeType()->isRealType())
continue;
// Pointer member variables should be initialised by `nullptr`.
dInitializer = m_Sema.ActOnCXXNullPtrLiteral(noLoc).get();
dInitializer = m_Sema.ActOnCXXNullPtrLiteral(validLoc).get();
} else {
int dValue = (fieldDecl == m_IndependentVar);
dInitializer = ConstantFolder::synthesizeLiteral(m_Context.IntTy,
Expand Down Expand Up @@ -403,7 +404,7 @@ BaseForwardModeVisitor::DerivePushforward(const FunctionDecl* FD,

IdentifierInfo* derivedFnII = &m_Context.Idents.get(
originalFnEffectiveName + GetPushForwardFunctionSuffix());
DeclarationNameInfo derivedFnName(derivedFnII, noLoc);
DeclarationNameInfo derivedFnName(derivedFnII, m_Function->getLocation());
llvm::SmallVector<QualType, 16> paramTypes;
llvm::SmallVector<QualType, 16> derivedParamTypes;

Expand Down Expand Up @@ -984,6 +985,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
return StmtDiff(Clone(CE));
}

SourceLocation validLoc{CE->getBeginLoc()};

// If the function is non_differentiable, return zero derivative.
if (clad::utils::hasNonDifferentiableAttribute(CE)) {
// Calling the function without computing derivatives
Expand All @@ -993,7 +996,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {

Expr* Call = m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
noLoc, ClonedArgs, noLoc)
validLoc, ClonedArgs, validLoc)
.get();
// Creating a zero derivative
auto* zero =
Expand All @@ -1008,7 +1011,6 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
if (m_DerivativeOrder == 1)
s = "";

SourceLocation noLoc;
llvm::SmallVector<Expr*, 4> CallArgs{};
llvm::SmallVector<Expr*, 4> diffArgs;

Expand Down Expand Up @@ -1108,7 +1110,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
callDiff =
m_Sema
.ActOnCallExpr(m_Sema.getScopeForContext(m_Sema.CurContext),
derivativeRef, noLoc, pushforwardFnArgs, noLoc)
derivativeRef, validLoc, pushforwardFnArgs, validLoc)
.get();
}

Expand All @@ -1132,8 +1134,9 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
if (allArgsAreConstantLiterals) {
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
validLoc, llvm::MutableArrayRef<Expr*>(CallArgs),
validLoc)
.get();
auto* zero =
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Expand Down Expand Up @@ -1173,11 +1176,11 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
Expr* execConfig = nullptr;
if (auto KCE = dyn_cast<CUDAKernelCallExpr>(CE))
execConfig = Clone(KCE->getConfig());
callDiff =
m_Sema
.ActOnCallExpr(getCurrentScope(), BuildDeclRef(pushforwardFD),
noLoc, pushforwardFnArgs, noLoc, execConfig)
.get();
callDiff = m_Sema
.ActOnCallExpr(getCurrentScope(),
BuildDeclRef(pushforwardFD), validLoc,
pushforwardFnArgs, validLoc, execConfig)
.get();
}
}
}
Expand All @@ -1188,8 +1191,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
Multiplier = diffArgs[0];
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), noLoc)
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), validLoc,
llvm::MutableArrayRef<Expr*>(CallArgs), validLoc)
.get();
// FIXME: Extend this for multiarg support
// Check if the function is eligible for numerical differentiation.
Expand Down
6 changes: 4 additions & 2 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,14 @@ namespace clad {
BuildParmVarDecl(clang::Sema& semaRef, clang::DeclContext* DC,
clang::IdentifierInfo* II, clang::QualType T,
clang::StorageClass SC, clang::Expr* defArg,
clang::TypeSourceInfo* TSI) {
clang::TypeSourceInfo* TSI, clang::SourceLocation Loc) {
ASTContext& C = semaRef.getASTContext();
if (!TSI)
TSI = C.getTrivialTypeSourceInfo(T, noLoc);
if (Loc.isInvalid())
Loc = utils::GetValidSLoc(semaRef);
ParmVarDecl* PVD =
ParmVarDecl::Create(C, DC, noLoc, noLoc, II, T, TSI, SC, defArg);
ParmVarDecl::Create(C, DC, Loc, Loc, II, T, TSI, SC, defArg);
return PVD;
}

Expand Down
13 changes: 10 additions & 3 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ using namespace clang;
namespace clad {
clang::CompoundStmt* VisitorBase::MakeCompoundStmt(const Stmts& Stmts) {
auto Stmts_ref = clad_compat::makeArrayRef(Stmts.data(), Stmts.size());
return clad_compat::CompoundStmt_Create(m_Context, Stmts_ref /**/ CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2(FPOptionsOverride()), noLoc, noLoc);
return clad_compat::CompoundStmt_Create(
m_Context,
Stmts_ref /**/ CLAD_COMPAT_CLANG15_CompoundStmt_Create_ExtraParam2(
FPOptionsOverride()),
utils::GetValidSLoc(m_Sema), utils::GetValidSLoc(m_Sema));
}

bool VisitorBase::isUnusedResult(const Expr* E) {
Expand Down Expand Up @@ -769,13 +773,16 @@ namespace clad {
ParmVarDecl* VisitorBase::CloneParmVarDecl(const ParmVarDecl* PVD,
IdentifierInfo* II,
bool pushOnScopeChains,
bool cloneDefaultArg) {
bool cloneDefaultArg,
SourceLocation Loc) {
Expr* newPVDDefaultArg = nullptr;
if (PVD->hasDefaultArg() && cloneDefaultArg) {
newPVDDefaultArg = Clone(PVD->getDefaultArg());
}
if (Loc.isInvalid())
Loc = PVD->getLocation();
auto newPVD = ParmVarDecl::Create(
m_Context, m_Sema.CurContext, noLoc, noLoc, II, PVD->getType(),
m_Context, m_Sema.CurContext, Loc, Loc, II, PVD->getType(),
PVD->getTypeSourceInfo(), PVD->getStorageClass(), newPVDDefaultArg);
if (pushOnScopeChains && newPVD->getIdentifier()) {
m_Sema.PushOnScopeChains(newPVD, getCurrentScope(),
Expand Down

0 comments on commit 509ec63

Please sign in to comment.