From 880055671814ac8d50ccdc824be9150e4232ea3d Mon Sep 17 00:00:00 2001 From: kchristin Date: Sun, 27 Oct 2024 09:59:03 +0200 Subject: [PATCH] Optimize check of break branch using first iteration check in reverse pass --- .../clad/Differentiator/ReverseModeVisitor.h | 9 +++ lib/Differentiator/ReverseModeVisitor.cpp | 17 +++++- test/Gradient/Loops.C | 56 +++++++++---------- 3 files changed, 53 insertions(+), 29 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 61723d0f6..4b8d22837 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -518,6 +518,7 @@ namespace clad { clang::Expr *m_Pop = nullptr; clang::Expr *m_Push = nullptr; ReverseModeVisitor& m_RMV; + clang::VarDecl* numRevIterations = nullptr; public: LoopCounter(ReverseModeVisitor& RMV); @@ -550,6 +551,14 @@ namespace clad { m_Ref, clang::Sema::ConditionKind::Boolean); } + + /// Sets the number of reverse iterations to be executed. + clang::VarDecl* setNumRevIterations(clang::VarDecl* numRevIterations) { + return this->numRevIterations = numRevIterations; + } + + /// Returns the number of reverse iterations to be executed. + clang::VarDecl* getNumRevIterations() const { return numRevIterations; } }; /// Helper function to differentiate a loop body. diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index cc0984c7f..139711e17 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1380,8 +1380,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BodyDiff.updateStmtDx(utils::unwrapIfSingleStmt(revPassCondStmts)); } + Stmt* revInit = loopCounter.getNumRevIterations() + ? BuildDeclStmt(loopCounter.getNumRevIterations()) + : nullptr; Stmt* Reverse = new (m_Context) - ForStmt(m_Context, nullptr, nullptr, nullptr, CounterDecrement, + ForStmt(m_Context, revInit, nullptr, nullptr, CounterDecrement, BodyDiff.getStmt_dx(), noLoc, noLoc, noLoc); addToCurrentBlock(initResult.getStmt_dx(), direction::reverse); @@ -4105,10 +4108,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); PopBreakContStmtHandler(); + Expr* revCounter = loopCounter.getCounterConditionResult().get().second; + if (m_CurrentBreakFlagExpr) { + VarDecl* numRevIterations = BuildVarDecl(m_Context.getSizeType(), + "_numRevIterations", revCounter); + loopCounter.setNumRevIterations(numRevIterations); + } + // Increment statement in the for-loop is executed for every case if (forLoopIncDiff) { Stmt* forLoopIncDiffExpr = forLoopIncDiff; if (m_CurrentBreakFlagExpr) { + m_CurrentBreakFlagExpr = + BuildOp(BinaryOperatorKind::BO_LOr, + BuildOp(BinaryOperatorKind::BO_NE, revCounter, + BuildDeclRef(loopCounter.getNumRevIterations())), + BuildParens(m_CurrentBreakFlagExpr)); forLoopIncDiffExpr = clad_compat::IfStmt_Create( m_Context, noLoc, false, nullptr, nullptr, m_CurrentBreakFlagExpr, noLoc, noLoc, forLoopIncDiff, noLoc, nullptr); diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 9d5dff546..af957e51c 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -1315,12 +1315,12 @@ double fn16(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{3U|3UL|3ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) // CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -1439,12 +1439,12 @@ double fn17(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) { // CHECK-NEXT: { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t5) != 0 && clad::back(_t5) != 1) +// CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t5) != 0 && clad::back(_t5) != 1)) // CHECK-NEXT: --ii; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -1556,12 +1556,12 @@ double fn18(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{3U|3UL|3ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 2) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 2)) // CHECK-NEXT: --counter; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -1889,9 +1889,9 @@ double fn23(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: _d_res = 0.; @@ -1901,7 +1901,7 @@ double fn23(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2001,9 +2001,9 @@ double fn25(double i, double j) { // CHECK-NEXT: clad::push(_t3, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) { // CHECK-NEXT: _d_res += 0; // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; @@ -2013,7 +2013,7 @@ double fn25(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t3) != 0 && clad::back(_t3) != 1) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t3)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2074,9 +2074,9 @@ double fn26(double i, double j) { // CHECK-NEXT: clad::push(_t3, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2149,9 +2149,9 @@ double fn27(double i, double j) { // CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}}); // CHECK-NEXT: } // CHECK-NEXT: _d_res += 1; -// CHECK-NEXT: for (;; _t0--) { +// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { // CHECK-NEXT: { -// CHECK-NEXT: if (!_t0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) { +// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) { // CHECK-NEXT: res = clad::pop(_t1); // CHECK-NEXT: double _r_d0 = _d_res; // CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2160,7 +2160,7 @@ double fn27(double i, double j) { // CHECK-NEXT: if (!_t0) // CHECK-NEXT: break; // CHECK-NEXT: } -// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) // CHECK-NEXT: --c; // CHECK-NEXT: switch (clad::pop(_t2)) { // CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2475,9 +2475,9 @@ double fn32(double i, double j) { //CHECK-NEXT: clad::push(_t8, {{2U|2UL|2ULL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_res += 1; -//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (clad::size(_t8) != 0 && clad::back(_t8) != 1)) { +//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1))) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: *_d_i += _r_d0 * j; @@ -2486,7 +2486,7 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::size(_t8) != 0 && clad::back(_t8) != 1) +//CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1)) //CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t8)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2505,9 +2505,9 @@ double fn32(double i, double j) { //CHECK-NEXT: clad::pop(_cond1); //CHECK-NEXT: } //CHECK-NEXT: { -//CHECK-NEXT: for (;; clad::back(_t2)--) { +//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = clad::back(_t2); ; clad::back(_t2)--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!clad::back(_t2) || (clad::size(_t6) != 0 && clad::back(_t6) != 1)) { +//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1))) { //CHECK-NEXT: res = clad::pop(_t4); //CHECK-NEXT: double _r_d1 = _d_res; //CHECK-NEXT: *_d_i += _r_d1 * j; @@ -2516,7 +2516,7 @@ double fn32(double i, double j) { //CHECK-NEXT: if (!clad::back(_t2)) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::size(_t6) != 0 && clad::back(_t6) != 1) +//CHECK-NEXT: if (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1)) //CHECK-NEXT: --d; //CHECK-NEXT: switch (clad::pop(_t6)) { //CHECK-NEXT: case {{2U|2UL|2ULL}}: @@ -2627,9 +2627,9 @@ double fn33(double i, double j) { //CHECK-NEXT: clad::push(_t4, {{3U|3UL|3ULL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_res += 1; -//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) { //CHECK-NEXT: { -//CHECK-NEXT: if (!_t0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2)) { +//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2))) { //CHECK-NEXT: res = clad::pop(_t1); //CHECK-NEXT: double _r_d0 = _d_res; //CHECK-NEXT: _d_res = 0.; @@ -2639,7 +2639,7 @@ double fn33(double i, double j) { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2) +//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2)) //CHECK-NEXT: --c; //CHECK-NEXT: switch (clad::pop(_t4)) { //CHECK-NEXT: case {{3U|3UL|3ULL}}: @@ -3233,12 +3233,12 @@ double fn41(double u, double v) { //CHECK-NEXT: clad::push(_t2, {{2U|2UL}}); //CHECK-NEXT: } //CHECK-NEXT: _d_res += 1; -//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: for (unsigned {{int|long}} _numRevIterations0 = _t0; ; _t0--) { //CHECK-NEXT: { //CHECK-NEXT: if (!_t0) //CHECK-NEXT: break; //CHECK-NEXT: } -//CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1) +//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) //CHECK-NEXT: i--; //CHECK-NEXT: switch (clad::pop(_t2)) { //CHECK-NEXT: case {{2U|2UL}}: