Skip to content

Commit

Permalink
Add more valueForRevPass.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Oct 30, 2023
1 parent 69e1425 commit c632e94
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2163,6 +2163,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff Ldiff{};
StmtDiff Rdiff{};
StmtDiff Lstored{};
Expr* valueForRevPass = nullptr;
auto L = BinOp->getLHS();
auto R = BinOp->getRHS();
// If it is an assignment operator, its result is a reference to LHS and
Expand Down Expand Up @@ -2250,8 +2251,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

if (!LDRE && !RDRE)
return Clone(BinOp);
Expr* LExpr = LDRE ? Visit(L).getExpr() : L;
Expr* RExpr = RDRE ? Visit(R).getExpr() : R;
Expr* LExpr = LDRE ? Visit(L).getRevSweepExpr() : L;
Expr* RExpr = RDRE ? Visit(R).getRevSweepExpr() : R;

return BuildOp(opCode, LExpr, RExpr);
}
Expand Down Expand Up @@ -2370,14 +2371,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

if (opCode == BO_Assign) {
Rdiff = Visit(R, oldValue);
valueForRevPass = Rdiff.getRevSweepExpr();
} else if (opCode == BO_AddAssign) {
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, oldValue);
valueForRevPass = BuildOp(BO_Add, Rdiff.getRevSweepExpr(), Ldiff.getRevSweepExpr());
} else if (opCode == BO_SubAssign) {
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue),
direction::reverse);
Rdiff = Visit(R, BuildOp(UO_Minus, oldValue));
valueForRevPass = BuildOp(BO_Sub, Rdiff.getRevSweepExpr(), Ldiff.getRevSweepExpr());
} else if (opCode == BO_MulAssign) {
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expand Down Expand Up @@ -2408,6 +2412,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
valueForRevPass = BuildOp(BO_Mul, Rdiff.getRevSweepExpr(), Ldiff.getRevSweepExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr());
} else if (opCode == BO_DivAssign) {
auto RDelayed = DelayedGlobalStoreAndRef(R);
Expand All @@ -2429,6 +2434,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
valueForRevPass = BuildOp(BO_Div, Rdiff.getRevSweepExpr(), Ldiff.getRevSweepExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr());
} else
llvm_unreachable("unknown assignment opCode");
Expand All @@ -2445,6 +2451,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
Ldiff = Visit(L, zero);
Rdiff = Visit(R, dfdx());
valueForRevPass = Ldiff.getRevSweepExpr();
ResultRef = Ldiff.getExpr();
} else {
// We should not output any warning on visiting boolean conditions
Expand Down

0 comments on commit c632e94

Please sign in to comment.