Skip to content

Commit

Permalink
Optimize check of break branch using first iteration check in reverse…
Browse files Browse the repository at this point in the history
… pass
  • Loading branch information
kchristin22 committed Oct 28, 2024
1 parent db73da9 commit 8800556
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 29 deletions.
9 changes: 9 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ namespace clad {
clang::Expr *m_Pop = nullptr;
clang::Expr *m_Push = nullptr;
ReverseModeVisitor& m_RMV;
clang::VarDecl* numRevIterations = nullptr;

public:
LoopCounter(ReverseModeVisitor& RMV);
Expand Down Expand Up @@ -550,6 +551,14 @@ namespace clad {
m_Ref,
clang::Sema::ConditionKind::Boolean);
}

/// Sets the number of reverse iterations to be executed.
clang::VarDecl* setNumRevIterations(clang::VarDecl* numRevIterations) {
return this->numRevIterations = numRevIterations;
}

/// Returns the number of reverse iterations to be executed.
clang::VarDecl* getNumRevIterations() const { return numRevIterations; }
};

/// Helper function to differentiate a loop body.
Expand Down
17 changes: 16 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1380,8 +1380,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BodyDiff.updateStmtDx(utils::unwrapIfSingleStmt(revPassCondStmts));
}

Stmt* revInit = loopCounter.getNumRevIterations()
? BuildDeclStmt(loopCounter.getNumRevIterations())
: nullptr;
Stmt* Reverse = new (m_Context)
ForStmt(m_Context, nullptr, nullptr, nullptr, CounterDecrement,
ForStmt(m_Context, revInit, nullptr, nullptr, CounterDecrement,
BodyDiff.getStmt_dx(), noLoc, noLoc, noLoc);

addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
Expand Down Expand Up @@ -4105,10 +4108,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff);
PopBreakContStmtHandler();

Expr* revCounter = loopCounter.getCounterConditionResult().get().second;
if (m_CurrentBreakFlagExpr) {
VarDecl* numRevIterations = BuildVarDecl(m_Context.getSizeType(),
"_numRevIterations", revCounter);
loopCounter.setNumRevIterations(numRevIterations);
}

