Skip to content

Commit

Permalink
Constify interfaces. NFC
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Dec 14, 2024
1 parent fbf11f1 commit cebc426
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 21 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ namespace clad {
/// such declaration context is found, then returns `nullptr`.
clang::DeclContext* FindDeclContext(clang::Sema& semaRef,
clang::DeclContext* DC1,
clang::DeclContext* DC2);
const clang::DeclContext* DC2);

/// Finds the qualified name `name` in the declaration context `DC`.
///
Expand Down
6 changes: 3 additions & 3 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ namespace clad {
/// null otherwise.
clang::Expr* BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<clang::Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
clang::Scope* S, const clang::DeclContext* originalFnDC,
bool forCustomDerv = true, bool namespaceShouldExist = true,
clang::Expr* CUDAExecConfig = nullptr);
bool noOverloadExists(clang::Expr* UnresolvedLookup,
Expand Down Expand Up @@ -150,7 +150,7 @@ namespace clad {
/// \returns The lookup result of the custom derivative or numerical
/// differentiation function.
clang::LookupResult LookupCustomDerivativeOrNumericalDiff(
const std::string& Name, clang::DeclContext* originalFnDC,
const std::string& Name, const clang::DeclContext* originalFnDC,
clang::CXXScopeSpec& SS, bool forCustomDerv = true,
bool namespaceShouldExist = true);

Expand All @@ -160,7 +160,7 @@ namespace clad {
/// \returns The custom derivative function if found, nullptr otherwise.
clang::FunctionDecl*
LookupCustomDerivativeDecl(const std::string& Name,
clang::DeclContext* originalFnDC,
const clang::DeclContext* originalFnDC,
clang::QualType functionType);

public:
Expand Down
8 changes: 3 additions & 5 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
clad::utils::ComputeEffectiveFnName(FD) + GetPushForwardFunctionSuffix();
callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
FD->getDeclContext());
// Custom derivative templates can be written in a
// general way that works for both vectorized and non-vectorized
// modes. We have to also look for the pushforward with the regular name.
Expand All @@ -1236,7 +1236,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) {
clad::utils::ComputeEffectiveFnName(FD) + "_pushforward";
callDiff = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, customDerivativeArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
FD->getDeclContext());
}
if (!isLambda) {
// Check if it is a recursive call.
Expand Down Expand Up @@ -2315,11 +2315,9 @@ clang::Expr* BaseForwardModeVisitor::BuildCustomDerivativeConstructorPFCall(
std::string customPushforwardName =
clad::utils::ComputeEffectiveFnName(CE->getConstructor()) +
GetPushForwardFunctionSuffix();
// FIXME: We should not use const_cast to get the decl context here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
Expr* pushforwardCall = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforwardName, customPushforwardArgs, getCurrentScope(),
const_cast<DeclContext*>(CE->getConstructor()->getDeclContext()));
CE->getConstructor()->getDeclContext());
return pushforwardCall;
}
} // end namespace clad
6 changes: 3 additions & 3 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ namespace clad {
}

DeclContext* FindDeclContext(clang::Sema& semaRef, clang::DeclContext* DC1,
clang::DeclContext* DC2) {
llvm::SmallVector<clang::DeclContext*, 4> contexts;
const clang::DeclContext* DC2) {
llvm::SmallVector<const clang::DeclContext*, 4> contexts;
assert((isa<NamespaceDecl>(DC1) || isa<TranslationUnitDecl>(DC1)) &&
"DC1 can only be extended if it is a "
"namespace or translation unit decl.");
Expand Down Expand Up @@ -240,7 +240,7 @@ namespace clad {
}
DeclContext* DC = DC1;
for (int i = contexts.size() - 1; i >= 0; --i) {
NamespaceDecl* ND = cast<NamespaceDecl>(contexts[i]);
const auto* ND = cast<NamespaceDecl>(contexts[i]);
if (ND->getIdentifier())
DC = LookupNSD(semaRef, ND->getIdentifier()->getName(),
/*shouldExist=*/false, DC1);
Expand Down
7 changes: 4 additions & 3 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "llvm/Support/SaveAndRestore.h"

#include <algorithm>
#include <string>

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/Compatibility.h"
Expand Down Expand Up @@ -167,7 +168,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
}

LookupResult DerivativeBuilder::LookupCustomDerivativeOrNumericalDiff(
const std::string& Name, clang::DeclContext* originalFnDC,
const std::string& Name, const clang::DeclContext* originalFnDC,
CXXScopeSpec& SS, bool forCustomDerv /*=true*/,
bool namespaceShouldExist /*=true*/) {

Expand Down Expand Up @@ -222,7 +223,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
}

FunctionDecl* DerivativeBuilder::LookupCustomDerivativeDecl(
const std::string& Name, clang::DeclContext* originalFnDC,
const std::string& Name, const clang::DeclContext* originalFnDC,
QualType functionType) {
CXXScopeSpec SS;
LookupResult R =
Expand All @@ -246,7 +247,7 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {

Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
clang::Scope* S, const clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/,
Expr* CUDAExecConfig /*=nullptr*/) {
CXXScopeSpec SS;
Expand Down
10 changes: 4 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1828,7 +1828,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPushforward, pushforwardCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()),
FD->getDeclContext(),
/*forCustomDerv=*/true, /*namespaceShouldExist=*/true,
CUDAExecConfig);
if (OverloadedDerivedFn)
Expand Down Expand Up @@ -1932,7 +1932,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
OverloadedDerivedFn =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullback, pullbackCallArgs, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()),
FD->getDeclContext(),
/*forCustomDerv=*/true, /*namespaceShouldExist=*/true,
CUDAExecConfig);
if (baseDiff.getExpr())
Expand Down Expand Up @@ -4248,8 +4248,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (Expr* customPullbackCall =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
customPullbackName, pullbackArgs, getCurrentScope(),
const_cast<DeclContext*>(
CE->getConstructor()->getDeclContext()))) {
CE->getConstructor()->getDeclContext())) {
curRevBlock.insert(it, customPullbackCall);
if (m_TrackConstructorPullbackInfo) {
setConstructorPullbackCallInfo(llvm::cast<CallExpr>(customPullbackCall),
Expand Down Expand Up @@ -4585,8 +4584,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
args.append(derivedArgs.begin(), derivedArgs.end());
Expr* customForwPassCE =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
forwPassFnName, args, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
forwPassFnName, args, getCurrentScope(), FD->getDeclContext());
return customForwPassCE;
}

Expand Down

0 comments on commit cebc426

Please sign in to comment.