diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 05efa5304..cefc5e6a7 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -66,7 +66,7 @@ namespace clad { clang::Expr* findOverloadedDefinition(clang::DeclarationNameInfo DNI, llvm::SmallVectorImpl& CallArgs); - bool overloadExists(clang::Expr* UnresolvedLookup, + bool noOverloadExists(clang::Expr* UnresolvedLookup, llvm::MutableArrayRef ARargs); /// Shorthand to issues a warning or error. template diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 755f812ef..fa65919c3 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -1371,7 +1371,7 @@ namespace clad { // This method is derived from the source code of both // buildOverloadedCallSet() in SemaOverload.cpp // and ActOnCallExpr() in SemaExpr.cpp. - bool DerivativeBuilder::overloadExists(Expr* UnresolvedLookup, + bool DerivativeBuilder::noOverloadExists(Expr* UnresolvedLookup, llvm::MutableArrayRef ARargs) { if (UnresolvedLookup->getType() == m_Context.OverloadTy) { OverloadExpr::FindResult find = OverloadExpr::find(UnresolvedLookup); @@ -1435,7 +1435,7 @@ namespace clad { SourceLocation Loc; Scope* S = m_Sema.getScopeForContext(m_Sema.CurContext); - if (overloadExists(UnresolvedLookup, MARargs)) { + if (noOverloadExists(UnresolvedLookup, MARargs)) { return 0; } @@ -1456,9 +1456,7 @@ namespace clad { std::string s = std::to_string(m_DerivativeOrder); if (m_DerivativeOrder == 1) s = ""; - // FIXME: add gradient-vector products to fix that. - assert((CE->getNumArgs() <= 1) && - "forward differentiation of multi-arg calls is currently broken"); + IdentifierInfo* II = &m_Context.Idents.get(FD->getNameAsString() + "_d" + s + "arg0"); DeclarationName name(II); @@ -1487,6 +1485,11 @@ namespace clad { // Try to find an overloaded derivative in 'custom_derivatives' Expr* callDiff = m_Builder.findOverloadedDefinition(DNInfo, CallArgs); + // FIXME: add gradient-vector products to fix that. + if(!callDiff) + assert((CE->getNumArgs() <= 1) && + "forward differentiation of multi-arg calls is currently broken"); + // Check if it is a recursive call. if (!callDiff && (FD == m_Function)) { // The differentiated function is called recursively. @@ -3229,4 +3232,4 @@ namespace clad { return result; } -} // end namespace clad +}// end namespace clad diff --git a/test/Misc/RunDemos.C b/test/Misc/RunDemos.C index da13c67b1..9135f0a9e 100644 --- a/test/Misc/RunDemos.C +++ b/test/Misc/RunDemos.C @@ -2,6 +2,7 @@ // RUN: %cladclang %S/../../demos/ControlFlow.cpp -I%S/../../include 2>&1 // RUN: %cladclang %S/../../demos/DebuggingClad.cpp -I%S/../../include 2>&1 // RUN: %cladclang %S/../../demos/RosenbrockFunction.cpp -I%S/../../include 2>&1 +// RUN: %cladclang -lstdc++ -lm %S/../../demos/ComputerGraphics/SmallPT.cpp -I%S/../../include 2>&1 //-----------------------------------------------------------------------------/ @@ -95,4 +96,4 @@ //-----------------------------------------------------------------------------/ // Demo: ODE Solver Sensitivity //-----------------------------------------------------------------------------/ -// RUN: %cladclang -lstdc++ %S/../../demos/ODESolverSensitivity.cpp -I%S/../../include -oODESolverSensitivity.out \ No newline at end of file +// RUN: %cladclang -lstdc++ %S/../../demos/ODESolverSensitivity.cpp -I%S/../../include -oODESolverSensitivity.out