Skip to content

Commit

Permalink
Fix issue #865 and allow conditions in if statements to affect the de…
Browse files Browse the repository at this point in the history
…rivatives
  • Loading branch information
gojakuch committed May 14, 2024
1 parent 5fa1c55 commit b26af55
Showing 1 changed file with 42 additions and 64 deletions.
106 changes: 42 additions & 64 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,51 +800,63 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to this scope.
beginScope(Scope::DeclScope | Scope::ControlScope);

StmtDiff cond = Clone(If->getCond());
// Condition has to be stored as a "global" variable, to take the correct
// branch in the reverse pass.
// If we are inside loop, the condition has to be stored in a stack after
// the if statement.
Expr* PushCond = nullptr;
Expr* PopCond = nullptr;
auto condExpr = Visit(cond.getExpr());
// Create a block "around" if statement, e.g:
// {
// ...
// if (...) {...}
// }
beginBlock(direction::forward);
beginBlock(direction::reverse);
StmtDiff condDiff;
// this ensures we can differentiate conditions that affect the derivatives
// as well as declarations inside the condition:
beginBlock(direction::reverse);
if (const auto* condDeclStmt = If->getConditionVariableDeclStmt())
condDiff = Visit(condDeclStmt);
else
condDiff = Visit(If->getCond());
auto* RCS = endBlock(direction::reverse);
std::reverse(
RCS->body_begin(),
RCS->body_end()); // it is reversed in the endBlock() but we don't
// actually need this, so we reverse it once again
addToCurrentBlock(RCS, direction::reverse);

if (isInsideLoop) {
// If we are inside for loop, cond will be stored in the following way:
// forward:
// _t = cond;
// if (_t) { ... }
// clad::push(..., _t);
// reverse:
// If we are inside for loop, condDiff will be stored in the following
// way: forward: _t = cond; if (_t) { ... } clad::push(..., _t); reverse:
// if (clad::pop(...)) { ... }
// Simply doing
// if (clad::push(..., _t) { ... }
// is incorrect when if contains return statement inside: return will
// skip corresponding push.
cond = StoreAndRef(condExpr.getExpr(), direction::forward, "_t",
/*forceDeclCreation=*/true);
StmtDiff condPushPop = GlobalStoreAndRef(cond.getExpr(), "_cond",
/*force=*/true);
condDiff = StoreAndRef(condDiff.getExpr(), m_Context.BoolTy,
direction::forward, "_t",
/*forceDeclCreation=*/true);
StmtDiff condPushPop =
GlobalStoreAndRef(condDiff.getExpr(), m_Context.BoolTy, "_cond",
/*force=*/true);
PushCond = condPushPop.getExpr();
PopCond = condPushPop.getExpr_dx();
} else
cond = GlobalStoreAndRef(condExpr.getExpr(), "_cond");
condDiff =
GlobalStoreAndRef(condDiff.getExpr(), m_Context.BoolTy, "_cond");
// Convert cond to boolean condition. We are modifying each Stmt in
// StmtDiff.
for (Stmt*& S : cond.getBothStmts())
for (Stmt*& S : condDiff.getBothStmts())
if (S)
S = m_Sema
.ActOnCondition(getCurrentScope(), noLoc, cast<Expr>(S),
Sema::ConditionKind::Boolean)
.get()
.second;

// Create a block "around" if statement, e.g:
// {
// ...
// if (...) {...}
// }
beginBlock(direction::forward);
beginBlock(direction::reverse);
const Stmt* init = If->getInit();
StmtDiff initResult = init ? Visit(init) : StmtDiff{};
// If there is Init, it's derivative will be output in the block before if:
Expand All @@ -858,24 +870,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// This is done to avoid variable names clashes.
addToCurrentBlock(initResult.getStmt_dx());

VarDecl* condVarClone = nullptr;
if (const VarDecl* condVarDecl = If->getConditionVariable()) {
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
if (condVarDeclDiff.getDecl_dx())
addToBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()), m_Globals);
}

// Condition is just cloned as it is, not derived.
// FIXME: if condition changes one of the variables, it may be reasonable
// to derive it, e.g.
// if (x += x) {...}
// should result in:
// {
// _d_y += _d_x
// if (y += x) {...}
// }

auto VisitBranch = [&](const Stmt* Branch) -> StmtDiff {
if (!Branch)
return {};
Expand All @@ -902,37 +896,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff thenDiff = VisitBranch(If->getThen());
StmtDiff elseDiff = VisitBranch(If->getElse());

// It is problematic to specify both condVarDecl and cond thorugh
// Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor.
Stmt* Forward = clad_compat::IfStmt_Create(m_Context,
noLoc,
If->isConstexpr(),
initResult.getStmt(),
condVarClone,
cond.getExpr(),
noLoc,
noLoc,
thenDiff.getStmt(),
noLoc,
elseDiff.getStmt());
Stmt* Forward = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt(), nullptr,
condDiff.getExpr(), noLoc, noLoc, thenDiff.getStmt(), noLoc,
elseDiff.getStmt());
addToCurrentBlock(Forward, direction::forward);

Expr* reverseCond = cond.getExpr_dx();
Expr* reverseCond = condDiff.getExpr_dx();
if (isInsideLoop) {
addToCurrentBlock(PushCond, direction::forward);
reverseCond = PopCond;
}
Stmt* Reverse = clad_compat::IfStmt_Create(m_Context,
noLoc,
If->isConstexpr(),
initResult.getStmt_dx(),
condVarClone,
reverseCond,
noLoc,
noLoc,
thenDiff.getStmt_dx(),
noLoc,
elseDiff.getStmt_dx());
Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), nullptr,
reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc,
elseDiff.getStmt_dx());
addToCurrentBlock(Reverse, direction::reverse);
CompoundStmt* ForwardBlock = endBlock(direction::forward);
CompoundStmt* ReverseBlock = endBlock(direction::reverse);
Expand Down

0 comments on commit b26af55

Please sign in to comment.