Skip to content

Commit

Permalink
Update the tests after allowing if condition differentiation
Browse files Browse the repository at this point in the history
  • Loading branch information
gojakuch committed May 15, 2024
1 parent b26af55 commit 86f5f27
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 25 deletions.
14 changes: 8 additions & 6 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
condDiff = Visit(condDeclStmt);
else
condDiff = Visit(If->getCond());
auto* RCS = endBlock(direction::reverse);
std::reverse(
RCS->body_begin(),
RCS->body_end()); // it is reversed in the endBlock() but we don't
// actually need this, so we reverse it once again
addToCurrentBlock(RCS, direction::reverse);
CompoundStmt* RCS = endBlock(direction::reverse);
if (!RCS->body_empty()) {
std::reverse(
RCS->body_begin(),
RCS->body_end()); // it is reversed in the endBlock() but we don't
// actually need this, so we reverse it once again
addToCurrentBlock(RCS, direction::reverse);
}

if (isInsideLoop) {
// If we are inside for loop, condDiff will be stored in the following
Expand Down
4 changes: 4 additions & 0 deletions test/ErrorEstimation/ConditonalStatements.C
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ float func(float x, float y) {
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _t2;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = y;
Expand All @@ -36,6 +37,7 @@ float func(float x, float y) {
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: _ret_value0 = x + y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
Expand Down Expand Up @@ -91,6 +93,7 @@ float func2(float x) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: float z = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = z > 9;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _ret_value0 = x + x;
Expand All @@ -99,6 +102,7 @@ float func2(float x) {
//CHECK-NEXT: _ret_value0 = x * x;
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
Expand Down
18 changes: 14 additions & 4 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ double f2(double x, double y) {
//CHECK: void f2_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _cond0 = x < y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
Expand Down Expand Up @@ -160,18 +162,22 @@ double f5(double x, double y) {
//CHECK-NEXT: double z = 0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double t = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _cond1 = y < 0;
//CHECK-NEXT: if (_cond1) {
//CHECK-NEXT: z = t;
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: _d_t += 1;
Expand Down Expand Up @@ -223,18 +229,22 @@ double f6(double x, double y) {
//CHECK-NEXT: double z = 0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double t = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _cond1 = y < 0;
//CHECK-NEXT: if (_cond1) {
//CHECK-NEXT: z = t;
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: _d_t += 1;
Expand Down
4 changes: 4 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -915,9 +915,11 @@ double sq_defined_later(double x) {

// CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, double *_d_x, char *_d_c, char *_d_s) {
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: {
// CHECK-NEXT: _cond0 = c == 'a' && s[0] == 'a';
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: }
// CHECK-NEXT: goto _label1;
// CHECK-NEXT: _label1:
// CHECK-NEXT: ;
Expand Down Expand Up @@ -957,9 +959,11 @@ double sq_defined_later(double x) {

//CHECK: void recFun_pullback(double x, double y, double _d_y0, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: {
Expand Down
8 changes: 8 additions & 0 deletions test/Gradient/Gradients.C
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,13 @@ double f_cond4(double x, double y) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: double arr[2] = {x, y};
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = y;
//CHECK-NEXT: y = arr[i] * x;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_y += 1;
Expand Down Expand Up @@ -331,11 +333,13 @@ double f_if1(double x, double y) {

//CHECK: void f_if1_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: else
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: }
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
Expand All @@ -358,6 +362,7 @@ double f_if2(double x, double y) {
//CHECK: void f_if2_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: bool _cond1;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: goto _label0;
Expand All @@ -368,6 +373,7 @@ double f_if2(double x, double y) {
//CHECK-NEXT: else
//CHECK-NEXT: goto _label2;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: _label0:
//CHECK-NEXT: *_d_x += 1;
Expand Down Expand Up @@ -584,6 +590,7 @@ void f_decls3_grad(double x, double y, double *_d_x, double *_d_y);
//CHECK-NEXT: double _d_b = 0;
//CHECK-NEXT: double a = 3 * x;
//CHECK-NEXT: double c = 333 * y;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > 1;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: goto _label0;
Expand All @@ -592,6 +599,7 @@ void f_decls3_grad(double x, double y, double *_d_x, double *_d_y);
//CHECK-NEXT: if (_cond1)
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: double b = a * a;
//CHECK-NEXT: goto _label2;
//CHECK-NEXT: _label2:
Expand Down
26 changes: 12 additions & 14 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ double f3(double x) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, t);
//CHECK-NEXT: t *= x;
//CHECK-NEXT: bool _t2 = i == 1;
//CHECK-NEXT: {
//CHECK-NEXT: bool _t2 = i == 1;
//CHECK-NEXT: if (_t2)
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: clad::push(_t3, _t2);
Expand Down Expand Up @@ -999,8 +999,8 @@ double fn14(double i, double j) {
// CHECK-NEXT: while (choice--)
// CHECK-NEXT: {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: bool _t1 = choice > 3;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t1 = choice > 3;
// CHECK-NEXT: if (_t1) {
// CHECK-NEXT: clad::push(_t3, res);
// CHECK-NEXT: res += i;
Expand All @@ -1011,8 +1011,8 @@ double fn14(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t2, _t1);
// CHECK-NEXT: }
// CHECK-NEXT: bool _t5 = choice > 1;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t5 = choice > 1;
// CHECK-NEXT: if (_t5) {
// CHECK-NEXT: clad::push(_t7, res);
// CHECK-NEXT: res += j;
Expand All @@ -1023,8 +1023,8 @@ double fn14(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t6, _t5);
// CHECK-NEXT: }
// CHECK-NEXT: bool _t8 = choice > 0;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t8 = choice > 0;
// CHECK-NEXT: if (_t8) {
// CHECK-NEXT: clad::push(_t10, res);
// CHECK-NEXT: res += i * j;
Expand Down Expand Up @@ -1119,8 +1119,8 @@ double fn15(double i, double j) {
// CHECK-NEXT: while (choice--)
// CHECK-NEXT: {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: bool _t1 = choice > 2;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t1 = choice > 2;
// CHECK-NEXT: if (_t1) {
// CHECK-NEXT: clad::push(_t3, {{1U|1UL}});
// CHECK-NEXT: continue;
Expand All @@ -1132,8 +1132,8 @@ double fn15(double i, double j) {
// CHECK-NEXT: while (another_choice--)
// CHECK-NEXT: {
// CHECK-NEXT: clad::back(_t5)++;
// CHECK-NEXT: bool _t6 = another_choice > 1;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t6 = another_choice > 1;
// CHECK-NEXT: if (_t6) {
// CHECK-NEXT: clad::push(_t8, res);
// CHECK-NEXT: res += i;
Expand All @@ -1144,8 +1144,8 @@ double fn15(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t7, _t6);
// CHECK-NEXT: }
// CHECK-NEXT: bool _t10 = another_choice > 0;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t10 = another_choice > 0;
// CHECK-NEXT: if (_t10) {
// CHECK-NEXT: clad::push(_t12, res);
// CHECK-NEXT: res += j;
Expand Down Expand Up @@ -1237,8 +1237,8 @@ double fn16(double i, double j) {
// CHECK-NEXT: _t0 = 0;
// CHECK-NEXT: for (ii = 0; ii < counter; ++ii) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: bool _t1 = ii == 4;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t1 = ii == 4;
// CHECK-NEXT: if (_t1) {
// CHECK-NEXT: clad::push(_t3, res);
// CHECK-NEXT: res += i * j;
Expand All @@ -1249,8 +1249,8 @@ double fn16(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t2, _t1);
// CHECK-NEXT: }
// CHECK-NEXT: bool _t5 = ii > 2;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t5 = ii > 2;
// CHECK-NEXT: if (_t5) {
// CHECK-NEXT: clad::push(_t7, res);
// CHECK-NEXT: res += 2 * i;
Expand Down Expand Up @@ -1343,8 +1343,8 @@ double fn17(double i, double j) {
// CHECK-NEXT: for (ii = 0; ii < counter; ++ii) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, jj) , jj = ii;
// CHECK-NEXT: bool _t2 = ii < 2;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t2 = ii < 2;
// CHECK-NEXT: if (_t2) {
// CHECK-NEXT: clad::push(_t4, {{1U|1UL}});
// CHECK-NEXT: continue;
Expand All @@ -1355,8 +1355,8 @@ double fn17(double i, double j) {
// CHECK-NEXT: while (jj--)
// CHECK-NEXT: {
// CHECK-NEXT: clad::back(_t5)++;
// CHECK-NEXT: bool _t6 = jj < 3;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t6 = jj < 3;
// CHECK-NEXT: if (_t6) {
// CHECK-NEXT: clad::push(_t8, res);
// CHECK-NEXT: res += i * j;
Expand Down Expand Up @@ -1460,14 +1460,13 @@ double fn18(double i, double j) {
// CHECK-NEXT: _t0 = 0;
// CHECK-NEXT: for (counter = 0; counter < choice; ++counter) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: bool _t1 = counter < 2;
// CHECK-NEXT: {
// CHECK-NEXT: bool _t1 = counter < 2;
// CHECK-NEXT: if (_t1) {
// CHECK-NEXT: clad::push(_t3, res);
// CHECK-NEXT: res += i + j;
// CHECK-NEXT: } else {
// CHECK-NEXT: bool _t4 = counter < 4;
// CHECK-NEXT: {
// CHECK-NEXT: if (_t4) {
// CHECK-NEXT: clad::push(_t6, {{1U|1UL}});
// CHECK-NEXT: continue;
Expand All @@ -1480,7 +1479,6 @@ double fn18(double i, double j) {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t5, _t4);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t2, _t1);
// CHECK-NEXT: }
Expand Down
4 changes: 3 additions & 1 deletion test/Hessian/BuiltinDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -383,20 +383,22 @@ int main() {
// CHECK-NEXT: float _d_val = 0;
// CHECK-NEXT: float _t0;
// CHECK-NEXT: float _d_derivative = 0;
// CHECK-NEXT: float _cond0;
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: float _t1;
// CHECK-NEXT: float _t2;
// CHECK-NEXT: float _t3;
// CHECK-NEXT: float val = ::std::pow(x, exponent);
// CHECK-NEXT: _t0 = ::std::pow(x, exponent - 1);
// CHECK-NEXT: float derivative = (exponent * _t0) * d_x;
// CHECK-NEXT: {
// CHECK-NEXT: _cond0 = d_exponent;
// CHECK-NEXT: if (_cond0) {
// CHECK-NEXT: _t1 = derivative;
// CHECK-NEXT: _t3 = ::std::pow(x, exponent);
// CHECK-NEXT: _t2 = ::std::log(x);
// CHECK-NEXT: derivative += (_t3 * _t2) * d_exponent;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
Expand Down
2 changes: 2 additions & 0 deletions test/NumericalDiff/NumDiff.C
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ double test_3(double x) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: double _d_constant = 0;
//CHECK-NEXT: double constant = 0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: constant = 11.;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: ;
Expand Down

0 comments on commit 86f5f27

Please sign in to comment.