Skip to content

Commit

Permalink
Simplify if-conditions. NFC.
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 18, 2024
1 parent b59992f commit 1a1b2ce
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1976,13 +1976,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
else
pullbackFD = m_Builder.HandleNestedDiffRequest(pullbackRequest);

// Clad failed to derive it.
// FIXME: Add support for reference arguments to the numerical diff. If
// it already correctly support reference arguments then confirm the
// support and add tests for the same.
if (!pullbackFD && !utils::HasAnyReferenceOrPointerArgument(FD) &&
!isa<CXXMethodDecl>(FD)) {
// Try numerically deriving it.
if (pullbackFD) {
if (MD) {
Expr* baseE = baseDiff.getExpr();
OverloadedDerivedFn = BuildCallExprToMemFn(
baseE, pullbackFD->getName(), pullbackCallArgs, Loc);
} else {
OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD),
Loc, pullbackCallArgs, Loc, CUDAExecConfig)
.get();
}
} else if (!utils::HasAnyReferenceOrPointerArgument(FD) && !MD) {
// FIXME: Add support for reference arguments to the numerical diff. If
// it already correctly support reference arguments then confirm the
// support and add tests for the same.
//
// Clad failed to derive it. Try numerically deriving it.
if (NArgs == 1) {
OverloadedDerivedFn = GetSingleArgCentralDiffCall(
Clone(CE->getCallee()), DerivedCallArgs[0],
Expand All @@ -2002,18 +2013,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
block.insert(block.begin(), PreCallStmts.begin(), PreCallStmts.end());
return StmtDiff(Clone(CE));
}
} else if (pullbackFD) {
if (baseDiff.getExpr()) {
Expr* baseE = baseDiff.getExpr();
OverloadedDerivedFn = BuildCallExprToMemFn(
baseE, pullbackFD->getName(), pullbackCallArgs, Loc);
} else {
OverloadedDerivedFn =
m_Sema
.ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD),
Loc, pullbackCallArgs, Loc, CUDAExecConfig)
.get();
}
}
}

Expand Down

0 comments on commit 1a1b2ce

Please sign in to comment.