Skip to content

Commit

Permalink
Handle variable declarations in conditions of if-statements
Browse files Browse the repository at this point in the history
Fixes: #865
  • Loading branch information
gojakuch committed May 16, 2024
1 parent c4ac006 commit 91f81d4
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 83 deletions.
108 changes: 44 additions & 64 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,51 +800,65 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to this scope.
beginScope(Scope::DeclScope | Scope::ControlScope);

StmtDiff cond = Clone(If->getCond());
// Condition has to be stored as a "global" variable, to take the correct
// branch in the reverse pass.
// If we are inside loop, the condition has to be stored in a stack after
// the if statement.
Expr* PushCond = nullptr;
Expr* PopCond = nullptr;
auto condExpr = Visit(cond.getExpr());
// Create a block "around" if statement, e.g:
// {
// ...
// if (...) {...}
// }
beginBlock(direction::forward);
beginBlock(direction::reverse);
StmtDiff condDiff;
// this ensures we can differentiate conditions that affect the derivatives
// as well as declarations inside the condition:
beginBlock(direction::reverse);
if (const auto* condDeclStmt = If->getConditionVariableDeclStmt())
condDiff = Visit(condDeclStmt);
else
condDiff = Visit(If->getCond());
CompoundStmt* RCS = endBlock(direction::reverse);
if (!RCS->body_empty()) {
std::reverse(
RCS->body_begin(),
RCS->body_end()); // it is reversed in the endBlock() but we don't
// actually need this, so we reverse it once again
addToCurrentBlock(RCS, direction::reverse);
}

if (isInsideLoop) {
// If we are inside for loop, cond will be stored in the following way:
// forward:
// _t = cond;
// if (_t) { ... }
// clad::push(..., _t);
// reverse:
// If we are inside for loop, condDiff will be stored in the following
// way: forward: _t = cond; if (_t) { ... } clad::push(..., _t); reverse:
// if (clad::pop(...)) { ... }
// Simply doing
// if (clad::push(..., _t) { ... }
// is incorrect when if contains return statement inside: return will
// skip corresponding push.
cond = StoreAndRef(condExpr.getExpr(), direction::forward, "_t",
/*forceDeclCreation=*/true);
StmtDiff condPushPop = GlobalStoreAndRef(cond.getExpr(), "_cond",
/*force=*/true);
condDiff = StoreAndRef(condDiff.getExpr(), m_Context.BoolTy,
direction::forward, "_t",
/*forceDeclCreation=*/true);
StmtDiff condPushPop =
GlobalStoreAndRef(condDiff.getExpr(), m_Context.BoolTy, "_cond",
/*force=*/true);
PushCond = condPushPop.getExpr();
PopCond = condPushPop.getExpr_dx();
} else
cond = GlobalStoreAndRef(condExpr.getExpr(), "_cond");
condDiff =
GlobalStoreAndRef(condDiff.getExpr(), m_Context.BoolTy, "_cond");
// Convert cond to boolean condition. We are modifying each Stmt in
// StmtDiff.
for (Stmt*& S : cond.getBothStmts())
for (Stmt*& S : condDiff.getBothStmts())
if (S)
S = m_Sema
.ActOnCondition(getCurrentScope(), noLoc, cast<Expr>(S),
Sema::ConditionKind::Boolean)
.get()
.second;

// Create a block "around" if statement, e.g:
// {
// ...
// if (...) {...}
// }
beginBlock(direction::forward);
beginBlock(direction::reverse);
const Stmt* init = If->getInit();
StmtDiff initResult = init ? Visit(init) : StmtDiff{};
// If there is Init, it's derivative will be output in the block before if:
Expand All @@ -858,24 +872,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// This is done to avoid variable names clashes.
addToCurrentBlock(initResult.getStmt_dx());

VarDecl* condVarClone = nullptr;
if (const VarDecl* condVarDecl = If->getConditionVariable()) {
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
if (condVarDeclDiff.getDecl_dx())
addToBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()), m_Globals);
}

