Skip to content

Commit

Permalink
Remove excessive stores for the RHS in *= operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Nov 24, 2023
1 parent 54df7e6 commit c40d4af
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 116 deletions.
52 changes: 26 additions & 26 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2397,37 +2397,37 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, BuildOp(UO_Minus, oldValue));
valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepAsExpr(), Ldiff.getRevSweepAsExpr());
} else if (opCode == BO_MulAssign) {
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
// Create a reference variable to keep the result of LHS, since it
// must be used on 2 places: when storing to a global variable
// accessible from the reverse pass, and when rebuilding the original
// expression for the forward pass. This allows to avoid executing
// same expression with side effects twice. E.g., on
// double r = (x *= y) *= z;
// instead of:
// _t0 = (x *= y);
// double r = (x *= y) *= z;
// which modifies x twice, we get:
// double & _ref0 = (x *= y);
// _t0 = _ref0;
// double r = _ref0 *= z;
if (isInsideLoop)
addToCurrentBlock(LCloned, direction::forward);
/// Capture all the emitted statements while visiting R
/// and insert them after `dl += dl * R`
beginBlock(direction::reverse);
Expr* dr = BuildOp(BO_Mul, LCloned, oldValue);
dr = StoreAndRef(dr, direction::reverse);
Rdiff = Visit(R, dr);
Stmts RBlock = EndBlockWithoutCreatingCS(direction::reverse);
addToCurrentBlock(
BuildOp(BO_AddAssign,
AssignedDiff,
BuildOp(BO_Mul, oldValue, RResult.getExpr_dx())),
BuildOp(BO_Mul, oldValue, Rdiff.getRevSweepAsExpr())),
direction::reverse);
if (!RDelayed.isConstant) {
// Create a reference variable to keep the result of LHS, since it
// must be used on 2 places: when storing to a global variable
// accessible from the reverse pass, and when rebuilding the original
// expression for the forward pass. This allows to avoid executing
// same expression with side effects twice. E.g., on
// double r = (x *= y) *= z;
// instead of:
// _t0 = (x *= y);
// double r = (x *= y) *= z;
// which modifies x twice, we get:
// double & _ref0 = (x *= y);
// _t0 = _ref0;
// double r = _ref0 *= z;

if (isInsideLoop)
addToCurrentBlock(LCloned, direction::forward);
Expr* dr = BuildOp(BO_Mul, LCloned, oldValue);
dr = StoreAndRef(dr, direction::reverse);
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
for (auto& S : RBlock)
addToCurrentBlock(S, direction::reverse);
valueForRevPass = BuildOp(BO_Mul, Rdiff.getRevSweepAsExpr(), Ldiff.getRevSweepAsExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, Rdiff.getExpr());
} else if (opCode == BO_DivAssign) {
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expand Down
9 changes: 4 additions & 5 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,13 @@ float func(float* a, float* b) {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<float> _t1 = {};
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: clad::tape<float> _t3 = {};
//CHECK-NEXT: float sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, a[i]);
//CHECK-NEXT: a[i] *= clad::push(_t2, b[i]);
//CHECK-NEXT: clad::push(_t3, sum);
//CHECK-NEXT: a[i] *= b[i];
//CHECK-NEXT: clad::push(_t2, sum);
//CHECK-NEXT: sum += a[i];
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand All @@ -91,7 +90,7 @@ float func(float* a, float* b) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: sum = clad::pop(_t2);
//CHECK-NEXT: float _r_d1 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d1;
//CHECK-NEXT: _d_a[i] += _r_d1;
Expand All @@ -100,7 +99,7 @@ float func(float* a, float* b) {
//CHECK-NEXT: {
//CHECK-NEXT: a[i] = clad::pop(_t1);
//CHECK-NEXT: float _r_d0 = _d_a[i];
//CHECK-NEXT: _d_a[i] += _r_d0 * clad::pop(_t2);
//CHECK-NEXT: _d_a[i] += _r_d0 * b[i];
//CHECK-NEXT: float _r0 = a[i] * _r_d0;
//CHECK-NEXT: _d_b[i] += _r0;
//CHECK-NEXT: _d_a[i] -= _r_d0;
Expand Down
19 changes: 8 additions & 11 deletions test/ErrorEstimation/ConditonalStatements.C
Original file line number Diff line number Diff line change
Expand Up @@ -188,28 +188,25 @@ float func4(float x, float y) {
//CHECK-NEXT: float _EERepl_x0 = x;
//CHECK-NEXT: float _EERepl_x1;
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _t2;
//CHECK-NEXT: float _EERepl_x2;
//CHECK-NEXT: float _t3;
//CHECK-NEXT: float _t2;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: _cond0 = !x;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: else {
//CHECK-NEXT: else
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: }
//CHECK-NEXT: _cond0 ? (x += 1) : (x *= _t2);
//CHECK-NEXT: _cond0 ? (x += 1) : (x *= x);
//CHECK-NEXT: _EERepl_x2 = x;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: _t3 = x;
//CHECK-NEXT: _ret_value0 = y / _t3;
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: _ret_value0 = y / _t2;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r1 = 1 / _t3;
//CHECK-NEXT: float _r1 = 1 / _t2;
//CHECK-NEXT: * _d_y += _r1;
//CHECK-NEXT: float _r2 = 1 * -y / (_t3 * _t3);
//CHECK-NEXT: float _r2 = 1 * -y / (_t2 * _t2);
//CHECK-NEXT: * _d_x += _r2;
//CHECK-NEXT: }
//CHECK-NEXT: {
Expand All @@ -222,7 +219,7 @@ float func4(float x, float y) {
//CHECK-NEXT: } else {
//CHECK-NEXT: x = _t1;
//CHECK-NEXT: float _r_d1 = * _d_x;
//CHECK-NEXT: * _d_x += _r_d1 * _t2;
//CHECK-NEXT: * _d_x += _r_d1 * x;
//CHECK-NEXT: float _r0 = x * _r_d1;
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: _delta_x += std::abs(_r_d1 * _EERepl_x2 * {{.+}});
Expand Down
Loading

0 comments on commit c40d4af

Please sign in to comment.