Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computer graphics issue fix #203

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ namespace clad {

clang::Expr* findOverloadedDefinition(clang::DeclarationNameInfo DNI,
llvm::SmallVectorImpl<clang::Expr*>& CallArgs);
bool overloadExists(clang::Expr* UnresolvedLookup,
bool noOverloadExists(clang::Expr* UnresolvedLookup,
llvm::MutableArrayRef<clang::Expr*> ARargs);
/// Shorthand to issues a warning or error.
template <std::size_t N>
Expand Down
15 changes: 9 additions & 6 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ namespace clad {
// This method is derived from the source code of both
// buildOverloadedCallSet() in SemaOverload.cpp
// and ActOnCallExpr() in SemaExpr.cpp.
bool DerivativeBuilder::overloadExists(Expr* UnresolvedLookup,
bool DerivativeBuilder::noOverloadExists(Expr* UnresolvedLookup,
llvm::MutableArrayRef<Expr*> ARargs) {
if (UnresolvedLookup->getType() == m_Context.OverloadTy) {
OverloadExpr::FindResult find = OverloadExpr::find(UnresolvedLookup);
Expand Down Expand Up @@ -1435,7 +1435,7 @@ namespace clad {
SourceLocation Loc;
Scope* S = m_Sema.getScopeForContext(m_Sema.CurContext);

if (overloadExists(UnresolvedLookup, MARargs)) {
if (noOverloadExists(UnresolvedLookup, MARargs)) {
return 0;
}

Expand All @@ -1456,9 +1456,7 @@ namespace clad {
std::string s = std::to_string(m_DerivativeOrder);
if (m_DerivativeOrder == 1)
s = "";
// FIXME: add gradient-vector products to fix that.
assert((CE->getNumArgs() <= 1) &&
"forward differentiation of multi-arg calls is currently broken");

IdentifierInfo* II = &m_Context.Idents.get(FD->getNameAsString() + "_d" +
s + "arg0");
DeclarationName name(II);
Expand Down Expand Up @@ -1487,6 +1485,11 @@ namespace clad {
// Try to find an overloaded derivative in 'custom_derivatives'
Expr* callDiff = m_Builder.findOverloadedDefinition(DNInfo, CallArgs);

// FIXME: add gradient-vector products to fix that.
if(!callDiff)
assert((CE->getNumArgs() <= 1) &&
"forward differentiation of multi-arg calls is currently broken");

// Check if it is a recursive call.
if (!callDiff && (FD == m_Function)) {
// The differentiated function is called recursively.
Expand Down Expand Up @@ -3229,4 +3232,4 @@ namespace clad {

return result;
}
} // end namespace clad
}// end namespace clad
3 changes: 2 additions & 1 deletion test/Misc/RunDemos.C
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// RUN: %cladclang %S/../../demos/ControlFlow.cpp -I%S/../../include 2>&1
// RUN: %cladclang %S/../../demos/DebuggingClad.cpp -I%S/../../include 2>&1
// RUN: %cladclang %S/../../demos/RosenbrockFunction.cpp -I%S/../../include 2>&1
// RUN: %cladclang -lstdc++ -lm %S/../../demos/ComputerGraphics/SmallPT.cpp -I%S/../../include 2>&1


//-----------------------------------------------------------------------------/
Expand Down Expand Up @@ -95,4 +96,4 @@
//-----------------------------------------------------------------------------/
// Demo: ODE Solver Sensitivity
//-----------------------------------------------------------------------------/
// RUN: %cladclang -lstdc++ %S/../../demos/ODESolverSensitivity.cpp -I%S/../../include -oODESolverSensitivity.out
// RUN: %cladclang -lstdc++ %S/../../demos/ODESolverSensitivity.cpp -I%S/../../include -oODESolverSensitivity.out