Skip to content

Commit

Permalink
Simplify derivative statements in VisitBinaryOperator.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 26, 2023
1 parent 6a42b46 commit 7b57f7c
Show file tree
Hide file tree
Showing 34 changed files with 789 additions and 1,853 deletions.
11 changes: 2 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2175,7 +2175,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* dl = nullptr;
if (dfdx()) {
dl = BuildOp(BO_Mul, dfdx(), RResult.getRevSweepAsExpr());
dl = StoreAndRef(dl, direction::reverse);
}
Ldiff = Visit(L, dl);
// dxi/xr = xl
Expand All @@ -2188,7 +2187,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* dr = nullptr;
if (dfdx()) {
dr = BuildOp(BO_Mul, Ldiff.getRevSweepAsExpr(), dfdx());
dr = StoreAndRef(dr, direction::reverse);
}
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
Expand All @@ -2206,7 +2204,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* dl = nullptr;
if (dfdx()) {
dl = BuildOp(BO_Div, dfdx(), RStored);
dl = StoreAndRef(dl, direction::reverse);
}
Ldiff = Visit(L, dl);
// dxi/xr = -xl / (xr * xr)
Expand Down Expand Up @@ -2359,14 +2356,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, oldValue);
valueForRevPass = Rdiff.getRevSweepAsExpr();
} else if (opCode == BO_AddAssign) {
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, oldValue);
valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
} else if (opCode == BO_SubAssign) {
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, BuildOp(UO_Minus, oldValue));
valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
Expand All @@ -2390,7 +2383,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/// and insert them after `dl += dl * R`
beginBlock(direction::reverse);
Expr* dr = BuildOp(BO_Mul, LCloned, oldValue);
dr = StoreAndRef(dr, direction::reverse);
Rdiff = Visit(R, dr);
Stmts RBlock = EndBlockWithoutCreatingCS(direction::reverse);
addToCurrentBlock(
Expand Down Expand Up @@ -2428,7 +2420,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActBeforeFinalisingAssignOp(LCloned, oldValue);

// Update the derivative.
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), direction::reverse);
if (opCode != BO_SubAssign && opCode != BO_AddAssign)
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), direction::reverse);
// Output statements from Visit(L).
for (auto it = Lblock_begin; it != Lblock_end; ++it)
addToCurrentBlock(*it, direction::reverse);
Expand Down
72 changes: 18 additions & 54 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ double addArr(const double *arr, int n) {
//CHECK-NEXT: {
//CHECK-NEXT: ret = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_ret;
//CHECK-NEXT: _d_ret += _r_d0;
//CHECK-NEXT: _d_arr[i] += _r_d0;
//CHECK-NEXT: _d_ret -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -92,16 +90,13 @@ float func(float* a, float* b) {
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t2);
//CHECK-NEXT: float _r_d1 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d1;
//CHECK-NEXT: _d_a[i] += _r_d1;
//CHECK-NEXT: _d_sum -= _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: a[i] = clad::pop(_t1);
//CHECK-NEXT: float _r_d0 = _d_a[i];
//CHECK-NEXT: _d_a[i] += _r_d0 * b[i];
//CHECK-NEXT: float _r0 = a[i] * _r_d0;
//CHECK-NEXT: _d_b[i] += _r0;
//CHECK-NEXT: _d_b[i] += a[i] * _r_d0;
//CHECK-NEXT: _d_a[i] -= _r_d0;
//CHECK-NEXT: _d_a[i];
//CHECK-NEXT: }
Expand All @@ -115,11 +110,7 @@ float helper(float x) {
// CHECK: void helper_pullback(float x, float _d_y, clad::array_ref<float> _d_x) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = _d_y * x;
// CHECK-NEXT: float _r1 = 2 * _d_y;
// CHECK-NEXT: * _d_x += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: * _d_x += 2 * _d_y;
// CHECK-NEXT: }

float func2(float* a) {
Expand Down Expand Up @@ -148,12 +139,10 @@ float func2(float* a) {
//CHECK-NEXT: i--;
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: float _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: float _grad0 = 0.F;
//CHECK-NEXT: helper_pullback(a[i], _r_d0, &_grad0);
//CHECK-NEXT: float _r0 = _grad0;
//CHECK-NEXT: _d_a[i] += _r0;
//CHECK-NEXT: _d_sum -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -185,14 +174,10 @@ float func3(float* a, float* b) {
//CHECK-NEXT: i--;
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: float _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: _d_a[i] += _r_d0;
//CHECK-NEXT: a[i] = clad::pop(_t2);
//CHECK-NEXT: float _r_d1 = _d_a[i];
//CHECK-NEXT: _d_a[i] += _r_d1;
//CHECK-NEXT: _d_b[i] += _r_d1;
//CHECK-NEXT: _d_a[i] -= _r_d1;
//CHECK-NEXT: _d_sum -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -227,23 +212,17 @@ double func4(double x) {
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(arr, 3, _r_d0, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r4(_d_arr);
//CHECK-NEXT: int _r5 = _grad1;
//CHECK-NEXT: _d_sum -= _r_d0;
//CHECK-NEXT: clad::array<double> _r0(_d_arr);
//CHECK-NEXT: int _r1 = _grad1;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += _d_arr[0];
//CHECK-NEXT: double _r0 = _d_arr[1] * x;
//CHECK-NEXT: double _r1 = 2 * _d_arr[1];
//CHECK-NEXT: * _d_x += _r1;
//CHECK-NEXT: double _r2 = _d_arr[2] * x;
//CHECK-NEXT: * _d_x += _r2;
//CHECK-NEXT: double _r3 = x * _d_arr[2];
//CHECK-NEXT: * _d_x += _r3;
//CHECK-NEXT: * _d_x += 2 * _d_arr[1];
//CHECK-NEXT: * _d_x += _d_arr[2] * x;
//CHECK-NEXT: * _d_x += x * _d_arr[2];
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -293,13 +272,11 @@ double func5(int k) {
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: double _r_d1 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d1;
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(arr, n, _r_d1, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r0(_d_arr);
//CHECK-NEXT: int _r1 = _grad1;
//CHECK-NEXT: _d_n += _r1;
//CHECK-NEXT: _d_sum -= _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: for (; _t0; _t0--) {
Expand Down Expand Up @@ -346,19 +323,15 @@ double func6(double seed) {
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(arr, 3, _r_d0, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r2(_d_arr);
//CHECK-NEXT: int _r3 = _grad1;
//CHECK-NEXT: _d_sum -= _r_d0;
//CHECK-NEXT: clad::array<double> _r0(_d_arr);
//CHECK-NEXT: int _r1 = _grad1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_seed += _d_arr[0];
//CHECK-NEXT: double _r0 = _d_arr[1] * i;
//CHECK-NEXT: * _d_seed += _r0;
//CHECK-NEXT: double _r1 = seed * _d_arr[1];
//CHECK-NEXT: _d_i += _r1;
//CHECK-NEXT: * _d_seed += _d_arr[1] * i;
//CHECK-NEXT: _d_i += seed * _d_arr[1];
//CHECK-NEXT: * _d_seed += _d_arr[2];
//CHECK-NEXT: _d_i += _d_arr[2];
//CHECK-NEXT: _d_arr = {};
Expand All @@ -376,12 +349,9 @@ double inv_square(double *params) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y / _t0;
//CHECK-NEXT: double _r1 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: double _r2 = _r1 * params[0];
//CHECK-NEXT: _d_params[0] += _r2;
//CHECK-NEXT: double _r3 = params[0] * _r1;
//CHECK-NEXT: _d_params[0] += _r3;
//CHECK-NEXT: double _r0 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: _d_params[0] += _r0 * params[0];
//CHECK-NEXT: _d_params[0] += params[0] * _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -436,10 +406,8 @@ double helper2(double i, double *arr, int n) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y * i;
//CHECK-NEXT: _d_arr[0] += _r0;
//CHECK-NEXT: double _r1 = arr[0] * _d_y;
//CHECK-NEXT: * _d_i += _r1;
//CHECK-NEXT: _d_arr[0] += _d_y * i;
//CHECK-NEXT: * _d_i += arr[0] * _d_y;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -572,10 +540,8 @@ double sq(double& elem) {
//CHECK-NEXT: {
//CHECK-NEXT: elem = _t0;
//CHECK-NEXT: double _r_d0 = * _d_elem;
//CHECK-NEXT: double _r0 = _r_d0 * elem;
//CHECK-NEXT: * _d_elem += _r0;
//CHECK-NEXT: double _r1 = elem * _r_d0;
//CHECK-NEXT: * _d_elem += _r1;
//CHECK-NEXT: * _d_elem += _r_d0 * elem;
//CHECK-NEXT: * _d_elem += elem * _r_d0;
//CHECK-NEXT: * _d_elem -= _r_d0;
//CHECK-NEXT: * _d_elem;
//CHECK-NEXT: }
Expand Down Expand Up @@ -612,12 +578,10 @@ double func10(double *arr, int n) {
//CHECK-NEXT: {
//CHECK-NEXT: res = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_res += _r_d0;
//CHECK-NEXT: double _r1 = clad::pop(_t2);
//CHECK-NEXT: arr[i] = _r1;
//CHECK-NEXT: sq_pullback(_r1, _r_d0, &_d_arr[i]);
//CHECK-NEXT: double _r0 = _d_arr[i];
//CHECK-NEXT: _d_res -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down
18 changes: 6 additions & 12 deletions test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,12 @@ double const_dot_product(double x, double y, double z) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 1 * consts[0];
//CHECK-NEXT: _d_vars[0] += _r0;
//CHECK-NEXT: double _r1 = vars[0] * 1;
//CHECK-NEXT: _d_consts[0] += _r1;
//CHECK-NEXT: double _r2 = 1 * consts[1];
//CHECK-NEXT: _d_vars[1] += _r2;
//CHECK-NEXT: double _r3 = vars[1] * 1;
//CHECK-NEXT: _d_consts[1] += _r3;
//CHECK-NEXT: double _r4 = 1 * consts[2];
//CHECK-NEXT: _d_vars[2] += _r4;
//CHECK-NEXT: double _r5 = vars[2] * 1;
//CHECK-NEXT: _d_consts[2] += _r5;
//CHECK-NEXT: _d_vars[0] += 1 * consts[0];
//CHECK-NEXT: _d_consts[0] += vars[0] * 1;
//CHECK-NEXT: _d_vars[1] += 1 * consts[1];
//CHECK-NEXT: _d_consts[1] += vars[1] * 1;
//CHECK-NEXT: _d_vars[2] += 1 * consts[2];
//CHECK-NEXT: _d_consts[2] += vars[2] * 1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += _d_vars[0];
Expand Down
18 changes: 6 additions & 12 deletions test/ErrorEstimation/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,10 @@ float func2(float x, int y) {
//CHECK-NEXT: {
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = * _d_x;
//CHECK-NEXT: float _r0 = _r_d0 * x;
//CHECK-NEXT: * _d_y += _r0;
//CHECK-NEXT: float _r1 = y * _r_d0;
//CHECK-NEXT: * _d_x += _r1;
//CHECK-NEXT: float _r2 = _r_d0 * x;
//CHECK-NEXT: * _d_x += _r2;
//CHECK-NEXT: float _r3 = x * _r_d0;
//CHECK-NEXT: * _d_x += _r3;
//CHECK-NEXT: * _d_y += _r_d0 * x;
//CHECK-NEXT: * _d_x += y * _r_d0;
//CHECK-NEXT: * _d_x += _r_d0 * x;
//CHECK-NEXT: * _d_x += x * _r_d0;
//CHECK-NEXT: _delta_x += std::abs(_r_d0 * _EERepl_x1 * {{.+}});
//CHECK-NEXT: * _d_x -= _r_d0;
//CHECK-NEXT: * _d_x;
Expand Down Expand Up @@ -195,10 +191,8 @@ float func7(float x, float y) { return (x * y); }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r0 = 1 * y;
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: float _r1 = x * 1;
//CHECK-NEXT: * _d_y += _r1;
//CHECK-NEXT: * _d_x += 1 * y;
//CHECK-NEXT: * _d_y += x * 1;
//CHECK-NEXT: }
//CHECK-NEXT: double _delta_x = 0;
//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}});
Expand Down
Loading

0 comments on commit 7b57f7c

Please sign in to comment.