diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp
index 5beab0cb3..7d6e0de85 100644
--- a/lib/Differentiator/BaseForwardModeVisitor.cpp
+++ b/lib/Differentiator/BaseForwardModeVisitor.cpp
@@ -668,18 +668,39 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
   const Stmt* init = FS->getInit();
   StmtDiff initDiff = init ? Visit(init) : StmtDiff{};
   addToCurrentBlock(initDiff.getStmt_dx());
+
+  // declaration in the condition (if any) needs to be differentiated
   VarDecl* condVarDecl = FS->getConditionVariable();
   VarDecl* condVarClone = nullptr;
+  DeclDiff<VarDecl> condVarResult;
+  DeclStmt* condVarDeclStmt_dx = nullptr;
   if (condVarDecl) {
-    DeclDiff<VarDecl> condVarResult = DifferentiateVarDecl(condVarDecl);
+    condVarResult = DifferentiateVarDecl(condVarDecl);
     condVarClone = condVarResult.getDecl();
     if (condVarResult.getDecl_dx())
-      addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx()));
+      condVarDeclStmt_dx = BuildDeclStmt(condVarResult.getDecl_dx());
+  }
+
+  // condition
+  StmtDiff condDiff = Clone(FS->getCond());
+  if (Expr* cond =
+          condDiff
+              .getExpr()) { // this adds support for assignments in conditions
+    while (CastExpr* condCast = dyn_cast<CastExpr>(cond))
+      cond = condCast->getSubExpr();
+    while (ParenExpr* condParen = dyn_cast<ParenExpr>(cond))
+      cond = condParen->getSubExpr();
+    if (BinaryOperator* condBO = dyn_cast<BinaryOperator>(cond)) {
+      if (condBO->isAssignmentOp())
+        condDiff = Visit(new (m_Context) ParenExpr(
+            noLoc, noLoc,
+            cond)); // if it's an assignment operator we wrap it back into
+                    // parentheses (as it is expected to be) and then visit
+    }
   }
-  Expr* cond = FS->getCond() ? Clone(FS->getCond()) : nullptr;
-  const Expr* inc = FS->getInc();
 
   // Differentiate the increment expression of the for loop
+  const Expr* inc = FS->getInc();
   beginBlock();
   StmtDiff incDiff = inc ? Visit(inc) : StmtDiff{};
   CompoundStmt* decls = endBlock();
@@ -714,27 +735,28 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
     incResult = incDiff.getExpr();
   }
 
+  // build the derived for loop body
   const Stmt* body = FS->getBody();
   beginScope(Scope::DeclScope);
   Stmt* bodyResult = nullptr;
-  if (isa<CompoundStmt>(body)) {
-    bodyResult = Visit(body).getStmt();
-  } else {
-    beginBlock();
-    StmtDiff Result = Visit(body);
-    for (Stmt* S : Result.getBothStmts())
-      addToCurrentBlock(S);
-    CompoundStmt* Block = endBlock();
-    if (Block->size() == 1)
-      bodyResult = Block->body_front();
-    else
-      bodyResult = Block;
-  }
+  beginBlock();
+  StmtDiff bodyVisited = Visit(body);
+  if (condVarDeclStmt_dx)
+    addToCurrentBlock(condVarDeclStmt_dx);
+  if (condDiff.getStmt_dx())
+    addToCurrentBlock(condDiff.getStmt_dx());
+  for (Stmt* S : bodyVisited.getBothStmts())
+    addToCurrentBlock(S);
+  CompoundStmt* bodyResultCmpd = endBlock();
+  if (bodyResultCmpd->size() == 1)
+    bodyResult = bodyResultCmpd->body_front();
+  else
+    bodyResult = bodyResultCmpd;
   endScope();
 
-  Stmt* forStmtDiff =
-      new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, condVarClone,
-                              incResult, bodyResult, noLoc, noLoc, noLoc);
+  Stmt* forStmtDiff = new (m_Context)
+      ForStmt(m_Context, initDiff.getStmt(), condDiff.getExpr(), condVarClone,
+              incResult, bodyResult, noLoc, noLoc, noLoc);
 
   addToCurrentBlock(forStmtDiff);
   CompoundStmt* Block = endBlock();
diff --git a/test/FirstDerivative/Loops.C b/test/FirstDerivative/Loops.C
index 03257ff34..7db23a25f 100644
--- a/test/FirstDerivative/Loops.C
+++ b/test/FirstDerivative/Loops.C
@@ -377,8 +377,9 @@ double fn10_darg0(double x, size_t n);
 // CHECK-NEXT:     double res = 0;
 // CHECK-NEXT:     {
 // CHECK-NEXT:         size_t _d_count = 0;
