From 16183dde85faaa79928be3928d3e917c0eca8b31 Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Thu, 23 May 2024 19:56:04 +0200 Subject: [PATCH] Fix condition declarations and assignments in for loops These changes fix the differentiation of variable declarations in for loop conditions that used to result into wrong derivatives. The commit also tackles the problem of having an assignment operator that affects the derivative in the for loop condition. Fixes: #273 --- lib/Differentiator/BaseForwardModeVisitor.cpp | 62 ++++++++++----- test/FirstDerivative/Loops.C | 78 ++++++++++++++++++- 2 files changed, 119 insertions(+), 21 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 5beab0cb3..7d6e0de85 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -668,18 +668,39 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { const Stmt* init = FS->getInit(); StmtDiff initDiff = init ? Visit(init) : StmtDiff{}; addToCurrentBlock(initDiff.getStmt_dx()); + + // declaration in the condition (if any) needs to be differentiated VarDecl* condVarDecl = FS->getConditionVariable(); VarDecl* condVarClone = nullptr; + DeclDiff condVarResult; + DeclStmt* condVarDeclStmt_dx = nullptr; if (condVarDecl) { - DeclDiff condVarResult = DifferentiateVarDecl(condVarDecl); + condVarResult = DifferentiateVarDecl(condVarDecl); condVarClone = condVarResult.getDecl(); if (condVarResult.getDecl_dx()) - addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx())); + condVarDeclStmt_dx = BuildDeclStmt(condVarResult.getDecl_dx()); + } + + // condition + StmtDiff condDiff = Clone(FS->getCond()); + if (Expr* cond = + condDiff + .getExpr()) { // this adds support for assignments in conditions + while (CastExpr* condCast = dyn_cast(cond)) + cond = condCast->getSubExpr(); + while (ParenExpr* condParen = dyn_cast(cond)) + cond = condParen->getSubExpr(); + if (BinaryOperator* condBO = dyn_cast(cond)) { + if (condBO->isAssignmentOp()) + condDiff = Visit(new (m_Context) ParenExpr( + noLoc, noLoc, + cond)); // if it's an assignment operator we wrap it back into + // parentheses (as it is expected to be) and then visit + } } - Expr* cond = FS->getCond() ? Clone(FS->getCond()) : nullptr; - const Expr* inc = FS->getInc(); // Differentiate the increment expression of the for loop + const Expr* inc = FS->getInc(); beginBlock(); StmtDiff incDiff = inc ? Visit(inc) : StmtDiff{}; CompoundStmt* decls = endBlock(); @@ -714,27 +735,28 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { incResult = incDiff.getExpr(); } + // build the derived for loop body const Stmt* body = FS->getBody(); beginScope(Scope::DeclScope); Stmt* bodyResult = nullptr; - if (isa(body)) { - bodyResult = Visit(body).getStmt(); - } else { - beginBlock(); - StmtDiff Result = Visit(body); - for (Stmt* S : Result.getBothStmts()) - addToCurrentBlock(S); - CompoundStmt* Block = endBlock(); - if (Block->size() == 1) - bodyResult = Block->body_front(); - else - bodyResult = Block; - } + beginBlock(); + StmtDiff bodyVisited = Visit(body); + if (condVarDeclStmt_dx) + addToCurrentBlock(condVarDeclStmt_dx); + if (condDiff.getStmt_dx()) + addToCurrentBlock(condDiff.getStmt_dx()); + for (Stmt* S : bodyVisited.getBothStmts()) + addToCurrentBlock(S); + CompoundStmt* bodyResultCmpd = endBlock(); + if (bodyResultCmpd->size() == 1) + bodyResult = bodyResultCmpd->body_front(); + else + bodyResult = bodyResultCmpd; endScope(); - Stmt* forStmtDiff = - new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, condVarClone, - incResult, bodyResult, noLoc, noLoc, noLoc); + Stmt* forStmtDiff = new (m_Context) + ForStmt(m_Context, initDiff.getStmt(), condDiff.getExpr(), condVarClone, + incResult, bodyResult, noLoc, noLoc, noLoc); addToCurrentBlock(forStmtDiff); CompoundStmt* Block = endBlock(); diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 03257ff34..a981d412e 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -377,8 +377,9 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: double res = 0; // CHECK-NEXT: { // CHECK-NEXT: size_t _d_count = 0; -// CHECK-NEXT: size_t _d_max_count = _d_n; // CHECK-NEXT: for (size_t count = 0; {{.*}}max_count{{.*}}; ++count) { +// CHECK-NEXT: size_t _d_max_count = _d_n; +// CHECK-NEXT: { // CHECK-NEXT: if (count >= max_count) // CHECK-NEXT: break; // CHECK-NEXT: { @@ -388,11 +389,75 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: res += y * y; // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +double fn11(double x, double y) { + double r = 0; + for (int i = 0; (r = x); ++i) { + if (i == 3) break; + r += x; + } + return r; +} // fn11(x,y) == x + +double fn11_darg0(double x, double y); +// CHECK: double fn11_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: for (int i = 0; (r = x); ++i) { +// CHECK-NEXT: (_d_r = _d_x); +// CHECK-NEXT: { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_r += _d_x; +// CHECK-NEXT: r += x; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + +double fn12(double x, double y) { + double r = 0; + for (int i = 0; double c = x; ++i) { + if (i == 3) break; + c += x; + r = c; + } + return r; +} // fn11(x,y) == 2*x + +double fn12_darg0(double x, double y); +// CHECK: double fn12_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: for (int i = 0; {{.*}}c{{.*}}; ++i) { +// CHECK-NEXT: double _d_c = _d_x; +// CHECK-NEXT: { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_c += _d_x; +// CHECK-NEXT: c += x; +// CHECK-NEXT: _d_r = _d_c; +// CHECK-NEXT: r = c; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ printf("%.2f\n", d_##fn.execute(3, 5)); @@ -430,4 +495,15 @@ int main() { clad::differentiate(fn10, 0); printf("Result is = %.2f\n", fn10_darg0(3, 5)); // CHECK-EXEC: Result is = 30.00 + + clad::differentiate(fn11, 0); + printf("Result is = %.2f\n", fn11_darg0(3, 5)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(-3, 6)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(1, 5)); // CHECK-EXEC: Result is = 1.00 + + clad::differentiate(fn12, 0); + printf("Result is = %.2f\n", fn12_darg0(3, 5)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(-3, 6)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(1, 5)); // CHECK-EXEC: Result is = 2.00 + }