Skip to content

Commit

Permalink
Simplify derivative statements in VisitBinaryOperator.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Dec 29, 2023
1 parent 1921b56 commit 79d7f13
Show file tree
Hide file tree
Showing 36 changed files with 812 additions and 1,895 deletions.
21 changes: 6 additions & 15 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2171,10 +2171,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

Expr* dl = nullptr;
if (dfdx()) {
if (dfdx())
dl = BuildOp(BO_Mul, dfdx(), RResult.getRevSweepAsExpr());
dl = StoreAndRef(dl, direction::reverse);
}
Ldiff = Visit(L, dl);
// dxi/xr = xl
// df/dxr += df/dxi * dxi/xr = df/dxi * xl
Expand All @@ -2184,10 +2182,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (RDelayed ||
!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
Expr* dr = nullptr;
if (dfdx()) {
if (dfdx())
dr = BuildOp(BO_Mul, Ldiff.getRevSweepAsExpr(), dfdx());
dr = StoreAndRef(dr, direction::reverse);
}
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
if (RDelayed)
Expand All @@ -2202,10 +2198,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff RResult = RDelayed.Result;
Expr* RStored = StoreAndRef(RResult.getExpr_dx(), direction::reverse);
Expr* dl = nullptr;
if (dfdx()) {
if (dfdx())
dl = BuildOp(BO_Div, dfdx(), RStored);
dl = StoreAndRef(dl, direction::reverse);
}
Ldiff = Visit(L, dl);
// dxi/xr = -xl / (xr * xr)
// df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr))
Expand Down Expand Up @@ -2357,14 +2351,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, oldValue);
valueForRevPass = Rdiff.getRevSweepAsExpr();
} else if (opCode == BO_AddAssign) {
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, oldValue);
valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
} else if (opCode == BO_SubAssign) {
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, BuildOp(UO_Minus, oldValue));
valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
Expand All @@ -2388,7 +2378,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/// and insert them after `dl += dl * R`
beginBlock(direction::reverse);
Expr* dr = BuildOp(BO_Mul, LCloned, oldValue);
dr = StoreAndRef(dr, direction::reverse);
Rdiff = Visit(R, dr);
Stmts RBlock = EndBlockWithoutCreatingCS(direction::reverse);
addToCurrentBlock(
Expand Down Expand Up @@ -2426,7 +2415,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActBeforeFinalisingAssignOp(LCloned, oldValue);

// Update the derivative.
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), direction::reverse);
if (opCode != BO_SubAssign && opCode != BO_AddAssign)
addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue),
direction::reverse);
// Output statements from Visit(L).
for (auto it = Lblock_begin; it != Lblock_end; ++it)
addToCurrentBlock(*it, direction::reverse);
Expand Down
72 changes: 18 additions & 54 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ double addArr(const double *arr, int n) {
//CHECK-NEXT: {
//CHECK-NEXT: ret = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_ret;
//CHECK-NEXT: _d_ret += _r_d0;
//CHECK-NEXT: _d_arr[i] += _r_d0;
//CHECK-NEXT: _d_ret -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down Expand Up @@ -92,16 +90,13 @@ float func(float* a, float* b) {
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t2);
//CHECK-NEXT: float _r_d1 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d1;
//CHECK-NEXT: _d_a[i] += _r_d1;
//CHECK-NEXT: _d_sum -= _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: a[i] = clad::pop(_t1);
//CHECK-NEXT: float _r_d0 = _d_a[i];
//CHECK-NEXT: _d_a[i] += _r_d0 * b[i];
//CHECK-NEXT: float _r0 = a[i] * _r_d0;
//CHECK-NEXT: _d_b[i] += _r0;
//CHECK-NEXT: _d_b[i] += a[i] * _r_d0;
//CHECK-NEXT: _d_a[i] -= _r_d0;
//CHECK-NEXT: _d_a[i];
//CHECK-NEXT: }
Expand All @@ -115,11 +110,7 @@ float helper(float x) {
// CHECK: void helper_pullback(float x, float _d_y, clad::array_ref<float> _d_x) {
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = _d_y * x;
// CHECK-NEXT: float _r1 = 2 * _d_y;
// CHECK-NEXT: * _d_x += _r1;
// CHECK-NEXT: }
// CHECK-NEXT: * _d_x += 2 * _d_y;
// CHECK-NEXT: }

float func2(float* a) {
Expand Down Expand Up @@ -148,12 +139,10 @@ float func2(float* a) {
//CHECK-NEXT: i--;
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: float _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: float _grad0 = 0.F;
//CHECK-NEXT: helper_pullback(a[i], _r_d0, &_grad0);
//CHECK-NEXT: float _r0 = _grad0;
//CHECK-NEXT: _d_a[i] += _r0;
//CHECK-NEXT: _d_sum -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -185,14 +174,10 @@ float func3(float* a, float* b) {
//CHECK-NEXT: i--;
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: float _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: _d_a[i] += _r_d0;
//CHECK-NEXT: a[i] = clad::pop(_t2);
//CHECK-NEXT: float _r_d1 = _d_a[i];
//CHECK-NEXT: _d_a[i] += _r_d1;
//CHECK-NEXT: _d_b[i] += _r_d1;
//CHECK-NEXT: _d_a[i] -= _r_d1;
//CHECK-NEXT: _d_sum -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -227,23 +212,17 @@ double func4(double x) {
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(arr, 3, _r_d0, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r4(_d_arr);
//CHECK-NEXT: int _r5 = _grad1;
//CHECK-NEXT: _d_sum -= _r_d0;
//CHECK-NEXT: clad::array<double> _r0(_d_arr);
//CHECK-NEXT: int _r1 = _grad1;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += _d_arr[0];
//CHECK-NEXT: double _r0 = _d_arr[1] * x;
//CHECK-NEXT: double _r1 = 2 * _d_arr[1];
//CHECK-NEXT: * _d_x += _r1;
//CHECK-NEXT: double _r2 = _d_arr[2] * x;
//CHECK-NEXT: * _d_x += _r2;
//CHECK-NEXT: double _r3 = x * _d_arr[2];
//CHECK-NEXT: * _d_x += _r3;
//CHECK-NEXT: * _d_x += 2 * _d_arr[1];
//CHECK-NEXT: * _d_x += _d_arr[2] * x;
//CHECK-NEXT: * _d_x += x * _d_arr[2];
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -293,13 +272,11 @@ double func5(int k) {
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: double _r_d1 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d1;
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(arr, n, _r_d1, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r0(_d_arr);
//CHECK-NEXT: int _r1 = _grad1;
//CHECK-NEXT: _d_n += _r1;
//CHECK-NEXT: _d_sum -= _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: for (; _t0; _t0--) {
Expand Down Expand Up @@ -346,19 +323,15 @@ double func6(double seed) {
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(arr, 3, _r_d0, _d_arr, &_grad1);
//CHECK-NEXT: clad::array<double> _r2(_d_arr);
//CHECK-NEXT: int _r3 = _grad1;
//CHECK-NEXT: _d_sum -= _r_d0;
//CHECK-NEXT: clad::array<double> _r0(_d_arr);
//CHECK-NEXT: int _r1 = _grad1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_seed += _d_arr[0];
//CHECK-NEXT: double _r0 = _d_arr[1] * i;
//CHECK-NEXT: * _d_seed += _r0;
//CHECK-NEXT: double _r1 = seed * _d_arr[1];
//CHECK-NEXT: _d_i += _r1;
//CHECK-NEXT: * _d_seed += _d_arr[1] * i;
//CHECK-NEXT: _d_i += seed * _d_arr[1];
//CHECK-NEXT: * _d_seed += _d_arr[2];
//CHECK-NEXT: _d_i += _d_arr[2];
//CHECK-NEXT: _d_arr = {};
Expand All @@ -376,12 +349,9 @@ double inv_square(double *params) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y / _t0;
//CHECK-NEXT: double _r1 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: double _r2 = _r1 * params[0];
//CHECK-NEXT: _d_params[0] += _r2;
//CHECK-NEXT: double _r3 = params[0] * _r1;
//CHECK-NEXT: _d_params[0] += _r3;
//CHECK-NEXT: double _r0 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: _d_params[0] += _r0 * params[0];
//CHECK-NEXT: _d_params[0] += params[0] * _r0;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -436,10 +406,8 @@ double helper2(double i, double *arr, int n) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y * i;
//CHECK-NEXT: _d_arr[0] += _r0;
//CHECK-NEXT: double _r1 = arr[0] * _d_y;
//CHECK-NEXT: * _d_i += _r1;
//CHECK-NEXT: _d_arr[0] += _d_y * i;
//CHECK-NEXT: * _d_i += arr[0] * _d_y;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down Expand Up @@ -572,10 +540,8 @@ double sq(double& elem) {
//CHECK-NEXT: {
//CHECK-NEXT: elem = _t0;
//CHECK-NEXT: double _r_d0 = * _d_elem;
//CHECK-NEXT: double _r0 = _r_d0 * elem;
//CHECK-NEXT: * _d_elem += _r0;
//CHECK-NEXT: double _r1 = elem * _r_d0;
//CHECK-NEXT: * _d_elem += _r1;
//CHECK-NEXT: * _d_elem += _r_d0 * elem;
//CHECK-NEXT: * _d_elem += elem * _r_d0;
//CHECK-NEXT: * _d_elem -= _r_d0;
//CHECK-NEXT: * _d_elem;
//CHECK-NEXT: }
Expand Down Expand Up @@ -612,12 +578,10 @@ double func10(double *arr, int n) {
//CHECK-NEXT: {
//CHECK-NEXT: res = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_res += _r_d0;
//CHECK-NEXT: double _r1 = clad::pop(_t2);
//CHECK-NEXT: arr[i] = _r1;
//CHECK-NEXT: sq_pullback(_r1, _r_d0, &_d_arr[i]);
//CHECK-NEXT: double _r0 = _d_arr[i];
//CHECK-NEXT: _d_res -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand Down
18 changes: 6 additions & 12 deletions test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,12 @@ double const_dot_product(double x, double y, double z) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 1 * consts[0];
//CHECK-NEXT: _d_vars[0] += _r0;
//CHECK-NEXT: double _r1 = vars[0] * 1;
//CHECK-NEXT: _d_consts[0] += _r1;
//CHECK-NEXT: double _r2 = 1 * consts[1];
//CHECK-NEXT: _d_vars[1] += _r2;
//CHECK-NEXT: double _r3 = vars[1] * 1;
//CHECK-NEXT: _d_consts[1] += _r3;
//CHECK-NEXT: double _r4 = 1 * consts[2];
//CHECK-NEXT: _d_vars[2] += _r4;
//CHECK-NEXT: double _r5 = vars[2] * 1;
//CHECK-NEXT: _d_consts[2] += _r5;
//CHECK-NEXT: _d_vars[0] += 1 * consts[0];
//CHECK-NEXT: _d_consts[0] += vars[0] * 1;
//CHECK-NEXT: _d_vars[1] += 1 * consts[1];
//CHECK-NEXT: _d_consts[1] += vars[1] * 1;
//CHECK-NEXT: _d_vars[2] += 1 * consts[2];
//CHECK-NEXT: _d_consts[2] += vars[2] * 1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += _d_vars[0];
Expand Down
49 changes: 17 additions & 32 deletions test/CUDA/GradientCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,52 +56,37 @@ auto gauss_g = clad::gradient(gauss, "p");
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r8 = 1 * _t4;
//CHECK-NEXT: double _r9 = _r8 * _t5;
//CHECK-NEXT: double _grad0 = 0.;
//CHECK-NEXT: double _grad1 = 0.;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(2 * 3.1415926535897931, -dim / _t6, _r9, &_grad0, &_grad1);
//CHECK-NEXT: double _r10 = _grad0;
//CHECK-NEXT: double _r11 = _r10 * 3.1415926535897931;
//CHECK-NEXT: double _r12 = _grad1;
//CHECK-NEXT: double _r13 = _r12 / _t6;
//CHECK-NEXT: _d_dim += -_r13;
//CHECK-NEXT: double _r14 = _r12 * --dim / (_t6 * _t6);
//CHECK-NEXT: double _r15 = std::pow(2 * 3.1415926535897931, -dim / _t6) * _r8;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(2 * 3.1415926535897931, -dim / _t6, 1 * _t4 * _t5, &_grad0, &_grad1);
//CHECK-NEXT: double _r1 = _grad0;
//CHECK-NEXT: double _r2 = _grad1;
//CHECK-NEXT: _d_dim += -_r2 / _t6;
//CHECK-NEXT: double _r3 = _r2 * --dim / (_t6 * _t6);
//CHECK-NEXT: double _grad2 = 0.;
//CHECK-NEXT: double _grad3 = 0.;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(sigma, -0.5, _r15, &_grad2, &_grad3);
//CHECK-NEXT: double _r16 = _grad2;
//CHECK-NEXT: _d_sigma += _r16;
//CHECK-NEXT: double _r17 = _grad3;
//CHECK-NEXT: double _r18 = std::pow(2 * 3.1415926535897931, -dim / _t6) * _t5 * 1;
//CHECK-NEXT: double _r19 = _r18 * clad::custom_derivatives::exp_pushforward(t, 1.).pushforward;
//CHECK-NEXT: _d_t += _r19;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(sigma, -0.5, std::pow(2 * 3.1415926535897931, -dim / _t6) * 1 * _t4, &_grad2, &_grad3);
//CHECK-NEXT: double _r4 = _grad2;
//CHECK-NEXT: _d_sigma += _r4;
//CHECK-NEXT: double _r5 = _grad3;
//CHECK-NEXT: double _r6 = std::pow(2 * 3.1415926535897931, -dim / _t6) * _t5 * 1 * clad::custom_derivatives::exp_pushforward(t, 1.).pushforward;
//CHECK-NEXT: _d_t += _r6;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: t = _t2;
//CHECK-NEXT: double _r_d1 = _d_t;
//CHECK-NEXT: double _r2 = _r_d1 / _t3;
//CHECK-NEXT: _d_t += -_r2;
//CHECK-NEXT: double _r3 = _r_d1 * --t / (_t3 * _t3);
//CHECK-NEXT: double _r4 = _r3 * sigma;
//CHECK-NEXT: double _r5 = _r4 * sigma;
//CHECK-NEXT: double _r6 = 2 * _r4;
//CHECK-NEXT: _d_sigma += _r6;
//CHECK-NEXT: double _r7 = 2 * sigma * _r3;
//CHECK-NEXT: _d_sigma += _r7;
//CHECK-NEXT: _d_t += -_r_d1 / _t3;
//CHECK-NEXT: double _r0 = _r_d1 * --t / (_t3 * _t3);
//CHECK-NEXT: _d_sigma += 2 * _r0 * sigma;
//CHECK-NEXT: _d_sigma += 2 * sigma * _r0;
//CHECK-NEXT: _d_t -= _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: t = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_t;
//CHECK-NEXT: _d_t += _r_d0;
//CHECK-NEXT: double _r0 = _r_d0 * (x[i] - p[i]);
//CHECK-NEXT: _d_p[i] += -_r0;
//CHECK-NEXT: double _r1 = (x[i] - p[i]) * _r_d0;
//CHECK-NEXT: _d_p[i] += -_r1;
//CHECK-NEXT: _d_t -= _r_d0;
//CHECK-NEXT: _d_p[i] += -_r_d0 * (x[i] - p[i]);
//CHECK-NEXT: _d_p[i] += -(x[i] - p[i]) * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

Expand Down
6 changes: 2 additions & 4 deletions test/Enzyme/DifferentCladEnzymeDerivatives.C
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ double foo(double x, double y){
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: double _r0 = 1 * y;
// CHECK-NEXT: * _d_x += _r0;
// CHECK-NEXT: double _r1 = x * 1;
// CHECK-NEXT: * _d_y += _r1;
// CHECK-NEXT: * _d_x += 1 * y;
// CHECK-NEXT: * _d_y += x * 1;
// CHECK-NEXT: }
// CHECK-NEXT: }

Expand Down
Loading

0 comments on commit 79d7f13

Please sign in to comment.