Skip to content

Commit

Permalink
Improve storing of LHS/RHS in multiplication/division operators.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 28, 2023
1 parent 19db17d commit 63d9290
Show file tree
Hide file tree
Showing 18 changed files with 245 additions and 149 deletions.
4 changes: 4 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ namespace clad {
/// to avoid recomputiation.
bool UsefulToStoreGlobal(clang::Expr* E);

/// For an expr E, decides if we should recompute it or store it.
/// This is the central point for checkpointing.
bool ShouldRecompute(const clang::Expr* E);

/// Builds a variable declaration and stores it in the function
/// global scope.
///
Expand Down
4 changes: 4 additions & 0 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ namespace clad {
bool forceDeclCreation = false,
clang::VarDecl::InitializationStyle IS =
clang::VarDecl::InitializationStyle::CInit);
/// For an expr E, decides if it is useful to store it in a temporary
/// variable and replace E's further usage by a reference to that variable
/// to avoid recomputation.
static bool UsefulToStore(clang::Expr* E);
/// A flag for silencing warnings/errors output by diag function.
bool silenceDiags = false;
/// Shorthand to issues a warning or error.
Expand Down
46 changes: 30 additions & 16 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2164,7 +2164,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff RResult;
// If R has no side effects, it can be just cloned
// (no need to store it).
if (utils::ContainsFunctionCalls(R) || R->HasSideEffects(m_Context)) {
if (!ShouldRecompute(R)) {
RDelayed = std::unique_ptr<DelayedStoreResult>(
new DelayedStoreResult(DelayedGlobalStoreAndRef(R)));
RResult = RDelayed->Result;
Expand All @@ -2179,30 +2179,39 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// dxi/xr = xl
// df/dxr += df/dxi * dxi/xr = df/dxi * xl
// Store left multiplier and assign it with L.
Expr* LStored = Ldiff.getExpr();
StmtDiff LStored = Ldiff;
if (!ShouldRecompute(LStored.getExpr()))
LStored = GlobalStoreAndRef(LStored.getExpr(), /*prefix=*/"_t",
/*force=*/true);
Expr::EvalResult dummy;
if (RDelayed ||
!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
Expr* dr = nullptr;
if (dfdx())
dr = BuildOp(BO_Mul, Ldiff.getRevSweepAsExpr(), dfdx());
dr = BuildOp(BO_Mul, LStored.getRevSweepAsExpr(), dfdx());
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
if (RDelayed)
RDelayed->Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult.getExpr());
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
} else if (opCode == BO_Div) {
// xi = xl / xr
// dxi/xl = 1 / xr
// df/dxl += df/dxi * dxi/xl = df/dxi * (1/xr)
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expr* RStored = StoreAndRef(RResult.getExpr_dx(), direction::reverse);
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
Expr* dl = nullptr;
if (dfdx())
dl = BuildOp(BO_Div, dfdx(), RStored);
Ldiff = Visit(L, dl);
StmtDiff LStored = Ldiff;
if (!ShouldRecompute(LStored.getExpr()))
LStored = GlobalStoreAndRef(LStored.getExpr(), /*prefix=*/"_t",
/*force=*/true);
// dxi/xr = -xl / (xr * xr)
// df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr))
// Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is
Expand All @@ -2211,17 +2220,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* dr = nullptr;
if (dfdx()) {
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
dr =
BuildOp(BO_Mul, dfdx(),
BuildOp(UO_Minus,
BuildOp(BO_Div, Ldiff.getRevSweepAsExpr(), RxR)));
dr = BuildOp(
BO_Mul, dfdx(),
BuildOp(UO_Minus,
BuildOp(BO_Div, LStored.getRevSweepAsExpr(), RxR)));
dr = StoreAndRef(dr, direction::reverse);
}
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) =
std::make_pair(Ldiff.getExpr(), RResult.getExpr());
std::make_pair(LStored.getExpr(), RResult.getExpr());
} else if (BinOp->isAssignmentOp()) {
if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning,
Expand Down Expand Up @@ -2394,7 +2403,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else if (opCode == BO_DivAssign) {
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expr* RStored = StoreAndRef(RResult.getExpr_dx(), direction::reverse);
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff,
BuildOp(BO_Div, oldValue, RStored)),
direction::reverse);
Expand Down Expand Up @@ -2787,6 +2797,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx());
}

bool ReverseModeVisitor::ShouldRecompute(const Expr* E) {
return !(utils::ContainsFunctionCalls(E) || E->HasSideEffects(m_Context));
}

bool ReverseModeVisitor::UsefulToStoreGlobal(Expr* E) {
if (!E)
return false;
Expand Down Expand Up @@ -2946,12 +2960,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E,
llvm::StringRef prefix) {
assert(E && "must be provided");
if (isa<DeclRefExpr>(E) /*!UsefulToStoreGlobal(E)*/) {
Expr* Cloned = Clone(E);
if (!UsefulToStore(E)) {
StmtDiff Ediff = Visit(E);
Expr::EvalResult evalRes;
bool isConst =
clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context);
return DelayedStoreResult{*this, StmtDiff{Cloned, Cloned},
return DelayedStoreResult{*this, Ediff,
/*isConstant*/ isConst,
/*isInsideLoop*/ false,
/*pNeedsUpdate=*/false};
Expand All @@ -2961,14 +2975,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto CladTape = MakeCladTapeFor(dummy);
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
return DelayedStoreResult{*this, StmtDiff{Push, Pop, nullptr, Pop},
return DelayedStoreResult{*this, StmtDiff{Push, nullptr, nullptr, Pop},
/*isConstant*/ false,
/*isInsideLoop*/ true, /*pNeedsUpdate=*/true};
}
Expr* Ref = BuildDeclRef(GlobalStoreImpl(
getNonConstType(E->getType(), m_Context, m_Sema), prefix));
// Return reference to the declaration instead of original expression.
return DelayedStoreResult{*this, StmtDiff{Ref, Ref},
return DelayedStoreResult{*this, StmtDiff{Ref, nullptr, nullptr, Ref},
/*isConstant*/ false,
/*isInsideLoop*/ false, /*pNeedsUpdate=*/true};
}
Expand Down
5 changes: 1 addition & 4 deletions lib/Differentiator/VisitorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,7 @@ namespace clad {
return StoreAndRef(E, Type, block, prefix, forceDeclCreation, IS);
}

/// For an expr E, decides if it is useful to store it in a temporary variable
/// and replace E's further usage by a reference to that variable to avoid
/// recomputiation.
static bool UsefulToStore(Expr* E) {
bool VisitorBase::UsefulToStore(Expr* E) {
if (!E)
return false;
Expr* B = E->IgnoreParenImpCasts();
Expand Down
19 changes: 9 additions & 10 deletions test/CUDA/GradientCuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,26 @@ auto gauss_g = clad::gradient(gauss, "p");
//CHECK-NEXT: _t2 = t;
//CHECK-NEXT: _t3 = (2 * sigma * sigma);
//CHECK-NEXT: t = -t / _t3;
//CHECK-NEXT: _t6 = 2.;
//CHECK-NEXT: _t6 = std::pow(2 * 3.1415926535897931, -dim / 2.);
//CHECK-NEXT: _t5 = std::pow(sigma, -0.5);
//CHECK-NEXT: _t4 = std::exp(t);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _grad0 = 0.;
//CHECK-NEXT: double _grad1 = 0.;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(2 * 3.1415926535897931, -dim / _t6, 1 * _t4 * _t5, &_grad0, &_grad1);
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(2 * 3.1415926535897931, -dim / 2., 1 * _t4 * _t5, &_grad0, &_grad1);
//CHECK-NEXT: double _r1 = _grad0;
//CHECK-NEXT: double _r2 = _grad1;
//CHECK-NEXT: _d_dim += -_r2 / _t6;
//CHECK-NEXT: double _r3 = _r2 * --dim / (_t6 * _t6);
//CHECK-NEXT: _d_dim += -_r2 / 2.;
//CHECK-NEXT: double _grad2 = 0.;
//CHECK-NEXT: double _grad3 = 0.;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(sigma, -0.5, std::pow(2 * 3.1415926535897931, -dim / _t6) * 1 * _t4, &_grad2, &_grad3);
//CHECK-NEXT: double _r4 = _grad2;
//CHECK-NEXT: _d_sigma += _r4;
//CHECK-NEXT: double _r5 = _grad3;
//CHECK-NEXT: double _r6 = std::pow(2 * 3.1415926535897931, -dim / _t6) * _t5 * 1 * clad::custom_derivatives::exp_pushforward(t, 1.).pushforward;
//CHECK-NEXT: _d_t += _r6;
//CHECK-NEXT: clad::custom_derivatives{{(::std)?}}::pow_pullback(sigma, -0.5, _t6 * 1 * _t4, &_grad2, &_grad3);
//CHECK-NEXT: double _r3 = _grad2;
//CHECK-NEXT: _d_sigma += _r3;
//CHECK-NEXT: double _r4 = _grad3;
//CHECK-NEXT: double _r5 = _t6 * _t5 * 1 * clad::custom_derivatives::exp_pushforward(t, 1.).pushforward;
//CHECK-NEXT: _d_t += _r5;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: t = _t2;
Expand Down
26 changes: 13 additions & 13 deletions test/ErrorEstimation/BasicOps.C
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,20 @@ float func2(float x, float y) {
//CHECK-NEXT: double _delta_x = 0;
//CHECK-NEXT: float _EERepl_x0 = x;
//CHECK-NEXT: float _EERepl_x1;
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _d_z = 0;
//CHECK-NEXT: double _delta_z = 0;
//CHECK-NEXT: float _EERepl_z0;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = x - y - y * y;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: float z = y / _t1;
//CHECK-NEXT: float z = y / x;
//CHECK-NEXT: _EERepl_z0 = z;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
//CHECK-NEXT: * _d_y += _d_z / _t1;
//CHECK-NEXT: float _r0 = _d_z * -y / (_t1 * _t1);
//CHECK-NEXT: * _d_y += _d_z / x;
//CHECK-NEXT: float _r0 = _d_z * -y / (x * x);
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: _delta_z += std::abs(_d_z * _EERepl_z0 * {{.+}});
//CHECK-NEXT: }
Expand Down Expand Up @@ -392,16 +390,18 @@ float func9(float x, float y) {
//CHECK-NEXT: float _t3;
//CHECK-NEXT: double _t4;
//CHECK-NEXT: float _t5;
//CHECK-NEXT: float _t7;
//CHECK-NEXT: double _t7;
//CHECK-NEXT: float _t8;
//CHECK-NEXT: float _EERepl_z1;
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: float z = helper(x, y) + helper2(x);
//CHECK-NEXT: _EERepl_z0 = z;
//CHECK-NEXT: _t3 = z;
//CHECK-NEXT: _t5 = x;
//CHECK-NEXT: _t7 = y;
//CHECK-NEXT: _t7 = helper2(x);
//CHECK-NEXT: _t8 = y;
//CHECK-NEXT: _t4 = helper2(y);
//CHECK-NEXT: z += helper2(x) * _t4;
//CHECK-NEXT: z += _t7 * _t4;
//CHECK-NEXT: _EERepl_z1 = z;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
Expand All @@ -413,12 +413,12 @@ float func9(float x, float y) {
//CHECK-NEXT: double _t6 = 0;
//CHECK-NEXT: helper2_pullback(_t5, _r_d0 * _t4, &* _d_x, _t6);
//CHECK-NEXT: float _r3 = * _d_x;
//CHECK-NEXT: y = _t7;
//CHECK-NEXT: double _t8 = 0;
//CHECK-NEXT: helper2_pullback(_t7, helper2(x) * _r_d0, &* _d_y, _t8);
//CHECK-NEXT: y = _t8;
//CHECK-NEXT: double _t9 = 0;
//CHECK-NEXT: helper2_pullback(_t8, _t7 * _r_d0, &* _d_y, _t9);
//CHECK-NEXT: float _r4 = * _d_y;
//CHECK-NEXT: _delta_z += _t6 + _t8;
//CHECK-NEXT: _final_error += std::abs(_r4 * _t7 * {{.+}});
//CHECK-NEXT: _delta_z += _t6 + _t9;
//CHECK-NEXT: _final_error += std::abs(_r4 * _t8 * {{.+}});
//CHECK-NEXT: _final_error += std::abs(_r3 * _t5 * {{.+}});
//CHECK-NEXT: }
//CHECK-NEXT: {
Expand Down
8 changes: 3 additions & 5 deletions test/ErrorEstimation/ConditonalStatements.C
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ float func4(float x, float y) {
//CHECK-NEXT: float _EERepl_x1;
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _EERepl_x2;
//CHECK-NEXT: float _t2;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: _cond0 = !x;
//CHECK-NEXT: if (_cond0)
Expand All @@ -178,13 +177,12 @@ float func4(float x, float y) {
//CHECK-NEXT: _cond0 ? (x += 1) : (x *= x);
//CHECK-NEXT: _EERepl_x2 = x;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: _ret_value0 = y / _t2;
//CHECK-NEXT: _ret_value0 = y / x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: * _d_y += 1 / _t2;
//CHECK-NEXT: float _r0 = 1 * -y / (_t2 * _t2);
//CHECK-NEXT: * _d_y += 1 / x;
//CHECK-NEXT: float _r0 = 1 * -y / (x * x);
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: }
//CHECK-NEXT: {
Expand Down
26 changes: 12 additions & 14 deletions test/ErrorEstimation/LoopsAndArraysExec.C
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,14 @@ double divSum(float* a, float* b, int n) {
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<float> _t2 = {};
//CHECK-NEXT: clad::tape<double> _EERepl_sum1 = {};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _EERepl_sum0 = sum;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < n; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += a[i] / clad::push(_t2, b[i]);
//CHECK-NEXT: sum += a[i] / b[i];
//CHECK-NEXT: clad::push(_EERepl_sum1, sum);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand All @@ -165,28 +164,27 @@ double divSum(float* a, float* b, int n) {
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: float _r0 = clad::pop(_t2);
//CHECK-NEXT: _d_a[i] += _r_d0 / _r0;
//CHECK-NEXT: double _r1 = _r_d0 * -a[i] / (_r0 * _r0);
//CHECK-NEXT: _d_b[i] += _r1;
//CHECK-NEXT: double _r2 = clad::pop(_EERepl_sum1);
//CHECK-NEXT: _delta_sum += std::abs(_r_d0 * _r2 * {{.+}});
//CHECK-NEXT: _d_a[i] += _r_d0 / b[i];
//CHECK-NEXT: double _r0 = _r_d0 * -a[i] / (b[i] * b[i]);
//CHECK-NEXT: _d_b[i] += _r0;
//CHECK-NEXT: double _r1 = clad::pop(_EERepl_sum1);
//CHECK-NEXT: _delta_sum += std::abs(_r_d0 * _r1 * {{.+}});
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: _delta_sum += std::abs(_d_sum * _EERepl_sum0 * {{.+}});
//CHECK-NEXT: clad::array<float> _delta_a(_d_a.size());
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: for (; i < _d_a.size(); i++) {
//CHECK-NEXT: double _t3 = std::abs(_d_a[i] * a[i] * {{.+}});
//CHECK-NEXT: _delta_a[i] += _t3;
//CHECK-NEXT: _final_error += _t3;
//CHECK-NEXT: double _t2 = std::abs(_d_a[i] * a[i] * {{.+}});
//CHECK-NEXT: _delta_a[i] += _t2;
//CHECK-NEXT: _final_error += _t2;
//CHECK-NEXT: }
//CHECK-NEXT: clad::array<float> _delta_b(_d_b.size());
//CHECK-NEXT: i = 0;
//CHECK-NEXT: for (; i < _d_b.size(); i++) {
//CHECK-NEXT: double _t4 = std::abs(_d_b[i] * b[i] * {{.+}});
//CHECK-NEXT: _delta_b[i] += _t4;
//CHECK-NEXT: _final_error += _t4;
//CHECK-NEXT: double _t3 = std::abs(_d_b[i] * b[i] * {{.+}});
//CHECK-NEXT: _delta_b[i] += _t3;
//CHECK-NEXT: _final_error += _t3;
//CHECK-NEXT: }
//CHECK-NEXT: _final_error += _delta_sum;
//CHECK-NEXT: }
Expand Down
16 changes: 7 additions & 9 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ double f7(double x, double y) {
//CHECK-NEXT: double _t4;
//CHECK-NEXT: double _t5;
//CHECK-NEXT: double _t6;
//CHECK-NEXT: double _t7;
//CHECK-NEXT: double t[3] = {1, x, x * x};
//CHECK-NEXT: t[0]++;
//CHECK-NEXT: t[0]--;
Expand All @@ -309,34 +308,33 @@ double f7(double x, double y) {
//CHECK-NEXT: _t3 = t[0];
//CHECK-NEXT: t[0] *= t[1];
//CHECK-NEXT: _t4 = t[0];
//CHECK-NEXT: _t5 = t[1];
//CHECK-NEXT: t[0] /= _t5;
//CHECK-NEXT: _t6 = t[0];
//CHECK-NEXT: t[0] /= t[1];
//CHECK-NEXT: _t5 = t[0];
//CHECK-NEXT: t[0] -= t[1];
//CHECK-NEXT: _t7 = x;
//CHECK-NEXT: _t6 = x;
//CHECK-NEXT: x = ++t[0];
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_t[0] += 1;
//CHECK-NEXT: {
//CHECK-NEXT: x = _t7;
//CHECK-NEXT: x = _t6;
//CHECK-NEXT: double _r_d6 = * _d_x;
//CHECK-NEXT: _d_t[0] += _r_d6;
//CHECK-NEXT: --t[0];
//CHECK-NEXT: * _d_x -= _r_d6;
//CHECK-NEXT: * _d_x;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: t[0] = _t6;
//CHECK-NEXT: t[0] = _t5;
//CHECK-NEXT: double _r_d5 = _d_t[0];
//CHECK-NEXT: _d_t[1] += -_r_d5;
//CHECK-NEXT: _d_t[0];
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: t[0] = _t4;
//CHECK-NEXT: double _r_d4 = _d_t[0];
//CHECK-NEXT: _d_t[0] += _r_d4 / _t5;
//CHECK-NEXT: double _r0 = _r_d4 * -t[0] / (_t5 * _t5);
//CHECK-NEXT: _d_t[0] += _r_d4 / t[1];
//CHECK-NEXT: double _r0 = _r_d4 * -t[0] / (t[1] * t[1]);
//CHECK-NEXT: _d_t[1] += _r0;
//CHECK-NEXT: _d_t[0] -= _r_d4;
//CHECK-NEXT: _d_t[0];
Expand Down
Loading

0 comments on commit 63d9290

Please sign in to comment.