Skip to content

Commit

Permalink
Fix condition declarations & assignments, enable logical operators in…
Browse files Browse the repository at this point in the history
… for loops

These changes fix the differentiation of variable declarations in for
loop conditions that used to result into wrong derivatives. The commit
also tackles the problem of having an assignment operator that affects
the derivative in the for-loop condition, as well as adds support for
logical operators in for-loops and allows to combine assignments with
them.

Fixes: vgvassilev#273
  • Loading branch information
gojakuch committed May 30, 2024
1 parent e3c1ff0 commit b1a2c36
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 26 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/BaseForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class BaseForwardModeVisitor
// Decl is not Stmt, so it cannot be visited directly.
virtual DeclDiff<clang::VarDecl>
DifferentiateVarDecl(const clang::VarDecl* VD);
virtual DeclDiff<clang::VarDecl>
DifferentiateVarDecl(const clang::VarDecl* VD, bool ignoreInit);
/// Shorthand for warning on differentiation of unsupported operators
void unsupportedOpWarn(clang::SourceLocation loc,
llvm::ArrayRef<llvm::StringRef> args = {}) {
Expand Down
118 changes: 94 additions & 24 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,18 +668,59 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
const Stmt* init = FS->getInit();
StmtDiff initDiff = init ? Visit(init) : StmtDiff{};
addToCurrentBlock(initDiff.getStmt_dx());
VarDecl* condVarDecl = FS->getConditionVariable();
VarDecl* condVarClone = nullptr;
if (condVarDecl) {
DeclDiff<VarDecl> condVarResult = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarResult.getDecl();

StmtDiff condDiff = Clone(FS->getCond());
Expr* cond = condDiff.getExpr();

// The declaration in the condition needs to be differentiated.
if (VarDecl* condVarDecl = FS->getConditionVariable()) {
// Here we create a fictional cond that is equal to the assignment used in
// the declaration. The declaration itself is thrown before the for-loop
// without any init value. The fictional condition is then differentiated as
// a normal condition would be (see below). For example, the declaration
// inside `for (;double t = x;) {}` will be first processed into the
// following code:
// ```
// {
// double t;
// for (;t = x;) {}
// }
// ```
// which will then get differentiated normally as a for-loop with a
// differentiable condition in the next section.
DeclDiff<VarDecl> condVarResult = DifferentiateVarDecl(condVarDecl, /*ignoreInit=*/true);
VarDecl* condVarClone = condVarResult.getDecl();
if (condVarResult.getDecl_dx())
addToCurrentBlock(BuildDeclStmt(condVarResult.getDecl_dx()));
auto condInit = condVarClone->getInit();
condVarClone->setInit(nullptr);
cond = BuildOp(BO_Assign, BuildDeclRef(condVarClone), condInit);
addToCurrentBlock(BuildDeclStmt(condVarClone));
}

// Condition differentiation.
// This adds support for assignments in conditions.
if (cond) {
cond = cond->IgnoreParenImpCasts();
// If it's a supported differentiable operator we wrap it back into
// parentheses and then visit. To ensure the correctness, a comma operator
// expression (cond_dx, cond) is generated and put instead of the condition.
// FIXME: Add support for other expressions in cond (unary operators,
// comparisons, function calls, etc.). Ideally, we should be able to simply
// always call Visit(cond)
BinaryOperator* condBO = dyn_cast<BinaryOperator>(cond);
if (condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) {
condDiff = Visit(cond);
if (condDiff.getExpr_dx() && !isUnusedResult(condDiff.getExpr_dx()))
cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()),
BuildParens(condDiff.getExpr()));
else
cond = condDiff.getExpr();
}
}
Expr* cond = FS->getCond() ? Clone(FS->getCond()) : nullptr;
const Expr* inc = FS->getInc();

// Differentiate the increment expression of the for loop
const Expr* inc = FS->getInc();
beginBlock();
StmtDiff incDiff = inc ? Visit(inc) : StmtDiff{};
CompoundStmt* decls = endBlock();
Expand Down Expand Up @@ -714,27 +755,24 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
incResult = incDiff.getExpr();
}

