diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 2b67f940b..c6689ec6e 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -947,125 +947,6 @@ BaseForwardModeVisitor::VisitFloatingLiteral(const FloatingLiteral* FL) { return StmtDiff(Clone(FL), constant0); } -// This method is derived from the source code of both -// buildOverloadedCallSet() in SemaOverload.cpp -// and ActOnCallExpr() in SemaExpr.cpp. -bool DerivativeBuilder::noOverloadExists(Expr* UnresolvedLookup, - llvm::MutableArrayRef ARargs) { - if (UnresolvedLookup->getType() == m_Context.OverloadTy) { - OverloadExpr::FindResult find = OverloadExpr::find(UnresolvedLookup); - - if (!find.HasFormOfMemberPointer) { - OverloadExpr* ovl = find.Expression; - - if (isa(ovl)) { - ExprResult result; - SourceLocation Loc; - OverloadCandidateSet CandidateSet(Loc, - OverloadCandidateSet::CSK_Normal); - Scope* S = m_Sema.getScopeForContext(m_Sema.CurContext); - UnresolvedLookupExpr* ULE = cast(ovl); - // Populate CandidateSet. - m_Sema.buildOverloadedCallSet(S, UnresolvedLookup, ULE, ARargs, Loc, - &CandidateSet, &result); - OverloadCandidateSet::iterator Best; - OverloadingResult OverloadResult = CandidateSet.BestViableFunction( - m_Sema, UnresolvedLookup->getBeginLoc(), Best); - if (OverloadResult) // No overloads were found. - return true; - } - } - } - return false; -} - -// FIXME: Move this to DerivativeBuilder.cpp -Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( - const std::string& Name, llvm::SmallVectorImpl& CallArgs, - clang::Scope* S, clang::DeclContext* originalFnDC, - bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) { - NamespaceDecl* NSD = nullptr; - std::string namespaceID; - if (forCustomDerv) { - namespaceID = "custom_derivatives"; - NamespaceDecl* cladNS = nullptr; - if (m_BuiltinDerivativesNSD) - NSD = m_BuiltinDerivativesNSD; - else { - cladNS = utils::LookupNSD(m_Sema, "clad", /*shouldExist=*/true); - NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist, cladNS); - m_BuiltinDerivativesNSD = NSD; - } - } else { - NSD = m_NumericalDiffNSD; - namespaceID = "numerical_diff"; - } - if (!NSD) { - NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist); - if (!forCustomDerv && !NSD) { - diag(DiagnosticsEngine::Warning, noLoc, - "Numerical differentiation is diabled using the " - "-DCLAD_NO_NUM_DIFF " - "flag, this means that every try to numerically differentiate a " - "function will fail! Remove the flag to revert to default " - "behaviour."); - return nullptr; - } - } - CXXScopeSpec SS; - DeclContext* DC = NSD; - - // FIXME: Here `if` branch should be removed once we update - // numerical diff to use correct declaration context. - if (forCustomDerv) { - DeclContext* outermostDC = utils::GetOutermostDC(m_Sema, originalFnDC); - // FIXME: We should ideally construct nested name specifier from the - // found custom derivative function. Current way will compute incorrect - // nested name specifier in some cases. - if (outermostDC && - outermostDC->getPrimaryContext() == NSD->getPrimaryContext()) { - utils::BuildNNS(m_Sema, originalFnDC, SS); - DC = originalFnDC; - } else { - if (isa(originalFnDC)) - DC = utils::LookupNSD(m_Sema, "class_functions", - /*shouldExist=*/false, NSD); - else - DC = utils::FindDeclContext(m_Sema, NSD, originalFnDC); - if (DC) - utils::BuildNNS(m_Sema, DC, SS); - } - } else { - SS.Extend(m_Context, NSD, noLoc, noLoc); - } - IdentifierInfo* II = &m_Context.Idents.get(Name); - DeclarationName name(II); - DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema)); - - LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName); - if (DC) - m_Sema.LookupQualifiedName(R, DC); - Expr* OverloadedFn = 0; - if (!R.empty()) { - // FIXME: We should find a way to specify nested name specifier - // after finding the custom derivative. - Expr* UnresolvedLookup = - m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get(); - - llvm::MutableArrayRef MARargs = - llvm::MutableArrayRef(CallArgs); - - SourceLocation Loc; - - if (noOverloadExists(UnresolvedLookup, MARargs)) - return 0; - - OverloadedFn = - m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get(); - } - return OverloadedFn; -} - QualType BaseForwardModeVisitor::GetPushForwardDerivativeType(QualType ParamType) { return ParamType; diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 23417dfda..204a3ecc1 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -134,6 +134,126 @@ namespace clad { return { returnedFD, enclosingNS }; } + // This method is derived from the source code of both + // buildOverloadedCallSet() in SemaOverload.cpp + // and ActOnCallExpr() in SemaExpr.cpp. + bool + DerivativeBuilder::noOverloadExists(Expr* UnresolvedLookup, + llvm::MutableArrayRef ARargs) { + if (UnresolvedLookup->getType() == m_Context.OverloadTy) { + OverloadExpr::FindResult find = OverloadExpr::find(UnresolvedLookup); + + if (!find.HasFormOfMemberPointer) { + OverloadExpr* ovl = find.Expression; + + if (isa(ovl)) { + ExprResult result; + SourceLocation Loc; + OverloadCandidateSet CandidateSet(Loc, + OverloadCandidateSet::CSK_Normal); + Scope* S = m_Sema.getScopeForContext(m_Sema.CurContext); + UnresolvedLookupExpr* ULE = cast(ovl); + // Populate CandidateSet. + m_Sema.buildOverloadedCallSet(S, UnresolvedLookup, ULE, ARargs, Loc, + &CandidateSet, &result); + OverloadCandidateSet::iterator Best; + OverloadingResult OverloadResult = CandidateSet.BestViableFunction( + m_Sema, UnresolvedLookup->getBeginLoc(), Best); + if (OverloadResult) // No overloads were found. + return true; + } + } + } + return false; + } + + Expr* DerivativeBuilder::BuildCallToCustomDerivativeOrNumericalDiff( + const std::string& Name, llvm::SmallVectorImpl& CallArgs, + clang::Scope* S, clang::DeclContext* originalFnDC, + bool forCustomDerv /*=true*/, bool namespaceShouldExist /*=true*/) { + NamespaceDecl* NSD = nullptr; + std::string namespaceID; + if (forCustomDerv) { + namespaceID = "custom_derivatives"; + NamespaceDecl* cladNS = nullptr; + if (m_BuiltinDerivativesNSD) + NSD = m_BuiltinDerivativesNSD; + else { + cladNS = utils::LookupNSD(m_Sema, "clad", /*shouldExist=*/true); + NSD = + utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist, cladNS); + m_BuiltinDerivativesNSD = NSD; + } + } else { + NSD = m_NumericalDiffNSD; + namespaceID = "numerical_diff"; + } + if (!NSD) { + NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist); + if (!forCustomDerv && !NSD) { + diag(DiagnosticsEngine::Warning, noLoc, + "Numerical differentiation is diabled using the " + "-DCLAD_NO_NUM_DIFF " + "flag, this means that every try to numerically differentiate a " + "function will fail! Remove the flag to revert to default " + "behaviour."); + return nullptr; + } + } + CXXScopeSpec SS; + DeclContext* DC = NSD; + + // FIXME: Here `if` branch should be removed once we update + // numerical diff to use correct declaration context. + if (forCustomDerv) { + DeclContext* outermostDC = utils::GetOutermostDC(m_Sema, originalFnDC); + // FIXME: We should ideally construct nested name specifier from the + // found custom derivative function. Current way will compute incorrect + // nested name specifier in some cases. + if (outermostDC && + outermostDC->getPrimaryContext() == NSD->getPrimaryContext()) { + utils::BuildNNS(m_Sema, originalFnDC, SS); + DC = originalFnDC; + } else { + if (isa(originalFnDC)) + DC = utils::LookupNSD(m_Sema, "class_functions", + /*shouldExist=*/false, NSD); + else + DC = utils::FindDeclContext(m_Sema, NSD, originalFnDC); + if (DC) + utils::BuildNNS(m_Sema, DC, SS); + } + } else { + SS.Extend(m_Context, NSD, noLoc, noLoc); + } + IdentifierInfo* II = &m_Context.Idents.get(Name); + DeclarationName name(II); + DeclarationNameInfo DNInfo(name, utils::GetValidSLoc(m_Sema)); + + LookupResult R(m_Sema, DNInfo, Sema::LookupOrdinaryName); + if (DC) + m_Sema.LookupQualifiedName(R, DC); + Expr* OverloadedFn = 0; + if (!R.empty()) { + // FIXME: We should find a way to specify nested name specifier + // after finding the custom derivative. + Expr* UnresolvedLookup = + m_Sema.BuildDeclarationNameExpr(SS, R, /*ADL*/ false).get(); + + llvm::MutableArrayRef MARargs = + llvm::MutableArrayRef(CallArgs); + + SourceLocation Loc; + + if (noOverloadExists(UnresolvedLookup, MARargs)) + return 0; + + OverloadedFn = + m_Sema.ActOnCallExpr(S, UnresolvedLookup, Loc, MARargs, Loc).get(); + } + return OverloadedFn; + } + void DerivativeBuilder::AddErrorEstimationModel( std::unique_ptr estModel) { m_EstModel.push_back(std::move(estModel));