diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 5beab0cb3..9d80b1c54 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -668,18 +668,50 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { const Stmt* init = FS->getInit(); StmtDiff initDiff = init ? Visit(init) : StmtDiff{}; addToCurrentBlock(initDiff.getStmt_dx()); - VarDecl* condVarDecl = FS->getConditionVariable(); - VarDecl* condVarClone = nullptr; - if (condVarDecl) { - DeclDiff condVarResult = DifferentiateVarDecl(condVarDecl); - condVarClone = condVarResult.getDecl(); - if (condVarResult.getDecl_dx()) - addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx())); + + StmtDiff condDiff = Clone(FS->getCond()); + Expr* cond = condDiff.getExpr(); + + // Declaration in the condition (if any) needs to be differentiated. + if (VarDecl* condVarDecl = FS->getConditionVariable()) { + // FIXME: here DifferentiateVarDecl is only used to clone the declaration + // properly because other ways have failed. + VarDecl* condVarClone = DifferentiateVarDecl(condVarDecl).getDecl(); + // Here we create a fictional cond that is equal to the assignment used in + // the declaration. The declaration itself is thrown before the for-loop + // without any init value. The fictional condition is then differentiated as + // normal condition (see below). + cond = + BuildOp(BO_Assign, BuildDeclRef(condVarClone), condVarClone->getInit()); + condVarClone->setInit(nullptr); + addToCurrentBlock(BuildDeclStmt(condVarClone)); + } + + // Condition differentiation. + // This adds support for assignments in conditions. + if (cond) { + while (CastExpr* condCast = dyn_cast(cond)) + cond = condCast->getSubExpr(); + while (ParenExpr* condParen = dyn_cast(cond)) + cond = condParen->getSubExpr(); + if (BinaryOperator* condBO = dyn_cast(cond)) { + // If it's an assignment operator or a binary logical operator we wrap it + // back into parentheses (as it is expected to be) and then visit. To + // ensure the correctness, a comma operator expression (cond_dx, cond) is + // generated and put instead of the condition. + if (condBO->isAssignmentOp() || condBO->getOpcode() == BO_LAnd || + condBO->getOpcode() == BO_LOr) { + condDiff = Visit(new (m_Context) ParenExpr(noLoc, noLoc, cond)); + cond = (condDiff.getExpr_dx()) + ? BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), + BuildParens(condDiff.getExpr())) + : condDiff.getExpr(); + } + } } - Expr* cond = FS->getCond() ? Clone(FS->getCond()) : nullptr; - const Expr* inc = FS->getInc(); // Differentiate the increment expression of the for loop + const Expr* inc = FS->getInc(); beginBlock(); StmtDiff incDiff = inc ? Visit(inc) : StmtDiff{}; CompoundStmt* decls = endBlock(); @@ -714,26 +746,23 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { incResult = incDiff.getExpr(); } + // Build the derived for loop body. const Stmt* body = FS->getBody(); beginScope(Scope::DeclScope); Stmt* bodyResult = nullptr; - if (isa(body)) { - bodyResult = Visit(body).getStmt(); - } else { - beginBlock(); - StmtDiff Result = Visit(body); - for (Stmt* S : Result.getBothStmts()) - addToCurrentBlock(S); - CompoundStmt* Block = endBlock(); - if (Block->size() == 1) - bodyResult = Block->body_front(); - else - bodyResult = Block; - } + beginBlock(); + StmtDiff bodyVisited = Visit(body); + for (Stmt* S : bodyVisited.getBothStmts()) + addToCurrentBlock(S); + CompoundStmt* bodyResultCmpd = endBlock(); + if (bodyResultCmpd->size() == 1) + bodyResult = bodyResultCmpd->body_front(); + else + bodyResult = bodyResultCmpd; endScope(); Stmt* forStmtDiff = - new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, condVarClone, + new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, nullptr, incResult, bodyResult, noLoc, noLoc, noLoc); addToCurrentBlock(forStmtDiff); @@ -1366,6 +1395,24 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { } else opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()), BuildParens(Rdiff.getExpr_dx())); + } else if (opCode == BO_LAnd || opCode == BO_LOr) { + // For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and + // correct derivative execution. dLL = (dL, L) + Expr* dLL = BuildParens( + (Ldiff.getExpr_dx() ? BuildOp(BO_Comma, BuildParens(Ldiff.getExpr_dx()), + BuildParens(Ldiff.getExpr())) + : Ldiff.getExpr())); + // dRR = (dR, R) + Expr* dRR = BuildParens( + (Rdiff.getExpr_dx() ? BuildOp(BO_Comma, BuildParens(Rdiff.getExpr_dx()), + BuildParens(Rdiff.getExpr())) + : Rdiff.getExpr())); + opDiff = BuildOp(opCode, dLL, dRR); + + // Since the both parts are included in the opDiff, there's no point in + // including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left + // nullptr is used for treating expressions like ((A && B) && C) correctly. + return StmtDiff(opDiff, nullptr); } if (!opDiff) { // FIXME: add support for other binary operators diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 03257ff34..c1bdd8252 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -377,8 +377,8 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: double res = 0; // CHECK-NEXT: { // CHECK-NEXT: size_t _d_count = 0; -// CHECK-NEXT: size_t _d_max_count = _d_n; -// CHECK-NEXT: for (size_t count = 0; {{.*}}max_count{{.*}}; ++count) { +// CHECK-NEXT: size_t max_count; +// CHECK-NEXT: for (size_t count = 0; (_d_max_count = _d_n) , (max_count = n); ++count) { // CHECK-NEXT: if (count >= max_count) // CHECK-NEXT: break; // CHECK-NEXT: { @@ -393,6 +393,81 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +double fn11(double x, double y) { + double r = 0; + for (int i = 0; (r = x); ++i) { + if (i == 3) break; + r += x; + } + return r; +} // fn11(x,y) == x + +double fn11_darg0(double x, double y); +// CHECK: double fn11_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: for (int i = 0; (_d_r = _d_x) , (r = x); ++i) { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_r += _d_x; +// CHECK-NEXT: r += x; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + +double fn12(double x, double y) { + double r = 0; + for (int i = 0; double c = x; ++i) { + if (i == 3) break; + c += x; + r = c; + } + return r; +} // fn11(x,y) == 2*x + +double fn12_darg0(double x, double y); +// CHECK: double fn12_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: double c; +// CHECK-NEXT: for (int i = 0; (_d_c = _d_x) , (c = x); ++i) { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_c += _d_x; +// CHECK-NEXT: c += x; +// CHECK-NEXT: _d_r = _d_c; +// CHECK-NEXT: r = c; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + +double fn13(double u, double v) { + double res = 0; + for (; (res = u * v) && (u = 0) ;) {} + return res; +} // = u*v + +double fn13_darg0(double u, double v); +// CHECK: double fn13_darg0(double u, double v) { +// CHECK-NEXT: double _d_u = 1; +// CHECK-NEXT: double _d_v = 0; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: for (; (((_d_res = _d_u * v + u * _d_v) , (res = u * v)) && ((_d_u = 0) , (u = 0)));) { +// CHECK-NEXT: } +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ printf("%.2f\n", d_##fn.execute(3, 5)); @@ -430,4 +505,19 @@ int main() { clad::differentiate(fn10, 0); printf("Result is = %.2f\n", fn10_darg0(3, 5)); // CHECK-EXEC: Result is = 30.00 + + clad::differentiate(fn11, 0); + printf("Result is = %.2f\n", fn11_darg0(3, 5)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(-3, 6)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(1, 5)); // CHECK-EXEC: Result is = 1.00 + + clad::differentiate(fn12, 0); + printf("Result is = %.2f\n", fn12_darg0(3, 5)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(-3, 6)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(1, 5)); // CHECK-EXEC: Result is = 2.00 + + clad::differentiate(fn13, 0); + printf("Result is = %.2f\n", fn13_darg0(3, 4)); // CHECK-EXEC: Result is = 4.00 + printf("Result is = %.2f\n", fn13_darg0(-3, 5)); // CHECK-EXEC: Result is = 5.00 + printf("Result is = %.2f\n", fn13_darg0(1, 6)); // CHECK-EXEC: Result is = 6.00 }