Skip to content

Commit

Permalink
Moving DerivaitiveBuilder methods from ForwardModeVisitor file
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jan 18, 2024
1 parent 97f4c7f commit 6887dca
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 119 deletions.
119 changes: 0 additions & 119 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,125 +947,6 @@ BaseForwardModeVisitor::VisitFloatingLiteral(const FloatingLiteral* FL) {
return StmtDiff(Clone(FL), constant0);
}

// This method is derived from the source code of both
// buildOverloadedCallSet() in SemaOverload.cpp
// and ActOnCallExpr() in SemaExpr.cpp.
bool DerivativeBuilder::noOverloadExists(Expr* UnresolvedLookup,
llvm::MutableArrayRef<Expr*> ARargs) {
if (UnresolvedLookup->getType() == m_Context.OverloadTy) {
OverloadExpr::FindResult find = OverloadExpr::find(UnresolvedLookup);

if (!find.HasFormOfMemberPointer) {
OverloadExpr* ovl = find.Expression;

if (isa<UnresolvedLookupExpr>(ovl)) {
ExprResult result;
SourceLocation Loc;
OverloadCandidateSet CandidateSet(Loc,
OverloadCandidateSet::CSK_Normal);
Scope* S = m_Sema.getScopeForContext(m_Sema.CurContext);
UnresolvedLookupExpr* ULE = cast<UnresolvedLookupExpr>(ovl);
// Populate CandidateSet.
m_Sema.buildOverloadedCallSet(S, UnresolvedLookup, ULE, ARargs, Loc,
&CandidateSet, &result);
OverloadCandidateSet::iterator Best;
OverloadingResult OverloadResult = CandidateSet.BestViableFunction(
m_Sema, UnresolvedLookup->getBeginLoc(), Best);
if (OverloadResult) // No overloads were found.
return true;
}
}
}
return false;
}

// FIXME: Move this to DerivativeBuilder.cpp
Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {
NamespaceDecl* NSD = nullptr;
std::string namespaceID;
if (forCustomDerv) {
namespaceID = "custom_derivatives";
NamespaceDecl* cladNS = nullptr;
if (m_BuiltinDerivativesNSD)
NSD = m_BuiltinDerivativesNSD;
else {
cladNS = utils::LookupNSD(m_Sema, "clad", /*shouldExist=*/true);
NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist, cladNS);
m_BuiltinDerivativesNSD = NSD;
}
} else {
NSD = m_NumericalDiffNSD;
namespaceID = "numerical_diff";
}
if (!NSD) {
NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist);
if (!forCustomDerv && !NSD) {
diag(DiagnosticsEngine::Warning, noLoc,
"Numerical differentiation is diabled using the "
"-DCLAD_NO_NUM_DIFF "
"flag, this means that every try to numerically differentiate a "
"function will fail! Remove the flag to revert to default "
"behaviour.");
return nullptr;
}
}
CXXScopeSpec SS;
DeclContext* DC = NSD;

// FIXME: Here `if` branch should be removed once we update
// numerical diff to use correct declaration context.
if (forCustomDerv) {
DeclContext* outermostDC = utils::GetOutermostDC(m_Sema, originalFnDC);
// FIXME: We should ideally construct nested name specifier from the
// found custom derivative function. Current way will compute incorrect
// nested name specifier in some cases.
if (outermostDC &&
outermostDC->getPrimaryContext() == NSD->getPrimaryContext()) {
utils::BuildNNS(m_Sema, originalFnDC, SS);
DC = originalFnDC;
} else {
if (isa<RecordDecl>(originalFnDC))
DC = utils::LookupNSD(m_Sema, "class_functions",
/*shouldExist=*/false, NSD);
else
DC = utils::FindDeclContext(m_Sema, NSD, originalFnDC);
if (DC)
utils::BuildNNS(m_Sema, DC, SS);
}
} else {
SS.Extend(m_Context, NSD, noLoc, noLoc);
}
IdentifierInfo* II = &m_Context.Idents.get(Name);
DeclarationName name(II);
DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema));

LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName);
if (DC)
m_Sema.LookupQualifiedName(R, DC);
Expr* OverloadedFn = 0;
if (!R.empty()) {
// FIXME: We should find a way to specify nested name specifier
// after finding the custom derivative.
Expr* UnresolvedLookup =
m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get();

llvm::MutableArrayRef<Expr*> MARargs =
llvm::MutableArrayRef<Expr*>(CallArgs);

SourceLocation Loc;

if (noOverloadExists(UnresolvedLookup, MARargs))
return 0;

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get();
}
return OverloadedFn;
}