// Condition is just cloned as it is, not derived.
// FIXME: if condition changes one of the variables, it may be reasonable
// to derive it, e.g.
// if (x += x) {...}
// should result in:
// {
// _d_y += _d_x
// if (y += x) {...}
// }

auto VisitBranch = [&](const Stmt* Branch) -> StmtDiff {
if (!Branch)
return {};
Expand All @@ -902,37 +898,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff thenDiff = VisitBranch(If->getThen());
StmtDiff elseDiff = VisitBranch(If->getElse());

// It is problematic to specify both condVarDecl and cond thorugh
// Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor.
Stmt* Forward = clad_compat::IfStmt_Create(m_Context,
noLoc,
If->isConstexpr(),
initResult.getStmt(),
condVarClone,
cond.getExpr(),
noLoc,
noLoc,
thenDiff.getStmt(),
noLoc,
elseDiff.getStmt());
Stmt* Forward = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt(), nullptr,
condDiff.getExpr(), noLoc, noLoc, thenDiff.getStmt(), noLoc,
elseDiff.getStmt());
addToCurrentBlock(Forward, direction::forward);

Expr* reverseCond = cond.getExpr_dx();
Expr* reverseCond = condDiff.getExpr_dx();
if (isInsideLoop) {
addToCurrentBlock(PushCond, direction::forward);
reverseCond = PopCond;
}
Stmt* Reverse = clad_compat::IfStmt_Create(m_Context,
noLoc,
If->isConstexpr(),
initResult.getStmt_dx(),
condVarClone,
reverseCond,
noLoc,
noLoc,
thenDiff.getStmt_dx(),
noLoc,
elseDiff.getStmt_dx());
Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), nullptr,
reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc,
elseDiff.getStmt_dx());
addToCurrentBlock(Reverse, direction::reverse);
CompoundStmt* ForwardBlock = endBlock(direction::forward);
CompoundStmt* ReverseBlock = endBlock(direction::reverse);
Expand Down
4 changes: 4 additions & 0 deletions test/ErrorEstimation/ConditonalStatements.C
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ float func(float x, float y) {
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _t2;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = y;
Expand All @@ -36,6 +37,7 @@ float func(float x, float y) {
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: _ret_value0 = x + y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
Expand Down Expand Up @@ -91,6 +93,7 @@ float func2(float x) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: float z = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = z > 9;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _ret_value0 = x + x;
Expand All @@ -99,6 +102,7 @@ float func2(float x) {
//CHECK-NEXT: _ret_value0 = x * x;
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
Expand Down
18 changes: 14 additions & 4 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ double f2(double x, double y) {
//CHECK: void f2_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _cond0 = x < y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
Expand Down Expand Up @@ -160,18 +162,22 @@ double f5(double x, double y) {
//CHECK-NEXT: double z = 0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double t = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _cond1 = y < 0;
//CHECK-NEXT: if (_cond1) {
//CHECK-NEXT: z = t;
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: _d_t += 1;
Expand Down Expand Up @@ -223,18 +229,22 @@ double f6(double x, double y) {
//CHECK-NEXT: double z = 0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double t = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _cond1 = y < 0;
//CHECK-NEXT: if (_cond1) {
//CHECK-NEXT: z = t;
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: _d_t += 1;
Expand Down
4 changes: 4 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -915,9 +915,11 @@ double sq_defined_later(double x) {

// CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, double *_d_x, char *_d_c, char *_d_s) {
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: {
// CHECK-NEXT: _cond0 = c == 'a' && s[0] == 'a';
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: }
// CHECK-NEXT: goto _label1;
// CHECK-NEXT: _label1:
// CHECK-NEXT: ;
Expand Down Expand Up @@ -957,9 +959,11 @@ double sq_defined_later(double x) {

//CHECK: void recFun_pullback(double x, double y, double _d_y0, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: {
Expand Down
Loading

0 comments on commit 91f81d4

Please sign in to comment.