diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 5beab0cb3..361434231 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -668,18 +668,37 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { const Stmt* init = FS->getInit(); StmtDiff initDiff = init ? Visit(init) : StmtDiff{}; addToCurrentBlock(initDiff.getStmt_dx()); + + // declaration in the condition (if any) needs to be differentiated VarDecl* condVarDecl = FS->getConditionVariable(); VarDecl* condVarClone = nullptr; + DeclDiff condVarResult; + DeclStmt* condVarDeclStmt_dx = nullptr; if (condVarDecl) { - DeclDiff condVarResult = DifferentiateVarDecl(condVarDecl); + condVarResult = DifferentiateVarDecl(condVarDecl); condVarClone = condVarResult.getDecl(); if (condVarResult.getDecl_dx()) - addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx())); + condVarDeclStmt_dx = BuildDeclStmt(condVarResult.getDecl_dx()); + } + + // condition + StmtDiff condDiff = Clone(FS->getCond()); + // this adds support for assignments in conditions + if (Expr* cond = condDiff.getExpr()) { + while (CastExpr* condCast = dyn_cast(cond)) + cond = condCast->getSubExpr(); + while (ParenExpr* condParen = dyn_cast(cond)) + cond = condParen->getSubExpr(); + if (BinaryOperator* condBO = dyn_cast(cond)) { + // if it's an assignment operator we wrap it back into + // parentheses (as it is expected to be) and then visit + if (condBO->isAssignmentOp()) + condDiff = Visit(new (m_Context) ParenExpr(noLoc, noLoc, cond)); + } } - Expr* cond = FS->getCond() ? Clone(FS->getCond()) : nullptr; - const Expr* inc = FS->getInc(); // Differentiate the increment expression of the for loop + const Expr* inc = FS->getInc(); beginBlock(); StmtDiff incDiff = inc ? Visit(inc) : StmtDiff{}; CompoundStmt* decls = endBlock(); @@ -714,27 +733,28 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { incResult = incDiff.getExpr(); } + // build the derived for loop body const Stmt* body = FS->getBody(); beginScope(Scope::DeclScope); Stmt* bodyResult = nullptr; - if (isa(body)) { - bodyResult = Visit(body).getStmt(); - } else { - beginBlock(); - StmtDiff Result = Visit(body); - for (Stmt* S : Result.getBothStmts()) - addToCurrentBlock(S); - CompoundStmt* Block = endBlock(); - if (Block->size() == 1) - bodyResult = Block->body_front(); - else - bodyResult = Block; - } + beginBlock(); + StmtDiff bodyVisited = Visit(body); + if (condVarDeclStmt_dx) + addToCurrentBlock(condVarDeclStmt_dx); + if (condDiff.getStmt_dx()) + addToCurrentBlock(condDiff.getStmt_dx()); + for (Stmt* S : bodyVisited.getBothStmts()) + addToCurrentBlock(S); + CompoundStmt* bodyResultCmpd = endBlock(); + if (bodyResultCmpd->size() == 1) + bodyResult = bodyResultCmpd->body_front(); + else + bodyResult = bodyResultCmpd; endScope(); - Stmt* forStmtDiff = - new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, condVarClone, - incResult, bodyResult, noLoc, noLoc, noLoc); + Stmt* forStmtDiff = new (m_Context) + ForStmt(m_Context, initDiff.getStmt(), condDiff.getExpr(), condVarClone, + incResult, bodyResult, noLoc, noLoc, noLoc); addToCurrentBlock(forStmtDiff); CompoundStmt* Block = endBlock(); diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 03257ff34..a981d412e 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -377,8 +377,9 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: double res = 0; // CHECK-NEXT: { // CHECK-NEXT: size_t _d_count = 0; -// CHECK-NEXT: size_t _d_max_count = _d_n; // CHECK-NEXT: for (size_t count = 0; {{.*}}max_count{{.*}}; ++count) { +// CHECK-NEXT: size_t _d_max_count = _d_n; +// CHECK-NEXT: { // CHECK-NEXT: if (count >= max_count) // CHECK-NEXT: break; // CHECK-NEXT: { @@ -388,11 +389,75 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: res += y * y; // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +double fn11(double x, double y) { + double r = 0; + for (int i = 0; (r = x); ++i) { + if (i == 3) break; + r += x; + } + return r; +} // fn11(x,y) == x + +double fn11_darg0(double x, double y); +// CHECK: double fn11_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: for (int i = 0; (r = x); ++i) { +// CHECK-NEXT: (_d_r = _d_x); +// CHECK-NEXT: { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_r += _d_x; +// CHECK-NEXT: r += x; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + +double fn12(double x, double y) { + double r = 0; + for (int i = 0; double c = x; ++i) { + if (i == 3) break; + c += x; + r = c; + } + return r; +} // fn11(x,y) == 2*x + +double fn12_darg0(double x, double y); +// CHECK: double fn12_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: for (int i = 0; {{.*}}c{{.*}}; ++i) { +// CHECK-NEXT: double _d_c = _d_x; +// CHECK-NEXT: { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_c += _d_x; +// CHECK-NEXT: c += x; +// CHECK-NEXT: _d_r = _d_c; +// CHECK-NEXT: r = c; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ printf("%.2f\n", d_##fn.execute(3, 5)); @@ -430,4 +495,15 @@ int main() { clad::differentiate(fn10, 0); printf("Result is = %.2f\n", fn10_darg0(3, 5)); // CHECK-EXEC: Result is = 30.00 + + clad::differentiate(fn11, 0); + printf("Result is = %.2f\n", fn11_darg0(3, 5)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(-3, 6)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(1, 5)); // CHECK-EXEC: Result is = 1.00 + + clad::differentiate(fn12, 0); + printf("Result is = %.2f\n", fn12_darg0(3, 5)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(-3, 6)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(1, 5)); // CHECK-EXEC: Result is = 2.00 + }