diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index ae2d26a99..a784f7e21 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -27,6 +27,10 @@ namespace clad { /// function `FD`. std::string ComputeEffectiveFnName(const clang::FunctionDecl* FD); + // Unwraps S to a single statement if it's a compound statement only + // containing 1 statement. + clang::Stmt* unwrapIfSingleStmt(clang::Stmt* S); + /// Creates and returns a compound statement having statements as follows: /// {`S`, all the statement of `initial` in sequence} clang::CompoundStmt* PrependAndCreateCompoundStmt(clang::ASTContext& C, diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 5beab0cb3..43d5a8f8d 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -668,18 +668,64 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { const Stmt* init = FS->getInit(); StmtDiff initDiff = init ? Visit(init) : StmtDiff{}; addToCurrentBlock(initDiff.getStmt_dx()); - VarDecl* condVarDecl = FS->getConditionVariable(); - VarDecl* condVarClone = nullptr; - if (condVarDecl) { - DeclDiff condVarResult = DifferentiateVarDecl(condVarDecl); - condVarClone = condVarResult.getDecl(); - if (condVarResult.getDecl_dx()) - addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx())); + + StmtDiff condDiff = Clone(FS->getCond()); + Expr* cond = condDiff.getExpr(); + + // Declaration in the condition (if any) needs to be differentiated. + if (VarDecl* condVarDecl = FS->getConditionVariable()) { + VarDecl* condVarClone = + BuildVarDecl(condVarDecl->getType(), condVarDecl->getNameAsString(), + Clone(condVarDecl->getInit()), condVarDecl->isDirectInit(), + nullptr, condVarDecl->getInitStyle()); + VarDecl* _d_condVarDecl = BuildVarDecl( + condVarClone->getType(), "_d_" + condVarClone->getNameAsString(), + /*Init=*/nullptr, condVarClone->isDirectInit(), nullptr, + condVarClone->getInitStyle()); + m_Variables.emplace(condVarClone, BuildDeclRef(_d_condVarDecl)); + // Here we create a fictional cond that is equal to the assignment used in + // the declaration. The declaration itself is thrown before the for-loop + // without any init value. The fictional condition is then differentiated as + // a normal condition would be (see below). For example, the declaration + // inside `for (;double t = x;) {}` will be first processed into the + // following code: + // ``` + // { + // double t; + // for (;t = x;) {} + // } + // ``` + // which will then get differentiated normally as a for-loop with a + // differentiable condition in the next section. + auto condInit = condVarClone->getInit(); + condVarClone->setInit(nullptr); + cond = BuildOp(BO_Assign, BuildDeclRef(condVarClone), condInit); + addToCurrentBlock(BuildDeclStmt(condVarClone)); + } + + // Condition differentiation. + // This adds support for assignments in conditions. + if (cond) { + cond = cond->IgnoreParenImpCasts(); + // If it's a supported differentiable operator we wrap it back into + // parentheses and then visit. To ensure the correctness, a comma operator + // expression (cond_dx, cond) is generated and put instead of the condition. + // FIXME: Add support for other expressions in cond (unary operators, + // comparisons, function calls, etc.). Ideally, we should be able to simply + // always call Visit(cond) + BinaryOperator* condBO = dyn_cast(cond); + if (condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) { + condDiff = Visit(cond); + if (condDiff.getExpr_dx() && !isUnusedResult(condDiff.getExpr_dx())) + cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), + BuildParens(condDiff.getExpr())); + else + cond = condDiff.getExpr(); + } } - Expr* cond = FS->getCond() ? Clone(FS->getCond()) : nullptr; - const Expr* inc = FS->getInc(); // Differentiate the increment expression of the for loop + const Expr* inc = FS->getInc(); beginBlock(); StmtDiff incDiff = inc ? Visit(inc) : StmtDiff{}; CompoundStmt* decls = endBlock(); @@ -714,27 +760,20 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { incResult = incDiff.getExpr(); } + // Build the derived for loop body. const Stmt* body = FS->getBody(); beginScope(Scope::DeclScope); Stmt* bodyResult = nullptr; - if (isa(body)) { - bodyResult = Visit(body).getStmt(); - } else { - beginBlock(); - StmtDiff Result = Visit(body); - for (Stmt* S : Result.getBothStmts()) - addToCurrentBlock(S); - CompoundStmt* Block = endBlock(); - if (Block->size() == 1) - bodyResult = Block->body_front(); - else - bodyResult = Block; - } + beginBlock(); + StmtDiff bodyVisited = Visit(body); + for (Stmt* S : bodyVisited.getBothStmts()) + addToCurrentBlock(S); + bodyResult = utils::unwrapIfSingleStmt(endBlock()); endScope(); - Stmt* forStmtDiff = - new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, condVarClone, - incResult, bodyResult, noLoc, noLoc, noLoc); + Stmt* forStmtDiff = new (m_Context) + ForStmt(m_Context, initDiff.getStmt(), cond, /*condVar=*/nullptr, + incResult, bodyResult, noLoc, noLoc, noLoc); addToCurrentBlock(forStmtDiff); CompoundStmt* Block = endBlock(); @@ -1366,6 +1405,26 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { } else opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()), BuildParens(Rdiff.getExpr_dx())); + } else if (BinOp->isLogicalOp()) { + // For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and + // correct derivative execution. + auto buildOneSide = [this](StmtDiff& Xdiff) { + if (Xdiff.getExpr_dx() && !isUnusedResult(Xdiff.getExpr_dx())) + return BuildParens(BuildOp(BO_Comma, BuildParens(Xdiff.getExpr_dx()), + BuildParens(Xdiff.getExpr()))); + else + return BuildParens(Xdiff.getExpr()); + }; + // dLL = (dL, L) + Expr* dLL = buildOneSide(Ldiff); + // dRR = (dR, R) + Expr* dRR = buildOneSide(Rdiff); + opDiff = BuildOp(opCode, dLL, dRR); + + // Since the both parts are included in the opDiff, there's no point in + // including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left + // nullptr is used for treating expressions like ((A && B) && C) correctly. + return StmtDiff(opDiff, nullptr); } if (!opDiff) { // FIXME: add support for other binary operators diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index c180d0b1a..831bbcb08 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -102,6 +102,19 @@ namespace clad { } } + Stmt* unwrapIfSingleStmt(Stmt* S) { + if (!S) + return nullptr; + if (!isa(S)) + return S; + auto* CS = cast(S); + if (CS->size() == 0) + return nullptr; + if (CS->size() == 1) + return CS->body_front(); + return CS; + } + CompoundStmt* PrependAndCreateCompoundStmt(ASTContext& C, Stmt* initial, Stmt* S) { llvm::SmallVector block; diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index f9ec442d8..4aba32332 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -783,19 +783,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Forward, Reverse); } - static Stmt* unwrapIfSingleStmt(Stmt* S) { - if (!S) - return nullptr; - if (!isa(S)) - return S; - auto* CS = cast(S); - if (CS->size() == 0) - return nullptr; - if (CS->size() == 1) - return CS->body_front(); - return CS; - } - StmtDiff ReverseModeVisitor::VisitIfStmt(const clang::IfStmt* If) { // Control scope of the IfStmt. E.g., in if (double x = ...) {...}, x goes // to this scope. @@ -888,8 +875,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource ->ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt(); - Stmt* Forward = unwrapIfSingleStmt(endBlock(direction::forward)); - Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); + Stmt* Forward = utils::unwrapIfSingleStmt(endBlock(direction::forward)); + Stmt* Reverse = utils::unwrapIfSingleStmt(BranchDiff.getStmt_dx()); return StmtDiff(Forward, Reverse); }; @@ -914,8 +901,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CompoundStmt* ForwardBlock = endBlock(direction::forward); CompoundStmt* ReverseBlock = endBlock(direction::reverse); endScope(); - return StmtDiff(unwrapIfSingleStmt(ForwardBlock), - unwrapIfSingleStmt(ReverseBlock)); + return StmtDiff(utils::unwrapIfSingleStmt(ForwardBlock), + utils::unwrapIfSingleStmt(ReverseBlock)); } StmtDiff ReverseModeVisitor::VisitConditionalOperator( @@ -941,8 +928,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto Result = DifferentiateSingleExpr(Branch, dfdx); StmtDiff BranchDiff = Result.first; StmtDiff ExprDiff = Result.second; - Stmt* Forward = unwrapIfSingleStmt(BranchDiff.getStmt()); - Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); + Stmt* Forward = utils::unwrapIfSingleStmt(BranchDiff.getStmt()); + Stmt* Reverse = utils::unwrapIfSingleStmt(BranchDiff.getStmt_dx()); return {StmtDiff(Forward, Reverse), ExprDiff}; }; @@ -1128,7 +1115,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Reverse = endBlock(direction::reverse); endScope(); - return {unwrapIfSingleStmt(Forward), unwrapIfSingleStmt(Reverse)}; + return {utils::unwrapIfSingleStmt(Forward), + utils::unwrapIfSingleStmt(Reverse)}; } StmtDiff @@ -2766,7 +2754,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse); CompoundStmt* RCS = endBlock(direction::reverse); std::reverse(RCS->body_begin(), RCS->body_end()); - Stmt* ReverseResult = unwrapIfSingleStmt(RCS); + Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS); return StmtDiff(SDiff.getStmt(), ReverseResult); } @@ -2780,7 +2768,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CompoundStmt* RCS = endBlock(direction::reverse); Stmt* ForwardResult = endBlock(direction::forward); std::reverse(RCS->body_begin(), RCS->body_end()); - Stmt* ReverseResult = unwrapIfSingleStmt(RCS); + Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS); return {StmtDiff(ForwardResult, ReverseResult), EDiff}; } @@ -2929,7 +2917,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (Decl* decl : decls) addToBlock(BuildDeclStmt(decl), m_Globals); Stmt* initAssignments = MakeCompoundStmt(inits); - initAssignments = unwrapIfSingleStmt(initAssignments); + initAssignments = utils::unwrapIfSingleStmt(initAssignments); return StmtDiff(initAssignments); } @@ -3574,7 +3562,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterProcessingSingleStmtBodyInVisitForLoop(); - Stmt* reverseBlock = unwrapIfSingleStmt(bodyDiff.getStmt_dx()); + Stmt* reverseBlock = utils::unwrapIfSingleStmt(bodyDiff.getStmt_dx()); bodyDiff = {endBlock(direction::forward), reverseBlock}; // for forward-pass loop statement body endScope(); @@ -3619,7 +3607,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(condVarDiff, direction::reverse); addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse); bodyDiff = {bodyDiff.getStmt(), - unwrapIfSingleStmt(endBlock(direction::reverse))}; + utils::unwrapIfSingleStmt(endBlock(direction::reverse))}; return bodyDiff; } diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 03257ff34..03fddc7a9 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -377,8 +377,8 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: double res = 0; // CHECK-NEXT: { // CHECK-NEXT: size_t _d_count = 0; -// CHECK-NEXT: size_t _d_max_count = _d_n; -// CHECK-NEXT: for (size_t count = 0; {{.*}}max_count{{.*}}; ++count) { +// CHECK-NEXT: size_t max_count; +// CHECK-NEXT: for (size_t count = 0; (_d_max_count = _d_n) , (max_count = n); ++count) { // CHECK-NEXT: if (count >= max_count) // CHECK-NEXT: break; // CHECK-NEXT: { @@ -393,6 +393,81 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +double fn11(double x, double y) { + double r = 0; + for (int i = 0; (r = x); ++i) { + if (i == 3) break; + r += x; + } + return r; +} // fn11(x,y) == x + +double fn11_darg0(double x, double y); +// CHECK: double fn11_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: for (int i = 0; (_d_r = _d_x) , (r = x); ++i) { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_r += _d_x; +// CHECK-NEXT: r += x; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + +double fn12(double x, double y) { + double r = 0; + for (int i = 0; double c = x; ++i) { + if (i == 3) break; + c += x; + r = c; + } + return r; +} // fn11(x,y) == 2*x + +double fn12_darg0(double x, double y); +// CHECK: double fn12_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: double c; +// CHECK-NEXT: for (int i = 0; (_d_c = _d_x) , (c = x); ++i) { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_c += _d_x; +// CHECK-NEXT: c += x; +// CHECK-NEXT: _d_r = _d_c; +// CHECK-NEXT: r = c; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + +double fn13(double u, double v) { + double res = 0; + for (; (res = u * v) && (u = 0) ;) {} + return res; +} // = u*v + +double fn13_darg0(double u, double v); +// CHECK: double fn13_darg0(double u, double v) { +// CHECK-NEXT: double _d_u = 1; +// CHECK-NEXT: double _d_v = 0; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: for (; ((_d_res = _d_u * v + u * _d_v) , (res = u * v)) && ((_d_u = 0) , (u = 0));) { +// CHECK-NEXT: } +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ printf("%.2f\n", d_##fn.execute(3, 5)); @@ -430,4 +505,20 @@ int main() { clad::differentiate(fn10, 0); printf("Result is = %.2f\n", fn10_darg0(3, 5)); // CHECK-EXEC: Result is = 30.00 + + clad::differentiate(fn11, 0); + printf("Result is = %.2f\n", fn11_darg0(3, 5)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(-3, 6)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(1, 5)); // CHECK-EXEC: Result is = 1.00 + + clad::differentiate(fn12, 0); + printf("Result is = %.2f\n", fn12_darg0(3, 5)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(-3, 6)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(1, 5)); // CHECK-EXEC: Result is = 2.00 + + clad::differentiate(fn13, 0); + printf("Result is = %.2f\n", fn13_darg0(3, 4)); // CHECK-EXEC: Result is = 4.00 + printf("Result is = %.2f\n", fn13_darg0(-3, 5)); // CHECK-EXEC: Result is = 5.00 + printf("Result is = %.2f\n", fn13_darg0(1, 6)); // CHECK-EXEC: Result is = 6.00 + }