Skip to content

Commit

Permalink
Fix condition declarations and assignments in for loops
Browse files Browse the repository at this point in the history
These changes fix the differentiation of variable declarations in for
loop conditions that used to result into wrong derivatives. The commit
also tackles the problem of having an assignment operator that affects
the derivative in the for loop condition.

Fixes: vgvassilev#273
  • Loading branch information
gojakuch committed May 23, 2024
1 parent 41565dd commit 16183dd
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 21 deletions.
62 changes: 42 additions & 20 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,18 +668,39 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
const Stmt* init = FS->getInit();
StmtDiff initDiff = init ? Visit(init) : StmtDiff{};
addToCurrentBlock(initDiff.getStmt_dx());

// declaration in the condition (if any) needs to be differentiated
VarDecl* condVarDecl = FS->getConditionVariable();
VarDecl* condVarClone = nullptr;
DeclDiff<VarDecl> condVarResult;
DeclStmt* condVarDeclStmt_dx = nullptr;
if (condVarDecl) {
DeclDiff<VarDecl> condVarResult = DifferentiateVarDecl(condVarDecl);
condVarResult = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarResult.getDecl();
if (condVarResult.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx()));
condVarDeclStmt_dx = BuildDeclStmt(condVarResult.getDecl_dx());
}

// condition
StmtDiff condDiff = Clone(FS->getCond());
if (Expr* cond =
condDiff
.getExpr()) { // this adds support for assignments in conditions
while (CastExpr* condCast = dyn_cast<CastExpr>(cond))
cond = condCast->getSubExpr();
while (ParenExpr* condParen = dyn_cast<ParenExpr>(cond))
cond = condParen->getSubExpr();
if (BinaryOperator* condBO = dyn_cast<BinaryOperator>(cond)) {
if (condBO->isAssignmentOp())
condDiff = Visit(new (m_Context) ParenExpr(
noLoc, noLoc,
cond)); // if it's an assignment operator we wrap it back into
// parentheses (as it is expected to be) and then visit
}
}
Expr* cond = FS->getCond() ? Clone(FS->getCond()) : nullptr;
const Expr* inc = FS->getInc();

// Differentiate the increment expression of the for loop
const Expr* inc = FS->getInc();
beginBlock();
StmtDiff incDiff = inc ? Visit(inc) : StmtDiff{};
CompoundStmt* decls = endBlock();
Expand Down Expand Up @@ -714,27 +735,28 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
incResult = incDiff.getExpr();
}

// build the derived for loop body
const Stmt* body = FS->getBody();
beginScope(Scope::DeclScope);
Stmt* bodyResult = nullptr;
if (isa<CompoundStmt>(body)) {
bodyResult = Visit(body).getStmt();
} else {
beginBlock();
StmtDiff Result = Visit(body);
for (Stmt* S : Result.getBothStmts())
addToCurrentBlock(S);
CompoundStmt* Block = endBlock();
if (Block->size() == 1)
bodyResult = Block->body_front();
else
bodyResult = Block;
}
beginBlock();
StmtDiff bodyVisited = Visit(body);
if (condVarDeclStmt_dx)
addToCurrentBlock(condVarDeclStmt_dx);
if (condDiff.getStmt_dx())
addToCurrentBlock(condDiff.getStmt_dx());
for (Stmt* S : bodyVisited.getBothStmts())
addToCurrentBlock(S);
CompoundStmt* bodyResultCmpd = endBlock();
if (bodyResultCmpd->size() == 1)
bodyResult = bodyResultCmpd->body_front();
else
bodyResult = bodyResultCmpd;
endScope();

Stmt* forStmtDiff =
new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, condVarClone,
incResult, bodyResult, noLoc, noLoc, noLoc);
Stmt* forStmtDiff = new (m_Context)
ForStmt(m_Context, initDiff.getStmt(), condDiff.getExpr(), condVarClone,
incResult, bodyResult, noLoc, noLoc, noLoc);

addToCurrentBlock(forStmtDiff);
CompoundStmt* Block = endBlock();
Expand Down
78 changes: 77 additions & 1 deletion test/FirstDerivative/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,9 @@ double fn10_darg0(double x, size_t n);
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {
// CHECK-NEXT: size_t _d_count = 0;
// CHECK-NEXT: size_t _d_max_count = _d_n;
// CHECK-NEXT: for (size_t count = 0; {{.*}}max_count{{.*}}; ++count) {
// CHECK-NEXT: size_t _d_max_count = _d_n;
// CHECK-NEXT: {
// CHECK-NEXT: if (count >= max_count)
// CHECK-NEXT: break;
// CHECK-NEXT: {
Expand All @@ -388,11 +389,75 @@ double fn10_darg0(double x, size_t n);
// CHECK-NEXT: res += y * y;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }

double fn11(double x, double y) {
double r = 0;
for (int i = 0; (r = x); ++i) {
if (i == 3) break;
r += x;
}
return r;
} // fn11(x,y) == x

double fn11_darg0(double x, double y);
// CHECK: double fn11_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: double _d_r = 0;
// CHECK-NEXT: double r = 0;
// CHECK-NEXT: {
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: for (int i = 0; (r = x); ++i) {
// CHECK-NEXT: (_d_r = _d_x);
// CHECK-NEXT: {
// CHECK-NEXT: if (i == 3)
// CHECK-NEXT: break;
// CHECK-NEXT: _d_r += _d_x;
// CHECK-NEXT: r += x;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return _d_r;
// CHECK-NEXT: }

double fn12(double x, double y) {
double r = 0;
for (int i = 0; double c = x; ++i) {
if (i == 3) break;
c += x;
r = c;
}
return r;
} // fn11(x,y) == 2*x

double fn12_darg0(double x, double y);
// CHECK: double fn12_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: double _d_r = 0;
// CHECK-NEXT: double r = 0;
// CHECK-NEXT: {
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: for (int i = 0; {{.*}}c{{.*}}; ++i) {
// CHECK-NEXT: double _d_c = _d_x;
// CHECK-NEXT: {
// CHECK-NEXT: if (i == 3)
// CHECK-NEXT: break;
// CHECK-NEXT: _d_c += _d_x;
// CHECK-NEXT: c += x;
// CHECK-NEXT: _d_r = _d_c;
// CHECK-NEXT: r = c;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return _d_r;
// CHECK-NEXT: }

#define TEST(fn)\
auto d_##fn = clad::differentiate(fn, "i");\
printf("%.2f\n", d_##fn.execute(3, 5));
Expand Down Expand Up @@ -430,4 +495,15 @@ int main() {

clad::differentiate(fn10, 0);
printf("Result is = %.2f\n", fn10_darg0(3, 5)); // CHECK-EXEC: Result is = 30.00

clad::differentiate(fn11, 0);
printf("Result is = %.2f\n", fn11_darg0(3, 5)); // CHECK-EXEC: Result is = 1.00
printf("Result is = %.2f\n", fn11_darg0(-3, 6)); // CHECK-EXEC: Result is = 1.00
printf("Result is = %.2f\n", fn11_darg0(1, 5)); // CHECK-EXEC: Result is = 1.00

clad::differentiate(fn12, 0);
printf("Result is = %.2f\n", fn12_darg0(3, 5)); // CHECK-EXEC: Result is = 2.00
printf("Result is = %.2f\n", fn12_darg0(-3, 6)); // CHECK-EXEC: Result is = 2.00
printf("Result is = %.2f\n", fn12_darg0(1, 5)); // CHECK-EXEC: Result is = 2.00

}

0 comments on commit 16183dd

Please sign in to comment.