diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 5beab0cb3..a0b4f59a6 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -668,18 +668,53 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { const Stmt* init = FS->getInit(); StmtDiff initDiff = init ? Visit(init) : StmtDiff{}; addToCurrentBlock(initDiff.getStmt_dx()); - VarDecl* condVarDecl = FS->getConditionVariable(); - VarDecl* condVarClone = nullptr; - if (condVarDecl) { - DeclDiff condVarResult = DifferentiateVarDecl(condVarDecl); - condVarClone = condVarResult.getDecl(); - if (condVarResult.getDecl_dx()) - addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx())); + + StmtDiff condDiff = Clone(FS->getCond()); + Expr* cond = condDiff.getExpr(); + + // Declaration in the condition (if any) needs to be differentiated. + if (VarDecl* condVarDecl = FS->getConditionVariable()) { + VarDecl* condVarClone = + BuildVarDecl(condVarDecl->getType(), condVarDecl->getNameAsString(), + Clone(condVarDecl->getInit()), condVarDecl->isDirectInit(), + nullptr, condVarDecl->getInitStyle()); + VarDecl* _d_condVarDecl = BuildVarDecl( + condVarClone->getType(), "_d_" + condVarClone->getNameAsString(), + nullptr, condVarClone->isDirectInit(), nullptr, + condVarClone->getInitStyle()); + m_Variables.emplace(condVarClone, BuildDeclRef(_d_condVarDecl)); + // Here we create a fictional cond that is equal to the assignment used in + // the declaration. The declaration itself is thrown before the for-loop + // without any init value. The fictional condition is then differentiated as + // normal condition (see below). + cond = + BuildOp(BO_Assign, BuildDeclRef(condVarClone), condVarClone->getInit()); + condVarClone->setInit(nullptr); + addToCurrentBlock(BuildDeclStmt(condVarClone)); + } + + // Condition differentiation. + // This adds support for assignments in conditions. + if (cond) { + cond = cond->IgnoreParenImpCasts(); + // If it's a supported differentiable operator we wrap it back into + // parentheses and then visit. To ensure the correctness, a comma operator + // expression (cond_dx, cond) is generated and put instead of the condition. + // FIXME: Add support for other expressions in cond (unary operators, + // comparisons, function calls, etc.). Ideally, we should be able to simply + // always call Visit(cond) + BinaryOperator* condBO = dyn_cast(cond); + if (condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) { + condDiff = Visit(cond); + cond = (condDiff.getExpr_dx() && !isUnusedResult(condDiff.getExpr_dx())) + ? BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), + BuildParens(condDiff.getExpr())) + : condDiff.getExpr(); + } } - Expr* cond = FS->getCond() ? Clone(FS->getCond()) : nullptr; - const Expr* inc = FS->getInc(); // Differentiate the increment expression of the for loop + const Expr* inc = FS->getInc(); beginBlock(); StmtDiff incDiff = inc ? Visit(inc) : StmtDiff{}; CompoundStmt* decls = endBlock(); @@ -714,26 +749,23 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { incResult = incDiff.getExpr(); } + // Build the derived for loop body. const Stmt* body = FS->getBody(); beginScope(Scope::DeclScope); Stmt* bodyResult = nullptr; - if (isa(body)) { - bodyResult = Visit(body).getStmt(); - } else { - beginBlock(); - StmtDiff Result = Visit(body); - for (Stmt* S : Result.getBothStmts()) - addToCurrentBlock(S); - CompoundStmt* Block = endBlock(); - if (Block->size() == 1) - bodyResult = Block->body_front(); - else - bodyResult = Block; - } + beginBlock(); + StmtDiff bodyVisited = Visit(body); + for (Stmt* S : bodyVisited.getBothStmts()) + addToCurrentBlock(S); + CompoundStmt* bodyResultCmpd = endBlock(); + if (bodyResultCmpd->size() == 1) + bodyResult = bodyResultCmpd->body_front(); + else + bodyResult = bodyResultCmpd; endScope(); Stmt* forStmtDiff = - new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, condVarClone, + new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, nullptr, incResult, bodyResult, noLoc, noLoc, noLoc); addToCurrentBlock(forStmtDiff); @@ -751,6 +783,15 @@ StmtDiff BaseForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { return nullptr; StmtDiff retValDiff = Visit(RS->getRetValue()); + // llvm::errs() << "\n\n\n\n\nretdx: \n"; + // retValDiff.getExpr_dx()->dump(); + // llvm::errs() << "\n\n\n"; + // retValDiff.getExpr_dx()->dumpPretty(m_Context); + // llvm::errs() << "\n\n\n\n\nret: \n"; + // retValDiff.getExpr()->dump(); + // llvm::errs() << "\n\n\n"; + // retValDiff.getExpr()->dumpPretty(m_Context); + // llvm::errs() << "\n\n\n\n\n"; Stmt* returnStmt = m_Sema.ActOnReturnStmt(noLoc, retValDiff.getExpr_dx(), getCurrentScope()) .get(); @@ -1366,6 +1407,26 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { } else opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()), BuildParens(Rdiff.getExpr_dx())); + } else if (BinOp->isLogicalOp() || BinOp->isComparisonOp()) { + // For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and + // correct derivative execution. dLL = (dL, L) + Expr* dLL = + BuildParens(((Ldiff.getExpr_dx() && !isUnusedResult(Ldiff.getExpr_dx())) + ? BuildOp(BO_Comma, BuildParens(Ldiff.getExpr_dx()), + BuildParens(Ldiff.getExpr())) + : Ldiff.getExpr())); + // dRR = (dR, R) + Expr* dRR = + BuildParens(((Rdiff.getExpr_dx() && !isUnusedResult(Rdiff.getExpr_dx())) + ? BuildOp(BO_Comma, BuildParens(Rdiff.getExpr_dx()), + BuildParens(Rdiff.getExpr())) + : Rdiff.getExpr())); + opDiff = BuildOp(opCode, dLL, dRR); + + // Since the both parts are included in the opDiff, there's no point in + // including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left + // nullptr is used for treating expressions like ((A && B) && C) correctly. + return StmtDiff(opDiff, nullptr); } if (!opDiff) { // FIXME: add support for other binary operators diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 03257ff34..03fddc7a9 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -377,8 +377,8 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: double res = 0; // CHECK-NEXT: { // CHECK-NEXT: size_t _d_count = 0; -// CHECK-NEXT: size_t _d_max_count = _d_n; -// CHECK-NEXT: for (size_t count = 0; {{.*}}max_count{{.*}}; ++count) { +// CHECK-NEXT: size_t max_count; +// CHECK-NEXT: for (size_t count = 0; (_d_max_count = _d_n) , (max_count = n); ++count) { // CHECK-NEXT: if (count >= max_count) // CHECK-NEXT: break; // CHECK-NEXT: { @@ -393,6 +393,81 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +double fn11(double x, double y) { + double r = 0; + for (int i = 0; (r = x); ++i) { + if (i == 3) break; + r += x; + } + return r; +} // fn11(x,y) == x + +double fn11_darg0(double x, double y); +// CHECK: double fn11_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: for (int i = 0; (_d_r = _d_x) , (r = x); ++i) { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_r += _d_x; +// CHECK-NEXT: r += x; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + +double fn12(double x, double y) { + double r = 0; + for (int i = 0; double c = x; ++i) { + if (i == 3) break; + c += x; + r = c; + } + return r; +} // fn11(x,y) == 2*x + +double fn12_darg0(double x, double y); +// CHECK: double fn12_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: double c; +// CHECK-NEXT: for (int i = 0; (_d_c = _d_x) , (c = x); ++i) { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_c += _d_x; +// CHECK-NEXT: c += x; +// CHECK-NEXT: _d_r = _d_c; +// CHECK-NEXT: r = c; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + +double fn13(double u, double v) { + double res = 0; + for (; (res = u * v) && (u = 0) ;) {} + return res; +} // = u*v + +double fn13_darg0(double u, double v); +// CHECK: double fn13_darg0(double u, double v) { +// CHECK-NEXT: double _d_u = 1; +// CHECK-NEXT: double _d_v = 0; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: for (; ((_d_res = _d_u * v + u * _d_v) , (res = u * v)) && ((_d_u = 0) , (u = 0));) { +// CHECK-NEXT: } +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ printf("%.2f\n", d_##fn.execute(3, 5)); @@ -430,4 +505,20 @@ int main() { clad::differentiate(fn10, 0); printf("Result is = %.2f\n", fn10_darg0(3, 5)); // CHECK-EXEC: Result is = 30.00 + + clad::differentiate(fn11, 0); + printf("Result is = %.2f\n", fn11_darg0(3, 5)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(-3, 6)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(1, 5)); // CHECK-EXEC: Result is = 1.00 + + clad::differentiate(fn12, 0); + printf("Result is = %.2f\n", fn12_darg0(3, 5)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(-3, 6)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(1, 5)); // CHECK-EXEC: Result is = 2.00 + + clad::differentiate(fn13, 0); + printf("Result is = %.2f\n", fn13_darg0(3, 4)); // CHECK-EXEC: Result is = 4.00 + printf("Result is = %.2f\n", fn13_darg0(-3, 5)); // CHECK-EXEC: Result is = 5.00 + printf("Result is = %.2f\n", fn13_darg0(1, 6)); // CHECK-EXEC: Result is = 6.00 + }