// Build the derived for loop body.
const Stmt* body = FS->getBody();
beginScope(Scope::DeclScope);
Stmt* bodyResult = nullptr;
if (isa<CompoundStmt>(body)) {
bodyResult = Visit(body).getStmt();
} else {
beginBlock();
StmtDiff Result = Visit(body);
for (Stmt* S : Result.getBothStmts())
addToCurrentBlock(S);
CompoundStmt* Block = endBlock();
if (Block->size() == 1)
bodyResult = Block->body_front();
else
bodyResult = Block;
}
beginBlock();
StmtDiff bodyVisited = Visit(body);
for (Stmt* S : bodyVisited.getBothStmts())
addToCurrentBlock(S);
CompoundStmt* bodyResultCmpd = endBlock();
if (bodyResultCmpd->size() == 1)
bodyResult = bodyResultCmpd->body_front();
else
bodyResult = bodyResultCmpd;
endScope();

Stmt* forStmtDiff =
new (m_Context) ForStmt(m_Context, initDiff.getStmt(), cond, condVarClone,
incResult, bodyResult, noLoc, noLoc, noLoc);
Stmt* forStmtDiff = new (m_Context)
ForStmt(m_Context, initDiff.getStmt(), cond, /*condVar=*/nullptr,
incResult, bodyResult, noLoc, noLoc, noLoc);