-// CHECK-NEXT:         size_t _d_max_count = _d_n;
 // CHECK-NEXT:         for (size_t count = 0; {{.*}}max_count{{.*}}; ++count) {
+// CHECK-NEXT:         size_t _d_max_count = _d_n;
+// CHECK-NEXT:             {
 // CHECK-NEXT:             if (count >= max_count)
 // CHECK-NEXT:                 break;
 // CHECK-NEXT:             {
@@ -388,11 +389,75 @@ double fn10_darg0(double x, size_t n);
 // CHECK-NEXT:                     res += y * y;
 // CHECK-NEXT:                 }
 // CHECK-NEXT:             }
+// CHECK-NEXT:             }
 // CHECK-NEXT:         }
 // CHECK-NEXT:     }
 // CHECK-NEXT:     return _d_res;
 // CHECK-NEXT: }
 
+double fn11(double x, double y) {
+    double r = 0;
+    for (int i = 0; (r = x); ++i) {
+        if (i == 3) break;
+        r += x;
+    }
+    return r;
+} // fn11(x,y) == x
+
+double fn11_darg0(double x, double y);
+// CHECK:      double fn11_darg0(double x, double y) {
+// CHECK-NEXT:          double _d_x = 1;
+// CHECK-NEXT:          double _d_y = 0;
+// CHECK-NEXT:          double _d_r = 0;
+// CHECK-NEXT:          double r = 0;
+// CHECK-NEXT:          {
+// CHECK-NEXT:              int _d_i = 0;
+// CHECK-NEXT:              for (int i = 0; (r = x); ++i) {
+// CHECK-NEXT:                  (_d_r = _d_x);
+// CHECK-NEXT:                  {
+// CHECK-NEXT:                      if (i == 3)
+// CHECK-NEXT:                          break;
+// CHECK-NEXT:                      _d_r += _d_x;
+// CHECK-NEXT:                      r += x;
+// CHECK-NEXT:                  }
+// CHECK-NEXT:              }
+// CHECK-NEXT:          }
+// CHECK-NEXT:          return _d_r;
+// CHECK-NEXT:      }
+
+double fn12(double x, double y) {
+    double r = 0;
+    for (int i = 0; double c = x; ++i) {
+        if (i == 3) break;
+        c += x;
+        r = c;
+    }
+    return r;
+} // fn11(x,y) == 2*x
+
+double fn12_darg0(double x, double y);
+// CHECK:      double fn12_darg0(double x, double y) {
+// CHECK-NEXT:          double _d_x = 1;
+// CHECK-NEXT:          double _d_y = 0;
+// CHECK-NEXT:          double _d_r = 0;
+// CHECK-NEXT:          double r = 0;
+// CHECK-NEXT:          {
+// CHECK-NEXT:              int _d_i = 0;
+// CHECK-NEXT:              for (int i = 0; c; ++i) {
+// CHECK-NEXT:                  double _d_c = _d_x;
+// CHECK-NEXT:                  {
+// CHECK-NEXT:                      if (i == 3)
+// CHECK-NEXT:                          break;
+// CHECK-NEXT:                      _d_c += _d_x;
+// CHECK-NEXT:                      c += x;
+// CHECK-NEXT:                      _d_r = _d_c;
+// CHECK-NEXT:                      r = c;
+// CHECK-NEXT:                  }
+// CHECK-NEXT:              }
+// CHECK-NEXT:          }
+// CHECK-NEXT:          return _d_r;
+// CHECK-NEXT:      }
+
 #define TEST(fn)\
 auto d_##fn = clad::differentiate(fn, "i");\
 printf("%.2f\n", d_##fn.execute(3, 5));
@@ -430,4 +495,15 @@ int main() {
 
   clad::differentiate(fn10, 0);
   printf("Result is = %.2f\n", fn10_darg0(3, 5)); // CHECK-EXEC: Result is = 30.00
+
+  clad::differentiate(fn11, 0);
+  printf("Result is = %.2f\n", fn11_darg0(3, 5)); // CHECK-EXEC: Result is = 1.00
+  printf("Result is = %.2f\n", fn11_darg0(-3, 6)); // CHECK-EXEC: Result is = 1.00
+  printf("Result is = %.2f\n", fn11_darg0(1, 5)); // CHECK-EXEC: Result is = 1.00
+
+  clad::differentiate(fn12, 0);
+  printf("Result is = %.2f\n", fn12_darg0(3, 5)); // CHECK-EXEC: Result is = 2.00
+  printf("Result is = %.2f\n", fn12_darg0(-3, 6)); // CHECK-EXEC: Result is = 2.00
+  printf("Result is = %.2f\n", fn12_darg0(1, 5)); // CHECK-EXEC: Result is = 2.00
+
 }