diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index cca1cd5cf..dfb900e1e 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -88,7 +88,7 @@ CUDA_HOST_DEVICE T push(tape& to, ArgsT... val) { /// N. template CUDA_HOST_DEVICE void zero_init(T* x, std::size_t N) { for (std::size_t i = 0; i < N; ++i) - zero_init(x[i]); + zero_init(x[i]); } /// Initialize a const sized array. diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index c9fd1296f..9104ef79d 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -517,6 +517,7 @@ namespace clad { clang::Expr *m_Pop = nullptr; clang::Expr *m_Push = nullptr; ReverseModeVisitor& m_RMV; + clang::VarDecl* m_numRevIterations = nullptr; public: LoopCounter(ReverseModeVisitor& RMV); @@ -549,6 +550,14 @@ namespace clad { m_Ref, clang::Sema::ConditionKind::Boolean); } + + /// Sets the number of reverse iterations to be executed. + void setNumRevIterations(clang::VarDecl* numRevIterations) { + m_numRevIterations = numRevIterations; + } + + /// Returns the number of reverse iterations to be executed. + clang::VarDecl* getNumRevIterations() const { return m_numRevIterations; } }; /// Helper function to differentiate a loop body. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 04f626286..11c0c1981 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1372,8 +1372,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BodyDiff.updateStmtDx(utils::unwrapIfSingleStmt(revPassCondStmts)); } + Stmt* revInit = loopCounter.getNumRevIterations() + ? BuildDeclStmt(loopCounter.getNumRevIterations()) + : nullptr; Stmt* Reverse = new (m_Context) - ForStmt(m_Context, nullptr, nullptr, nullptr, CounterDecrement, + ForStmt(m_Context, revInit, nullptr, nullptr, CounterDecrement, BodyDiff.getStmt_dx(), noLoc, noLoc, noLoc); addToCurrentBlock(initResult.getStmt_dx(), direction::reverse); @@ -3791,6 +3794,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); isInsideLoop = true; + llvm::SaveAndRestore SaveCurrentBreakFlagExpr( + m_CurrentBreakFlagExpr); + m_CurrentBreakFlagExpr = nullptr; Expr* condClone = (WS->getCond() ? Clone(WS->getCond()) : nullptr); const VarDecl* condVarDecl = WS->getConditionVariable(); @@ -3849,6 +3855,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); isInsideLoop = true; + llvm::SaveAndRestore SaveCurrentBreakFlagExpr( + m_CurrentBreakFlagExpr); + m_CurrentBreakFlagExpr = nullptr; Expr* clonedCond = (DS->getCond() ? Clone(DS->getCond()) : nullptr); @@ -4105,22 +4114,38 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock)); m_LoopBlock.pop_back(); - // Increment statement in the for-loop is only executed if the iteration - // did not end with a break/continue statement. Therefore, forLoopIncDiff - // should be inside the last switch case in the reverse pass. + activeBreakContHandler->EndCFSwitchStmtScope(); + activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + PopBreakContStmtHandler(); + + Expr* revCounter = loopCounter.getCounterConditionResult().get().second; + if (m_CurrentBreakFlagExpr) { + VarDecl* numRevIterations = BuildVarDecl(m_Context.getSizeType(), + "_numRevIterations", revCounter); + loopCounter.setNumRevIterations(numRevIterations); + } + + // Increment statement in the for-loop is executed for every case if (forLoopIncDiff) { + Stmt* forLoopIncDiffExpr = forLoopIncDiff; + if (m_CurrentBreakFlagExpr) { + m_CurrentBreakFlagExpr = + BuildOp(BinaryOperatorKind::BO_LOr, + BuildOp(BinaryOperatorKind::BO_NE, revCounter, + BuildDeclRef(loopCounter.getNumRevIterations())), + BuildParens(m_CurrentBreakFlagExpr)); + forLoopIncDiffExpr = clad_compat::IfStmt_Create( + m_Context, noLoc, false, nullptr, nullptr, m_CurrentBreakFlagExpr, + noLoc, noLoc, forLoopIncDiff, noLoc, nullptr); + } if (bodyDiff.getStmt_dx()) { bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( - m_Context, bodyDiff.getStmt_dx(), forLoopIncDiff)); + m_Context, bodyDiff.getStmt_dx(), forLoopIncDiffExpr)); } else { - bodyDiff.updateStmtDx(forLoopIncDiff); + bodyDiff.updateStmtDx(forLoopIncDiffExpr); } } - activeBreakContHandler->EndCFSwitchStmtScope(); - activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); - PopBreakContStmtHandler(); - Expr* counterDecrement = loopCounter.getCounterDecrement(); // Create reverse pass loop body statements by arranging various @@ -4169,7 +4194,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_CurrentBreakFlagExpr = BuildOp(BinaryOperatorKind::BO_LAnd, m_CurrentBreakFlagExpr, tapeBackExprForCurrentCase); - } else { m_CurrentBreakFlagExpr = tapeBackExprForCurrentCase; } diff --git a/test/Analyses/TBR.cpp b/test/Analyses/TBR.cpp index 21f4c5b9f..b16ee6f12 100644 --- a/test/Analyses/TBR.cpp +++ b/test/Analyses/TBR.cpp @@ -82,10 +82,10 @@ double f2(double val) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } +//CHECK-NEXT: --i; //CHECK-NEXT: switch (clad::pop(_t1)) { //CHECK-NEXT: case {{2U|2UL}}: //CHECK-NEXT: ; -//CHECK-NEXT: --i; //CHECK-NEXT: { //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: _d_i += _r_d0 * val; @@ -167,6 +167,6 @@ double f3 (double x){ int main() { double result[3] = {}; TEST(f1, 3); // CHECK-EXEC: {27.00} - TEST(f2, 3); // CHECK-EXEC: {9.00} + TEST(f2, 3); // CHECK-EXEC: {7.00} TEST(f3, 3); // CHECK-EXEC: {2.00} } diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 8586aac60..5a44d7c50 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -1315,15 +1315,16 @@ double fn16(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{3U|3UL|3ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) +// CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --ii; // CHECK-NEXT: { // CHECK-NEXT: res = clad::pop(_t4); // CHECK-NEXT: double _r_d2 = _d_res; @@ -1443,10 +1444,10 @@ double fn17(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --ii; // CHECK-NEXT: { // CHECK-NEXT: while (clad::back(_t3)) // CHECK-NEXT: { @@ -1554,15 +1555,16 @@ double fn18(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{3U|3UL|3ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 2)) +// CHECK-NEXT: --counter; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --counter; // CHECK-NEXT: if (clad::back(_cond0)) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; @@ -1886,9 +1888,9 @@ double fn23(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::back(_t2) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: _d_res = 0.; @@ -1898,10 +1900,11 @@ double fn23(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) +// CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --c; // CHECK-NEXT: { // CHECK-NEXT: if (clad::back(_cond0)) // CHECK-NEXT: case {{1U|1UL|1ULL}}: @@ -1997,9 +2000,9 @@ double fn25(double i, double j) { // CHECK-NEXT: clad::push(_t3, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::back(_t3) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t3) != 1))) { // CHECK-NEXT: _d_res += 0; // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; @@ -2009,10 +2012,11 @@ double fn25(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t3) != 1)) +// CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t3)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --c; // CHECK-NEXT: { // CHECK-NEXT: if (clad::back(_cond0)) { // CHECK-NEXT: case {{1U|1UL|1ULL}}: @@ -2069,9 +2073,9 @@ double fn26(double i, double j) { // CHECK-NEXT: clad::push(_t3, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::back(_t3) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t3) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2080,19 +2084,19 @@ double fn26(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: res = clad::pop(_t2); +// CHECK-NEXT: double _r_d1 = _d_res; +// CHECK-NEXT: _d_res = 0.; +// CHECK-NEXT: *_d_i += 7 * _r_d1 * j; +// CHECK-NEXT: *_d_j += 7 * i * _r_d1; +// CHECK-NEXT: _d_c += 0; +// CHECK-NEXT: --c; +// CHECK-NEXT: } // CHECK-NEXT: switch (clad::pop(_t3)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; // CHECK-NEXT: { -// CHECK-NEXT: res = clad::pop(_t2); -// CHECK-NEXT: double _r_d1 = _d_res; -// CHECK-NEXT: _d_res = 0.; -// CHECK-NEXT: *_d_i += 7 * _r_d1 * j; -// CHECK-NEXT: *_d_j += 7 * i * _r_d1; -// CHECK-NEXT: _d_c += 0; -// CHECK-NEXT: --c; -// CHECK-NEXT: } -// CHECK-NEXT: { // CHECK-NEXT: if (clad::back(_cond0)) // CHECK-NEXT: case {{1U|1UL|1ULL}}: // CHECK-NEXT: ; @@ -2144,9 +2148,9 @@ double fn27(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::back(_t2) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t2) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2155,10 +2159,11 @@ double fn27(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) +// CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: // CHECK-NEXT: ; -// CHECK-NEXT: --c; // CHECK-NEXT: { // CHECK-NEXT: res = clad::pop(_t3); // CHECK-NEXT: double _r_d1 = _d_res; @@ -2469,9 +2474,9 @@ double fn32(double i, double j) { //CHECK-NEXT: clad::push(_t8, {{2U|2UL|2ULL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_res += 1; -//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (clad::back(_t8) != 1)) { +//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations1 || (clad::back(_t8) != 1))) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2480,10 +2485,11 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } +//CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::back(_t8) != 1)) +//CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t8)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: //CHECK-NEXT: ; -//CHECK-NEXT: --c; //CHECK-NEXT: { //CHECK-NEXT: if (clad::back(_cond1)) { //CHECK-NEXT: case {{1U|1UL|1ULL}}: @@ -2498,9 +2504,9 @@ double fn32(double i, double j) { //CHECK-NEXT: clad::pop(_cond1); //CHECK-NEXT: } //CHECK-NEXT: { -//CHECK-NEXT: for (;; clad::back(_t2)--) { +//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = clad::back(_t2); ; clad::back(_t2)--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t6) != 1)) { +//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t2) != _numRevIterations0 || (clad::back(_t6) != 1))) { //CHECK-NEXT: res = clad::pop(_t4); //CHECK-NEXT: double _r_d1 = _d_res; //CHECK-NEXT: *_d_i += _r_d1 * j; @@ -2509,10 +2515,11 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!clad::back(_t2)) //CHECK-NEXT: break; //CHECK-NEXT: } +//CHECK-NEXT: if (clad::back(_t2) != _numRevIterations0 || (clad::back(_t6) != 1)) +//CHECK-NEXT: --d; //CHECK-NEXT: switch (clad::pop(_t6)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: //CHECK-NEXT: ; -//CHECK-NEXT: --d; //CHECK-NEXT: { //CHECK-NEXT: if (clad::back(_cond0)) { //CHECK-NEXT: case {{1U|1UL|1ULL}}: @@ -2619,9 +2626,9 @@ double fn33(double i, double j) { //CHECK-NEXT: clad::push(_t4, {{3U|3UL|3ULL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_res += 1; -//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (clad::back(_t4) != 1 && clad::back(_t4) != 2)) { +//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::back(_t4) != 1 && clad::back(_t4) != 2))) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: _d_res = 0.; @@ -2631,10 +2638,11 @@ double fn33(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } +//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t4) != 1 && clad::back(_t4) != 2)) +//CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t4)) { //CHECK-NEXT: case {{3U|3UL|3ULL}}: //CHECK-NEXT: ; -//CHECK-NEXT: --c; //CHECK-NEXT: { //CHECK-NEXT: if (clad::back(_cond5)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -3133,6 +3141,123 @@ double fn39(double x) { //CHECK-NEXT: } //CHECK-NEXT: } +double fn40(double u, double v) { + double res = 11 * u; + for (int i = 0; i < 3; i++) { + res += u * i; + continue; + } + return res; +} + +// CHECK: void fn40_grad(double u, double v, double *_d_u, double *_d_v) { +//CHECK-NEXT: int _d_i = 0; +//CHECK-NEXT: int i = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: double _d_res = 0.; +//CHECK-NEXT: double res = 11 * u; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: for (i = 0; ; i++) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!(i < 3)) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, res); +//CHECK-NEXT: res += u * i; +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_t2, {{1U|1UL}}); +//CHECK-NEXT: continue; +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t2, {{2U|2UL}}); +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!_t0) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: i--; +//CHECK-NEXT: switch (clad::pop(_t2)) { +//CHECK-NEXT: case {{2U|2UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: case {{1U|1UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: { +//CHECK-NEXT: res = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: *_d_u += _r_d0 * i; +//CHECK-NEXT: _d_i += u * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: *_d_u += 11 * _d_res; +//CHECK-NEXT:} + +double fn41(double u, double v) { + double res = 0; + for (int i = 1; i < 3; i++) { + res += i * u; + if (i == 1) + break; + } + return res; +} + +//CHECK: void fn41_grad(double u, double v, double *_d_u, double *_d_v) { +//CHECK-NEXT: int _d_i = 0; +//CHECK-NEXT: int i = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _cond0 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: double _d_res = 0.; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: for (i = 1; ; i++) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!(i < 3)) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, res); +//CHECK-NEXT: res += i * u; +//CHECK-NEXT: { +//CHECK-NEXT: clad::push(_cond0, i == 1); +//CHECK-NEXT: if (clad::back(_cond0)) { +//CHECK-NEXT: clad::push(_t2, {{1U|1UL}}); +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: clad::push(_t2, {{2U|2UL}}); +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: for (unsigned {{int|long}} _numRevIterations0 = _t0; ; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!_t0) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::back(_t2) != 1)) +//CHECK-NEXT: i--; +//CHECK-NEXT: switch (clad::pop(_t2)) { +//CHECK-NEXT: case {{2U|2UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: { +//CHECK-NEXT: if (clad::back(_cond0)) +//CHECK-NEXT: case {{1U|1UL}}: +//CHECK-NEXT: ; +//CHECK-NEXT: clad::pop(_cond0); +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: res = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: _d_i += _r_d0 * u; +//CHECK-NEXT: *_d_u += i * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT:} + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -3189,7 +3314,7 @@ int main() { TEST_2(fn16, 3, 5); // CHECK-EXEC: {10.00, 6.00} TEST_2(fn17, 3, 5); // CHECK-EXEC: {15.00, 9.00} TEST_2(fn18, 3, 5); // CHECK-EXEC: {4.00, 4.00} - + INIT_GRADIENT(fn19, "arr"); double arr[5] = {}; @@ -3223,6 +3348,8 @@ int main() { TEST_2(fn37, 1, 1); // CHECK-EXEC: {1.00, 1.00} TEST_2(fn38, 6, 3); // CHECK-EXEC: {1.00, 1.00} TEST(fn39, 9); // CHECK-EXEC: {6.00} + TEST_2(fn40, 2, 3); // CHECK-EXEC: {14.00, 0.00} + TEST_2(fn41, 2, 3); // CHECK-EXEC: {1.00, 0.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {