diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index 3d16cd8e4..491a36683 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -706,13 +706,16 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { // If it's a supported differentiable operator we wrap it back into // parentheses and then visit. To ensure the correctness, a comma operator // expression (cond_dx, cond) is generated and put instead of the condition. - // FIXME: Add support for other expressions in cond (unary operators, - // comparisons, function calls, etc.). Ideally, we should be able to simply - // always call Visit(cond) - BinaryOperator* condBO = dyn_cast(cond); - if (condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) { + // FIXME: Add support for other expressions in cond (comparisons, function + // calls, etc.). Ideally, we should be able to simply always call + // Visit(cond) + auto* condBO = dyn_cast(cond); + auto* condUO = dyn_cast(cond); + if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) || + condUO) { condDiff = Visit(cond); - if (condDiff.getExpr_dx() && !isUnusedResult(condDiff.getExpr_dx())) + if (condDiff.getExpr_dx() && + (!isUnusedResult(condDiff.getExpr_dx()) || condUO)) cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), BuildParens(condDiff.getExpr())); else @@ -1286,6 +1289,8 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_AddrOf) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); + } else if (opKind == UnaryOperatorKind::UO_LNot) { + return StmtDiff(op, diff.getExpr_dx()); } else { unsupportedOpWarn(UnOp->getEndLoc()); auto zero = diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C index 0682c1801..43a35c749 100644 --- a/test/FirstDerivative/Loops.C +++ b/test/FirstDerivative/Loops.C @@ -499,6 +499,24 @@ double fn14_darg0(double x); // CHECK-NEXT: return _d_x; // CHECK-NEXT: } +double fn15(double u, double v) { + double res = 0; + for (; !(res = u * v) ;) {} + return 2*res; +} + +double fn15_darg0(double u, double v); +//CHECK: double fn15_darg0(double u, double v) { +// CHECK-NEXT: double _d_u = 1; +// CHECK-NEXT: double _d_v = 0; +// CHECK-NEXT: double _d_res = 0; +// CHECK-NEXT: double res = 0; +// CHECK-NEXT: for (; (_d_res = _d_u * v + u * _d_v) , !(res = u * v);) { +// CHECK-NEXT: } +// CHECK-NEXT: return 0 * res + 2 * _d_res; +// CHECK-NEXT: } + + #define TEST(fn)\ auto d_##fn = clad::differentiate(fn, "i");\ printf("%.2f\n", d_##fn.execute(3, 5)); @@ -556,4 +574,7 @@ int main() { printf("Result is = %.2f\n", fn14_darg0(3)); // CHECK-EXEC: Result is = 4.00 printf("Result is = %.2f\n", fn14_darg0(-3)); // CHECK-EXEC: Result is = 4.00 printf("Result is = %.2f\n", fn14_darg0(1)); // CHECK-EXEC: Result is = 4.00 + + clad::differentiate(fn15, 0); + printf("Result is = %.2f\n", fn15_darg0(7, 3)); // CHECK-EXEC: Result is = 6.00 }