addToCurrentBlock(forStmtDiff);
CompoundStmt* Block = endBlock();
Expand Down Expand Up @@ -1366,6 +1404,25 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
} else
opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()),
BuildParens(Rdiff.getExpr_dx()));
} else if (BinOp->isLogicalOp()) {
// For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and
// correct derivative execution.
auto buildOneSide = [this](StmtDiff& Xdiff) {
if (Xdiff.getExpr_dx() && !isUnusedResult(Xdiff.getExpr_dx()))
return BuildParens(BuildOp(BO_Comma, BuildParens(Xdiff.getExpr_dx()),
BuildParens(Xdiff.getExpr())));
return BuildParens(Xdiff.getExpr());
};
// dLL = (dL, L)
Expr* dLL = buildOneSide(Ldiff);
// dRR = (dR, R)
Expr* dRR = buildOneSide(Rdiff);
opDiff = BuildOp(opCode, dLL, dRR);

// Since the both parts are included in the opDiff, there's no point in
// including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left
// nullptr is used for treating expressions like ((A && B) && C) correctly.
return StmtDiff(opDiff, nullptr);
}
if (!opDiff) {
// FIXME: add support for other binary operators
Expand All @@ -1386,7 +1443,20 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {

DeclDiff<VarDecl>
BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
StmtDiff initDiff = VD->getInit() ? Visit(VD->getInit()) : StmtDiff{};
return DifferentiateVarDecl(VD, false);
}

DeclDiff<VarDecl>
BaseForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD, bool ignoreInit) {
StmtDiff initDiff{};
const Expr* init = VD->getInit();
if (init) {
if (!ignoreInit)
initDiff = Visit(init);
else
initDiff = StmtDiff(Clone(init));
}

// Here we are assuming that derived type and the original type are same.
// This may not necessarily be true in the future.
VarDecl* VDClone =
Expand Down
130 changes: 128 additions & 2 deletions test/FirstDerivative/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,9 @@ double fn10_darg0(double x, size_t n);
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {
// CHECK-NEXT: size_t _d_count = 0;
// CHECK-NEXT: size_t _d_max_count = _d_n;
// CHECK-NEXT: for (size_t count = 0; {{.*}}max_count{{.*}}; ++count) {
// CHECK-NEXT: size_t _d_max_count;
// CHECK-NEXT: size_t max_count;
// CHECK-NEXT: for (size_t count = 0; (_d_max_count = _d_n) , (max_count = n); ++count) {
// CHECK-NEXT: if (count >= max_count)
// CHECK-NEXT: break;
// CHECK-NEXT: {
Expand All @@ -393,6 +394,111 @@ double fn10_darg0(double x, size_t n);
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }

double fn11(double x, double y) {
double r = 0;
for (int i = 0; (r = x); ++i) {
if (i == 3) break;
r += x;
}
return r;
} // fn11(x,y) == x

double fn11_darg0(double x, double y);
// CHECK: double fn11_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: double _d_r = 0;
// CHECK-NEXT: double r = 0;
// CHECK-NEXT: {
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: for (int i = 0; (_d_r = _d_x) , (r = x); ++i) {
// CHECK-NEXT: if (i == 3)
// CHECK-NEXT: break;
// CHECK-NEXT: _d_r += _d_x;
// CHECK-NEXT: r += x;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return _d_r;
// CHECK-NEXT: }

double fn12(double x, double y) {
double r = 0;
for (int i = 0; double c = x; ++i) {
if (i == 3) break;
c += x;
r = c;
}
return r;
} // fn11(x,y) == 2*x

double fn12_darg0(double x, double y);
// CHECK: double fn12_darg0(double x, double y) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_y = 0;
// CHECK-NEXT: double _d_r = 0;
// CHECK-NEXT: double r = 0;
// CHECK-NEXT: {
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: double _d_c;
// CHECK-NEXT: double c;
// CHECK-NEXT: for (int i = 0; (_d_c = _d_x) , (c = x); ++i) {
// CHECK-NEXT: if (i == 3)
// CHECK-NEXT: break;
// CHECK-NEXT: _d_c += _d_x;
// CHECK-NEXT: c += x;
// CHECK-NEXT: _d_r = _d_c;
// CHECK-NEXT: r = c;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return _d_r;
// CHECK-NEXT: }

double fn13(double u, double v) {
double res = 0;
for (; (res = u * v) && (u = 0) ;) {}
return res;
} // = u*v

double fn13_darg0(double u, double v);
// CHECK: double fn13_darg0(double u, double v) {
// CHECK-NEXT: double _d_u = 1;
// CHECK-NEXT: double _d_v = 0;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: for (; ((_d_res = _d_u * v + u * _d_v) , (res = u * v)) && ((_d_u = 0) , (u = 0));) {
// CHECK-NEXT: }
// CHECK-NEXT: return _d_res;
// CHECK-NEXT: }

double fn14(double x) {
double r = 0;
double t = x;
for (int i = 0; (r = t) || false; ++i) {
if (i == 3) break;
x += r;
}
return x;
} // = 4*x

double fn14_darg0(double x);
// CHECK: double fn14_darg0(double x) {
// CHECK-NEXT: double _d_x = 1;
// CHECK-NEXT: double _d_r = 0;
// CHECK-NEXT: double r = 0;
// CHECK-NEXT: double _d_t = _d_x;
// CHECK-NEXT: double t = x;
// CHECK-NEXT: {
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: for (int i = 0; ((_d_r = _d_t) , (r = t)) || false; ++i) {
// CHECK-NEXT: if (i == 3)
// CHECK-NEXT: break;
// CHECK-NEXT: _d_x += _d_r;
// CHECK-NEXT: x += r;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

#define TEST(fn)\
auto d_##fn = clad::differentiate(fn, "i");\
printf("%.2f\n", d_##fn.execute(3, 5));
Expand Down Expand Up @@ -430,4 +536,24 @@ int main() {

clad::differentiate(fn10, 0);
printf("Result is = %.2f\n", fn10_darg0(3, 5)); // CHECK-EXEC: Result is = 30.00

clad::differentiate(fn11, 0);
printf("Result is = %.2f\n", fn11_darg0(3, 5)); // CHECK-EXEC: Result is = 1.00
printf("Result is = %.2f\n", fn11_darg0(-3, 6)); // CHECK-EXEC: Result is = 1.00
printf("Result is = %.2f\n", fn11_darg0(1, 5)); // CHECK-EXEC: Result is = 1.00

clad::differentiate(fn12, 0);
printf("Result is = %.2f\n", fn12_darg0(3, 5)); // CHECK-EXEC: Result is = 2.00
printf("Result is = %.2f\n", fn12_darg0(-3, 6)); // CHECK-EXEC: Result is = 2.00
printf("Result is = %.2f\n", fn12_darg0(1, 5)); // CHECK-EXEC: Result is = 2.00

clad::differentiate(fn13, 0);
printf("Result is = %.2f\n", fn13_darg0(3, 4)); // CHECK-EXEC: Result is = 4.00
printf("Result is = %.2f\n", fn13_darg0(-3, 5)); // CHECK-EXEC: Result is = 5.00
printf("Result is = %.2f\n", fn13_darg0(1, 6)); // CHECK-EXEC: Result is = 6.00

clad::differentiate(fn14, 0);
printf("Result is = %.2f\n", fn14_darg0(3)); // CHECK-EXEC: Result is = 4.00
printf("Result is = %.2f\n", fn14_darg0(-3)); // CHECK-EXEC: Result is = 4.00
printf("Result is = %.2f\n", fn14_darg0(1)); // CHECK-EXEC: Result is = 4.00
}

0 comments on commit b1a2c36

Please sign in to comment.