From 86f5f27e38460a3674876897239a89f3be09455e Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Wed, 15 May 2024 16:15:58 +0200 Subject: [PATCH] Update the tests after allowing if condition differentiation --- lib/Differentiator/ReverseModeVisitor.cpp | 14 ++++++----- test/ErrorEstimation/ConditonalStatements.C | 4 ++++ test/Gradient/Assignments.C | 18 ++++++++++---- test/Gradient/FunctionCalls.C | 4 ++++ test/Gradient/Gradients.C | 8 +++++++ test/Gradient/Loops.C | 26 ++++++++++----------- test/Hessian/BuiltinDerivatives.C | 4 +++- test/NumericalDiff/NumDiff.C | 2 ++ 8 files changed, 55 insertions(+), 25 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6ddfd3220..bdc3582bd 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -821,12 +821,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, condDiff = Visit(condDeclStmt); else condDiff = Visit(If->getCond()); - auto* RCS = endBlock(direction::reverse); - std::reverse( - RCS->body_begin(), - RCS->body_end()); // it is reversed in the endBlock() but we don't - // actually need this, so we reverse it once again - addToCurrentBlock(RCS, direction::reverse); + CompoundStmt* RCS = endBlock(direction::reverse); + if (!RCS->body_empty()) { + std::reverse( + RCS->body_begin(), + RCS->body_end()); // it is reversed in the endBlock() but we don't + // actually need this, so we reverse it once again + addToCurrentBlock(RCS, direction::reverse); + } if (isInsideLoop) { // If we are inside for loop, condDiff will be stored in the following diff --git a/test/ErrorEstimation/ConditonalStatements.C b/test/ErrorEstimation/ConditonalStatements.C index 1f9108d36..db91578fd 100644 --- a/test/ErrorEstimation/ConditonalStatements.C +++ b/test/ErrorEstimation/ConditonalStatements.C @@ -25,6 +25,7 @@ float func(float x, float y) { //CHECK-NEXT: float _t1; //CHECK-NEXT: float _t2; //CHECK-NEXT: double _ret_value0 = 0; +//CHECK-NEXT: { //CHECK-NEXT: _cond0 = x > y; //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: _t0 = y; @@ -36,6 +37,7 @@ float func(float x, float y) { //CHECK-NEXT: _t2 = x; //CHECK-NEXT: x = y; //CHECK-NEXT: } +//CHECK-NEXT: } //CHECK-NEXT: _ret_value0 = x + y; //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -91,6 +93,7 @@ float func2(float x) { //CHECK-NEXT: bool _cond0; //CHECK-NEXT: double _ret_value0 = 0; //CHECK-NEXT: float z = x * x; +//CHECK-NEXT: { //CHECK-NEXT: _cond0 = z > 9; //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: _ret_value0 = x + x; @@ -99,6 +102,7 @@ float func2(float x) { //CHECK-NEXT: _ret_value0 = x * x; //CHECK-NEXT: goto _label1; //CHECK-NEXT: } +//CHECK-NEXT: } //CHECK-NEXT: if (_cond0) //CHECK-NEXT: _label0: //CHECK-NEXT: { diff --git a/test/Gradient/Assignments.C b/test/Gradient/Assignments.C index 4dce006c3..c6542b09a 100644 --- a/test/Gradient/Assignments.C +++ b/test/Gradient/Assignments.C @@ -38,10 +38,12 @@ double f2(double x, double y) { //CHECK: void f2_grad(double x, double y, double *_d_x, double *_d_y) { //CHECK-NEXT: bool _cond0; //CHECK-NEXT: double _t0; -//CHECK-NEXT: _cond0 = x < y; -//CHECK-NEXT: if (_cond0) { -//CHECK-NEXT: _t0 = x; -//CHECK-NEXT: x = y; +//CHECK-NEXT: { +//CHECK-NEXT: _cond0 = x < y; +//CHECK-NEXT: if (_cond0) { +//CHECK-NEXT: _t0 = x; +//CHECK-NEXT: x = y; +//CHECK-NEXT: } //CHECK-NEXT: } //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: @@ -160,18 +162,22 @@ double f5(double x, double y) { //CHECK-NEXT: double z = 0; //CHECK-NEXT: double _t1; //CHECK-NEXT: double t = x * x; +//CHECK-NEXT: { //CHECK-NEXT: _cond0 = x < 0; //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: _t0 = t; //CHECK-NEXT: t = -t; //CHECK-NEXT: goto _label0; //CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: { //CHECK-NEXT: _cond1 = y < 0; //CHECK-NEXT: if (_cond1) { //CHECK-NEXT: z = t; //CHECK-NEXT: _t1 = t; //CHECK-NEXT: t = -t; //CHECK-NEXT: } +//CHECK-NEXT: } //CHECK-NEXT: goto _label1; //CHECK-NEXT: _label1: //CHECK-NEXT: _d_t += 1; @@ -223,18 +229,22 @@ double f6(double x, double y) { //CHECK-NEXT: double z = 0; //CHECK-NEXT: double _t1; //CHECK-NEXT: double t = x * x; +//CHECK-NEXT: { //CHECK-NEXT: _cond0 = x < 0; //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: _t0 = t; //CHECK-NEXT: t = -t; //CHECK-NEXT: goto _label0; //CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: { //CHECK-NEXT: _cond1 = y < 0; //CHECK-NEXT: if (_cond1) { //CHECK-NEXT: z = t; //CHECK-NEXT: _t1 = t; //CHECK-NEXT: t = -t; //CHECK-NEXT: } +//CHECK-NEXT: } //CHECK-NEXT: goto _label1; //CHECK-NEXT: _label1: //CHECK-NEXT: _d_t += 1; diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index a3af2a557..cffff026d 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -915,9 +915,11 @@ double sq_defined_later(double x) { // CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, double *_d_x, char *_d_c, char *_d_s) { // CHECK-NEXT: bool _cond0; +// CHECK-NEXT: { // CHECK-NEXT: _cond0 = c == 'a' && s[0] == 'a'; // CHECK-NEXT: if (_cond0) // CHECK-NEXT: goto _label0; +// CHECK-NEXT: } // CHECK-NEXT: goto _label1; // CHECK-NEXT: _label1: // CHECK-NEXT: ; @@ -957,9 +959,11 @@ double sq_defined_later(double x) { //CHECK: void recFun_pullback(double x, double y, double _d_y0, double *_d_x, double *_d_y) { //CHECK-NEXT: bool _cond0; +//CHECK-NEXT: { //CHECK-NEXT: _cond0 = x > y; //CHECK-NEXT: if (_cond0) //CHECK-NEXT: goto _label0; +//CHECK-NEXT: } //CHECK-NEXT: goto _label1; //CHECK-NEXT: _label1: //CHECK-NEXT: { diff --git a/test/Gradient/Gradients.C b/test/Gradient/Gradients.C index d2b1ea23f..4445e273a 100644 --- a/test/Gradient/Gradients.C +++ b/test/Gradient/Gradients.C @@ -297,11 +297,13 @@ double f_cond4(double x, double y) { //CHECK-NEXT: double _t0; //CHECK-NEXT: int i = 0; //CHECK-NEXT: double arr[2] = {x, y}; +//CHECK-NEXT: { //CHECK-NEXT: _cond0 = x > 0; //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: _t0 = y; //CHECK-NEXT: y = arr[i] * x; //CHECK-NEXT: } +//CHECK-NEXT: } //CHECK-NEXT: goto _label0; //CHECK-NEXT: _label0: //CHECK-NEXT: *_d_y += 1; @@ -331,11 +333,13 @@ double f_if1(double x, double y) { //CHECK: void f_if1_grad(double x, double y, double *_d_x, double *_d_y) { //CHECK-NEXT: bool _cond0; +//CHECK-NEXT: { //CHECK-NEXT: _cond0 = x > y; //CHECK-NEXT: if (_cond0) //CHECK-NEXT: goto _label0; //CHECK-NEXT: else //CHECK-NEXT: goto _label1; +//CHECK-NEXT: } //CHECK-NEXT: if (_cond0) //CHECK-NEXT: _label0: //CHECK-NEXT: *_d_x += 1; @@ -358,6 +362,7 @@ double f_if2(double x, double y) { //CHECK: void f_if2_grad(double x, double y, double *_d_x, double *_d_y) { //CHECK-NEXT: bool _cond0; //CHECK-NEXT: bool _cond1; +//CHECK-NEXT: { //CHECK-NEXT: _cond0 = x > y; //CHECK-NEXT: if (_cond0) //CHECK-NEXT: goto _label0; @@ -368,6 +373,7 @@ double f_if2(double x, double y) { //CHECK-NEXT: else //CHECK-NEXT: goto _label2; //CHECK-NEXT: } +//CHECK-NEXT: } //CHECK-NEXT: if (_cond0) //CHECK-NEXT: _label0: //CHECK-NEXT: *_d_x += 1; @@ -584,6 +590,7 @@ void f_decls3_grad(double x, double y, double *_d_x, double *_d_y); //CHECK-NEXT: double _d_b = 0; //CHECK-NEXT: double a = 3 * x; //CHECK-NEXT: double c = 333 * y; +//CHECK-NEXT: { //CHECK-NEXT: _cond0 = x > 1; //CHECK-NEXT: if (_cond0) //CHECK-NEXT: goto _label0; @@ -592,6 +599,7 @@ void f_decls3_grad(double x, double y, double *_d_x, double *_d_y); //CHECK-NEXT: if (_cond1) //CHECK-NEXT: goto _label1; //CHECK-NEXT: } +//CHECK-NEXT: } //CHECK-NEXT: double b = a * a; //CHECK-NEXT: goto _label2; //CHECK-NEXT: _label2: diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 1051bff9e..0c3690e64 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -115,8 +115,8 @@ double f3(double x) { //CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, t); //CHECK-NEXT: t *= x; -//CHECK-NEXT: bool _t2 = i == 1; //CHECK-NEXT: { +//CHECK-NEXT: bool _t2 = i == 1; //CHECK-NEXT: if (_t2) //CHECK-NEXT: goto _label0; //CHECK-NEXT: clad::push(_t3, _t2); @@ -999,8 +999,8 @@ double fn14(double i, double j) { // CHECK-NEXT: while (choice--) // CHECK-NEXT: { // CHECK-NEXT: _t0++; -// CHECK-NEXT: bool _t1 = choice > 3; // CHECK-NEXT: { +// CHECK-NEXT: bool _t1 = choice > 3; // CHECK-NEXT: if (_t1) { // CHECK-NEXT: clad::push(_t3, res); // CHECK-NEXT: res += i; @@ -1011,8 +1011,8 @@ double fn14(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: clad::push(_t2, _t1); // CHECK-NEXT: } -// CHECK-NEXT: bool _t5 = choice > 1; // CHECK-NEXT: { +// CHECK-NEXT: bool _t5 = choice > 1; // CHECK-NEXT: if (_t5) { // CHECK-NEXT: clad::push(_t7, res); // CHECK-NEXT: res += j; @@ -1023,8 +1023,8 @@ double fn14(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: clad::push(_t6, _t5); // CHECK-NEXT: } -// CHECK-NEXT: bool _t8 = choice > 0; // CHECK-NEXT: { +// CHECK-NEXT: bool _t8 = choice > 0; // CHECK-NEXT: if (_t8) { // CHECK-NEXT: clad::push(_t10, res); // CHECK-NEXT: res += i * j; @@ -1119,8 +1119,8 @@ double fn15(double i, double j) { // CHECK-NEXT: while (choice--) // CHECK-NEXT: { // CHECK-NEXT: _t0++; -// CHECK-NEXT: bool _t1 = choice > 2; // CHECK-NEXT: { +// CHECK-NEXT: bool _t1 = choice > 2; // CHECK-NEXT: if (_t1) { // CHECK-NEXT: clad::push(_t3, {{1U|1UL}}); // CHECK-NEXT: continue; @@ -1132,8 +1132,8 @@ double fn15(double i, double j) { // CHECK-NEXT: while (another_choice--) // CHECK-NEXT: { // CHECK-NEXT: clad::back(_t5)++; -// CHECK-NEXT: bool _t6 = another_choice > 1; // CHECK-NEXT: { +// CHECK-NEXT: bool _t6 = another_choice > 1; // CHECK-NEXT: if (_t6) { // CHECK-NEXT: clad::push(_t8, res); // CHECK-NEXT: res += i; @@ -1144,8 +1144,8 @@ double fn15(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: clad::push(_t7, _t6); // CHECK-NEXT: } -// CHECK-NEXT: bool _t10 = another_choice > 0; // CHECK-NEXT: { +// CHECK-NEXT: bool _t10 = another_choice > 0; // CHECK-NEXT: if (_t10) { // CHECK-NEXT: clad::push(_t12, res); // CHECK-NEXT: res += j; @@ -1237,8 +1237,8 @@ double fn16(double i, double j) { // CHECK-NEXT: _t0 = 0; // CHECK-NEXT: for (ii = 0; ii < counter; ++ii) { // CHECK-NEXT: _t0++; -// CHECK-NEXT: bool _t1 = ii == 4; // CHECK-NEXT: { +// CHECK-NEXT: bool _t1 = ii == 4; // CHECK-NEXT: if (_t1) { // CHECK-NEXT: clad::push(_t3, res); // CHECK-NEXT: res += i * j; @@ -1249,8 +1249,8 @@ double fn16(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: clad::push(_t2, _t1); // CHECK-NEXT: } -// CHECK-NEXT: bool _t5 = ii > 2; // CHECK-NEXT: { +// CHECK-NEXT: bool _t5 = ii > 2; // CHECK-NEXT: if (_t5) { // CHECK-NEXT: clad::push(_t7, res); // CHECK-NEXT: res += 2 * i; @@ -1343,8 +1343,8 @@ double fn17(double i, double j) { // CHECK-NEXT: for (ii = 0; ii < counter; ++ii) { // CHECK-NEXT: _t0++; // CHECK-NEXT: clad::push(_t1, jj) , jj = ii; -// CHECK-NEXT: bool _t2 = ii < 2; // CHECK-NEXT: { +// CHECK-NEXT: bool _t2 = ii < 2; // CHECK-NEXT: if (_t2) { // CHECK-NEXT: clad::push(_t4, {{1U|1UL}}); // CHECK-NEXT: continue; @@ -1355,8 +1355,8 @@ double fn17(double i, double j) { // CHECK-NEXT: while (jj--) // CHECK-NEXT: { // CHECK-NEXT: clad::back(_t5)++; -// CHECK-NEXT: bool _t6 = jj < 3; // CHECK-NEXT: { +// CHECK-NEXT: bool _t6 = jj < 3; // CHECK-NEXT: if (_t6) { // CHECK-NEXT: clad::push(_t8, res); // CHECK-NEXT: res += i * j; @@ -1460,14 +1460,13 @@ double fn18(double i, double j) { // CHECK-NEXT: _t0 = 0; // CHECK-NEXT: for (counter = 0; counter < choice; ++counter) { // CHECK-NEXT: _t0++; -// CHECK-NEXT: bool _t1 = counter < 2; // CHECK-NEXT: { +// CHECK-NEXT: bool _t1 = counter < 2; // CHECK-NEXT: if (_t1) { // CHECK-NEXT: clad::push(_t3, res); // CHECK-NEXT: res += i + j; // CHECK-NEXT: } else { // CHECK-NEXT: bool _t4 = counter < 4; -// CHECK-NEXT: { // CHECK-NEXT: if (_t4) { // CHECK-NEXT: clad::push(_t6, {{1U|1UL}}); // CHECK-NEXT: continue; @@ -1480,7 +1479,6 @@ double fn18(double i, double j) { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: clad::push(_t5, _t4); -// CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: clad::push(_t2, _t1); // CHECK-NEXT: } diff --git a/test/Hessian/BuiltinDerivatives.C b/test/Hessian/BuiltinDerivatives.C index 0e0adbdbc..533c80a9b 100644 --- a/test/Hessian/BuiltinDerivatives.C +++ b/test/Hessian/BuiltinDerivatives.C @@ -383,13 +383,14 @@ int main() { // CHECK-NEXT: float _d_val = 0; // CHECK-NEXT: float _t0; // CHECK-NEXT: float _d_derivative = 0; -// CHECK-NEXT: float _cond0; +// CHECK-NEXT: bool _cond0; // CHECK-NEXT: float _t1; // CHECK-NEXT: float _t2; // CHECK-NEXT: float _t3; // CHECK-NEXT: float val = ::std::pow(x, exponent); // CHECK-NEXT: _t0 = ::std::pow(x, exponent - 1); // CHECK-NEXT: float derivative = (exponent * _t0) * d_x; +// CHECK-NEXT: { // CHECK-NEXT: _cond0 = d_exponent; // CHECK-NEXT: if (_cond0) { // CHECK-NEXT: _t1 = derivative; @@ -397,6 +398,7 @@ int main() { // CHECK-NEXT: _t2 = ::std::log(x); // CHECK-NEXT: derivative += (_t3 * _t2) * d_exponent; // CHECK-NEXT: } +// CHECK-NEXT: } // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { diff --git a/test/NumericalDiff/NumDiff.C b/test/NumericalDiff/NumDiff.C index c65ff8a39..a3b118757 100644 --- a/test/NumericalDiff/NumDiff.C +++ b/test/NumericalDiff/NumDiff.C @@ -43,11 +43,13 @@ double test_3(double x) { //CHECK-NEXT: bool _cond0; //CHECK-NEXT: double _d_constant = 0; //CHECK-NEXT: double constant = 0; +//CHECK-NEXT: { //CHECK-NEXT: _cond0 = x > 0; //CHECK-NEXT: if (_cond0) { //CHECK-NEXT: constant = 11.; //CHECK-NEXT: goto _label0; //CHECK-NEXT: } +//CHECK-NEXT: } //CHECK-NEXT: goto _label1; //CHECK-NEXT: _label1: //CHECK-NEXT: ;