Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle variable declarations in conditions of if-statements (#865) #891

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 47 additions & 71 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,14 +800,42 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to this scope.
beginScope(Scope::DeclScope | Scope::ControlScope);

StmtDiff cond = Clone(If->getCond());
// Condition has to be stored as a "global" variable, to take the correct
// branch in the reverse pass.
// If we are inside loop, the condition has to be stored in a stack after
// the if statement.
Expr* PushCond = nullptr;
Expr* PopCond = nullptr;
auto condExpr = Visit(cond.getExpr());
// Create a block "around" if statement, e.g:
// {
// ...
// if (...) {...}
// }
beginBlock(direction::forward);
beginBlock(direction::reverse);
StmtDiff condDiff;
// if the statement has an init, we process it
if (If->hasInitStorage()) {
StmtDiff initDiff = Visit(If->getInit());
addToCurrentBlock(initDiff.getStmt(), direction::forward);
addToCurrentBlock(initDiff.getStmt_dx(), direction::reverse);
}
// this ensures we can differentiate conditions that affect the derivatives
// as well as declarations inside the condition:
beginBlock(direction::reverse);
if (const auto* condDeclStmt = If->getConditionVariableDeclStmt())
condDiff = Visit(condDeclStmt);
else
condDiff = Visit(If->getCond());
CompoundStmt* RCS = endBlock(direction::reverse);
if (!RCS->body_empty()) {
std::reverse(
RCS->body_begin(),
RCS->body_end()); // it is reversed in the endBlock() but we don't
// actually need this, so we reverse it once again
addToCurrentBlock(RCS, direction::reverse);
}

if (isInsideLoop) {
// If we are inside for loop, cond will be stored in the following way:
// forward:
Expand All @@ -820,62 +848,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// if (clad::push(..., _t) { ... }
// is incorrect when if contains return statement inside: return will
// skip corresponding push.
cond = StoreAndRef(condExpr.getExpr(), direction::forward, "_t",
/*forceDeclCreation=*/true);
StmtDiff condPushPop = GlobalStoreAndRef(cond.getExpr(), "_cond",
/*force=*/true);
condDiff = StoreAndRef(condDiff.getExpr(), m_Context.BoolTy,
direction::forward, "_t",
/*forceDeclCreation=*/true);
StmtDiff condPushPop =
GlobalStoreAndRef(condDiff.getExpr(), m_Context.BoolTy, "_cond",
/*force=*/true);
PushCond = condPushPop.getExpr();
PopCond = condPushPop.getExpr_dx();
} else
cond = GlobalStoreAndRef(condExpr.getExpr(), "_cond");
condDiff =
GlobalStoreAndRef(condDiff.getExpr(), m_Context.BoolTy, "_cond");
// Convert cond to boolean condition. We are modifying each Stmt in
// StmtDiff.
for (Stmt*& S : cond.getBothStmts())
for (Stmt*& S : condDiff.getBothStmts())
if (S)
S = m_Sema
.ActOnCondition(getCurrentScope(), noLoc, cast<Expr>(S),
Sema::ConditionKind::Boolean)
.get()
.second;

// Create a block "around" if statement, e.g:
// {
// ...
// if (...) {...}
// }
beginBlock(direction::forward);
beginBlock(direction::reverse);
const Stmt* init = If->getInit();
StmtDiff initResult = init ? Visit(init) : StmtDiff{};
// If there is Init, it's derivative will be output in the block before if:
// E.g., for:
// if (int x = 1; ...) {...}
// result will be:
// {
// int _d_x = 0;
// if (int x = 1; ...) {...}
// }
// This is done to avoid variable names clashes.
addToCurrentBlock(initResult.getStmt_dx());

VarDecl* condVarClone = nullptr;
if (const VarDecl* condVarDecl = If->getConditionVariable()) {
DeclDiff<VarDecl> condVarDeclDiff = DifferentiateVarDecl(condVarDecl);
condVarClone = condVarDeclDiff.getDecl();
if (condVarDeclDiff.getDecl_dx())
addToBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()), m_Globals);
}

// Condition is just cloned as it is, not derived.
// FIXME: if condition changes one of the variables, it may be reasonable
// to derive it, e.g.
// if (x += x) {...}
// should result in:
// {
// _d_y += _d_x
// if (y += x) {...}
// }

