Skip to content

Commit

Permalink
Emit clad::pop after clad::back when differentiating binary operators…
Browse files Browse the repository at this point in the history
… in the reverse mode. Fixes vgvassilev#927
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Jun 14, 2024
1 parent 1cdb738 commit 4ebd1af
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
18 changes: 18 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2096,9 +2096,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// df/dxr += df/dxi * dxi/xr = df/dxi * xl
// Store left multiplier and assign it with L.
StmtDiff LStored = Ldiff;
// Catch the pop statement and emit it after
// the LStored value is used.
// This workaround is necessary because GlobalStoreAndRef
// is designed to work with the reversed order of statements
// in the reverse sweep and in RMV::VisitBinaryOperator
// the order is not reversed.
beginBlock(direction::reverse);
if (!ShouldRecompute(LStored.getExpr()))
LStored = GlobalStoreAndRef(LStored.getExpr(), /*prefix=*/"_t",
/*force=*/true);
Stmt* LPop = endBlock(direction::reverse);
Expr::EvalResult dummy;
if (RDelayed ||
!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
Expand All @@ -2110,6 +2118,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (RDelayed)
RDelayed->Finalize(Rdiff.getExpr());
}
addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse);
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
} else if (opCode == BO_Div) {
Expand All @@ -2125,9 +2134,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
dl = BuildOp(BO_Div, dfdx(), RStored);
Ldiff = Visit(L, dl);
StmtDiff LStored = Ldiff;
// Catch the pop statement and emit it after
// the LStored value is used.
// This workaround is necessary because GlobalStoreAndRef
// is designed to work with the reversed order of statements
// in the reverse sweep and in RMV::VisitBinaryOperator
// the order is not reversed.
beginBlock(direction::reverse);
if (!ShouldRecompute(LStored.getExpr()))
LStored = GlobalStoreAndRef(LStored.getExpr(), /*prefix=*/"_t",
/*force=*/true);
Stmt* LPop = endBlock(direction::reverse);
// dxi/xr = -xl / (xr * xr)
// df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr))
// Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is
Expand All @@ -2145,6 +2162,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse);
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
} else if (BinOp->isAssignmentOp()) {
Expand Down
48 changes: 48 additions & 0 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -1725,6 +1725,53 @@ double fn21(double x) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn22(double param) {
double out = 0.0;
for (int i = 0; i < 1; i++) {
double arr[] = {1.};
out += arr[0] * param;
}
return out;
}


// CHECK: void fn22_grad(double param, double *_d_param) {
// CHECK-NEXT: double _d_out = 0;
// CHECK-NEXT: unsigned {{int|long}} _t0;
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: int i = 0;
// CHECK-NEXT: clad::tape<clad::array<double> > _t1 = {};
// CHECK-NEXT: double _d_arr[1] = {0};
// CHECK-NEXT: clad::array<double> arr({{1U|1UL}});
// CHECK-NEXT: clad::tape<double> _t2 = {};
// CHECK-NEXT: clad::tape<double> _t3 = {};
// CHECK-NEXT: double out = 0.;
// CHECK-NEXT: _t0 = {{0U|0UL}};
// CHECK-NEXT: for (i = 0; i < 1; i++) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: clad::push(_t1, arr) , arr = {1.};
// CHECK-NEXT: clad::push(_t2, out);
// CHECK-NEXT: clad::push(_t3, arr[0]);
// CHECK-NEXT: out += clad::back(_t3) * param;
// CHECK-NEXT: }
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _d_out += 1;
// CHECK-NEXT: for (; _t0; _t0--) {
// CHECK-NEXT: i--;
// CHECK-NEXT: {
// CHECK-NEXT: out = clad::pop(_t2);
// CHECK-NEXT: double _r_d0 = _d_out;
// CHECK-NEXT: _d_arr[0] += _r_d0 * param;
// CHECK-NEXT: *_d_param += clad::back(_t3) * _r_d0;
// CHECK-NEXT: clad::pop(_t3);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::zero_init(_d_arr);
// CHECK-NEXT: arr = clad::pop(_t1);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

#define TEST(F, x) { \
result[0] = 0; \
Expand Down Expand Up @@ -1797,6 +1844,7 @@ int main() {
printf("{%.2f, %.2f, %.2f, %.2f, %.2f}\n", result[0], result[1], result[2], result[3], result[4]); // CHECK-EXEC: {5.00, 5.00, 5.00, 5.00, 5.00}

TEST(fn21, 5); // CHECK-EXEC: {5.00}
TEST(fn22, 5); // CHECK-EXEC: {1.00}
}

//CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {
Expand Down

0 comments on commit 4ebd1af

Please sign in to comment.