Skip to content

Commit

Permalink
Add support for the logical not operator in for-loop conds (forward…
Browse files Browse the repository at this point in the history
… mode)

This commit adds support for logical negation operators with side-effects
inside for-loop conditions in the forward mode.

Fixes: vgvassilev#911
  • Loading branch information
gojakuch authored and MihailMihov committed Jun 5, 2024
1 parent 4865ce1 commit 6368542
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
17 changes: 11 additions & 6 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,16 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
// If it's a supported differentiable operator we wrap it back into
// parentheses and then visit. To ensure the correctness, a comma operator
// expression (cond_dx, cond) is generated and put instead of the condition.
// FIXME: Add support for other expressions in cond (unary operators,
// comparisons, function calls, etc.). Ideally, we should be able to simply
// always call Visit(cond)
BinaryOperator* condBO = dyn_cast<BinaryOperator>(cond);
if (condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) {
// FIXME: Add support for other expressions in cond (comparisons, function
// calls, etc.). Ideally, we should be able to simply always call
// Visit(cond)
auto* condBO = dyn_cast<BinaryOperator>(cond);
auto* condUO = dyn_cast<UnaryOperator>(cond);
if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) ||
condUO) {
condDiff = Visit(cond);
if (condDiff.getExpr_dx() && !isUnusedResult(condDiff.getExpr_dx()))
if (condDiff.getExpr_dx() &&
(!isUnusedResult(condDiff.getExpr_dx()) || condUO))
cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()),
BuildParens(condDiff.getExpr()));
else
Expand Down Expand Up @@ -1286,6 +1289,8 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_AddrOf) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_LNot) {
return StmtDiff(op, diff.getExpr_dx());
} else {
unsupportedOpWarn(UnOp->getEndLoc());
auto zero =
Expand Down
21 changes: 21 additions & 0 deletions test/FirstDerivative/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,24 @@ double fn14_darg0(double x);
// CHECK-NEXT: return _d_x;
// CHECK-NEXT: }

double fn15(double u, double v) {
double res = 0;
for (; !(res = u * v) ;) {}
return 2*res;
}

double fn15_darg0(double u, double v);
//CHECK: double fn15_darg0(double u, double v) {
// CHECK-NEXT: double _d_u = 1;
// CHECK-NEXT: double _d_v = 0;
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: for (; (_d_res = _d_u * v + u * _d_v) , !(res = u * v);) {
// CHECK-NEXT: }
// CHECK-NEXT: return 0 * res + 2 * _d_res;
// CHECK-NEXT: }


#define TEST(fn)\
auto d_##fn = clad::differentiate(fn, "i");\
printf("%.2f\n", d_##fn.execute(3, 5));
Expand Down Expand Up @@ -556,4 +574,7 @@ int main() {
printf("Result is = %.2f\n", fn14_darg0(3)); // CHECK-EXEC: Result is = 4.00
printf("Result is = %.2f\n", fn14_darg0(-3)); // CHECK-EXEC: Result is = 4.00
printf("Result is = %.2f\n", fn14_darg0(1)); // CHECK-EXEC: Result is = 4.00

clad::differentiate(fn15, 0);
printf("Result is = %.2f\n", fn15_darg0(7, 3)); // CHECK-EXEC: Result is = 6.00
}

0 comments on commit 6368542

Please sign in to comment.