Skip to content

Commit

Permalink
Handle variable declarations in conditions of if-statements
Browse files Browse the repository at this point in the history
This patch handles variable declarations in conditions of if-statements in reverse mode differentiation. This is done by actually visiting the conditions instead of just cloning them inside the VisitIfStmt and by choosing the condition variable DeclStmt over cond when possible. This also helps to handle conditions that may affect derivatives as a side-effect. Also a couple of blocks have been added to wrap some statements and preserve the correct logic in the constructed derivatives and some unused code has been removed from the method.

Fixes: vgvassilev#865
  • Loading branch information
gojakuch authored and vgvassilev committed May 17, 2024
1 parent 5bcb602 commit 8a9f873
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 91 deletions.
118 changes: 47 additions & 71 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,14 +800,42 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to this scope.
beginScope(Scope::DeclScope | Scope::ControlScope);

StmtDiff cond = Clone(If->getCond());
// Condition has to be stored as a "global" variable, to take the correct
// branch in the reverse pass.
// If we are inside loop, the condition has to be stored in a stack after
// the if statement.
Expr* PushCond = nullptr;
Expr* PopCond = nullptr;
auto condExpr = Visit(cond.getExpr());
// Create a block "around" if statement, e.g:
// {
// ...
// if (...) {...}
// }
beginBlock(direction::forward);
beginBlock(direction::reverse);
StmtDiff condDiff;
// if the statement has an init, we process it
if (If->hasInitStorage()) {
StmtDiff initDiff = Visit(If->getInit());
addToCurrentBlock(initDiff.getStmt(), direction::forward);
addToCurrentBlock(initDiff.getStmt_dx(), direction::reverse);
}
// this ensures we can differentiate conditions that affect the derivatives
// as well as declarations inside the condition:
beginBlock(direction::reverse);
if (const auto* condDeclStmt = If->getConditionVariableDeclStmt())
condDiff = Visit(condDeclStmt);
else
condDiff = Visit(If->getCond());
CompoundStmt* RCS = endBlock(direction::reverse);
if (!RCS->body_empty()) {
std::reverse(
RCS->body_begin(),
RCS->body_end()); // it is reversed in the endBlock() but we don't
// actually need this, so we reverse it once again
addToCurrentBlock(RCS, direction::reverse);
}

if (isInsideLoop) {
// If we are inside for loop, cond will be stored in the following way:
// forward:
Expand All @@ -820,62 +848,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// if (clad::push(..., _t) { ... }
// is incorrect when if contains return statement inside: return will
// skip corresponding push.
cond = StoreAndRef(condExpr.getExpr(), direction::forward, "_t",
/*forceDeclCreation=*/true);
StmtDiff condPushPop = GlobalStoreAndRef(cond.getExpr(), "_cond",
/*force=*/true);
condDiff = StoreAndRef(condDiff.getExpr(), m_Context.BoolTy,
direction::forward, "_t",
/*forceDeclCreation=*/true);
StmtDiff condPushPop =
GlobalStoreAndRef(condDiff.getExpr(), m_Context.BoolTy, "_cond",
/*force=*/true);
PushCond = condPushPop.getExpr();
PopCond = condPushPop.getExpr_dx();
} else
cond = GlobalStoreAndRef(condExpr.getExpr(), "_cond");
condDiff =
GlobalStoreAndRef(condDiff.getExpr(), m_Context.BoolTy, "_cond");
// Convert cond to boolean condition. We are modifying each Stmt in
// StmtDiff.
for (Stmt*& S : cond.getBothStmts())
for (Stmt*& S : condDiff.getBothStmts())
if (S)
S = m_Sema
.ActOnCondition(getCurrentScope(), noLoc, cast<Expr>(S),
Sema::ConditionKind::Boolean)
.get()
.second;

// Create a block "around" if statement, e.g:
// {
// ...
// if (...) {...}
// }
beginBlock(direction::forward);
beginBlock(direction::reverse);
const Stmt* init = If->getInit();
StmtDiff initResult = init ? Visit(init) : StmtDiff{};
// If there is Init, it's derivative will be output in the block before if:
// E.g., for:
// if (int x = 1; ...) {...}
// result will be:
// {
// int _d_x = 0;
// if (int x = 1; ...) {...}
// }
// This is done to avoid variable names clashes.
addToCurrentBlock(initResult.getStmt_dx());

