From 7cec7c8aa276d330bf4efbf7aa27443dcc765a93 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Mon, 5 Aug 2024 13:48:09 +0000 Subject: [PATCH 1/9] Return const ref instead of a copy. --- include/clad/Differentiator/DynamicGraph.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index db1113116..f7b5f61b0 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -106,7 +106,7 @@ template class DynamicGraph { bool isProcessingNode() { return m_currentId != -1; } /// Get the nodes in the graph. - std::vector getNodes() { return m_nodes; } + const std::vector& getNodes() { return m_nodes; } /// Print the nodes and edges in the graph. void print() { @@ -140,4 +140,4 @@ template class DynamicGraph { }; } // end namespace clad -#endif // CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H \ No newline at end of file +#endif // CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H From a9074926445c2a339b56634daa260f7a3d481798 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 5 Aug 2024 17:55:09 +0300 Subject: [PATCH 2/9] Support multiple indices in clad::gradient calls. Currently, the user can provide a string as the second argument of ``clad::gradient`` to specify independent parameters as a list of comma-separated names. This commit allows users to specify indices alongside with names. e.g. ``` clad::gradient(fn, "0"); clad::gradient(fn, "1, z"); ... ``` Previously, it was possible to provide a single index as an integer literal. e.g. ``` clad::gradient(fn, 0); ``` Fixes #46. --- README.md | 3 ++- lib/Differentiator/DiffPlanner.cpp | 15 +++++++++++++++ test/FirstDerivative/DiffInterface.C | 8 +++++--- test/Gradient/DiffInterface.C | 10 ++++++++++ 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 4410c8601..bfae57d9b 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,8 @@ Reverse-mode AD allows computing the gradient of `f` using *at most* a constant 1. `f` is a pointer to a function or a method to be differentiated 2. `ARGS` is either: * not provided, then `f` is differentiated w.r.t. its every argument - * a string literal with comma-separated names of independent variables (e.g. `"x"` or `"y"` or `"x, y"` or `"y, x"`) + * a string literal with comma-separated names/indices of independent variables (e.g. `"x"`, `"y"`, `"x, y"`, `"y, x"`, "0, 1", "0, y", etc.) + * a SINGLE number representing the index of the independent variable Since a vector of derivatives must be returned from a function generated by the reverse mode, its signature is slightly different. The generated function has `void` return type and same input arguments. The function has additional `n` arguments (where `n` refers to the number of arguments whose gradient was requested) of type `T*`, where `T` is the type of the corresponding original variable. Each of these variables stores the derivative of the elements as they appear in the orignal function signature. *The caller is responsible for allocating and zeroing-out the gradient storage*. Example: ```cpp diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index c7ea8af14..0f7829a20 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -374,6 +374,21 @@ namespace clad { DiffInputVarInfo dVarInfo; dVarInfo.source = diffSpec.str(); + // Check if diffSpec represents an index of an independent variable. + if ('0' <= diffSpec[0] && diffSpec[0] <= '9') { + unsigned idx = std::stoi(dVarInfo.source); + // Fail if the specified index is invalid. + if (idx >= FD->getNumParams()) { + utils::EmitDiag( + semaRef, DiagnosticsEngine::Error, diffArgs->getEndLoc(), + "Invalid argument index '%0' of '%1' argument(s)", + {std::to_string(idx), std::to_string(FD->getNumParams())}); + return; + } + dVarInfo.param = FD->getParamDecl(idx); + DVI.push_back(dVarInfo); + continue; + } llvm::StringRef pName = computeParamName(diffSpec); auto it = std::find_if(std::begin(candidates), std::end(candidates), [&pName]( diff --git a/test/FirstDerivative/DiffInterface.C b/test/FirstDerivative/DiffInterface.C index df764e493..7e538faa7 100644 --- a/test/FirstDerivative/DiffInterface.C +++ b/test/FirstDerivative/DiffInterface.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | %filecheck %s +// RUN: %cladclang -ferror-limit=100 %s -I%S/../../include -fsyntax-only -Xclang -verify 2>&1 | %filecheck %s #include "clad/Differentiator/Differentiator.h" @@ -131,8 +131,6 @@ int main () { clad::differentiate(f_2, -1); // expected-error {{Invalid argument index '-1' of '3' argument(s)}} - clad::differentiate(f_2, -1); // expected-error {{Invalid argument index '-1' of '3' argument(s)}} - clad::differentiate(f_2, 3); // expected-error {{Invalid argument index '3' of '3' argument(s)}} clad::differentiate(f_2, 9); // expected-error {{Invalid argument index '9' of '3' argument(s)}} @@ -141,6 +139,10 @@ int main () { clad::differentiate(f_2, f_2); // expected-error {{Failed to parse the parameters, must be a string or numeric literal}} + clad::gradient(f_2, -1); // expected-error {{Invalid argument index '-1' of '3' argument(s)}} + + clad::gradient(f_2, "9"); // expected-error {{Invalid argument index '9' of '3' argument(s)}} + clad::differentiate(f_3, 0); // expected-error {{Invalid argument index '0' of '0' argument(s)}} float one = 1.0; diff --git a/test/Gradient/DiffInterface.C b/test/Gradient/DiffInterface.C index 5eca705dc..63e246c74 100644 --- a/test/Gradient/DiffInterface.C +++ b/test/Gradient/DiffInterface.C @@ -98,12 +98,22 @@ int main () { auto f1_grad_y = clad::gradient(f_1, "y"); TEST(f1_grad_y, &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} + + auto f1_grad_0 = clad::gradient(f_1, "1"); + TEST(f1_grad_0, &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} + auto f1_grad_z = clad::gradient(f_1, "z"); TEST(f1_grad_z, &result[2]); // CHECK-EXEC: {0.00, 0.00, 2.00} auto f1_grad_xy = clad::gradient(f_1, "x, y"); TEST(f1_grad_xy, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} + auto f1_grad_0y = clad::gradient(f_1, "0, y"); + TEST(f1_grad_0y, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} + + auto f1_grad_10 = clad::gradient(f_1, "1, 0"); + TEST(f1_grad_10, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} + auto f1_grad_yx = clad::gradient(f_1, "y, x"); TEST(f1_grad_yx, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00} From 73b93913a4c7275aa24539670bf4e046f9f793f9 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Mon, 5 Aug 2024 16:59:02 +0300 Subject: [PATCH 3/9] Support CXXScalarValueInitExpr --- include/clad/Differentiator/BaseForwardModeVisitor.h | 2 ++ lib/Differentiator/BaseForwardModeVisitor.cpp | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/include/clad/Differentiator/BaseForwardModeVisitor.h b/include/clad/Differentiator/BaseForwardModeVisitor.h index 3de507d69..fa7c09ad6 100644 --- a/include/clad/Differentiator/BaseForwardModeVisitor.h +++ b/include/clad/Differentiator/BaseForwardModeVisitor.h @@ -102,6 +102,8 @@ class BaseForwardModeVisitor StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE); StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE); StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE); + StmtDiff + VisitCXXScalarValueInitExpr(const clang::CXXScalarValueInitExpr* SVIE); StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* CSE); StmtDiff VisitCXXFunctionalCastExpr(const clang::CXXFunctionalCastExpr* FCE); StmtDiff VisitCXXBindTemporaryExpr(const clang::CXXBindTemporaryExpr* BTE); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 4135c6972..f06120c6c 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -2271,6 +2271,11 @@ StmtDiff BaseForwardModeVisitor::VisitCXXStdInitializerListExpr( return Visit(ILE->getSubExpr()); } +StmtDiff BaseForwardModeVisitor::VisitCXXScalarValueInitExpr( + const CXXScalarValueInitExpr* SVIE) { + return {Clone(SVIE), Clone(SVIE)}; +} + clang::Expr* BaseForwardModeVisitor::BuildCustomDerivativeConstructorPFCall( const clang::CXXConstructExpr* CE, llvm::SmallVectorImpl& clonedArgs, From 33d9441969ba4e9643543cf8ed6037227780ec3d Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 1 Aug 2024 19:50:22 +0300 Subject: [PATCH 4/9] Support bitwise, shift, comparison, remainder, not operators. This commit adds support for bitwise, shift, comparison, remainder, and bitwise not operators. Shift operators are considered differentiable since they essentially represent multiplication by ``2^n`` or ``2^-n``, where ``n`` is the RHS of the shift operators ``<<`` and ``>>``. Not operators are considered differentiable as well because they represent ``2^n - 1 - x`` or ``- 1 - x`` (depending on whether the type is signed) so the derivative is ``-_d_x``. Other operators have unclear differentiable effects and so they are considered non-differentiable. Fixes #381. --- lib/Differentiator/BaseForwardModeVisitor.cpp | 24 ++++++++++++++----- test/FirstDerivative/UnsupportedOpsWarn.C | 17 +++++-------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index f06120c6c..c53757261 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -801,8 +801,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) || condUO) { condDiff = Visit(cond); - if (condDiff.getExpr_dx() && - (!isUnusedResult(condDiff.getExpr_dx()) || condUO)) + if (condDiff.getExpr_dx() && (!isUnusedResult(condDiff.getExpr_dx()))) cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), BuildParens(condDiff.getExpr())); else @@ -1381,7 +1380,15 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { } else if (opKind == UnaryOperatorKind::UO_AddrOf) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_LNot) { - return StmtDiff(op, diff.getExpr_dx()); + Expr* zero = getZeroInit(UnOp->getType()); + if (diff.getExpr_dx() && !isUnusedResult(diff.getExpr_dx())) + return {BuildOp(BO_Comma, BuildParens(diff.getExpr_dx()), op), zero}; + return {op, zero}; + } else if (opKind == UnaryOperatorKind::UO_Not) { + // ~x is 2^n - 1 - x for unsigned types and -x - 1 for the signed ones. + // Either way, taking a derivative gives us -_d_x. + Expr* derivedOp = BuildOp(UO_Minus, diff.getExpr_dx()); + return {op, derivedOp}; } else { unsupportedOpWarn(UnOp->getEndLoc()); auto zero = @@ -1497,7 +1504,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { } else opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()), BuildParens(Rdiff.getExpr_dx())); - } else if (BinOp->isLogicalOp()) { + } else if (BinOp->isLogicalOp() || BinOp->isBitwiseOp() || + BinOp->isComparisonOp() || opCode == BO_Rem) { // For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and // correct derivative execution. auto buildOneSide = [this](StmtDiff& Xdiff) { @@ -1514,8 +1522,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { // Since the both parts are included in the opDiff, there's no point in // including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left - // nullptr is used for treating expressions like ((A && B) && C) correctly. - return StmtDiff(opDiff, nullptr); + // zero is used for treating expressions like ((A && B) && C) correctly. + return StmtDiff(opDiff, getZeroInit(BinOp->getType())); + } else if (BinOp->isShiftOp()) { + // Shifting is essentially multiplicating the LHS by 2^RHS (or 2^-RHS). + // We should do the same to the derivarive. + opDiff = BuildOp(opCode, Ldiff.getExpr_dx(), Rdiff.getExpr()); } else { // FIXME: add support for other binary operators unsupportedOpWarn(BinOp->getEndLoc()); diff --git a/test/FirstDerivative/UnsupportedOpsWarn.C b/test/FirstDerivative/UnsupportedOpsWarn.C index 0f59ac961..551391407 100644 --- a/test/FirstDerivative/UnsupportedOpsWarn.C +++ b/test/FirstDerivative/UnsupportedOpsWarn.C @@ -6,12 +6,10 @@ //CHECK-NOT: {{.*error|warning|note:.*}} int binOpWarn_0(int x){ - return x << 1; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}} + return x << 1; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}} } -// CHECK: int binOpWarn_0_darg0(int x) { -// CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: return 0; +// CHECK: void binOpWarn_0_grad(int x, int *_d_x) { // CHECK-NEXT: } @@ -23,17 +21,14 @@ int binOpWarn_1(int x){ // CHECK-NEXT: } int unOpWarn_0(int x){ - return ~x; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}} + return ~x; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}} } -// CHECK: int unOpWarn_0_darg0(int x) { -// CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: return 0; +// CHECK: void unOpWarn_0_grad(int x, int *_d_x) { // CHECK-NEXT: } int main(){ - - clad::differentiate(binOpWarn_0, 0); + clad::gradient(binOpWarn_0); clad::gradient(binOpWarn_1); - clad::differentiate(unOpWarn_0, 0); + clad::gradient(unOpWarn_0); } From fc64644401c0498a5b00427c0860b967f8a308ba Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Tue, 6 Aug 2024 11:17:49 +0000 Subject: [PATCH 5/9] Add location information and improve clarity of diagnostics. This patch is a first step towards diagnostics refactoring in the context of non-differentiable propagators. --- include/clad/Differentiator/VisitorBase.h | 15 +++----- lib/Differentiator/BaseForwardModeVisitor.cpp | 3 +- lib/Differentiator/DerivativeBuilder.cpp | 9 +---- .../ReverseModeForwPassVisitor.cpp | 1 - lib/Differentiator/ReverseModeVisitor.cpp | 5 +-- lib/Differentiator/VisitorBase.cpp | 36 +++++++++++-------- test/Gradient/NonDifferentiableError.C | 7 ++++ test/NumericalDiff/GradientMultiArg.C | 8 ++--- test/NumericalDiff/NoNumDiff.C | 11 +++--- test/NumericalDiff/NumDiff.C | 15 ++++---- test/NumericalDiff/PrintErrorNumDiff.C | 8 ++--- 11 files changed, 57 insertions(+), 61 deletions(-) diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index b68318f7b..c3759e2d4 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -394,14 +394,12 @@ namespace clad { /// to avoid recomputation. static bool UsefulToStore(clang::Expr* E); /// A flag for silencing warnings/errors output by diag function. - bool silenceDiags = false; /// Shorthand to issues a warning or error. template void diag(clang::DiagnosticsEngine::Level level, // Warning or Error clang::SourceLocation loc, const char (&format)[N], llvm::ArrayRef args = {}) { - if (!silenceDiags) - m_Builder.diag(level, loc, format, args); + m_Builder.diag(level, loc, format, args); } /// Creates unique identifier of the form "_nameBase" that is @@ -584,17 +582,14 @@ namespace clad { clang::Expr* GetSingleArgCentralDiffCall( clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos, unsigned numArgs, llvm::SmallVectorImpl& args); + /// Emits diagnostic messages on differentiation (or lack thereof) for /// call expressions. /// - /// \param[in] \c funcName The name of the underlying function of the - /// call expression. + /// \param[in] \c FD - The function declaration. /// \param[in] \c srcLoc Any associated source location information. - /// \param[in] \c isDerived A flag to determine if differentiation of the - /// call expression was successful. - void CallExprDiffDiagnostics(llvm::StringRef funcName, - clang::SourceLocation srcLoc, - bool isDerived); + void CallExprDiffDiagnostics(const clang::FunctionDecl* FD, + clang::SourceLocation srcLoc); clang::QualType DetermineCladArrayValueType(clang::QualType T); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index c53757261..7a70f3bcb 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -63,7 +63,6 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive(const FunctionDecl* FD, const DiffRequest& request) { assert(m_DiffReq == request && "Can't pass two different requests!"); - silenceDiags = !request.VerboseDiags; m_Functor = request.Functor; assert(m_DiffReq.Mode == DiffMode::forward); assert(!m_DerivativeInFlight && @@ -1331,7 +1330,7 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { GetSingleArgCentralDiffCall(fnCallee, CallArgs[0], /*targetPos=*/0, /*numArgs=*/1, CallArgs); } - CallExprDiffDiagnostics(FD->getNameAsString(), CE->getBeginLoc(), callDiff); + CallExprDiffDiagnostics(FD, CE->getBeginLoc()); if (!callDiff) { auto zero = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index a7c657eaa..a4defc87e 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -194,15 +194,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { } if (!NSD) { NSD = utils::LookupNSD(m_Sema, namespaceID, namespaceShouldExist); - if (!forCustomDerv && !NSD) { - diag(DiagnosticsEngine::Warning, noLoc, - "Numerical differentiation is diabled using the " - "-DCLAD_NO_NUM_DIFF " - "flag, this means that every try to numerically differentiate a " - "function will fail! Remove the flag to revert to default " - "behaviour."); + if (!NSD) return R; - } } DeclContext* DC = NSD; diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index 9d892868f..6cf306f2f 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -20,7 +20,6 @@ DerivativeAndOverload ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, const DiffRequest& request) { assert(m_DiffReq == request); - silenceDiags = !request.VerboseDiags; assert(m_DiffReq.Mode == DiffMode::reverse_mode_forward_pass); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 38a7fbe39..97c78d519 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -256,7 +256,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const DiffRequest& request) { if (m_ExternalSource) m_ExternalSource->ActOnStartOfDerive(); - silenceDiags = !request.VerboseDiags; assert(m_DiffReq == request); // FIXME: reverse mode plugins may have request mode other than @@ -479,7 +478,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // for the two 'Derive's being different functions. if (m_ExternalSource) m_ExternalSource->ActOnStartOfDerive(); - silenceDiags = !request.VerboseDiags; // FIXME: We should not use const_cast to get the decl request here. // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) const_cast(m_DiffReq) = request; @@ -1943,8 +1941,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CE->getNumArgs(), dfdx(), PreCallStmts, PostCallStmts, DerivedCallArgs, CallArgDx); } - CallExprDiffDiagnostics(FD->getNameAsString(), CE->getBeginLoc(), - OverloadedDerivedFn); + CallExprDiffDiagnostics(FD, CE->getBeginLoc()); if (!OverloadedDerivedFn) { Stmts& block = getCurrentBlock(direction::reverse); block.insert(block.begin(), PreCallStmts.begin(), diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index 7c4b0abb1..aacfad9cb 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -17,6 +17,7 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/Expr.h" #include "clang/AST/TemplateBase.h" +#include "clang/Lex/Preprocessor.h" #include "clang/Sema/Lookup.h" #include "clang/Sema/Overload.h" #include "clang/Sema/Scope.h" @@ -733,23 +734,28 @@ namespace clad { /*namespaceShouldExist=*/false); } - void VisitorBase::CallExprDiffDiagnostics(llvm::StringRef funcName, - SourceLocation srcLoc, bool isDerived){ - if (!isDerived) { - // Function was not derived => issue a warning. - diag(DiagnosticsEngine::Warning, - srcLoc, - "function '%0' was not differentiated because clad failed to " - "differentiate it and no suitable overload was found in " - "namespace 'custom_derivatives', and function may not be " - "eligible for numerical differentiation.", + void VisitorBase::CallExprDiffDiagnostics(const clang::FunctionDecl* FD, + SourceLocation srcLoc) { + bool NumDiffEnabled = + !m_Sema.getPreprocessor().isMacroDefined("CLAD_NO_NUM_DIFF"); + // FIXME: Switch to the real diagnostics engine and pass FD directly. + std::string funcName = FD->getNameAsString(); + diag(DiagnosticsEngine::Warning, srcLoc, + "function '%0' was not differentiated because clad failed to " + "differentiate it and no suitable overload was found in " + "namespace 'custom_derivatives'", + {funcName}); + if (NumDiffEnabled) { + diag(DiagnosticsEngine::Note, srcLoc, + "falling back to numerical differentiation for '%0' since no " + "suitable overload was found and clad could not derive it; " + "to disable this feature, compile your programs with " + "-DCLAD_NO_NUM_DIFF", {funcName}); } else { - diag(DiagnosticsEngine::Warning, noLoc, - "Falling back to numerical differentiation for '%0' since no " - "suitable overload was found and clad could not derive it. " - "To disable this feature, compile your programs with " - "-DCLAD_NO_NUM_DIFF.", + diag(DiagnosticsEngine::Note, srcLoc, + "fallback to numerical differentiation is disabled by the " + "'CLAD_NO_NUM_DIFF' macro; considering '%0' as 0", {funcName}); } } diff --git a/test/Gradient/NonDifferentiableError.C b/test/Gradient/NonDifferentiableError.C index 501c16268..83558d077 100644 --- a/test/Gradient/NonDifferentiableError.C +++ b/test/Gradient/NonDifferentiableError.C @@ -29,6 +29,12 @@ non_differentiable double fn_s2_mem_fn(double i, double j) { return obj.mem_fn(i, j) + i * j; } +double no_body(double x); + +double fn1(double x) { return no_body(x); } //expected-warning {{function 'no_body' was not differentiated}} +//expected-note@34 {{fallback to numerical differentiation is disabled}} +double fn2(double x) { return fn1(x); } + #define INIT_EXPR(classname) \ classname expr_1(2, 3); \ classname expr_2(3, 5); @@ -48,4 +54,5 @@ int main() { INIT_EXPR(SimpleFunctions2); TEST_CLASS(SimpleFunctions2, mem_fn, 3, 5); TEST_FUNC(fn_s2_mem_fn, 3, 5); // expected-error {{attempted differentiation of function 'fn_s2_mem_fn', which is marked as non-differentiable}} + auto fn2_grad = clad::gradient(fn2); } diff --git a/test/NumericalDiff/GradientMultiArg.C b/test/NumericalDiff/GradientMultiArg.C index 840dc7fb0..721620c8c 100644 --- a/test/NumericalDiff/GradientMultiArg.C +++ b/test/NumericalDiff/GradientMultiArg.C @@ -1,6 +1,6 @@ -// RUN: %cladnumdiffclang %s -I%S/../../include -oGradientMultiArg.out 2>&1 | FileCheck -check-prefix=CHECK %s +// RUN: %cladnumdiffclang %s -I%S/../../include -oGradientMultiArg.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s // RUN: ./GradientMultiArg.out | %filecheck_exec %s -// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oGradientMultiArg.out +// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oGradientMultiArg.out -Xclang -verify // RUN: ./GradientMultiArg.out | %filecheck_exec %s //CHECK-NOT: {{.*error|warning|note:.*}} @@ -11,9 +11,9 @@ #include double test_1(double x, double y){ - return std::hypot(x, y); + return std::hypot(x, y); // expected-warning {{function 'hypot' was not differentiated}} + // expected-note@14 {{falling back to numerical differentiation}} } -// CHECK: warning: Falling back to numerical differentiation for 'hypot' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF. // CHECK: void test_1_grad(double x, double y, double *_d_x, double *_d_y) { // CHECK-NEXT: { // CHECK-NEXT: double _r0 = 0; diff --git a/test/NumericalDiff/NoNumDiff.C b/test/NumericalDiff/NoNumDiff.C index 2dcf0a618..811efb765 100644 --- a/test/NumericalDiff/NoNumDiff.C +++ b/test/NumericalDiff/NoNumDiff.C @@ -1,4 +1,4 @@ -// RUN: %cladclang %s -I%S/../../include -oNoNumDiff.out 2>&1 | FileCheck -check-prefix=CHECK %s +// RUN: %cladclang %s -I%S/../../include -oNoNumDiff.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s //CHECK-NOT: {{.*error|warning|note:.*}} @@ -6,10 +6,9 @@ #include -double func(double x) { return std::tanh(x); } +double func(double x) { return std::tanh(x); } // expected-warning 2{{function 'tanh' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}} +// expected-note@9 2{{fallback to numerical differentiation is disabled by the 'CLAD_NO_NUM_DIFF' macro}} -//CHECK: warning: Numerical differentiation is diabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour. -//CHECK: warning: Numerical differentiation is diabled using the -DCLAD_NO_NUM_DIFF flag, this means that every try to numerically differentiate a function will fail! Remove the flag to revert to default behaviour. //CHECK: double func_darg0(double x) { //CHECK-NEXT: double _d_x = 1; //CHECK-NEXT: return 0; @@ -24,6 +23,6 @@ double func(double x) { return std::tanh(x); } int main(){ - clad::differentiate(func, "x"); - clad::gradient(func); + clad::differentiate(func, "x"); + clad::gradient(func); } diff --git a/test/NumericalDiff/NumDiff.C b/test/NumericalDiff/NumDiff.C index c149123eb..1703412ff 100644 --- a/test/NumericalDiff/NumDiff.C +++ b/test/NumericalDiff/NumDiff.C @@ -1,15 +1,14 @@ -// RUN: %cladnumdiffclang %s -I%S/../../include -oNumDiff.out 2>&1 | FileCheck -check-prefix=CHECK %s +// RUN: %cladnumdiffclang %s -I%S/../../include -oNumDiff.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s // RUN: ./NumDiff.out | %filecheck_exec %s -// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oNumDiff.out +// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -enable-tbr -Xclang -verify %s -I%S/../../include -oNumDiff.out // RUN: ./NumDiff.out | %filecheck_exec %s //CHECK-NOT: {{.*error|warning|note:.*}} #include "clad/Differentiator/Differentiator.h" double test_1(double x){ - return tanh(x); + return tanh(x); // expected-warning {{function 'tanh' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}} + // expected-note@9 {{falling back to numerical differentiation for 'tanh'}} } -//CHECK: warning: Falling back to numerical differentiation for 'tanh' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF. -//CHECK: warning: Falling back to numerical differentiation for 'log10' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF. //CHECK: void test_1_grad(double x, double *_d_x) { //CHECK-NEXT: { @@ -21,7 +20,8 @@ double test_1(double x){ double test_2(double x){ - return std::log10(x); + return std::log10(x);// expected-warning {{function 'log10' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}} + // expected-note@23 {{falling back to numerical differentiation for 'log10'}} } //CHECK: double test_2_darg0(double x) { //CHECK-NEXT: double _d_x = 1; @@ -32,7 +32,8 @@ double test_2(double x){ double test_3(double x) { if (x > 0) { double constant = 11.; - return std::hypot(x, constant); + return std::hypot(x, constant); // expected-warning {{function 'hypot' was not differentiated because clad failed to differentiate it and no suitable overload was found in namespace 'custom_derivatives'}} + // expected-note@35 {{falling back to numerical differentiation for 'hypot'}} } return 0; } diff --git a/test/NumericalDiff/PrintErrorNumDiff.C b/test/NumericalDiff/PrintErrorNumDiff.C index 3ea747374..e52980d32 100644 --- a/test/NumericalDiff/PrintErrorNumDiff.C +++ b/test/NumericalDiff/PrintErrorNumDiff.C @@ -1,6 +1,6 @@ -// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors %s -I%S/../../include -oPrintErrorNumDiff.out 2>&1 | FileCheck -check-prefix=CHECK %s +// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors %s -I%S/../../include -oPrintErrorNumDiff.out -Xclang -verify 2>&1 | FileCheck -check-prefix=CHECK %s // RUN: ./PrintErrorNumDiff.out | %filecheck_exec %s -// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oPrintErrorNumDiff.out +// RUN: %cladnumdiffclang -Xclang -plugin-arg-clad -Xclang -fprint-num-diff-errors -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oPrintErrorNumDiff.out -Xclang -verify // RUN: ./PrintErrorNumDiff.out | %filecheck_exec %s //CHECK-NOT: {{.*error|warning|note:.*}} @@ -12,10 +12,10 @@ extern "C" int printf(const char* fmt, ...); double test_1(double x){ - return tanh(x); + return tanh(x); // expected-warning {{function 'tanh' was not differentiated because}} + // expected-note@15 {{falling back to numerical differentiation for 'tanh}} } -//CHECK: warning: Falling back to numerical differentiation for 'tanh' since no suitable overload was found and clad could not derive it. To disable this feature, compile your programs with -DCLAD_NO_NUM_DIFF. //CHECK: void test_1_grad(double x, double *_d_x) { //CHECK-NEXT: { //CHECK-NEXT: double _r0 = 0; From 0c841b936e49f7c0f07914edbb91a69ecbdb6272 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Thu, 8 Aug 2024 06:16:36 +0000 Subject: [PATCH 6/9] Prepare for release v1.7. --- VERSION | 2 +- docs/internalDocs/ReleaseNotes.md | 47 +++++++++++++++++-------------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/VERSION b/VERSION index 5cc610824..d3bdbdf1f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.7~dev +1.7 diff --git a/docs/internalDocs/ReleaseNotes.md b/docs/internalDocs/ReleaseNotes.md index 70640f602..cdee01af2 100644 --- a/docs/internalDocs/ReleaseNotes.md +++ b/docs/internalDocs/ReleaseNotes.md @@ -26,36 +26,36 @@ External Dependencies Forward Mode & Reverse Mode --------------------------- -* +* Add propagators for `__builtin_pow` and `__builtin_log` +* Support range-based for loops +* Improve diagnostics clarity Forward Mode ------------ -* +* Advance support of frameworks such as Kokkos +* Support `std::array` Reverse Mode ------------ -* +* Support non_differentiable attribute -CUDA ----- -* - -Error Estimation ----------------- -* - -Misc ----- -* Fixed Bugs ---------- -[XXX](https://github.com/vgvassilev/clad/issues/XXX) +[46](https://github.com/vgvassilev/clad/issues/46) +[381](https://github.com/vgvassilev/clad/issues/381) +[479](https://github.com/vgvassilev/clad/issues/479) +[525](https://github.com/vgvassilev/clad/issues/525) +[717](https://github.com/vgvassilev/clad/issues/717) +[723](https://github.com/vgvassilev/clad/issues/723) +[829](https://github.com/vgvassilev/clad/issues/829) +[979](https://github.com/vgvassilev/clad/issues/979) +[983](https://github.com/vgvassilev/clad/issues/983) +[986](https://github.com/vgvassilev/clad/issues/986) +[988](https://github.com/vgvassilev/clad/issues/988) +[1005](https://github.com/vgvassilev/clad/issues/1005) - Special Kudos ============= @@ -67,6 +67,11 @@ FirstName LastName (#commits) A B (N) - +petro.zarytskyi (11) +Vassil Vassilev (11) +Atell Krasnopolski (5) +Vaibhav Thakkar (2) +Mihail Mihov (2) +ovdiiuv (1) +Rohan Julka (1) +Max Andriychuk (1) From b0995ff840403a521fc08513b393ecb72a7b08e4 Mon Sep 17 00:00:00 2001 From: Vassil Vassilev Date: Thu, 8 Aug 2024 06:23:30 +0000 Subject: [PATCH 7/9] Bump clad version to 1.8 --- VERSION | 2 +- docs/internalDocs/ReleaseNotes.md | 51 ++++++++++++++----------------- 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/VERSION b/VERSION index d3bdbdf1f..f21da21f5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.7 +1.8~dev diff --git a/docs/internalDocs/ReleaseNotes.md b/docs/internalDocs/ReleaseNotes.md index cdee01af2..de0a3df6b 100644 --- a/docs/internalDocs/ReleaseNotes.md +++ b/docs/internalDocs/ReleaseNotes.md @@ -2,7 +2,7 @@ Introduction ============ This document contains the release notes for the automatic differentiation -plugin for clang Clad, release 1.7. Clad is built on top of +plugin for clang Clad, release 1.8. Clad is built on top of [Clang](http://clang.llvm.org) and [LLVM](http://llvm.org>) compiler infrastructure. Here we describe the status of Clad in some detail, including major improvements from the previous release and new feature work. @@ -11,7 +11,7 @@ Note that if you are reading this file from a git checkout, this document applies to the *next* release, not the current one. -What's New in Clad 1.7? +What's New in Clad 1.8? ======================== Some of the major new features and improvements to Clad are listed here. Generic @@ -26,36 +26,36 @@ External Dependencies Forward Mode & Reverse Mode --------------------------- -* Add propagators for `__builtin_pow` and `__builtin_log` -* Support range-based for loops -* Improve diagnostics clarity +* Forward Mode ------------ -* Advance support of frameworks such as Kokkos -* Support `std::array` +* Reverse Mode ------------ -* Support non_differentiable attribute +* +CUDA +---- +* + +Error Estimation +---------------- +* + +Misc +---- +* Fixed Bugs ---------- -[46](https://github.com/vgvassilev/clad/issues/46) -[381](https://github.com/vgvassilev/clad/issues/381) -[479](https://github.com/vgvassilev/clad/issues/479) -[525](https://github.com/vgvassilev/clad/issues/525) -[717](https://github.com/vgvassilev/clad/issues/717) -[723](https://github.com/vgvassilev/clad/issues/723) -[829](https://github.com/vgvassilev/clad/issues/829) -[979](https://github.com/vgvassilev/clad/issues/979) -[983](https://github.com/vgvassilev/clad/issues/983) -[986](https://github.com/vgvassilev/clad/issues/986) -[988](https://github.com/vgvassilev/clad/issues/988) -[1005](https://github.com/vgvassilev/clad/issues/1005) +[XXX](https://github.com/vgvassilev/clad/issues/XXX) + Special Kudos ============= @@ -67,11 +67,6 @@ FirstName LastName (#commits) A B (N) -petro.zarytskyi (11) -Vassil Vassilev (11) -Atell Krasnopolski (5) -Vaibhav Thakkar (2) -Mihail Mihov (2) -ovdiiuv (1) -Rohan Julka (1) -Max Andriychuk (1) + From 1b81084bd69b15c776582e24aa5d7bbe411c5b81 Mon Sep 17 00:00:00 2001 From: ovdiiuv <104850830+ovdiiuv@users.noreply.github.com> Date: Thu, 8 Aug 2024 08:57:40 +0200 Subject: [PATCH 8/9] Product of references in different scope fix (#1030) --- lib/Differentiator/ReverseModeVisitor.cpp | 10 +++++++++- test/Gradient/Loops.C | 10 ++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 97c78d519..fb1102d66 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2292,7 +2292,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff RResult; // If R has no side effects, it can be just cloned // (no need to store it). - if (!ShouldRecompute(R)) { + + // Check if the local variable declaration is reference type, since it is + // moved to the global scope and the right side should be recomputed + bool promoteToFnScope = false; + if (auto* RDeclRef = dyn_cast(R->IgnoreImplicit())) + promoteToFnScope = RDeclRef->getDecl()->getType()->isReferenceType() && + !getCurrentScope()->isFunctionScope(); + + if (!ShouldRecompute(R) || promoteToFnScope) { RDelayed = std::unique_ptr( new DelayedStoreResult(DelayedGlobalStoreAndRef(R))); RResult = RDelayed->Result; diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index d648a1149..7dfd0ba9a 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2689,7 +2689,7 @@ double fn34(double x, double y){ double r = 0; double a[] = {y, x*y, x*x + y}; for(auto& i: a){ - r+=i; + r+=i*i; } return r; } @@ -2724,7 +2724,7 @@ double fn34(double x, double y){ //CHECK-NEXT: clad::push(_t3, _d_i); //CHECK-NEXT: } //CHECK-NEXT: clad::push(_t1, r); -//CHECK-NEXT: r += *i; +//CHECK-NEXT: r += *i * *i; //CHECK-NEXT: } //CHECK-NEXT: _d_r += 1; //CHECK-NEXT: for (; _t0; _t0--) { @@ -2737,7 +2737,8 @@ double fn34(double x, double y){ //CHECK-NEXT: { //CHECK-NEXT: r = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_r; -//CHECK-NEXT: *_d_i += _r_d0; +//CHECK-NEXT: *_d_i += _r_d0 * *i; +//CHECK-NEXT: *_d_i += *i * _r_d0; //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: } @@ -2751,6 +2752,7 @@ double fn34(double x, double y){ //CHECK-NEXT: } //CHECK-NEXT: } + double fn35(double x, double y){ double a[] = {1, 2, 3}; double sum = 0; @@ -2901,7 +2903,7 @@ int main() { TEST_2(fn32, 3, 5); // CHECK-EXEC: {45.00, 27.00} TEST_2(fn33, 3, 5); // CHECK-EXEC: {15.00, 9.00} - TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00} + TEST_2(fn34, 2, 2); // CHECK-EXEC: {64.00, 32.00} TEST_2(fn35, 1, 1); // CHECK-EXEC: {1.89, 0.00} } From 6cc83eeca4a42c862493898a892b7c57abf55109 Mon Sep 17 00:00:00 2001 From: ovdiiuv <104850830+ovdiiuv@users.noreply.github.com> Date: Sat, 10 Aug 2024 22:41:42 +0200 Subject: [PATCH 9/9] Redesign of the rangebased for loops body (#1034) Fixes:#1019 Fixes:#1033 --- .../clad/Differentiator/ReverseModeVisitor.h | 3 +- lib/Differentiator/ReverseModeVisitor.cpp | 205 ++++++------ test/Gradient/Loops.C | 292 +++++++++++++++--- 3 files changed, 358 insertions(+), 142 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index a161f1f58..8d899a4ac 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -411,7 +411,8 @@ namespace clad { StmtDiff VisitSwitchStmt(const clang::SwitchStmt* SS); StmtDiff VisitCaseStmt(const clang::CaseStmt* CS); StmtDiff VisitDefaultStmt(const clang::DefaultStmt* DS); - DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); + DeclDiff DifferentiateVarDecl(const clang::VarDecl* VD, + bool AddToBlock = true); StmtDiff VisitSubstNonTypeTemplateParmExpr( const clang::SubstNonTypeTemplateParmExpr* NTTP); StmtDiff diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index fb1102d66..8308da4b7 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -976,91 +976,99 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitCXXForRangeStmt(const CXXForRangeStmt* FRS) { + beginBlock(direction::reverse); + LoopCounter loopCounter(*this); beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | Scope::ContinueScope); - beginBlock(direction::reverse); - LoopCounter loopCounter(*this); + llvm::SaveAndRestore SaveCurrentBreakFlagExpr( + m_CurrentBreakFlagExpr); + m_CurrentBreakFlagExpr = nullptr; + auto* activeBreakContHandler = PushBreakContStmtHandler(); + activeBreakContHandler->BeginCFSwitchStmtScope(); const VarDecl* LoopVD = FRS->getLoopVariable(); - const Stmt* RangeDecl = FRS->getRangeStmt(); - const Stmt* BeginDecl = FRS->getBeginStmt(); - StmtDiff VisitRange = Visit(RangeDecl); - StmtDiff VisitBegin = Visit(BeginDecl); - Expr* BeginExpr = cast(VisitBegin.getStmt())->getLHS(); + llvm::SaveAndRestore SaveIsInside(isInsideLoop, + /*NewValue=*/false); + + const auto* RangeDecl = cast(FRS->getRangeStmt()->getSingleDecl()); + const auto* BeginDecl = cast(FRS->getBeginStmt()->getSingleDecl()); + + DeclDiff VisitRange = DifferentiateVarDecl(RangeDecl, false); + DeclDiff VisitBegin = DifferentiateVarDecl(BeginDecl, false); beginBlock(direction::reverse); // Create all declarations needed. - auto* BeginDeclRef = cast(BeginExpr); - Expr* d_BeginDeclRef = m_Variables[BeginDeclRef->getDecl()]; - - auto* RangeExpr = - cast(cast(VisitRange.getStmt())->getLHS()); - - Expr* RangeInit = Clone(FRS->getRangeInit()); - Expr* AssignRange = - BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit)); - Expr* AssignBegin = - BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr)); - addToCurrentBlock(AssignRange); - addToCurrentBlock(AssignBegin); - const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); + DeclRefExpr* beginDeclRef = BuildDeclRef(VisitBegin.getDecl()); + Expr* d_beginDeclRef = m_Variables[beginDeclRef->getDecl()]; + DeclRefExpr* rangeDeclRef = BuildDeclRef(VisitRange.getDecl()); + Expr* d_rangeDeclRef = m_Variables[rangeDeclRef->getDecl()]; + + Expr* rangeInit = Clone(FRS->getRangeInit()); + Expr* d_rangeInitDeclRef = + m_Variables[cast(rangeInit)->getDecl()]; + VisitRange.getDecl_dx()->setInit(BuildOp(UO_AddrOf, d_rangeInitDeclRef)); + Expr* assignAdjBegin = BuildOp(BO_Assign, d_beginDeclRef, d_rangeDeclRef); + Expr* assignRange = + BuildOp(BO_Assign, rangeDeclRef, BuildOp(UO_AddrOf, rangeInit)); + + addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl())); + addToCurrentBlock(BuildDeclStmt(VisitRange.getDecl_dx())); + addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl())); + addToCurrentBlock(BuildDeclStmt(VisitBegin.getDecl_dx())); + addToCurrentBlock(assignAdjBegin); + addToCurrentBlock(assignRange); - Expr* EndInit = cast(EndDecl->getInit())->getRHS(); - QualType EndType = CloneType(EndDecl->getType()); - std::string EndName = EndDecl->getNameAsString(); - Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit); - VarDecl* EndVarDecl = - BuildGlobalVarDecl(EndType, EndName, EndAssign, /*DirectInit=*/false); - DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl); - - addToCurrentBlock(AssignEnd); - auto* AssignEndVarDecl = - cast(cast(AssignEnd)->getSingleDecl()); - DeclRefExpr* EndExpr = BuildDeclRef(AssignEndVarDecl); - Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef); + const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); + QualType endType = CloneType(EndDecl->getType()); + std::string endName = EndDecl->getNameAsString(); + Expr* endInit = Visit(EndDecl->getInit()).getExpr(); + VarDecl* endVarDecl = + BuildGlobalVarDecl(endType, endName, endInit, /*DirectInit=*/false); + addToCurrentBlock(BuildDeclStmt(endVarDecl)); + DeclRefExpr* endExpr = BuildDeclRef(endVarDecl); + Expr* incBegin = BuildOp(UO_PreInc, beginDeclRef); beginBlock(direction::forward); DeclDiff LoopVDDiff = DifferentiateVarDecl(LoopVD); - Stmt* AdjLoopVDAddAssign = + Stmt* adjLoopVDAddAssign = utils::unwrapIfSingleStmt(endBlock(direction::forward)); - if ((LoopVDDiff.getDecl()->getDeclName() != LoopVD->getDeclName() || - LoopVD->getType() != LoopVDDiff.getDecl()->getType())) - m_DeclReplacements[LoopVD] = LoopVDDiff.getDecl(); llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop, /*NewValue=*/true); - Expr* d_IncBegin = BuildOp(UO_PreInc, d_BeginDeclRef); - Expr* d_DecBegin = BuildOp(UO_PostDec, d_BeginDeclRef); - Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr); - // Add item assignment statement to the body. + Expr* d_incBegin = BuildOp(UO_PreInc, d_beginDeclRef); + Expr* d_decBegin = BuildOp(UO_PostDec, d_beginDeclRef); + Expr* forwardCond = BuildOp(BO_NE, beginDeclRef, endExpr); const Stmt* body = FRS->getBody(); - StmtDiff bodyDiff = Visit(body); + StmtDiff bodyDiff = + DifferentiateLoopBody(body, loopCounter, nullptr, nullptr, + /*isForLoop=*/true); + + activeBreakContHandler->EndCFSwitchStmtScope(); + activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + PopBreakContStmtHandler(); StmtDiff storeLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl())); StmtDiff storeAdjLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl_dx())); - addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx())); - Expr* CounterIncrement = loopCounter.getCounterIncrement(); - Expr* LoopInit = LoopVDDiff.getDecl()->getInit(); + Expr* loopInit = LoopVDDiff.getDecl()->getInit(); LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType())); addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl())); - Expr* AssignLoop = - BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), LoopInit); + Expr* assignLoop = + BuildOp(BO_Assign, BuildDeclRef(LoopVDDiff.getDecl()), loopInit); if (!LoopVD->getType()->isReferenceType()) { Expr* d_LoopVD = BuildDeclRef(LoopVDDiff.getDecl_dx()); - AdjLoopVDAddAssign = - BuildOp(BO_Assign, d_LoopVD, BuildOp(UO_Deref, d_BeginDeclRef)); + adjLoopVDAddAssign = + BuildOp(BO_Assign, d_LoopVD, BuildOp(UO_Deref, d_beginDeclRef)); } beginBlock(direction::forward); - addToCurrentBlock(CounterIncrement); - addToCurrentBlock(AdjLoopVDAddAssign); - addToCurrentBlock(AssignLoop); + addToCurrentBlock(adjLoopVDAddAssign); + addToCurrentBlock(assignLoop); addToCurrentBlock(storeLoop.getStmt()); addToCurrentBlock(storeAdjLoop.getStmt()); CompoundStmt* LoopVDForwardDiff = endBlock(direction::forward); @@ -1068,28 +1076,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_Sema.getASTContext(), bodyDiff.getStmt(), LoopVDForwardDiff); beginBlock(direction::forward); - addToCurrentBlock(d_DecBegin); + addToCurrentBlock(d_decBegin); addToCurrentBlock(storeLoop.getStmt_dx()); addToCurrentBlock(storeAdjLoop.getStmt_dx()); CompoundStmt* LoopVDReverseDiff = endBlock(direction::forward); CompoundStmt* bodyReverse = utils::PrependAndCreateCompoundStmt( m_Sema.getASTContext(), bodyDiff.getStmt_dx(), LoopVDReverseDiff); - Expr* Inc = BuildOp(BO_Comma, IncBegin, d_IncBegin); + Expr* inc = BuildOp(BO_Comma, incBegin, d_incBegin); Stmt* Forward = new (m_Context) ForStmt( - m_Context, /*Init=*/nullptr, ForwardCond, /*CondVar=*/nullptr, Inc, + m_Context, /*Init=*/nullptr, forwardCond, /*CondVar=*/nullptr, inc, bodyForward, FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc()); - Expr* CounterCondition = + Expr* counterCondition = loopCounter.getCounterConditionResult().get().second; - Expr* CounterDecrement = loopCounter.getCounterDecrement(); + Expr* counterDecrement = loopCounter.getCounterDecrement(); Stmt* Reverse = bodyReverse; addToCurrentBlock(Reverse, direction::reverse); Reverse = endBlock(direction::reverse); Reverse = new (m_Context) - ForStmt(m_Context, /*Init=*/nullptr, CounterCondition, - /*CondVar=*/nullptr, CounterDecrement, Reverse, + ForStmt(m_Context, /*Init=*/nullptr, counterCondition, + /*CondVar=*/nullptr, counterDecrement, Reverse, FRS->getForLoc(), FRS->getBeginLoc(), FRS->getEndLoc()); addToCurrentBlock(Reverse, direction::reverse); Reverse = endBlock(direction::reverse); @@ -2647,18 +2655,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!BinOp->isComparisonOp() && !BinOp->isLogicalOp()) unsupportedOpWarn(BinOp->getEndLoc()); - // If either LHS or RHS is a declaration reference, visit it to avoid - // naming collision - auto* LDRE = dyn_cast(L); - auto* RDRE = dyn_cast(R); - - if (!LDRE && !RDRE) - return Clone(BinOp); - - Expr* LExpr = LDRE ? Visit(L).getExpr() : L; - Expr* RExpr = RDRE ? Visit(R).getExpr() : R; - - return BuildOp(opCode, LExpr, RExpr); + return BuildOp(opCode, Visit(L).getExpr(), Visit(R).getExpr()); } Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr()); @@ -2685,10 +2682,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(op, ResultRef, nullptr, valueForRevPass); } - DeclDiff - ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { + DeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD, + bool addToBlock) { StmtDiff initDiff; Expr* VDDerivedInit = nullptr; + // Local declarations are promoted to the function global scope. This // procedure is done to make declarations visible in the reverse sweep. // The reverse_mode_forward_pass mode does not have a reverse pass so @@ -2863,7 +2861,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, getZeroInit(VDDerivedType)); else assignToZero = GetCladZeroInit(declRef); - addToCurrentBlock(assignToZero, direction::reverse); + if (addToBlock) + addToCurrentBlock(assignToZero, direction::reverse); } } @@ -2879,10 +2878,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getForwSweepExpr_dx())); - addToCurrentBlock(assignDerivativeE); + if (addToBlock) + addToCurrentBlock(assignDerivativeE); if (isInsideLoop) { StmtDiff pushPop = StoreAndRestore(derivedVDE); - addToCurrentBlock(pushPop.getStmt(), direction::forward); + if (addToBlock) + addToCurrentBlock(pushPop.getStmt(), direction::forward); m_LoopBlock.back().push_back(pushPop.getStmt_dx()); } derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE); @@ -2908,10 +2909,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (promoteToFnScope) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, initDiff.getExpr_dx()); - addToCurrentBlock(assignDerivativeE, direction::forward); + if (addToBlock) + addToCurrentBlock(assignDerivativeE, direction::forward); if (isInsideLoop) { auto tape = MakeCladTapeFor(derivedVDE); - addToCurrentBlock(tape.Push); + if (addToBlock) + addToCurrentBlock(tape.Push); auto* reverseSweepDerivativePointerE = BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop); m_LoopBlock.back().push_back( @@ -2925,6 +2928,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } if (derivedVDE) m_Variables.emplace(VDClone, derivedVDE); + // Check if decl's name is the same as before. The name may be changed + // if decl name collides with something in the derivative body. + // This can happen in rare cases, e.g. when the original function + // has both y and _d_y (here _d_y collides with the name produced by + // the derivation process), e.g. + // double f(double x) { + // double y = x; + // double _d_y = x; + // } + // -> + // double f_darg0(double x) { + // double _d_x = 1; + // double _d_y = _d_x; // produced as a derivative for y + // double y = x; + // double _d__d_y = _d_x; + // double _d_y = x; // copied from original function, collides with + // _d_y + // } + if ((VD->getDeclName() != VDClone->getDeclName() || + VD->getType() != VDClone->getType())) + m_DeclReplacements[VD] = VDClone; return DeclDiff(VDClone, VDDerived); } @@ -3027,29 +3051,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!isLambda) VDDiff = DifferentiateVarDecl(VD); - // Check if decl's name is the same as before. The name may be changed - // if decl name collides with something in the derivative body. - // This can happen in rare cases, e.g. when the original function - // has both y and _d_y (here _d_y collides with the name produced by - // the derivation process), e.g. - // double f(double x) { - // double y = x; - // double _d_y = x; - // } - // -> - // double f_darg0(double x) { - // double _d_x = 1; - // double _d_y = _d_x; // produced as a derivative for y - // double y = x; - // double _d__d_y = _d_x; - // double _d_y = x; // copied from original function, collides with - // _d_y - // } - if (!isLambda && - (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || - VD->getType() != VDDiff.getDecl()->getType())) - m_DeclReplacements[VD] = VDDiff.getDecl(); - // Here, we move the declaration to the function global scope. // Initialization is replaced with an assignment operation at the same // place as the original declaration. This procedure is done to make the diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 7dfd0ba9a..e8d4e6ca2 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2695,11 +2695,6 @@ double fn34(double x, double y){ } //CHECK: void fn34_grad(double x, double y, double *_d_x, double *_d_y) { -//CHECK-NEXT: unsigned {{int|long}} _t0; -//CHECK-NEXT: double (*_d___range1)[3] = 0; -//CHECK-NEXT: double (*__range10)[3] = {}; -//CHECK-NEXT: double *_d___begin1 = 0; -//CHECK-NEXT: double *__begin10 = 0; //CHECK-NEXT: clad::tape _t1 = {}; //CHECK-NEXT: clad::tape _t2 = {}; //CHECK-NEXT: clad::tape _t3 = {}; @@ -2707,22 +2702,24 @@ double fn34(double x, double y){ //CHECK-NEXT: double r = 0; //CHECK-NEXT: double _d_a[3] = {0}; //CHECK-NEXT: double a[3] = {y, x * y, x * x + y}; -//CHECK-NEXT: _t0 = {{0U|0UL}}; -//CHECK-NEXT: _d___range1 = &_d_a; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = &a; +//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = 0; //CHECK-NEXT: _d___begin1 = *_d___range1; //CHECK-NEXT: __range10 = &a; -//CHECK-NEXT: __begin10 = *__range10; //CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; //CHECK-NEXT: double *_d_i = 0; //CHECK-NEXT: double *i = 0; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { //CHECK-NEXT: { -//CHECK-NEXT: _t0++; //CHECK-NEXT: _d_i = &*_d___begin1; //CHECK-NEXT: i = &*__begin10; //CHECK-NEXT: clad::push(_t2, i); //CHECK-NEXT: clad::push(_t3, _d_i); //CHECK-NEXT: } +//CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, r); //CHECK-NEXT: r += *i * *i; //CHECK-NEXT: } @@ -2752,72 +2749,287 @@ double fn34(double x, double y){ //CHECK-NEXT: } //CHECK-NEXT: } - double fn35(double x, double y){ + double r = 0; + double a[] = {x, x*y, 0}; + for(auto& i: a){ + for(auto& j:a){ + if(r<=x*x){ + r+=i*j; + }else if(r>x*x){ + break; + } + } + } + return r; +} + +//CHECK: void fn35_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _cond0 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _cond1 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: clad::tape _t4 = {}; +//CHECK-NEXT: clad::tape _t5 = {}; +//CHECK-NEXT: clad::tape _t6 = {}; +//CHECK-NEXT: clad::tape _t7 = {}; +//CHECK-NEXT: double _d_r = 0; +//CHECK-NEXT: double r = 0; +//CHECK-NEXT: double _d_a[3] = {0}; +//CHECK-NEXT: double a[3] = {x, x * y, 0}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = &a; +//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = 0; +//CHECK-NEXT: _d___begin1 = *_d___range1; +//CHECK-NEXT: __range10 = &a; +//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: double *_d_i = 0; +//CHECK-NEXT: double *i = 0; +//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_i = &*_d___begin1; +//CHECK-NEXT: i = &*__begin10; +//CHECK-NEXT: clad::push(_t6, i); +//CHECK-NEXT: clad::push(_t7, _d_i); +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, {{0U|0UL}}); +//CHECK-NEXT: double (*__range20)[3] = &a; +//CHECK-NEXT: double (*_d___range2)[3] = &_d_a; +//CHECK-NEXT: double *__begin20 = *__range20; +//CHECK-NEXT: double *_d___begin2 = 0; +//CHECK-NEXT: _d___begin2 = *_d___range2; +//CHECK-NEXT: __range20 = &a; +//CHECK-NEXT: double *__end20 = *__range20 + {{3|3L}}; +//CHECK-NEXT: double *_d_j = 0; +//CHECK-NEXT: double *j = 0; +//CHECK-NEXT: for (; __begin20 != __end20; ++__begin20 , ++_d___begin2) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_j = &*_d___begin2; +//CHECK-NEXT: j = &*__begin20; +//CHECK-NEXT: clad::push(_t4, j); +//CHECK-NEXT: clad::push(_t5, _d_j); +//CHECK-NEXT: } +//CHECK-NEXT: clad::back(_t1)++; +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_cond0, r <= x * x); +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: clad::push(_t2, r); +//CHECK-NEXT: r += *i * *j; +//CHECK-NEXT: } else { +//CHECK-NEXT: clad::push(_cond1, r > x * x); +//CHECK-NEXT: if (clad::back(_cond1)) { +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_t3, {{1U|1UL}}); +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t3, {{2U|2UL}}); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: _d_r += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin1--; +//CHECK-NEXT: i = clad::pop(_t6); +//CHECK-NEXT: _d_i = clad::pop(_t7); +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: for (; clad::back(_t1); clad::back(_t1)--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin2--; +//CHECK-NEXT: j = clad::pop(_t4); +//CHECK-NEXT: _d_j = clad::pop(_t5); +//CHECK-NEXT: } +//CHECK-NEXT: switch (clad::pop(_t3)) { +//CHECK-NEXT: case {{2U|2UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: { +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: { +//CHECK-NEXT: r = clad::pop(_t2); +//CHECK-NEXT: double _r_d0 = _d_r; +//CHECK-NEXT: *_d_i += _r_d0 * *j; +//CHECK-NEXT: *_d_j += *i * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } else { +//CHECK-NEXT: if (clad::back(_cond1)) { +//CHECK-NEXT: case {{1U|1UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_cond1); +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_cond0); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_t1); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_a[0]; +//CHECK-NEXT: *_d_x += _d_a[1] * y; +//CHECK-NEXT: *_d_y += x * _d_a[1]; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double fn36(double x, double y){ double a[] = {1, 2, 3}; double sum = 0; - for(auto i:a){ - sum += sin(i)*x; + for(auto i: a){ + if(sum > x){ + continue; + }else if(1){ + sum += sin(i)*x; + } } return sum; } -//CHECK: void fn35_grad(double x, double y, double *_d_x, double *_d_y) { -//CHECK-NEXT: unsigned {{int|long}} _t0; -//CHECK-NEXT: double (*_d___range1)[3] = 0; -//CHECK-NEXT: double (*__range10)[3] = {}; -//CHECK-NEXT: double *_d___begin1 = 0; -//CHECK-NEXT: double *__begin10 = 0; -//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK: void fn36_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: clad::tape _cond0 = {}; +//CHECK-NEXT: clad::tape _t1 = {}; //CHECK-NEXT: clad::tape _t2 = {}; //CHECK-NEXT: clad::tape _t3 = {}; //CHECK-NEXT: clad::tape _t4 = {}; +//CHECK-NEXT: clad::tape _t5 = {}; //CHECK-NEXT: double _d_a[3] = {0}; //CHECK-NEXT: double a[3] = {1, 2, 3}; //CHECK-NEXT: double _d_sum = 0; //CHECK-NEXT: double sum = 0; -//CHECK-NEXT: _t0 = {{0U|0UL}}; -//CHECK-NEXT: _d___range1 = &_d_a; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = &a; +//CHECK-NEXT: double (*_d___range1)[3] = &_d_a; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = 0; //CHECK-NEXT: _d___begin1 = *_d___range1; //CHECK-NEXT: __range10 = &a; -//CHECK-NEXT: __begin10 = *__range10; //CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; //CHECK-NEXT: double _d_i = 0; //CHECK-NEXT: double i = 0; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { //CHECK-NEXT: { -//CHECK-NEXT: _t0++; //CHECK-NEXT: _d_i = *_d___begin1; //CHECK-NEXT: i = *__begin10; -//CHECK-NEXT: clad::push(_t3, i); -//CHECK-NEXT: clad::push(_t4, _d_i); +//CHECK-NEXT: clad::push(_t4, i); +//CHECK-NEXT: clad::push(_t5, _d_i); //CHECK-NEXT: } -//CHECK-NEXT: clad::push(_t1, sum); -//CHECK-NEXT: clad::push(_t2, sin(i)); -//CHECK-NEXT: sum += clad::back(_t2) * x; +//CHECK-NEXT: _t0++; +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_cond0, sum > x); +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_t1, {{1U|1UL}}); +//CHECK-NEXT: continue; +//CHECK-NEXT: } +//CHECK-NEXT: } else if (1) { +//CHECK-NEXT: clad::push(_t2, sum); +//CHECK-NEXT: clad::push(_t3, sin(i)); +//CHECK-NEXT: sum += clad::back(_t3) * x; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t1, {{2U|2UL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_sum += 1; //CHECK-NEXT: for (; _t0; _t0--) { //CHECK-NEXT: { //CHECK-NEXT: { //CHECK-NEXT: _d___begin1--; -//CHECK-NEXT: i = clad::pop(_t3); -//CHECK-NEXT: _d_i = clad::pop(_t4); +//CHECK-NEXT: i = clad::pop(_t4); +//CHECK-NEXT: _d_i = clad::pop(_t5); //CHECK-NEXT: } -//CHECK-NEXT: { -//CHECK-NEXT: sum = clad::pop(_t1); -//CHECK-NEXT: double _r_d0 = _d_sum; -//CHECK-NEXT: double _r0 = 0; -//CHECK-NEXT: _r0 += _r_d0 * x * clad::custom_derivatives::sin_pushforward(i, 1.).pushforward; -//CHECK-NEXT: _d_i += _r0; -//CHECK-NEXT: *_d_x += clad::back(_t2) * _r_d0; -//CHECK-NEXT: clad::pop(_t2); +//CHECK-NEXT: switch (clad::pop(_t1)) { +//CHECK-NEXT: case {{2U|2UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: { +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: case {{1U|1UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: } else if (1) { +//CHECK-NEXT: { +//CHECK-NEXT: sum = clad::pop(_t2); +//CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: double _r0 = 0; +//CHECK-NEXT: _r0 += _r_d0 * x * clad::custom_derivatives::sin_pushforward(i, 1.).pushforward; +//CHECK-NEXT: _d_i += _r0; +//CHECK-NEXT: *_d_x += clad::back(_t3) * _r_d0; +//CHECK-NEXT: clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::pop(_cond0); +//CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: *_d___begin1 += _d_i; //CHECK-NEXT: } //CHECK-NEXT: } +double fn37(double x, double y) { + double range[] = {x, 4., y}; + double sum = 0; + for (auto elem: range) + sum += elem; + return sum; +} + +//CHECK: void fn37_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: double _d_range[3] = {0}; +//CHECK-NEXT: double range[3] = {x, 4., y}; +//CHECK-NEXT: double _d_sum = 0; +//CHECK-NEXT: double sum = 0; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: double (*__range10)[3] = ⦥ +//CHECK-NEXT: double (*_d___range1)[3] = &_d_range; +//CHECK-NEXT: double *__begin10 = *__range10; +//CHECK-NEXT: double *_d___begin1 = 0; +//CHECK-NEXT: _d___begin1 = *_d___range1; +//CHECK-NEXT: __range10 = ⦥ +//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: double _d_elem = 0; +//CHECK-NEXT: double elem = 0; +//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_elem = *_d___begin1; +//CHECK-NEXT: elem = *__begin10; +//CHECK-NEXT: clad::push(_t2, elem); +//CHECK-NEXT: clad::push(_t3, _d_elem); +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, sum); +//CHECK-NEXT: sum += elem; +//CHECK-NEXT: } +//CHECK-NEXT: _d_sum += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin1--; +//CHECK-NEXT: elem = clad::pop(_t2); +//CHECK-NEXT: _d_elem = clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: sum = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: _d_elem += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: *_d___begin1 += _d_elem; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_range[0]; +//CHECK-NEXT: *_d_y += _d_range[2]; +//CHECK-NEXT: } +//CHECK-NEXT: } + #define TEST(F, x) { \ result[0] = 0; \ @@ -2904,7 +3116,9 @@ int main() { TEST_2(fn33, 3, 5); // CHECK-EXEC: {15.00, 9.00} TEST_2(fn34, 2, 2); // CHECK-EXEC: {64.00, 32.00} - TEST_2(fn35, 1, 1); // CHECK-EXEC: {1.89, 0.00} + TEST_2(fn35, 2, 2); // CHECK-EXEC: {12.00, 4.00} + TEST_2(fn36, 1, 1); // CHECK-EXEC: {1.75, 0.00} + TEST_2(fn37, 1, 1); // CHECK-EXEC: {1.00, 1.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {