Skip to content

Commit

Permalink
Do not store the RHS of multiplication with no side effects.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Nov 29, 2023
1 parent a3415b8 commit 3ac5587
Show file tree
Hide file tree
Showing 37 changed files with 878 additions and 1,656 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ namespace clad {
/// need to decide what needs to be stored on tape in reverse mode.
void GetInnermostReturnExpr(const clang::Expr* E,
llvm::SmallVectorImpl<clang::Expr*>& Exprs);

bool ContainsFunctionCalls(const clang::Stmt* E);
} // namespace utils
}

Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/StmtClone.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "clang/Sema/Scope.h"

#include "llvm/ADT/DenseMap.h"
#include <unordered_map>

namespace clang {
class Stmt;
Expand Down
16 changes: 16 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"
#include "clang/AST/Expr.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/Lookup.h"
#include "clad/Differentiator/Compatibility.h"
Expand Down Expand Up @@ -617,5 +618,20 @@ namespace clad {

return false;
}

bool ContainsFunctionCalls(const clang::Stmt* S) {
class CallExprFinder : public RecursiveASTVisitor<CallExprFinder> {
public:
bool hasCallExpr = false;

bool VisitCallExpr(CallExpr *CE) {
hasCallExpr = true;
return false;
}
};
CallExprFinder finder;
finder.TraverseStmt(const_cast<Stmt*>(S));
return finder.hasCallExpr;
}
} // namespace utils
} // namespace clad
27 changes: 18 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2184,29 +2184,38 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to reduce cloning complexity and only clones once. Storing it in a
// global variable allows to save current result and make it accessible
// in the reverse pass.
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
std::unique_ptr<DelayedStoreResult> RDelayed;
StmtDiff RResult;
// If R has no side effects, it can be just cloned
// (no need to store it).
if (utils::ContainsFunctionCalls(R) || R->HasSideEffects(m_Context)) {
RDelayed = std::unique_ptr<DelayedStoreResult>(new DelayedStoreResult(DelayedGlobalStoreAndRef(R)));
RResult = RDelayed->Result;
} else {
RResult = StmtDiff(Clone(R));
}

Expr* dl = nullptr;
if (dfdx()) {
dl = BuildOp(BO_Mul, dfdx(), RResult.getExpr_dx());
dl = BuildOp(BO_Mul, dfdx(), RResult.getRevSweepAsExpr());
dl = StoreAndRef(dl, direction::reverse);
}
Ldiff = Visit(L, dl);
// dxi/xr = xl
// df/dxr += df/dxi * dxi/xr = df/dxi * xl
// Store left multiplier and assign it with L.
Expr* LStored = Ldiff.getExpr();
// RDelayed.isConstant == true implies that R is a constant expression,
// therefore we can skip visiting it.
if (!RDelayed.isConstant) {
Expr::EvalResult dummy;
if (RDelayed || !clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
Expr* dr = nullptr;
if (dfdx()) {
dr = BuildOp(BO_Mul, Ldiff.getRevSweepAsExpr(), dfdx());
dr = StoreAndRef(dr, direction::reverse);
}
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
RDelayed.Finalize(Rdiff.getExpr());
if (RDelayed)
RDelayed->Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult.getExpr());
} else if (opCode == BO_Div) {
Expand Down Expand Up @@ -2239,7 +2248,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) = std::make_pair(Ldiff.getRevSweepAsExpr(), RResult.getRevSweepAsExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(Ldiff.getExpr(), RResult.getExpr());
} else if (BinOp->isAssignmentOp()) {
if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning,
Expand Down Expand Up @@ -2966,7 +2975,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
return DelayedStoreResult{*this,
StmtDiff{Push, Pop},
StmtDiff{Push, Pop, nullptr, Pop},
/*isConstant*/ false,
/*isInsideLoop*/ true, /*pNeedsUpdate=*/ true};
}
Expand Down
45 changes: 18 additions & 27 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,10 @@ float helper(float x) {
}

// CHECK: void helper_pullback(float x, float _d_y, clad::array_ref<float> _d_x) {
// CHECK-NEXT: float _t0;
// CHECK-NEXT: _t0 = x;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = _d_y * _t0;
// CHECK-NEXT: float _r0 = _d_y * x;
// CHECK-NEXT: float _r1 = 2 * _d_y;
// CHECK-NEXT: * _d_x += _r1;
// CHECK-NEXT: }
Expand Down Expand Up @@ -208,30 +206,26 @@ double func4(double x) {
}

//CHECK: void func4_grad(double x, clad::array_ref<double> _d_x) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: clad::array<double> _d_arr(3UL);
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t2;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: double arr[3] = {x, 2 * _t0, x * _t1};
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double arr[3] = {x, 2 * x, x * x};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t2 = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _t2++;
//CHECK-NEXT: clad::push(_t3, sum);
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t2; _t2--) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: int _grad1 = 0;
Expand All @@ -243,10 +237,10 @@ double func4(double x) {
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += _d_arr[0];
//CHECK-NEXT: double _r0 = _d_arr[1] * _t0;
//CHECK-NEXT: double _r0 = _d_arr[1] * x;
//CHECK-NEXT: double _r1 = 2 * _d_arr[1];
//CHECK-NEXT: * _d_x += _r1;
//CHECK-NEXT: double _r2 = _d_arr[2] * _t1;
//CHECK-NEXT: double _r2 = _d_arr[2] * x;
//CHECK-NEXT: * _d_x += _r2;
//CHECK-NEXT: double _r3 = x * _d_arr[2];
//CHECK-NEXT: * _d_x += _r3;
Expand Down Expand Up @@ -334,15 +328,14 @@ double func6(double seed) {
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<int> _t1 = {};
//CHECK-NEXT: clad::array<double> _d_arr(3UL);
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; i++) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: double arr[3] = {seed, seed * clad::push(_t1, i), seed + i};
//CHECK-NEXT: clad::push(_t2, sum);
//CHECK-NEXT: double arr[3] = {seed, seed * i, seed + i};
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
Expand All @@ -351,7 +344,7 @@ double func6(double seed) {
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: sum = clad::pop(_t2);
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_sum += _r_d0;
//CHECK-NEXT: int _grad1 = 0;
Expand All @@ -362,7 +355,7 @@ double func6(double seed) {
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_seed += _d_arr[0];
//CHECK-NEXT: double _r0 = _d_arr[1] * clad::pop(_t1);
//CHECK-NEXT: double _r0 = _d_arr[1] * i;
//CHECK-NEXT: * _d_seed += _r0;
//CHECK-NEXT: double _r1 = seed * _d_arr[1];
//CHECK-NEXT: _d_i += _r1;
Expand All @@ -379,15 +372,13 @@ double inv_square(double *params) {

//CHECK: void inv_square_pullback(double *params, double _d_y, clad::array_ref<double> _d_params) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: _t1 = params[0];
//CHECK-NEXT: _t0 = (params[0] * _t1);
//CHECK-NEXT: _t0 = (params[0] * params[0]);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y / _t0;
//CHECK-NEXT: double _r1 = _d_y * -1 / (_t0 * _t0);
//CHECK-NEXT: double _r2 = _r1 * _t1;
//CHECK-NEXT: double _r2 = _r1 * params[0];
//CHECK-NEXT: _d_params[0] += _r2;
//CHECK-NEXT: double _r3 = params[0] * _r1;
//CHECK-NEXT: _d_params[0] += _r3;
Expand Down
12 changes: 3 additions & 9 deletions test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,20 @@ double const_dot_product(double x, double y, double z) {
//CHECK: void const_dot_product_grad(double x, double y, double z, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y, clad::array_ref<double> _d_z) {
//CHECK-NEXT: clad::array<double> _d_vars(3UL);
//CHECK-NEXT: clad::array<double> _d_consts(3UL);
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double _t2;
//CHECK-NEXT: double vars[3] = {x, y, z};
//CHECK-NEXT: double consts[3] = {1, 2, 3};
//CHECK-NEXT: _t0 = consts[0];
//CHECK-NEXT: _t1 = consts[1];
//CHECK-NEXT: _t2 = consts[2];
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 1 * _t0;
//CHECK-NEXT: double _r0 = 1 * consts[0];
//CHECK-NEXT: _d_vars[0] += _r0;
//CHECK-NEXT: double _r1 = vars[0] * 1;
//CHECK-NEXT: _d_consts[0] += _r1;
//CHECK-NEXT: double _r2 = 1 * _t1;
//CHECK-NEXT: double _r2 = 1 * consts[1];
//CHECK-NEXT: _d_vars[1] += _r2;
//CHECK-NEXT: double _r3 = vars[1] * 1;
//CHECK-NEXT: _d_consts[1] += _r3;
//CHECK-NEXT: double _r4 = 1 * _t2;
//CHECK-NEXT: double _r4 = 1 * consts[2];
//CHECK-NEXT: _d_vars[2] += _r4;
//CHECK-NEXT: double _r5 = vars[2] * 1;
//CHECK-NEXT: _d_consts[2] += _r5;
Expand Down
16 changes: 5 additions & 11 deletions test/ErrorEstimation/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,21 @@ float func2(float x, int y) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: double _delta_x = 0;
//CHECK-NEXT: float _EERepl_x0 = x;
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _t2;
//CHECK-NEXT: float _EERepl_x1;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: x = y * _t1 + x * _t2;
//CHECK-NEXT: x = y * x + x * x;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: x = _t0;
//CHECK-NEXT: float _r_d0 = * _d_x;
//CHECK-NEXT: float _r0 = _r_d0 * _t1;
//CHECK-NEXT: float _r0 = _r_d0 * x;
//CHECK-NEXT: * _d_y += _r0;
//CHECK-NEXT: float _r1 = y * _r_d0;
//CHECK-NEXT: * _d_x += _r1;
//CHECK-NEXT: float _r2 = _r_d0 * _t2;
//CHECK-NEXT: float _r2 = _r_d0 * x;
//CHECK-NEXT: * _d_x += _r2;
//CHECK-NEXT: float _r3 = x * _r_d0;
//CHECK-NEXT: * _d_x += _r3;
Expand Down Expand Up @@ -194,14 +190,12 @@ float func6(float x) { return x; }
float func7(float x, float y) { return (x * y); }

//CHECK: void func7_grad(float x, float y, clad::array_ref<float> _d_x, clad::array_ref<float> _d_y, double &_final_error) {
//CHECK-NEXT: float _t0;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: _t0 = y;
//CHECK-NEXT: _ret_value0 = (x * _t0);
//CHECK-NEXT: _ret_value0 = (x * y);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r0 = 1 * _t0;
//CHECK-NEXT: float _r0 = 1 * y;
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: float _r1 = x * 1;
//CHECK-NEXT: * _d_y += _r1;
Expand Down
Loading

0 comments on commit 3ac5587

Please sign in to comment.