From 4ebd1af09b5eee8d8e1fa5590d2f9cc112ca3362 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 12 Jun 2024 00:43:40 +0300 Subject: [PATCH] Emit clad::pop after clad::back when differentiating binary operators in the reverse mode. Fixes #927 --- lib/Differentiator/ReverseModeVisitor.cpp | 18 +++++++++ test/Gradient/Loops.C | 48 +++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 759abe28a..03412af37 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2096,9 +2096,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // df/dxr += df/dxi * dxi/xr = df/dxi * xl // Store left multiplier and assign it with L. StmtDiff LStored = Ldiff; + // Catch the pop statement and emit it after + // the LStored value is used. + // This workaround is necessary because GlobalStoreAndRef + // is designed to work with the reversed order of statements + // in the reverse sweep and in RMV::VisitBinaryOperator + // the order is not reversed. + beginBlock(direction::reverse); if (!ShouldRecompute(LStored.getExpr())) LStored = GlobalStoreAndRef(LStored.getExpr(), /*prefix=*/"_t", /*force=*/true); + Stmt* LPop = endBlock(direction::reverse); Expr::EvalResult dummy; if (RDelayed || !clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) { @@ -2110,6 +2118,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (RDelayed) RDelayed->Finalize(Rdiff.getExpr()); } + addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse); std::tie(Ldiff, Rdiff) = std::make_pair(LStored.getExpr(), RResult.getExpr()); } else if (opCode == BO_Div) { @@ -2125,9 +2134,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, dl = BuildOp(BO_Div, dfdx(), RStored); Ldiff = Visit(L, dl); StmtDiff LStored = Ldiff; + // Catch the pop statement and emit it after + // the LStored value is used. + // This workaround is necessary because GlobalStoreAndRef + // is designed to work with the reversed order of statements + // in the reverse sweep and in RMV::VisitBinaryOperator + // the order is not reversed. + beginBlock(direction::reverse); if (!ShouldRecompute(LStored.getExpr())) LStored = GlobalStoreAndRef(LStored.getExpr(), /*prefix=*/"_t", /*force=*/true); + Stmt* LPop = endBlock(direction::reverse); // dxi/xr = -xl / (xr * xr) // df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr)) // Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is @@ -2145,6 +2162,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Rdiff = Visit(R, dr); RDelayed.Finalize(Rdiff.getExpr()); } + addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse); std::tie(Ldiff, Rdiff) = std::make_pair(LStored.getExpr(), RResult.getExpr()); } else if (BinOp->isAssignmentOp()) { diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index c2869818a..05c8c755c 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -1725,6 +1725,53 @@ double fn21(double x) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn22(double param) { + double out = 0.0; + for (int i = 0; i < 1; i++) { + double arr[] = {1.}; + out += arr[0] * param; + } + return out; +} + + +// CHECK: void fn22_grad(double param, double *_d_param) { +// CHECK-NEXT: double _d_out = 0; +// CHECK-NEXT: unsigned {{int|long}} _t0; +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: int i = 0; +// CHECK-NEXT: clad::tape > _t1 = {}; +// CHECK-NEXT: double _d_arr[1] = {0}; +// CHECK-NEXT: clad::array arr({{1U|1UL}}); +// CHECK-NEXT: clad::tape _t2 = {}; +// CHECK-NEXT: clad::tape _t3 = {}; +// CHECK-NEXT: double out = 0.; +// CHECK-NEXT: _t0 = {{0U|0UL}}; +// CHECK-NEXT: for (i = 0; i < 1; i++) { +// CHECK-NEXT: _t0++; +// CHECK-NEXT: clad::push(_t1, arr) , arr = {1.}; +// CHECK-NEXT: clad::push(_t2, out); +// CHECK-NEXT: clad::push(_t3, arr[0]); +// CHECK-NEXT: out += clad::back(_t3) * param; +// CHECK-NEXT: } +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: _d_out += 1; +// CHECK-NEXT: for (; _t0; _t0--) { +// CHECK-NEXT: i--; +// CHECK-NEXT: { +// CHECK-NEXT: out = clad::pop(_t2); +// CHECK-NEXT: double _r_d0 = _d_out; +// CHECK-NEXT: _d_arr[0] += _r_d0 * param; +// CHECK-NEXT: *_d_param += clad::back(_t3) * _r_d0; +// CHECK-NEXT: clad::pop(_t3); +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: clad::zero_init(_d_arr); +// CHECK-NEXT: arr = clad::pop(_t1); +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } #define TEST(F, x) { \ result[0] = 0; \ @@ -1797,6 +1844,7 @@ int main() { printf("{%.2f, %.2f, %.2f, %.2f, %.2f}\n", result[0], result[1], result[2], result[3], result[4]); // CHECK-EXEC: {5.00, 5.00, 5.00, 5.00, 5.00} TEST(fn21, 5); // CHECK-EXEC: {5.00} + TEST(fn22, 5); // CHECK-EXEC: {1.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {