From dc4102ee89add3f2b47b5706d6a547b94766a6eb Mon Sep 17 00:00:00 2001 From: Atell Krasnopolski Date: Thu, 23 May 2024 19:56:04 +0200 Subject: [PATCH] Fix condition declarations & assignments, enable logical operators in for loops These changes fix the differentiation of variable declarations in for loop conditions that used to result into wrong derivatives. The commit also tackles the problem of having an assignment operator that affects the derivative in the for-loop condition, as well as adds support for logical operators in for-loops and allows to combine assignments with them. Fixes: #273 --- include/clad/Differentiator/CladUtils.h | 4 + lib/Differentiator/BaseForwardModeVisitor.cpp | 109 ++++++++++++++---- lib/Differentiator/CladUtils.cpp | 13 +++ lib/Differentiator/ReverseModeVisitor.cpp | 38 +++--- test/FirstDerivative/Loops.C | 95 ++++++++++++++- 5 files changed, 207 insertions(+), 52 deletions(-) diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index ae2d26a99..a784f7e21 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -27,6 +27,10 @@ namespace clad { /// function `FD`. std::string ComputeEffectiveFnName(const clang::FunctionDecl* FD); + // Unwraps S to a single statement if it's a compound statement only + // containing 1 statement. + clang::Stmt* unwrapIfSingleStmt(clang::Stmt* S); + /// Creates and returns a compound statement having statements as follows: /// {`S`, all the statement of `initial` in sequence} clang::CompoundStmt* PrependAndCreateCompoundStmt(clang::ASTContext& C, diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 5beab0cb3..43d5a8f8d 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -668,18 +668,64 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { const Stmt* init = FS->getInit(); StmtDiff initDiff = init ? Visit(init) : StmtDiff{}; addToCurrentBlock(initDiff.getStmt_dx()); - VarDecl* condVarDecl = FS->getConditionVariable(); - VarDecl* condVarClone = nullptr; - if (condVarDecl) { - DeclDiff condVarResult = DifferentiateVarDecl(condVarDecl); - condVarClone = condVarResult.getDecl(); - if (condVarResult.getDecl_dx()) - addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx())); + + StmtDiff condDiff = Clone(FS->getCond()); + Expr* cond = condDiff.getExpr(); + + // Declaration in the condition (if any) needs to be differentiated. + if (VarDecl* condVarDecl = FS->getConditionVariable()) { + VarDecl* condVarClone = + BuildVarDecl(condVarDecl->getType(), condVarDecl->getNameAsString(), + Clone(condVarDecl->getInit()), condVarDecl->isDirectInit(), + nullptr, condVarDecl->getInitStyle()); + VarDecl* _d_condVarDecl = BuildVarDecl( + condVarClone->getType(), "_d_" + condVarClone->getNameAsString(), + /*Init=*/nullptr, condVarClone->isDirectInit(), nullptr, + condVarClone->getInitStyle()); + m_Variables.emplace(condVarClone, BuildDeclRef(_d_condVarDecl)); + // Here we create a fictional cond that is equal to the assignment used in + // the declaration. The declaration itself is thrown before the for-loop + // without any init value. The fictional condition is then differentiated as + // a normal condition would be (see below). For example, the declaration + // inside `for (;double t = x;) {}` will be first processed into the + // following code: + // ``` + // { + // double t; + // for (;t = x;) {} + // } + // ``` + // which will then get differentiated normally as a for-loop with a + // differentiable condition in the next section. + auto condInit = condVarClone->getInit(); + condVarClone->setInit(nullptr); + cond = BuildOp(BO_Assign, BuildDeclRef(condVarClone), condInit); + addToCurrentBlock(BuildDeclStmt(condVarClone)); + } + + // Condition differentiation. + // This adds support for assignments in conditions. + if (cond) { + cond = cond->IgnoreParenImpCasts(); + // If it's a supported differentiable operator we wrap it back into + // parentheses and then visit. To ensure the correctness, a comma operator + // expression (cond_dx, cond) is generated and put instead of the condition. + // FIXME: Add support for other expressions in cond (unary operators, + // comparisons, function calls, etc.). Ideally, we should be able to simply + // always call Visit(cond) + BinaryOperator* condBO = dyn_cast(cond); + if (condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) { + condDiff = Visit(cond); + if (condDiff.getExpr_dx() && !isUnusedResult(condDiff.getExpr_dx())) + cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), + BuildParens(condDiff.getExpr())); + else + cond = condDiff.getExpr(); + } } - Expr* cond = FS->getCond() ? Clone(FS->getCond()) : nullptr; - const Expr* inc = FS->getInc(); // Differentiate the increment expression of the for loop + const Expr* inc = FS->getInc(); beginBlock(); StmtDiff incDiff = inc ? Visit(inc) : StmtDiff{}; CompoundStmt* decls = endBlock(); @@ -714,27 +760,20 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { incResult = incDiff.getExpr(); } + // Build the derived for loop body. const Stmt* body = FS->getBody(); beginScope(Scope::DeclScope); Stmt* bodyResult = nullptr; - if (isa(body)) { - bodyResult = Visit(body).getStmt(); - } else { - beginBlock(); - StmtDiff Result = Visit(body); - for (Stmt* S : Result.getBothStmts()) - addToCurrentBlock(S); - CompoundStmt* Block = endBlock(); - if (Block->size() == 1) - bodyResult = Block->body_front(); - else - bodyResult = Block; - } + beginBlock(); + StmtDiff bodyVisited = Visit(body); + for (Stmt* S : bodyVisited.getBothStmts()) + addToCurrentBlock(S); + bodyResult = utils::unwrapIfSingleStmt(endBlock()); endScope(); - Stmt* forStmtDiff = - new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, condVarClone, - incResult, bodyResult, noLoc, noLoc, noLoc); + Stmt* forStmtDiff = new (m_Context) + ForStmt(m_Context, initDiff.getStmt(), cond, /*condVar=*/nullptr, + incResult, bodyResult, noLoc, noLoc, noLoc); addToCurrentBlock(forStmtDiff); CompoundStmt* Block = endBlock(); @@ -1366,6 +1405,26 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { } else opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()), BuildParens(Rdiff.getExpr_dx())); + } else if (BinOp->isLogicalOp()) { + // For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and + // correct derivative execution. + auto buildOneSide = [this](StmtDiff& Xdiff) { + if (Xdiff.getExpr_dx() && !isUnusedResult(Xdiff.getExpr_dx())) + return BuildParens(BuildOp(BO_Comma, BuildParens(Xdiff.getExpr_dx()), + BuildParens(Xdiff.getExpr()))); + else + return BuildParens(Xdiff.getExpr()); + }; + // dLL = (dL, L) + Expr* dLL = buildOneSide(Ldiff); + // dRR = (dR, R) + Expr* dRR = buildOneSide(Rdiff); + opDiff = BuildOp(opCode, dLL, dRR); + + // Since the both parts are included in the opDiff, there's no point in + // including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left + // nullptr is used for treating expressions like ((A && B) && C) correctly. + return StmtDiff(opDiff, nullptr); } if (!opDiff) { // FIXME: add support for other binary operators diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index c180d0b1a..831bbcb08 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -102,6 +102,19 @@ namespace clad { } } + Stmt* unwrapIfSingleStmt(Stmt* S) { + if (!S) + return nullptr; + if (!isa(S)) + return S; + auto* CS = cast(S); + if (CS->size() == 0) + return nullptr; + if (CS->size() == 1) + return CS->body_front(); + return CS; + } + CompoundStmt* PrependAndCreateCompoundStmt(ASTContext& C, Stmt* initial, Stmt* S) { llvm::SmallVector block; diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index f9ec442d8..4aba32332 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -783,19 +783,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Forward, Reverse); } - static Stmt* unwrapIfSingleStmt(Stmt* S) { - if (!S) - return nullptr; - if (!isa(S)) - return S; - auto* CS = cast(S); - if (CS->size() == 0) - return nullptr; - if (CS->size() == 1) - return CS->body_front(); - return CS; - } - StmtDiff ReverseModeVisitor::VisitIfStmt(const clang::IfStmt* If) { // Control scope of the IfStmt. E.g., in if (double x = ...) {...}, x goes // to this scope. @@ -888,8 +875,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, m_ExternalSource ->ActBeforeFinalizingVisitBranchSingleStmtInIfVisitStmt(); - Stmt* Forward = unwrapIfSingleStmt(endBlock(direction::forward)); - Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); + Stmt* Forward = utils::unwrapIfSingleStmt(endBlock(direction::forward)); + Stmt* Reverse = utils::unwrapIfSingleStmt(BranchDiff.getStmt_dx()); return StmtDiff(Forward, Reverse); }; @@ -914,8 +901,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CompoundStmt* ForwardBlock = endBlock(direction::forward); CompoundStmt* ReverseBlock = endBlock(direction::reverse); endScope(); - return StmtDiff(unwrapIfSingleStmt(ForwardBlock), - unwrapIfSingleStmt(ReverseBlock)); + return StmtDiff(utils::unwrapIfSingleStmt(ForwardBlock), + utils::unwrapIfSingleStmt(ReverseBlock)); } StmtDiff ReverseModeVisitor::VisitConditionalOperator( @@ -941,8 +928,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto Result = DifferentiateSingleExpr(Branch, dfdx); StmtDiff BranchDiff = Result.first; StmtDiff ExprDiff = Result.second; - Stmt* Forward = unwrapIfSingleStmt(BranchDiff.getStmt()); - Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); + Stmt* Forward = utils::unwrapIfSingleStmt(BranchDiff.getStmt()); + Stmt* Reverse = utils::unwrapIfSingleStmt(BranchDiff.getStmt_dx()); return {StmtDiff(Forward, Reverse), ExprDiff}; }; @@ -1128,7 +1115,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Reverse = endBlock(direction::reverse); endScope(); - return {unwrapIfSingleStmt(Forward), unwrapIfSingleStmt(Reverse)}; + return {utils::unwrapIfSingleStmt(Forward), + utils::unwrapIfSingleStmt(Reverse)}; } StmtDiff @@ -2766,7 +2754,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse); CompoundStmt* RCS = endBlock(direction::reverse); std::reverse(RCS->body_begin(), RCS->body_end()); - Stmt* ReverseResult = unwrapIfSingleStmt(RCS); + Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS); return StmtDiff(SDiff.getStmt(), ReverseResult); } @@ -2780,7 +2768,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CompoundStmt* RCS = endBlock(direction::reverse); Stmt* ForwardResult = endBlock(direction::forward); std::reverse(RCS->body_begin(), RCS->body_end()); - Stmt* ReverseResult = unwrapIfSingleStmt(RCS); + Stmt* ReverseResult = utils::unwrapIfSingleStmt(RCS); return {StmtDiff(ForwardResult, ReverseResult), EDiff}; } @@ -2929,7 +2917,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, for (Decl* decl : decls) addToBlock(BuildDeclStmt(decl), m_Globals); Stmt* initAssignments = MakeCompoundStmt(inits); - initAssignments = unwrapIfSingleStmt(initAssignments); + initAssignments = utils::unwrapIfSingleStmt(initAssignments); return StmtDiff(initAssignments); } @@ -3574,7 +3562,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterProcessingSingleStmtBodyInVisitForLoop(); - Stmt* reverseBlock = unwrapIfSingleStmt(bodyDiff.getStmt_dx()); + Stmt* reverseBlock = utils::unwrapIfSingleStmt(bodyDiff.getStmt_dx()); bodyDiff = {endBlock(direction::forward), reverseBlock}; // for forward-pass loop statement body endScope(); @@ -3619,7 +3607,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(condVarDiff, direction::reverse); addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse); bodyDiff = {bodyDiff.getStmt(), - unwrapIfSingleStmt(endBlock(direction::reverse))}; + utils::unwrapIfSingleStmt(endBlock(direction::reverse))}; return bodyDiff; } diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 03257ff34..03fddc7a9 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -377,8 +377,8 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: double res = 0; // CHECK-NEXT: { // CHECK-NEXT: size_t _d_count = 0; -// CHECK-NEXT: size_t _d_max_count = _d_n; -// CHECK-NEXT: for (size_t count = 0; {{.*}}max_count{{.*}}; ++count) { +// CHECK-NEXT: size_t max_count; +// CHECK-NEXT: for (size_t count = 0; (_d_max_count = _d_n) , (max_count = n); ++count) { // CHECK-NEXT: if (count >= max_count) // CHECK-NEXT: break; // CHECK-NEXT: { @@ -393,6 +393,81 @@ double fn10_darg0(double x, size_t n); // CHECK-NEXT: return _d_res; // CHECK-NEXT: } +double fn11(double x, double y) { + double r = 0; + for (int i = 0; (r = x); ++i) { + if (i == 3) break; + r += x; + } + return r; +} // fn11(x,y) == x + +double fn11_darg0(double x, double y); +// CHECK: double fn11_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: for (int i = 0; (_d_r = _d_x) , (r = x); ++i) { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_r += _d_x; +// CHECK-NEXT: r += x; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + +double fn12(double x, double y) { + double r = 0; + for (int i = 0; double c = x; ++i) { + if (i == 3) break; + c += x; + r = c; + } + return r; +} // fn11(x,y) == 2*x + +double fn12_darg0(double x, double y); +// CHECK: double fn12_darg0(double x, double y) { +// CHECK-NEXT: double _d_x = 1; +// CHECK-NEXT: double _d_y = 0; +// CHECK-NEXT: double _d_r = 0; +// CHECK-NEXT: double r = 0; +// CHECK-NEXT: { +// CHECK-NEXT: int _d_i = 0; +// CHECK-NEXT: double c; +// CHECK-NEXT: for (int i = 0; (_d_c = _d_x) , (c = x); ++i) { +// CHECK-NEXT: if (i == 3) +// CHECK-NEXT: break; +// CHECK-NEXT: _d_c += _d_x; +// CHECK-NEXT: c += x; +// CHECK-NEXT: _d_r = _d_c; +// CHECK-NEXT: r = c; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return _d_r; +// CHECK-NEXT: } + +double fn13(double u, double v) { + double res = 0; + for (; (res = u * v) && (u = 0) ;) {} + return res; +} // = u*v + +double fn13_darg0(double u, double v); +// CHECK: double fn13_darg0(double u, double v) { +// CHECK-NEXT: double _d_u = 1; +// CHECK-NEXT: double _d_v = 0; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: for (; ((_d_res = _d_u * v + u * _d_v) , (res = u * v)) && ((_d_u = 0) , (u = 0));) { +// CHECK-NEXT: } +// CHECK-NEXT: return _d_res; +// CHECK-NEXT: } + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ printf("%.2f\n", d_##fn.execute(3, 5)); @@ -430,4 +505,20 @@ int main() { clad::differentiate(fn10, 0); printf("Result is = %.2f\n", fn10_darg0(3, 5)); // CHECK-EXEC: Result is = 30.00 + + clad::differentiate(fn11, 0); + printf("Result is = %.2f\n", fn11_darg0(3, 5)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(-3, 6)); // CHECK-EXEC: Result is = 1.00 + printf("Result is = %.2f\n", fn11_darg0(1, 5)); // CHECK-EXEC: Result is = 1.00 + + clad::differentiate(fn12, 0); + printf("Result is = %.2f\n", fn12_darg0(3, 5)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(-3, 6)); // CHECK-EXEC: Result is = 2.00 + printf("Result is = %.2f\n", fn12_darg0(1, 5)); // CHECK-EXEC: Result is = 2.00 + + clad::differentiate(fn13, 0); + printf("Result is = %.2f\n", fn13_darg0(3, 4)); // CHECK-EXEC: Result is = 4.00 + printf("Result is = %.2f\n", fn13_darg0(-3, 5)); // CHECK-EXEC: Result is = 5.00 + printf("Result is = %.2f\n", fn13_darg0(1, 6)); // CHECK-EXEC: Result is = 6.00 + }