QualType
BaseForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) {
return ParamType;
Expand Down
120 changes: 120 additions & 0 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,126 @@ namespace clad {
return { returnedFD, enclosingNS };
}

// This method is derived from the source code of both
// buildOverloadedCallSet() in SemaOverload.cpp
// and ActOnCallExpr() in SemaExpr.cpp.
bool
DerivativeBuilder::noOverloadExists(Expr* UnresolvedLookup,
llvm::MutableArrayRef<Expr*> ARargs) {
if (UnresolvedLookup->getType() == m_Context.OverloadTy) {
OverloadExpr::FindResult find = OverloadExpr::find(UnresolvedLookup);

if (!find.HasFormOfMemberPointer) {
OverloadExpr* ovl = find.Expression;

if (isa<UnresolvedLookupExpr>(ovl)) {
ExprResult result;
SourceLocation Loc;
OverloadCandidateSet CandidateSet(Loc,
OverloadCandidateSet::CSK_Normal);
Scope* S = m_Sema.getScopeForContext(m_Sema.CurContext);
UnresolvedLookupExpr* ULE = cast<UnresolvedLookupExpr>(ovl);
// Populate CandidateSet.
m_Sema.buildOverloadedCallSet(S, UnresolvedLookup, ULE, ARargs, Loc,
&CandidateSet, &result);
OverloadCandidateSet::iterator Best;
OverloadingResult OverloadResult = CandidateSet.BestViableFunction(
m_Sema, UnresolvedLookup->getBeginLoc(), Best);
if (OverloadResult) // No overloads were found.
return true;
}
}
}
return false;
}

Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff(
const std::string& Name, llvm::SmallVectorImpl<Expr*>& CallArgs,
clang::Scope* S, clang::DeclContext* originalFnDC,
bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) {
NamespaceDecl* NSD = nullptr;
std::string namespaceID;
if (forCustomDerv) {
namespaceID = "custom_derivatives";
NamespaceDecl* cladNS = nullptr;
if (m_BuiltinDerivativesNSD)
NSD = m_BuiltinDerivativesNSD;
else {
cladNS = utils::LookupNSD(m_Sema, "clad", /*shouldExist=*/true);
NSD =
utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist, cladNS);
m_BuiltinDerivativesNSD = NSD;
}
} else {
NSD = m_NumericalDiffNSD;
namespaceID = "numerical_diff";
}
if (!NSD) {
NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist);
if (!forCustomDerv && !NSD) {
diag(DiagnosticsEngine::Warning, noLoc,
"Numerical differentiation is diabled using the "
"-DCLAD_NO_NUM_DIFF "
"flag, this means that every try to numerically differentiate a "
"function will fail! Remove the flag to revert to default "
"behaviour.");
return nullptr;
}
}
CXXScopeSpec SS;
DeclContext* DC = NSD;

// FIXME: Here `if` branch should be removed once we update
// numerical diff to use correct declaration context.
if (forCustomDerv) {
DeclContext* outermostDC = utils::GetOutermostDC(m_Sema, originalFnDC);
// FIXME: We should ideally construct nested name specifier from the
// found custom derivative function. Current way will compute incorrect
// nested name specifier in some cases.
if (outermostDC &&
outermostDC->getPrimaryContext() == NSD->getPrimaryContext()) {
utils::BuildNNS(m_Sema, originalFnDC, SS);
DC = originalFnDC;
} else {
if (isa<RecordDecl>(originalFnDC))
DC = utils::LookupNSD(m_Sema, "class_functions",
/*shouldExist=*/false, NSD);
else
DC = utils::FindDeclContext(m_Sema, NSD, originalFnDC);
if (DC)
utils::BuildNNS(m_Sema, DC, SS);
}
} else {
SS.Extend(m_Context, NSD, noLoc, noLoc);
}
IdentifierInfo* II = &m_Context.Idents.get(Name);
DeclarationName name(II);
DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema));

LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName);
if (DC)
m_Sema.LookupQualifiedName(R, DC);
Expr* OverloadedFn = 0;
if (!R.empty()) {
// FIXME: We should find a way to specify nested name specifier
// after finding the custom derivative.
Expr* UnresolvedLookup =
m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get();

llvm::MutableArrayRef<Expr*> MARargs =
llvm::MutableArrayRef<Expr*>(CallArgs);

SourceLocation Loc;

if (noOverloadExists(UnresolvedLookup, MARargs))
return 0;

OverloadedFn =
m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get();
}
return OverloadedFn;
}

void DerivativeBuilder::AddErrorEstimationModel(
std::unique_ptr<FPErrorEstimationModel> estModel) {
m_EstModel.push_back(std::move(estModel));
Expand Down

0 comments on commit 6887dca

Please sign in to comment.