auto VisitBranch = [&](const Stmt* Branch) -> StmtDiff {
if (!Branch)
return {};
Expand All @@ -902,37 +895,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff thenDiff = VisitBranch(If->getThen());
StmtDiff elseDiff = VisitBranch(If->getElse());

// It is problematic to specify both condVarDecl and cond thorugh
// Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to delete this comment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

condVarDecl is always set to nullptr here now, because it doesn't change anything, so I'm not sure if this is relevant. we generally don't need condVarDecl in the generated statement, only cond

Copy link
Collaborator

@vaithak vaithak May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, would it make sense to use ActOnIfStmt now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's possible, but I'm not sure what benefit that would have.

Stmt* Forward = clad_compat::IfStmt_Create(m_Context,
noLoc,
If->isConstexpr(),
initResult.getStmt(),
condVarClone,
cond.getExpr(),
noLoc,
noLoc,
thenDiff.getStmt(),
noLoc,
elseDiff.getStmt());
Stmt* Forward = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), nullptr, nullptr,
condDiff.getExpr(), noLoc, noLoc, thenDiff.getStmt(), noLoc,
elseDiff.getStmt());
addToCurrentBlock(Forward, direction::forward);

Expr* reverseCond = cond.getExpr_dx();
Expr* reverseCond = condDiff.getExpr_dx();
if (isInsideLoop) {
addToCurrentBlock(PushCond, direction::forward);
reverseCond = PopCond;
}
Stmt* Reverse = clad_compat::IfStmt_Create(m_Context,
noLoc,
If->isConstexpr(),
initResult.getStmt_dx(),
condVarClone,
reverseCond,
noLoc,
noLoc,
thenDiff.getStmt_dx(),
noLoc,
elseDiff.getStmt_dx());
Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), nullptr, nullptr, reverseCond,
noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, elseDiff.getStmt_dx());
addToCurrentBlock(Reverse, direction::reverse);
CompoundStmt* ForwardBlock = endBlock(direction::forward);
CompoundStmt* ReverseBlock = endBlock(direction::reverse);
Expand Down
4 changes: 4 additions & 0 deletions test/ErrorEstimation/ConditonalStatements.C
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ float func(float x, float y) {
//CHECK-NEXT: float _t1;
//CHECK-NEXT: float _t2;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = y;
Expand All @@ -36,6 +37,7 @@ float func(float x, float y) {
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: _ret_value0 = x + y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
Expand Down Expand Up @@ -91,6 +93,7 @@ float func2(float x) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: float z = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = z > 9;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _ret_value0 = x + x;
Expand All @@ -99,6 +102,7 @@ float func2(float x) {
//CHECK-NEXT: _ret_value0 = x * x;
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
Expand Down
18 changes: 14 additions & 4 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ double f2(double x, double y) {
//CHECK: void f2_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _cond0 = x < y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < y;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: x = y;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
Expand Down Expand Up @@ -160,18 +162,22 @@ double f5(double x, double y) {
//CHECK-NEXT: double z = 0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double t = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _cond1 = y < 0;
//CHECK-NEXT: if (_cond1) {
//CHECK-NEXT: z = t;
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: _d_t += 1;
Expand Down Expand Up @@ -223,18 +229,22 @@ double f6(double x, double y) {
//CHECK-NEXT: double z = 0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double t = x * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x < 0;
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: _t0 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: _cond1 = y < 0;
//CHECK-NEXT: if (_cond1) {
//CHECK-NEXT: z = t;
//CHECK-NEXT: _t1 = t;
//CHECK-NEXT: t = -t;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: _d_t += 1;
Expand Down
4 changes: 4 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -915,9 +915,11 @@ double sq_defined_later(double x) {

// CHECK: void check_and_return_pullback(double x, char c, const char *s, double _d_y, double *_d_x, char *_d_c, char *_d_s) {
// CHECK-NEXT: bool _cond0;
// CHECK-NEXT: {
// CHECK-NEXT: _cond0 = c == 'a' && s[0] == 'a';
// CHECK-NEXT: if (_cond0)
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: }
// CHECK-NEXT: goto _label1;
// CHECK-NEXT: _label1:
// CHECK-NEXT: ;
Expand Down Expand Up @@ -957,9 +959,11 @@ double sq_defined_later(double x) {

//CHECK: void recFun_pullback(double x, double y, double _d_y0, double *_d_x, double *_d_y) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = x > y;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: }
//CHECK-NEXT: goto _label1;
//CHECK-NEXT: _label1:
//CHECK-NEXT: {
Expand Down
Loading
Loading