diff --git a/README.md b/README.md index 4410c8601..bfae57d9b 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,8 @@ Reverse-mode AD allows computing the gradient of `f` using *at most* a constant 1. `f` is a pointer to a function or a method to be differentiated 2. `ARGS` is either: * not provided, then `f` is differentiated w.r.t. its every argument - * a string literal with comma-separated names of independent variables (e.g. `"x"` or `"y"` or `"x, y"` or `"y, x"`) + * a string literal with comma-separated names/indices of independent variables (e.g. `"x"`, `"y"`, `"x, y"`, `"y, x"`, "0, 1", "0, y", etc.) + * a SINGLE number representing the index of the independent variable Since a vector of derivatives must be returned from a function generated by the reverse mode, its signature is slightly different. The generated function has `void` return type and same input arguments. The function has additional `n` arguments (where `n` refers to the number of arguments whose gradient was requested) of type `T*`, where `T` is the type of the corresponding original variable. Each of these variables stores the derivative of the elements as they appear in the orignal function signature. *The caller is responsible for allocating and zeroing-out the gradient storage*. Example: ```cpp diff --git a/VERSION b/VERSION index 5cc610824..f21da21f5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.7~dev +1.8~dev diff --git a/docs/internalDocs/ReleaseNotes.md b/docs/internalDocs/ReleaseNotes.md index 70640f602..de0a3df6b 100644 --- a/docs/internalDocs/ReleaseNotes.md +++ b/docs/internalDocs/ReleaseNotes.md @@ -2,7 +2,7 @@ Introduction ============ This document contains the release notes for the automatic differentiation -plugin for clang Clad, release 1.7. Clad is built on top of +plugin for clang Clad, release 1.8. Clad is built on top of [Clang](http://clang.llvm.org) and [LLVM](http://llvm.org>) compiler infrastructure. Here we describe the status of Clad in some detail, including major improvements from the previous release and new feature work. @@ -11,7 +11,7 @@ Note that if you are reading this file from a git checkout, this document applies to the *next* release, not the current one. -What's New in Clad 1.7? +What's New in Clad 1.8? ======================== Some of the major new features and improvements to Clad are listed here. Generic @@ -54,7 +54,7 @@ Fixed Bugs [XXX](https://github.com/vgvassilev/clad/issues/XXX) Special Kudos @@ -68,5 +68,5 @@ FirstName LastName (#commits) A B (N) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 3de507d69..fa7c09ad6 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -102,6 +102,8 @@ class BaseForwardModeVisitor StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE); StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE); StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE); + StmtDiff + VisitCXXScalarValueInitExpr(const clang::CXXScalarValueInitExpr* SVIE); StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* CSE); StmtDiff VisitCXXFunctionalCastExpr(const clang::CXXFunctionalCastExpr* FCE); StmtDiff VisitCXXBindTemporaryExpr(const clang::CXXBindTemporaryExpr* BTE); diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index db1113116..f7b5f61b0 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -106,7 +106,7 @@ template class DynamicGraph { bool isProcessingNode() { return m_currentId != -1; } /// Get the nodes in the graph. - std::vector getNodes() { return m_nodes; } + const std::vector& getNodes() { return m_nodes; } /// Print the nodes and edges in the graph. void print() { @@ -140,4 +140,4 @@ template class DynamicGraph { }; } // end namespace clad -#endif // CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H \ No newline at end of file +#endif // CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index a161f1f58..8d899a4ac 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -411,7 +411,8 @@ namespace clad { StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS); StmtDiff VisitCaseStmt(const clang::CaseStmt* CS); StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS); - DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); + DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD, + bool AddToBlock = true); StmtDiff VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index b68318f7b..c3759e2d4 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -394,14 +394,12 @@ namespace clad { /// to avoid recomputation. static bool UsefulToStore(clang::Expr* E); /// A flag for silencing warnings/errors output by diag function. - bool silenceDiags = false; /// Shorthand to issues a warning or error. template void diag(clang::DiagnosticsEngine::Level level, // Warning or Error clang::SourceLocation loc, const char (&format)[N], llvm::ArrayRef args = {}) { - if (!silenceDiags) - m_Builder.diag(level, loc, format, args); + m_Builder.diag(level, loc, format, args); } /// Creates unique identifier of the form "_nameBase" that is @@ -584,17 +582,14 @@ namespace clad { clang::Expr* GetSingleArgCentralDiffCall( clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos, unsigned numArgs, llvm::SmallVectorImpl& args); + /// Emits diagnostic messages on differentiation (or lack thereof) for /// call expressions. /// - /// \param[in] \c funcName The name of the underlying function of the - /// call expression. + /// \param[in] \c FD - The function declaration. /// \param[in] \c srcLoc Any associated source location information. - /// \param[in] \c isDerived A flag to determine if differentiation of the - /// call expression was successful. - void CallExprDiffDiagnostics(llvm::StringRef funcName, - clang::SourceLocation srcLoc, - bool isDerived); + void CallExprDiffDiagnostics(const clang::FunctionDecl* FD, + clang::SourceLocation srcLoc); clang::QualType DetermineCladArrayValueType(clang::QualType T); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 6942b2d4d..dbd8ca440 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -63,7 +63,6 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive(const FunctionDecl* FD, const DiffRequest& request) { assert(m_DiffReq == request && "Can't pass two different requests!"); - silenceDiags = !request.VerboseDiags; m_Functor = request.Functor; assert(m_DiffReq.Mode == DiffMode::forward); assert(!m_DerivativeInFlight && @@ -801,8 +800,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) || condUO) { condDiff = Visit(cond); - if (condDiff.getExpr_dx() && - (!isUnusedResult(condDiff.getExpr_dx()) || condUO)) + if (condDiff.getExpr_dx() && (!isUnusedResult(condDiff.getExpr_dx()))) cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), BuildParens(condDiff.getExpr())); else @@ -1335,7 +1333,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { GetSingleArgCentralDiffCall(fnCallee, CallArgs[0], /*targetPos=*/0, /*numArgs=*/1, CallArgs); } - CallExprDiffDiagnostics(FD->getNameAsString(), CE->getBeginLoc(), callDiff); + CallExprDiffDiagnostics(FD, CE->getBeginLoc()); if (!callDiff) { auto zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); @@ -1388,7 +1386,15 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { } else if (opKind == UnaryOperatorKind::UO_AddrOf) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_LNot) { - return StmtDiff(op, diff.getExpr_dx()); + Expr* zero = getZeroInit(UnOp->getType()); + if (diff.getExpr_dx() && !isUnusedResult(diff.getExpr_dx())) + return {BuildOp(BO_Comma, BuildParens(diff.getExpr_dx()), op), zero}; + return {op, zero}; + } else if (opKind == UnaryOperatorKind::UO_Not) { + // ~x is 2^n - 1 - x for unsigned types and -x - 1 for the signed ones. + // Either way, taking a derivative gives us -_d_x. + Expr* derivedOp = BuildOp(UO_Minus, diff.getExpr_dx()); + return {op, derivedOp}; } else { unsupportedOpWarn(UnOp->getEndLoc()); auto zero = @@ -1504,7 +1510,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { } else opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()), BuildParens(Rdiff.getExpr_dx())); - } else if (BinOp->isLogicalOp()) { + } else if (BinOp->isLogicalOp() || BinOp->isBitwiseOp() || + BinOp->isComparisonOp() || opCode == BO_Rem) { // For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and // correct derivative execution. auto buildOneSide = [this](StmtDiff& Xdiff) { @@ -1521,8 +1528,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { // Since the both parts are included in the opDiff, there's no point in // including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left - // nullptr is used for treating expressions like ((A && B) && C) correctly. - return StmtDiff(opDiff, nullptr); + // zero is used for treating expressions like ((A && B) && C) correctly. + return StmtDiff(opDiff, getZeroInit(BinOp->getType())); + } else if (BinOp->isShiftOp()) { + // Shifting is essentially multiplicating the LHS by 2^RHS (or 2^-RHS). + // We should do the same to the derivarive. + opDiff = BuildOp(opCode, Ldiff.getExpr_dx(), Rdiff.getExpr()); } else { // FIXME: add support for other binary operators unsupportedOpWarn(BinOp->getEndLoc()); @@ -2278,6 +2289,11 @@ StmtDiff BaseForwardModeVisitor::VisitCXXStdInitializerListExpr( return Visit(ILE->getSubExpr()); } +StmtDiff BaseForwardModeVisitor::VisitCXXScalarValueInitExpr( + const CXXScalarValueInitExpr* SVIE) { + return {Clone(SVIE), Clone(SVIE)}; +} + clang::Expr* BaseForwardModeVisitor::BuildCustomDerivativeConstructorPFCall( const clang::CXXConstructExpr* CE, llvm::SmallVectorImpl& clonedArgs, diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index a7c657eaa..a4defc87e 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -194,15 +194,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { } if (!NSD) { NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist); - if (!forCustomDerv && !NSD) { - diag(DiagnosticsEngine::Warning, noLoc, - "Numerical differentiation is diabled using the " - "-DCLAD_NO_NUM_DIFF " - "flag, this means that every try to numerically differentiate a " - "function will fail! Remove the flag to revert to default " - "behaviour."); + if (!NSD) return R; - } } DeclContext* DC = NSD; diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index c7ea8af14..0f7829a20 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -374,6 +374,21 @@ namespace clad { DiffInputVarInfo dVarInfo; dVarInfo.source = diffSpec.str(); + // Check if diffSpec represents an index of an independent variable. + if ('0' <= diffSpec[0] && diffSpec[0] <= '9') { + unsigned idx = std::stoi(dVarInfo.source); + // Fail if the specified index is invalid. + if (idx >= FD->getNumParams()) { + utils::EmitDiag( + semaRef, DiagnosticsEngine::Error, diffArgs->getEndLoc(), + "Invalid argument index '%0' of '%1' argument(s)", + {std::to_string(idx), std::to_string(FD->getNumParams())}); + return; + } + dVarInfo.param = FD->getParamDecl(idx); + DVI.push_back(dVarInfo); + continue; + } llvm::StringRef pName = computeParamName(diffSpec); auto it = std::find_if(std::begin(candidates), std::end(candidates), [&pName]( diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index 9d892868f..6cf306f2f 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -20,7 +20,6 @@ DerivativeAndOverload ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, const DiffRequest& request) { assert(m_DiffReq == request); - silenceDiags = !request.VerboseDiags; assert(m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 38a7fbe39..8308da4b7 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -256,7 +256,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const DiffRequest& request) { if (m_ExternalSource) m_ExternalSource->ActOnStartOfDerive(); - silenceDiags = !request.VerboseDiags; assert(m_DiffReq == request); // FIXME: reverse mode plugins may have request mode other than @@ -479,7 +478,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // for the two 'Derive's being different functions. if (m_ExternalSource) m_ExternalSource->ActOnStartOfDerive(); - silenceDiags = !request.VerboseDiags; // FIXME: We should not use const_cast to get the decl request here. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast(m_DiffReq) = request; @@ -978,91 +976,99 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) { + beginBlock(direction::reverse); + LoopCounter loopCounter(*this); beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | Scope::ContinueScope); - beginBlock(direction::reverse); - LoopCounter loopCounter(*this); + llvm::SaveAndRestore SaveCurrentBreakFlagExpr( + m_CurrentBreakFlagExpr); + m_CurrentBreakFlagExpr = nullptr; + auto* activeBreakContHandler = PushBreakContStmtHandler(); + activeBreakContHandler->BeginCFSwitchStmtScope(); const VarDecl* LoopVD = FRS->getLoopVariable(); - const Stmt* RangeDecl = FRS->getRangeStmt(); - const Stmt* BeginDecl = FRS->getBeginStmt(); - StmtDiff VisitRange = Visit(RangeDecl); - StmtDiff VisitBegin = Visit(BeginDecl); - Expr* BeginExpr = cast(VisitBegin.getStmt())->getLHS(); + llvm::SaveAndRestore SaveIsInside(isInsideLoop, + /*NewValue=*/false); + + const auto* RangeDecl = cast(FRS->getRangeStmt()->getSingleDecl()); + const auto* BeginDecl = cast(FRS->getBeginStmt()->getSingleDecl()); + + DeclDiff VisitRange = DifferentiateVarDecl(RangeDecl, false); + DeclDiff VisitBegin = DifferentiateVarDecl(BeginDecl, false); beginBlock(direction::reverse); // Create all declarations needed. - auto* BeginDeclRef = cast(BeginExpr); - Expr* d_BeginDeclRef = m_Variables[BeginDeclRef->getDecl()]; - - auto* RangeExpr = - cast(cast(VisitRange.getStmt())->getLHS()); - - Expr* RangeInit = Clone(FRS->getRangeInit()); - Expr* AssignRange = - BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit)); - Expr* AssignBegin = - BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr)); - addToCurrentBlock(AssignRange); - addToCurrentBlock(AssignBegin); - const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); - - Expr* EndInit = cast(EndDecl->getInit())->getRHS(); - QualType EndType = CloneType(EndDecl->getType()); - std::string EndName = EndDecl->getNameAsString(); - Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit); - VarDecl* EndVarDecl = - BuildGlobalVarDecl(EndType, EndName, EndAssign, /*DirectInit=*/false); - DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl); + DeclRefExpr* beginDeclRef = BuildDeclRef(VisitBegin.getDecl()); + Expr* d_beginDeclRef = m_Variables[beginDeclRef->getDecl()]; + DeclRefExpr* rangeDeclRef = BuildDeclRef(VisitRange.getDecl()); + Expr* d_rangeDeclRef = m_Variables[rangeDeclRef->getDecl()]; + + Expr* rangeInit = Clone(FRS->getRangeInit()); + Expr* d_rangeInitDeclRef = + m_Variables[cast(rangeInit)->getDecl()]; + VisitRange.getDecl_dx()->setInit(BuildOp(UO_AddrOf, d_rangeInitDeclRef)); + Expr* assignAdjBegin = BuildOp(BO_Assign, d_beginDeclRef, d_rangeDeclRef); + Expr* assignRange = + BuildOp(BO_Assign, rangeDeclRef, BuildOp(UO_AddrOf, rangeInit)); + + addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl())); + addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl_dx())); + addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl())); + addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl_dx())); + addToCurrentBlock(assignAdjBegin); + addToCurrentBlock(assignRange); - addToCurrentBlock(AssignEnd); - auto* AssignEndVarDecl = - cast(cast(AssignEnd)->getSingleDecl()); - DeclRefExpr* EndExpr = BuildDeclRef(AssignEndVarDecl); - Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef); + const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); + QualType endType = CloneType(EndDecl->getType()); + std::string endName = EndDecl->getNameAsString(); + Expr* endInit = Visit(EndDecl->getInit()).getExpr(); + VarDecl* endVarDecl = + BuildGlobalVarDecl(endType, endName, endInit, /*DirectInit=*/false); + addToCurrentBlock(BuildDeclStmt(endVarDecl)); + DeclRefExpr* endExpr = BuildDeclRef(endVarDecl); + Expr* incBegin = BuildOp(UO_PreInc, beginDeclRef); beginBlock(direction::forward); DeclDiff LoopVDDiff = DifferentiateVarDecl(LoopVD); - Stmt* AdjLoopVDAddAssign = + Stmt* adjLoopVDAddAssign = utils::unwrapIfSingleStmt(endBlock(direction::forward)); - if ((LoopVDDiff.getDecl()->getDeclName() != LoopVD->getDeclName() || - LoopVD->getType() != LoopVDDiff.getDecl()->getType())) - m_DeclReplacements[LoopVD] = LoopVDDiff.getDecl(); llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop, /*NewValue=*/true); - Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginDeclRef); - Expr* d_DecBegin = BuildOp(UO_PostDec, d_BeginDeclRef); - Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr); - // Add item assignment statement to the body. + Expr* d_incBegin = BuildOp(UO_PreInc, d_beginDeclRef); + Expr* d_decBegin = BuildOp(UO_PostDec, d_beginDeclRef); + Expr* forwardCond = BuildOp(BO_NE, beginDeclRef, endExpr); const Stmt* body = FRS->getBody(); - StmtDiff bodyDiff = Visit(body); + StmtDiff bodyDiff = + DifferentiateLoopBody(body, loopCounter, nullptr, nullptr, + /*isForLoop=*/true); + + activeBreakContHandler->EndCFSwitchStmtScope(); + activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + PopBreakContStmtHandler(); StmtDiff storeLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl())); StmtDiff storeAdjLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl_dx())); - addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx())); - Expr* CounterIncrement = loopCounter.getCounterIncrement(); - Expr* LoopInit = LoopVDDiff.getDecl()->getInit(); + Expr* loopInit = LoopVDDiff.getDecl()->getInit(); LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType())); addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl())); - Expr* AssignLoop = - BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), LoopInit); + Expr* assignLoop = + BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), loopInit); if (!LoopVD->getType()->isReferenceType()) { Expr* d_LoopVD = BuildDeclRef(LoopVDDiff.getDecl_dx()); - AdjLoopVDAddAssign = - BuildOp(BO_Assign, d_LoopVD, BuildOp(UO_Deref, d_BeginDeclRef)); + adjLoopVDAddAssign = + BuildOp(BO_Assign, d_LoopVD, BuildOp(UO_Deref, d_beginDeclRef)); } beginBlock(direction::forward); - addToCurrentBlock(CounterIncrement); - addToCurrentBlock(AdjLoopVDAddAssign); - addToCurrentBlock(AssignLoop); + addToCurrentBlock(adjLoopVDAddAssign); + addToCurrentBlock(assignLoop); addToCurrentBlock(storeLoop.getStmt()); addToCurrentBlock(storeAdjLoop.getStmt()); CompoundStmt* LoopVDForwardDiff = endBlock(direction::forward); @@ -1070,28 +1076,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Sema.getASTContext(), bodyDiff.getStmt(), LoopVDForwardDiff); beginBlock(direction::forward); - addToCurrentBlock(d_DecBegin); + addToCurrentBlock(d_decBegin); addToCurrentBlock(storeLoop.getStmt_dx()); addToCurrentBlock(storeAdjLoop.getStmt_dx()); CompoundStmt* LoopVDReverseDiff = endBlock(direction::forward); CompoundStmt* bodyReverse = utils::PrependAndCreateCompoundStmt( m_Sema.getASTContext(), bodyDiff.getStmt_dx(), LoopVDReverseDiff); - Expr* Inc = BuildOp(BO_Comma, IncBegin, d_IncBegin); + Expr* inc = BuildOp(BO_Comma, incBegin, d_incBegin); Stmt* Forward = new (m_Context) ForStmt( - m_Context, /*Init=*/nullptr, ForwardCond, /*CondVar=*/nullptr, Inc, + m_Context, /*Init=*/nullptr, forwardCond, /*CondVar=*/nullptr, inc, bodyForward, FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc()); - Expr* CounterCondition = + Expr* counterCondition = loopCounter.getCounterConditionResult().get().second; - Expr* CounterDecrement = loopCounter.getCounterDecrement(); + Expr* counterDecrement = loopCounter.getCounterDecrement(); Stmt* Reverse = bodyReverse; addToCurrentBlock(Reverse, direction::reverse); Reverse = endBlock(direction::reverse); Reverse = new (m_Context) - ForStmt(m_Context, /*Init=*/nullptr, CounterCondition, - /*CondVar=*/nullptr, CounterDecrement, Reverse, + ForStmt(m_Context, /*Init=*/nullptr, counterCondition, + /*CondVar=*/nullptr, counterDecrement, Reverse, FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc()); addToCurrentBlock(Reverse, direction::reverse); Reverse = endBlock(direction::reverse); @@ -1943,8 +1949,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts, DerivedCallArgs, CallArgDx); } - CallExprDiffDiagnostics(FD->getNameAsString(), CE->getBeginLoc(), - OverloadedDerivedFn); + CallExprDiffDiagnostics(FD, CE->getBeginLoc()); if (!OverloadedDerivedFn) { Stmts& block = getCurrentBlock(direction::reverse); block.insert(block.begin(), PreCallStmts.begin(), @@ -2295,7 +2300,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff RResult; // If R has no side effects, it can be just cloned // (no need to store it). - if (!ShouldRecompute(R)) { + + // Check if the local variable declaration is reference type, since it is + // moved to the global scope and the right side should be recomputed + bool promoteToFnScope = false; + if (auto* RDeclRef = dyn_cast(R->IgnoreImplicit())) + promoteToFnScope = RDeclRef->getDecl()->getType()->isReferenceType() && + !getCurrentScope()->isFunctionScope(); + + if (!ShouldRecompute(R) || promoteToFnScope) { RDelayed = std::unique_ptr( new DelayedStoreResult(DelayedGlobalStoreAndRef(R))); RResult = RDelayed->Result; @@ -2642,18 +2655,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!BinOp->isComparisonOp() && !BinOp->isLogicalOp()) unsupportedOpWarn(BinOp->getEndLoc()); - // If either LHS or RHS is a declaration reference, visit it to avoid - // naming collision - auto* LDRE = dyn_cast(L); - auto* RDRE = dyn_cast(R); - - if (!LDRE && !RDRE) - return Clone(BinOp); - - Expr* LExpr = LDRE ? Visit(L).getExpr() : L; - Expr* RExpr = RDRE ? Visit(R).getExpr() : R; - - return BuildOp(opCode, LExpr, RExpr); + return BuildOp(opCode, Visit(L).getExpr(), Visit(R).getExpr()); } Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr()); @@ -2680,10 +2682,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(op, ResultRef, nullptr, valueForRevPass); } - DeclDiff - ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { + DeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD, + bool addToBlock) { StmtDiff initDiff; Expr* VDDerivedInit = nullptr; + // Local declarations are promoted to the function global scope. This // procedure is done to make declarations visible in the reverse sweep. // The reverse_mode_forward_pass mode does not have a reverse pass so @@ -2858,7 +2861,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, getZeroInit(VDDerivedType)); else assignToZero = GetCladZeroInit(declRef); - addToCurrentBlock(assignToZero, direction::reverse); + if (addToBlock) + addToCurrentBlock(assignToZero, direction::reverse); } } @@ -2874,10 +2878,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getForwSweepExpr_dx())); - addToCurrentBlock(assignDerivativeE); + if (addToBlock) + addToCurrentBlock(assignDerivativeE); if (isInsideLoop) { StmtDiff pushPop = StoreAndRestore(derivedVDE); - addToCurrentBlock(pushPop.getStmt(), direction::forward); + if (addToBlock) + addToCurrentBlock(pushPop.getStmt(), direction::forward); m_LoopBlock.back().push_back(pushPop.getStmt_dx()); } derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE); @@ -2903,10 +2909,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (promoteToFnScope) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, initDiff.getExpr_dx()); - addToCurrentBlock(assignDerivativeE, direction::forward); + if (addToBlock) + addToCurrentBlock(assignDerivativeE, direction::forward); if (isInsideLoop) { auto tape = MakeCladTapeFor(derivedVDE); - addToCurrentBlock(tape.Push); + if (addToBlock) + addToCurrentBlock(tape.Push); auto* reverseSweepDerivativePointerE = BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop); m_LoopBlock.back().push_back( @@ -2920,6 +2928,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } if (derivedVDE) m_Variables.emplace(VDClone, derivedVDE); + // Check if decl's name is the same as before. The name may be changed + // if decl name collides with something in the derivative body. + // This can happen in rare cases, e.g. when the original function + // has both y and _d_y (here _d_y collides with the name produced by + // the derivation process), e.g. + // double f(double x) { + // double y = x; + // double _d_y = x; + // } + // -> + // double f_darg0(double x) { + // double _d_x = 1; + // double _d_y = _d_x; // produced as a derivative for y + // double y = x; + // double _d__d_y = _d_x; + // double _d_y = x; // copied from original function, collides with + // _d_y + // } + if ((VD->getDeclName() != VDClone->getDeclName() || + VD->getType() != VDClone->getType())) + m_DeclReplacements[VD] = VDClone; return DeclDiff(VDClone, VDDerived); } @@ -3022,29 +3051,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!isLambda) VDDiff = DifferentiateVarDecl(VD); - // Check if decl's name is the same as before. The name may be changed - // if decl name collides with something in the derivative body. - // This can happen in rare cases, e.g. when the original function - // has both y and _d_y (here _d_y collides with the name produced by - // the derivation process), e.g. - // double f(double x) { - // double y = x; - // double _d_y = x; - // } - // -> - // double f_darg0(double x) { - // double _d_x = 1; - // double _d_y = _d_x; // produced as a derivative for y - // double y = x; - // double _d__d_y = _d_x; - // double _d_y = x; // copied from original function, collides with - // _d_y - // } - if (!isLambda && - (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || - VD->getType() != VDDiff.getDecl()->getType())) - m_DeclReplacements[VD] = VDDiff.getDecl(); - // Here, we move the declaration to the function global scope. // Initialization is replaced with an assignment operation at the same // place as the original declaration. This procedure is done to make the diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 407c1d7e1..414e24f8d 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -17,6 +17,7 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" #include "clang/AST/TemplateBase.h" +#include "clang/Lex/Preprocessor.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Overload.h" #include "clang/Sema/Scope.h" @@ -733,23 +734,28 @@ namespace clad { /*namespaceShouldExist=*/false); } - void VisitorBase::CallExprDiffDiagnostics(llvm::StringRef funcName, - SourceLocation srcLoc, bool isDerived){ - if (!isDerived) { - // Function was not derived => issue a warning. - diag(DiagnosticsEngine::Warning, - srcLoc, - "function '%0' was not differentiated because clad failed to " - "differentiate it and no suitable overload was found in " - "namespace 'custom_derivatives', and function may not be " - "eligible for numerical differentiation.", + void VisitorBase::CallExprDiffDiagnostics(const clang::FunctionDecl* FD, + SourceLocation srcLoc) { + bool NumDiffEnabled = + !m_Sema.getPreprocessor().isMacroDefined("CLAD_NO_NUM_DIFF"); + // FIXME: Switch to the real diagnostics engine and pass FD directly. + std::string funcName = FD->getNameAsString(); + diag(DiagnosticsEngine::Warning, srcLoc, + "function '%0' was not differentiated because clad failed to " + "differentiate it and no suitable overload was found in " + "namespace 'custom_derivatives'", + {funcName}); + if (NumDiffEnabled) { + diag(DiagnosticsEngine::Note, srcLoc, + "falling back to numerical differentiation for '%0' since no " + "suitable overload was found and clad could not derive it; " + "to disable this feature, compile your programs with " + "-DCLAD_NO_NUM_DIFF", {funcName}); } else { - diag(DiagnosticsEngine::Warning, noLoc, - "Falling back to numerical differentiation for '%0' since no " - "suitable overload was found and clad could not derive it. " - "To disable this feature, compile your programs with " - "-DCLAD_NO_NUM_DIFF.", + diag(DiagnosticsEngine::Note, srcLoc, + "fallback to numerical differentiation is disabled by the " + "'CLAD_NO_NUM_DIFF' macro; considering '%0' as 0", {funcName}); } } diff --git a/test/FirstDerivative/DiffInterface.C b/test/FirstDerivative/DiffInterface.C index df764e493..7e538faa7 100644 --- a/test/FirstDerivative/DiffInterface.C +++ b/test/FirstDerivative/DiffInterface.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | %filecheck %s +// RUN: %cladclang -ferror-limit=100 %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | %filecheck %s #include "clad/Differentiator/Differentiator.h" @@ -131,8 +131,6 @@ int main () { clad::differentiate(f_2, -1); // expected-error {{Invalid argument index '-1' of '3' argument(s)}} - clad::differentiate(f_2, -1); // expected-error {{Invalid argument index '-1' of '3' argument(s)}} - clad::differentiate(f_2, 3); // expected-error {{Invalid argument index '3' of '3' argument(s)}} clad::differentiate(f_2, 9); // expected-error {{Invalid argument index '9' of '3' argument(s)}} @@ -141,6 +139,10 @@ int main () { clad::differentiate(f_2, f_2); // expected-error {{Failed to parse the parameters, must be a string or numeric literal}} + clad::gradient(f_2, -1); // expected-error {{Invalid argument index '-1' of '3' argument(s)}} + + clad::gradient(f_2, "9"); // expected-error {{Invalid argument index '9' of '3' argument(s)}} + clad::differentiate(f_3, 0); // expected-error {{Invalid argument index '0' of '0' argument(s)}} float one = 1.0; diff --git a/test/FirstDerivative/UnsupportedOpsWarn.C b/test/FirstDerivative/UnsupportedOpsWarn.C index 0f59ac961..551391407 100644 --- a/test/FirstDerivative/UnsupportedOpsWarn.C +++ b/test/FirstDerivative/UnsupportedOpsWarn.C @@ -6,12 +6,10 @@ //CHECK-NOT: {{.*error|warning|note:.*}} int binOpWarn_0(int x){ - return x << 1; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}} + return x << 1; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}} } -// CHECK: int binOpWarn_0_darg0(int x) { -// CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: return 0; +// CHECK: void binOpWarn_0_grad(int x, int *_d_x) { // CHECK-NEXT: } @@ -23,17 +21,14 @@ int binOpWarn_1(int x){ // CHECK-NEXT: } int unOpWarn_0(int x){ - return ~x; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}} + return ~x; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}} } -// CHECK: int unOpWarn_0_darg0(int x) { -// CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: return 0; +// CHECK: void unOpWarn_0_grad(int x, int *_d_x) { // CHECK-NEXT: } int main(){ - - clad::differentiate(binOpWarn_0, 0); + clad::gradient(binOpWarn_0); clad::gradient(binOpWarn_1); - clad::differentiate(unOpWarn_0, 0); + clad::gradient(unOpWarn_0); } diff --git a/test/Gradient/DiffInterface.C b/test/Gradient/DiffInterface.C index 8fd3c279c..1ea3804b1 100644 --- a/test/Gradient/DiffInterface.C +++ b/test/Gradient/DiffInterface.C @@ -98,12 +98,22 @@ int main () { auto f1_grad_y = clad::gradient(f_1, "y"); TEST(f1_grad_y, &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} + + auto f1_grad_0 = clad::gradient(f_1, "1"); + TEST(f1_grad_0, &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} + auto f1_grad_z = clad::gradient(f_1, "z"); TEST(f1_grad_z, &result[2]); // CHECK-EXEC: {0.00, 0.00, 2.00} auto f1_grad_xy = clad::gradient(f_1, "x, y"); TEST(f1_grad_xy, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} + auto f1_grad_0y = clad::gradient(f_1, "0, y"); + TEST(f1_grad_0y, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} + + auto f1_grad_10 = clad::gradient(f_1, "1, 0"); + TEST(f1_grad_10, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} + auto f1_grad_yx = clad::gradient(f_1, "y, x"); TEST(f1_grad_yx, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 397ec653f..9c3d892aa 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2689,17 +2689,12 @@ double fn34(double x, double y){ double r = 0; double a[] = {y, x*y, x*x + y}; for(auto& i: a){ - r+=i; + r+=i*i; } return r; } //CHECK: void fn34_grad(double x, double y, double *_d_x, double *_d_y) { -//CHECK-NEXT: unsigned {{int|long}} _t0; -//CHECK-NEXT: double (*_d___range1)[3] = {}; -//CHECK-NEXT: double (*__range10)[3] = {}; -//CHECK-NEXT: double *_d___begin1 = {}; -//CHECK-NEXT: double *__begin10 = {}; //CHECK-NEXT: clad::tape _t1 = {}; //CHECK-NEXT: clad::tape _t2 = {}; //CHECK-NEXT: clad::tape _t3 = {}; @@ -2707,24 +2702,26 @@ double fn34(double x, double y){ //CHECK-NEXT: double r = 0; //CHECK-NEXT: double _d_a[3] = {0}; //CHECK-NEXT: double a[3] = {y, x * y, x * x + y}; -//CHECK-NEXT: _t0 = {{0U|0UL}}; -//CHECK-NEXT: _d___range1 = &_d_a; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = &a; +//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = {}; //CHECK-NEXT: _d___begin1 = *_d___range1; //CHECK-NEXT: __range10 = &a; -//CHECK-NEXT: __begin10 = *__range10; //CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; //CHECK-NEXT: double *_d_i = {}; //CHECK-NEXT: double *i = {}; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { //CHECK-NEXT: { -//CHECK-NEXT: _t0++; //CHECK-NEXT: _d_i = &*_d___begin1; //CHECK-NEXT: i = &*__begin10; //CHECK-NEXT: clad::push(_t2, i); //CHECK-NEXT: clad::push(_t3, _d_i); //CHECK-NEXT: } +//CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, r); -//CHECK-NEXT: r += *i; +//CHECK-NEXT: r += *i * *i; //CHECK-NEXT: } //CHECK-NEXT: _d_r += 1; //CHECK-NEXT: for (; _t0; _t0--) { @@ -2737,7 +2734,8 @@ double fn34(double x, double y){ //CHECK-NEXT: { //CHECK-NEXT: r = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_r; -//CHECK-NEXT: *_d_i += _r_d0; +//CHECK-NEXT: *_d_i += _r_d0 * *i; +//CHECK-NEXT: *_d_i += *i * _r_d0; //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: } @@ -2752,70 +2750,286 @@ double fn34(double x, double y){ //CHECK-NEXT: } double fn35(double x, double y){ + double r = 0; + double a[] = {x, x*y, 0}; + for(auto& i: a){ + for(auto& j:a){ + if(r<=x*x){ + r+=i*j; + }else if(r>x*x){ + break; + } + } + } + return r; +} + +//CHECK: void fn35_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _cond0 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _cond1 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: clad::tape _t4 = {}; +//CHECK-NEXT: clad::tape _t5 = {}; +//CHECK-NEXT: clad::tape _t6 = {}; +//CHECK-NEXT: clad::tape _t7 = {}; +//CHECK-NEXT: double _d_r = 0.; +//CHECK-NEXT: double r = 0; +//CHECK-NEXT: double _d_a[3] = {0}; +//CHECK-NEXT: double a[3] = {x, x * y, 0}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = &a; +//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = {}; +//CHECK-NEXT: _d___begin1 = *_d___range1; +//CHECK-NEXT: __range10 = &a; +//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: double *_d_i = {}; +//CHECK-NEXT: double *i = {}; +//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_i = &*_d___begin1; +//CHECK-NEXT: i = &*__begin10; +//CHECK-NEXT: clad::push(_t6, i); +//CHECK-NEXT: clad::push(_t7, _d_i); +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, {{0U|0UL}}); +//CHECK-NEXT: double (*__range20)[3] = &a; +//CHECK-NEXT: double (*_d___range2)[3] = &_d_a; +//CHECK-NEXT: double *__begin20 = *__range20; +//CHECK-NEXT: double *_d___begin2 = {}; +//CHECK-NEXT: _d___begin2 = *_d___range2; +//CHECK-NEXT: __range20 = &a; +//CHECK-NEXT: double *__end20 = *__range20 + {{3|3L}}; +//CHECK-NEXT: double *_d_j = {}; +//CHECK-NEXT: double *j = {}; +//CHECK-NEXT: for (; __begin20 != __end20; ++__begin20 , ++_d___begin2) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_j = &*_d___begin2; +//CHECK-NEXT: j = &*__begin20; +//CHECK-NEXT: clad::push(_t4, j); +//CHECK-NEXT: clad::push(_t5, _d_j); +//CHECK-NEXT: } +//CHECK-NEXT: clad::back(_t1)++; +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_cond0, r <= x * x); +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: clad::push(_t2, r); +//CHECK-NEXT: r += *i * *j; +//CHECK-NEXT: } else { +//CHECK-NEXT: clad::push(_cond1, r > x * x); +//CHECK-NEXT: if (clad::back(_cond1)) { +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_t3, {{1U|1UL}}); +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t3, {{2U|2UL}}); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: _d_r += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin1--; +//CHECK-NEXT: i = clad::pop(_t6); +//CHECK-NEXT: _d_i = clad::pop(_t7); +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: for (; clad::back(_t1); clad::back(_t1)--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin2--; +//CHECK-NEXT: j = clad::pop(_t4); +//CHECK-NEXT: _d_j = clad::pop(_t5); +//CHECK-NEXT: } +//CHECK-NEXT: switch (clad::pop(_t3)) { +//CHECK-NEXT: case {{2U|2UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: { +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: { +//CHECK-NEXT: r = clad::pop(_t2); +//CHECK-NEXT: double _r_d0 = _d_r; +//CHECK-NEXT: *_d_i += _r_d0 * *j; +//CHECK-NEXT: *_d_j += *i * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } else { +//CHECK-NEXT: if (clad::back(_cond1)) { +//CHECK-NEXT: case {{1U|1UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_cond1); +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_cond0); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_t1); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_a[0]; +//CHECK-NEXT: *_d_x += _d_a[1] * y; +//CHECK-NEXT: *_d_y += x * _d_a[1]; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double fn36(double x, double y){ double a[] = {1, 2, 3}; double sum = 0; - for(auto i:a){ - sum += sin(i)*x; + for(auto i: a){ + if(sum > x){ + continue; + }else if(1){ + sum += sin(i)*x; + } } return sum; } -//CHECK: void fn35_grad(double x, double y, double *_d_x, double *_d_y) { -//CHECK-NEXT: unsigned {{int|long}} _t0; -//CHECK-NEXT: double (*_d___range1)[3] = {}; -//CHECK-NEXT: double (*__range10)[3] = {}; -//CHECK-NEXT: double *_d___begin1 = {}; -//CHECK-NEXT: double *__begin10 = {}; -//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK: void fn36_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: clad::tape _cond0 = {}; +//CHECK-NEXT: clad::tape _t1 = {}; //CHECK-NEXT: clad::tape _t2 = {}; //CHECK-NEXT: clad::tape _t3 = {}; //CHECK-NEXT: clad::tape _t4 = {}; +//CHECK-NEXT: clad::tape _t5 = {}; //CHECK-NEXT: double _d_a[3] = {0}; //CHECK-NEXT: double a[3] = {1, 2, 3}; //CHECK-NEXT: double _d_sum = 0.; //CHECK-NEXT: double sum = 0; -//CHECK-NEXT: _t0 = {{0U|0UL}}; -//CHECK-NEXT: _d___range1 = &_d_a; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = &a; +//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = {}; //CHECK-NEXT: _d___begin1 = *_d___range1; //CHECK-NEXT: __range10 = &a; -//CHECK-NEXT: __begin10 = *__range10; //CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; //CHECK-NEXT: double _d_i = 0.; //CHECK-NEXT: double i = 0.; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { //CHECK-NEXT: { -//CHECK-NEXT: _t0++; //CHECK-NEXT: _d_i = *_d___begin1; //CHECK-NEXT: i = *__begin10; -//CHECK-NEXT: clad::push(_t3, i); -//CHECK-NEXT: clad::push(_t4, _d_i); +//CHECK-NEXT: clad::push(_t4, i); +//CHECK-NEXT: clad::push(_t5, _d_i); //CHECK-NEXT: } -//CHECK-NEXT: clad::push(_t1, sum); -//CHECK-NEXT: clad::push(_t2, sin(i)); -//CHECK-NEXT: sum += clad::back(_t2) * x; +//CHECK-NEXT: _t0++; +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_cond0, sum > x); +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_t1, {{1U|1UL}}); +//CHECK-NEXT: continue; +//CHECK-NEXT: } +//CHECK-NEXT: } else if (1) { +//CHECK-NEXT: clad::push(_t2, sum); +//CHECK-NEXT: clad::push(_t3, sin(i)); +//CHECK-NEXT: sum += clad::back(_t3) * x; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t1, {{2U|2UL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_sum += 1; //CHECK-NEXT: for (; _t0; _t0--) { //CHECK-NEXT: { //CHECK-NEXT: { //CHECK-NEXT: _d___begin1--; -//CHECK-NEXT: i = clad::pop(_t3); -//CHECK-NEXT: _d_i = clad::pop(_t4); +//CHECK-NEXT: i = clad::pop(_t4); +//CHECK-NEXT: _d_i = clad::pop(_t5); //CHECK-NEXT: } -//CHECK-NEXT: { -//CHECK-NEXT: sum = clad::pop(_t1); -//CHECK-NEXT: double _r_d0 = _d_sum; -//CHECK-NEXT: double _r0 = 0.; -//CHECK-NEXT: _r0 += _r_d0 * x * clad::custom_derivatives::sin_pushforward(i, 1.).pushforward; -//CHECK-NEXT: _d_i += _r0; -//CHECK-NEXT: *_d_x += clad::back(_t2) * _r_d0; -//CHECK-NEXT: clad::pop(_t2); +//CHECK-NEXT: switch (clad::pop(_t1)) { +//CHECK-NEXT: case {{2U|2UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: { +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: case {{1U|1UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: } else if (1) { +//CHECK-NEXT: { +//CHECK-NEXT: sum = clad::pop(_t2); +//CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: double _r0 = 0.; +//CHECK-NEXT: _r0 += _r_d0 * x * clad::custom_derivatives::sin_pushforward(i, 1.).pushforward; +//CHECK-NEXT: _d_i += _r0; +//CHECK-NEXT: *_d_x += clad::back(_t3) * _r_d0; +//CHECK-NEXT: clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_cond0); +//CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: *_d___begin1 += _d_i; //CHECK-NEXT: } //CHECK-NEXT: } +double fn37(double x, double y) { + double range[] = {x, 4., y}; + double sum = 0; + for (auto elem: range) + sum += elem; + return sum; +} + +//CHECK: void fn37_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: double _d_range[3] = {0}; +//CHECK-NEXT: double range[3] = {x, 4., y}; +//CHECK-NEXT: double _d_sum = 0.; +//CHECK-NEXT: double sum = 0; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = ⦥ +//CHECK-NEXT: double (*_d___range1)[3] = &_d_range; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = {}; +//CHECK-NEXT: _d___begin1 = *_d___range1; +//CHECK-NEXT: __range10 = ⦥ +//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: double _d_elem = 0.; +//CHECK-NEXT: double elem = 0.; +//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_elem = *_d___begin1; +//CHECK-NEXT: elem = *__begin10; +//CHECK-NEXT: clad::push(_t2, elem); +//CHECK-NEXT: clad::push(_t3, _d_elem); +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, sum); +//CHECK-NEXT: sum += elem; +//CHECK-NEXT: } +//CHECK-NEXT: _d_sum += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin1--; +//CHECK-NEXT: elem = clad::pop(_t2); +//CHECK-NEXT: _d_elem = clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: sum = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: _d_elem += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: *_d___begin1 += _d_elem; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_range[0]; +//CHECK-NEXT: *_d_y += _d_range[2]; +//CHECK-NEXT: } +//CHECK-NEXT: } + #define TEST(F, x) { \ result[0] = 0; \ @@ -2901,8 +3115,10 @@ int main() { TEST_2(fn32, 3, 5); // CHECK-EXEC: {45.00, 27.00} TEST_2(fn33, 3, 5); // CHECK-EXEC: {15.00, 9.00} - TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00} - TEST_2(fn35, 1, 1); // CHECK-EXEC: {1.89, 0.00} + TEST_2(fn34, 2, 2); // CHECK-EXEC: {64.00, 32.00} + TEST_2(fn35, 2, 2); // CHECK-EXEC: {12.00, 4.00} + TEST_2(fn36, 1, 1); // CHECK-EXEC: {1.75, 0.00} + TEST_2(fn37, 1, 1); // CHECK-EXEC: {1.00, 1.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) { @@ -2910,4 +3126,4 @@ int main() { //CHECK-NEXT: *_d_x += _d_y * x; //CHECK-NEXT: *_d_x += x * _d_y; //CHECK-NEXT: } -//CHECK-NEXT: } +//CHECK-NEXT: } \ No newline at end of file diff --git a/test/Gradient/NonDifferentiableError.C b/test/Gradient/NonDifferentiableError.C index 501c16268..83558d077 100644 --- a/test/Gradient/NonDifferentiableError.C +++ b/test/Gradient/NonDifferentiableError.C @@ -29,6 +29,12 @@ non_differentiable double fn_s2_mem_fn(double i, double j) { return obj.mem_fn(i, j) + i * j; } +double no_body(double x); + +double fn1(double x) { return no_body(x); } //expected-warning {{function 'no_body' was not differentiated}} +//expected-note@34 {{fallback to numerical differentiation is disabled}} +double fn2(double x) { return fn1(x); } + #define INIT_EXPR(classname) \ classname expr_1(2, 3); \ classname expr_2(3, 5); @@ -48,4 +54,5 @@ int main() { INIT_EXPR(SimpleFunctions2); TEST_CLASS(SimpleFunctions2, mem_fn, 3, 5); TEST_FUNC(fn_s2_mem_fn, 3, 5); // expected-error {{attempted differentiation of function 'fn_s2_mem_fn', which is marked as non-differentiable}} + auto fn2_grad = clad::gradient(fn2); } diff --git a/test/NumericalDiff/GradientMultiArg.C b/test/NumericalDiff/GradientMultiArg.C index 30a993463..5e9a07dff 100644 --- a/test/NumericalDiff/GradientMultiArg.C +++ b/test/NumericalDiff/GradientMultiArg.C @@ -1,6 +1,6 @@ -// RUN: %cladnumdiffclang %s -I%S/../../include -oGradientMultiArg.out 2>&1 | FileCheck -check-prefix=CHECK %s +// RUN: %cladnumdiffclang %s -I%S/../../include -oGradientMultiArg.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s // RUN: ./GradientMultiArg.out | %filecheck_exec %s -// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oGradientMultiArg.out +// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oGradientMultiArg.out -Xclang -verify // RUN: ./GradientMultiArg.out | %filecheck_exec %s //CHECK-NOT: {{.*error|warning|note:.*}} @@ -11,9 +11,9 @@ #include double test_1(double x, double y){ - return std::hypot(x, y); + return std::hypot(x, y); // expected-warning {{function 'hypot' was not differentiated}} + // expected-note@14 {{falling back to numerical differentiation}} } -// CHECK: warning: Falling back to numerical differentiation for 'hypot' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF. // CHECK: void test_1_grad(double x, double y, double *_d_x, double *_d_y) { // CHECK-NEXT: { // CHECK-NEXT: double _r0 = 0.; diff --git a/test/NumericalDiff/NoNumDiff.C b/test/NumericalDiff/NoNumDiff.C index 0a14de784..7f514926a 100644 --- a/test/NumericalDiff/NoNumDiff.C +++ b/test/NumericalDiff/NoNumDiff.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oNoNumDiff.out 2>&1 | FileCheck -check-prefix=CHECK %s +// RUN: %cladclang %s -I%S/../../include -oNoNumDiff.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s //CHECK-NOT: {{.*error|warning|note:.*}} @@ -6,10 +6,9 @@ #include -double func(double x) { return std::tanh(x); } +double func(double x) { return std::tanh(x); } // expected-warning 2{{function 'tanh' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}} +// expected-note@9 2{{fallback to numerical differentiation is disabled by the 'CLAD_NO_NUM_DIFF' macro}} -//CHECK: warning: Numerical differentiation is diabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour. -//CHECK: warning: Numerical differentiation is diabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour. //CHECK: double func_darg0(double x) { //CHECK-NEXT: double _d_x = 1; //CHECK-NEXT: return 0; @@ -24,6 +23,6 @@ double func(double x) { return std::tanh(x); } int main(){ - clad::differentiate(func, "x"); - clad::gradient(func); + clad::differentiate(func, "x"); + clad::gradient(func); } diff --git a/test/NumericalDiff/NumDiff.C b/test/NumericalDiff/NumDiff.C index baeb7391b..d5cd62dc2 100644 --- a/test/NumericalDiff/NumDiff.C +++ b/test/NumericalDiff/NumDiff.C @@ -1,15 +1,14 @@ -// RUN: %cladnumdiffclang %s -I%S/../../include -oNumDiff.out 2>&1 | FileCheck -check-prefix=CHECK %s +// RUN: %cladnumdiffclang %s -I%S/../../include -oNumDiff.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s // RUN: ./NumDiff.out | %filecheck_exec %s -// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oNumDiff.out +// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr -Xclang -verify %s -I%S/../../include -oNumDiff.out // RUN: ./NumDiff.out | %filecheck_exec %s //CHECK-NOT: {{.*error|warning|note:.*}} #include "clad/Differentiator/Differentiator.h" double test_1(double x){ - return tanh(x); + return tanh(x); // expected-warning {{function 'tanh' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}} + // expected-note@9 {{falling back to numerical differentiation for 'tanh'}} } -//CHECK: warning: Falling back to numerical differentiation for 'tanh' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF. -//CHECK: warning: Falling back to numerical differentiation for 'log10' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF. //CHECK: void test_1_grad(double x, double *_d_x) { //CHECK-NEXT: { @@ -21,7 +20,8 @@ double test_1(double x){ double test_2(double x){ - return std::log10(x); + return std::log10(x);// expected-warning {{function 'log10' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}} + // expected-note@23 {{falling back to numerical differentiation for 'log10'}} } //CHECK: double test_2_darg0(double x) { //CHECK-NEXT: double _d_x = 1; @@ -32,7 +32,8 @@ double test_2(double x){ double test_3(double x) { if (x > 0) { double constant = 11.; - return std::hypot(x, constant); + return std::hypot(x, constant); // expected-warning {{function 'hypot' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}} + // expected-note@35 {{falling back to numerical differentiation for 'hypot'}} } return 0; } diff --git a/test/NumericalDiff/PrintErrorNumDiff.C b/test/NumericalDiff/PrintErrorNumDiff.C index 942f04c8f..35957d7e1 100644 --- a/test/NumericalDiff/PrintErrorNumDiff.C +++ b/test/NumericalDiff/PrintErrorNumDiff.C @@ -1,6 +1,6 @@ -// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors %s -I%S/../../include -oPrintErrorNumDiff.out 2>&1 | FileCheck -check-prefix=CHECK %s +// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors %s -I%S/../../include -oPrintErrorNumDiff.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s // RUN: ./PrintErrorNumDiff.out | %filecheck_exec %s -// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oPrintErrorNumDiff.out +// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oPrintErrorNumDiff.out -Xclang -verify // RUN: ./PrintErrorNumDiff.out | %filecheck_exec %s //CHECK-NOT: {{.*error|warning|note:.*}} @@ -12,10 +12,10 @@ extern "C" int printf(const char* fmt, ...); double test_1(double x){ - return tanh(x); + return tanh(x); // expected-warning {{function 'tanh' was not differentiated because}} + // expected-note@15 {{falling back to numerical differentiation for 'tanh}} } -//CHECK: warning: Falling back to numerical differentiation for 'tanh' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF. //CHECK: void test_1_grad(double x, double *_d_x) { //CHECK-NEXT: { //CHECK-NEXT: double _r0 = 0.;