// Increment statement in the for-loop is executed for every case
if (forLoopIncDiff) {
Stmt* forLoopIncDiffExpr = forLoopIncDiff;
if (m_CurrentBreakFlagExpr) {
m_CurrentBreakFlagExpr =
BuildOp(BinaryOperatorKind::BO_LOr,
BuildOp(BinaryOperatorKind::BO_NE, revCounter,
BuildDeclRef(loopCounter.getNumRevIterations())),
BuildParens(m_CurrentBreakFlagExpr));
forLoopIncDiffExpr = clad_compat::IfStmt_Create(
m_Context, noLoc, false, nullptr, nullptr, m_CurrentBreakFlagExpr,
noLoc, noLoc, forLoopIncDiff, noLoc, nullptr);
Expand Down
56 changes: 28 additions & 28 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -1315,12 +1315,12 @@ double fn16(double i, double j) {
// CHECK-NEXT: clad::push(_t2, {{3U|3UL|3ULL}});
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1)
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))
// CHECK-NEXT: --ii;
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case {{3U|3UL|3ULL}}:
Expand Down Expand Up @@ -1439,12 +1439,12 @@ double fn17(double i, double j) {
// CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}});
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (clad::size(_t5) != 0 && clad::back(_t5) != 1)
// CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t5) != 0 && clad::back(_t5) != 1))
// CHECK-NEXT: --ii;
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand Down Expand Up @@ -1556,12 +1556,12 @@ double fn18(double i, double j) {
// CHECK-NEXT: clad::push(_t2, {{3U|3UL|3ULL}});
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 2)
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 2))
// CHECK-NEXT: --counter;
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case {{3U|3UL|3ULL}}:
Expand Down Expand Up @@ -1889,9 +1889,9 @@ double fn23(double i, double j) {
// CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}});
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) {
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: _d_res = 0.;
Expand All @@ -1901,7 +1901,7 @@ double fn23(double i, double j) {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1)
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))
// CHECK-NEXT: --c;
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand Down Expand Up @@ -2001,9 +2001,9 @@ double fn25(double i, double j) {
// CHECK-NEXT: clad::push(_t3, {{2U|2UL|2ULL}});
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) {
// CHECK-NEXT: _d_res += 0;
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
Expand All @@ -2013,7 +2013,7 @@ double fn25(double i, double j) {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (clad::size(_t3) != 0 && clad::back(_t3) != 1)
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))
// CHECK-NEXT: --c;
// CHECK-NEXT: switch (clad::pop(_t3)) {
// CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand Down Expand Up @@ -2074,9 +2074,9 @@ double fn26(double i, double j) {
// CHECK-NEXT: clad::push(_t3, {{2U|2UL|2ULL}});
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1)) {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t3) != 0 && clad::back(_t3) != 1))) {
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: *_d_i += _r_d0 * j;
Expand Down Expand Up @@ -2149,9 +2149,9 @@ double fn27(double i, double j) {
// CHECK-NEXT: clad::push(_t2, {{2U|2UL|2ULL}});
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: for (;; _t0--) {
// CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
// CHECK-NEXT: {
// CHECK-NEXT: if (!_t0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1)) {
// CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))) {
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: *_d_i += _r_d0 * j;
Expand All @@ -2160,7 +2160,7 @@ double fn27(double i, double j) {
// CHECK-NEXT: if (!_t0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1)
// CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))
// CHECK-NEXT: --c;
// CHECK-NEXT: switch (clad::pop(_t2)) {
// CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand Down Expand Up @@ -2475,9 +2475,9 @@ double fn32(double i, double j) {
//CHECK-NEXT: clad::push(_t8, {{2U|2UL|2ULL}});
//CHECK-NEXT: }
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (;; _t0--) {
//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations1 = _t0; ; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0 || (clad::size(_t8) != 0 && clad::back(_t8) != 1)) {
//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1))) {
//CHECK-NEXT: res = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: *_d_i += _r_d0 * j;
Expand All @@ -2486,7 +2486,7 @@ double fn32(double i, double j) {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: if (clad::size(_t8) != 0 && clad::back(_t8) != 1)
//CHECK-NEXT: if (_t0 != _numRevIterations1 || (clad::size(_t8) != 0 && clad::back(_t8) != 1))
//CHECK-NEXT: --c;
//CHECK-NEXT: switch (clad::pop(_t8)) {
//CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand All @@ -2505,9 +2505,9 @@ double fn32(double i, double j) {
//CHECK-NEXT: clad::pop(_cond1);
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: for (;; clad::back(_t2)--) {
//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = clad::back(_t2); ; clad::back(_t2)--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!clad::back(_t2) || (clad::size(_t6) != 0 && clad::back(_t6) != 1)) {
//CHECK-NEXT: if (!clad::back(_t2) || (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1))) {
//CHECK-NEXT: res = clad::pop(_t4);
//CHECK-NEXT: double _r_d1 = _d_res;
//CHECK-NEXT: *_d_i += _r_d1 * j;
Expand All @@ -2516,7 +2516,7 @@ double fn32(double i, double j) {
//CHECK-NEXT: if (!clad::back(_t2))
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: if (clad::size(_t6) != 0 && clad::back(_t6) != 1)
//CHECK-NEXT: if (clad::back(_t2) != _numRevIterations0 || (clad::size(_t6) != 0 && clad::back(_t6) != 1))
//CHECK-NEXT: --d;
//CHECK-NEXT: switch (clad::pop(_t6)) {
//CHECK-NEXT: case {{2U|2UL|2ULL}}:
Expand Down Expand Up @@ -2627,9 +2627,9 @@ double fn33(double i, double j) {
//CHECK-NEXT: clad::push(_t4, {{3U|3UL|3ULL}});
//CHECK-NEXT: }
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (;; _t0--) {
//CHECK-NEXT: for (unsigned {{int|long|long long}} _numRevIterations0 = _t0; ; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2)) {
//CHECK-NEXT: if (!_t0 || (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2))) {
//CHECK-NEXT: res = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_res = 0.;
Expand All @@ -2639,7 +2639,7 @@ double fn33(double i, double j) {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: if (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2)
//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t4) != 0 && clad::back(_t4) != 1 && clad::back(_t4) != 2))
//CHECK-NEXT: --c;
//CHECK-NEXT: switch (clad::pop(_t4)) {
//CHECK-NEXT: case {{3U|3UL|3ULL}}:
Expand Down Expand Up @@ -3233,12 +3233,12 @@ double fn41(double u, double v) {
//CHECK-NEXT: clad::push(_t2, {{2U|2UL}});
//CHECK-NEXT: }
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (;; _t0--) {
//CHECK-NEXT: for (unsigned {{int|long}} _numRevIterations0 = _t0; ; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: if (clad::size(_t2) != 0 && clad::back(_t2) != 1)
//CHECK-NEXT: if (_t0 != _numRevIterations0 || (clad::size(_t2) != 0 && clad::back(_t2) != 1))
//CHECK-NEXT: i--;
//CHECK-NEXT: switch (clad::pop(_t2)) {
//CHECK-NEXT: case {{2U|2UL}}:
Expand Down

0 comments on commit 8800556

Please sign in to comment.