VarDecl* condVarClone = nullptr;
if (const VarDecl* condVarDecl = If->getConditionVariable()) {
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
if (condVarDeclDiff.getDecl_dx())
addToBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()), m_Globals);
}

// Condition is just cloned as it is, not derived.
// FIXME: if condition changes one of the variables, it may be reasonable
// to derive it, e.g.
// if (x += x) {...}
// should result in:
// {
// _d_y += _d_x
// if (y += x) {...}
// }

auto VisitBranch = [&](const Stmt* Branch) -> StmtDiff {
if (!Branch)
return {};
Expand All @@ -902,37 +895,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff thenDiff = VisitBranch(If->getThen());
StmtDiff elseDiff = VisitBranch(If->getElse());

// It is problematic to specify both condVarDecl and cond thorugh
// Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor.
Stmt* Forward = clad_compat::IfStmt_Create(m_Context,
noLoc,
If->isConstexpr(),
initResult.getStmt(),
condVarClone,
cond.getExpr(),
noLoc,
noLoc,
thenDiff.getStmt(),
noLoc,
elseDiff.getStmt());
Stmt* Forward = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), nullptr, nullptr,
condDiff.getExpr(), noLoc, noLoc, thenDiff.getStmt(), noLoc,
elseDiff.getStmt());
addToCurrentBlock(Forward, direction::forward);

Expr* reverseCond = cond.getExpr_dx();
Expr* reverseCond = condDiff.getExpr_dx();
if (isInsideLoop) {
addToCurrentBlock(PushCond, direction::forward);
reverseCond = PopCond;
}
Stmt* Reverse = clad_compat::IfStmt_Create(m_Context,
noLoc,
If->isConstexpr(),
initResult.getStmt_dx(),
condVarClone,
reverseCond,
noLoc,
noLoc,
thenDiff.getStmt_dx(),
noLoc,
elseDiff.getStmt_dx());
Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), nullptr, nullptr, reverseCond,
noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, elseDiff.getStmt_dx());
addToCurrentBlock(Reverse, direction::reverse);
CompoundStmt* ForwardBlock = endBlock(direction::forward);
CompoundStmt* ReverseBlock = endBlock(direction::reverse);
Expand Down
4 changes: 4 additions & 0 deletions test/ErrorEstimation/ConditonalStatements.C
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ float func(float x, float y) {
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _t2;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = y;
Expand All @@ -36,6 +37,7 @@ float func(float x, float y) {
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: _ret_value0 = x + y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
Expand Down Expand Up @@ -91,6 +93,7 @@ float func2(float x) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: float z = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = z > 9;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _ret_value0 = x + x;
Expand All @@ -99,6 +102,7 @@ float func2(float x) {
//CHECK-NEXT: _ret_value0 = x * x;
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
Expand Down
18 changes: 14 additions & 4 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ double f2(double x, double y) {
//CHECK: void f2_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _cond0 = x < y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
Expand Down Expand Up @@ -160,18 +162,22 @@ double f5(double x, double y) {
//CHECK-NEXT: double z = 0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double t = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _cond1 = y < 0;
//CHECK-NEXT: if (_cond1) {
//CHECK-NEXT: z = t;
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: _d_t += 1;
Expand Down Expand Up @@ -223,18 +229,22 @@ double f6(double x, double y) {
//CHECK-NEXT: double z = 0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double t = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _cond1 = y < 0;
//CHECK-NEXT: if (_cond1) {
//CHECK-NEXT: z = t;
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: _d_t += 1;
Expand Down
4 changes: 4 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -915,9 +915,11 @@ double sq_defined_later(double x) {

// CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, double *_d_x, char *_d_c, char *_d_s) {
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: {
// CHECK-NEXT: _cond0 = c == 'a' && s[0] == 'a';
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: }
// CHECK-NEXT: goto _label1;
// CHECK-NEXT: _label1:
// CHECK-NEXT: ;
Expand Down Expand Up @@ -957,9 +959,11 @@ double sq_defined_later(double x) {

//CHECK: void recFun_pullback(double x, double y, double _d_y0, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: {
Expand Down
Loading

0 comments on commit 8a9f873

Please sign in to comment.