Skip to content

Commit

Permalink
Remove excessive stores for increments.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Oct 31, 2023
1 parent 3ac77d2 commit 7f5d91f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 90 deletions.
177 changes: 91 additions & 86 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,8 +1059,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VarDecl* condVarClone = nullptr;
if (FS->getConditionVariable()) {
condVarRes = DifferentiateSingleStmt(FS->getConditionVariableDeclStmt());
Decl* decl = cast<DeclStmt>(condVarRes.getStmt())->getSingleDecl();
condVarClone = cast<VarDecl>(decl);
if (isa<DeclStmt>(condVarRes.getStmt())) {
Decl *decl = cast<DeclStmt>(condVarRes.getStmt())->getSingleDecl();
condVarClone = cast<VarDecl>(decl);
}
}

// FIXME: for now we assume that cond has no differentiable effects,
Expand Down Expand Up @@ -1111,9 +1113,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
incDiff.getStmt_dx(),
/*isForLoop=*/true);

/// FIXME: This part in necessary to replace local variables inside loops
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from overwriting
/// local variables on different loop passes.
Expr* forwardCond = cond.getExpr();
if (condVarRes.getExpr() && isa<Expr>(condVarRes.getExpr())) {
forwardCond = BuildOp(BO_Comma, cond.getExpr(), cast<Expr>(condVarRes.getExpr()));
}

Stmt* Forward = new (m_Context) ForStmt(m_Context,
initResult.getStmt(),
cond.getExpr(),
forwardCond,
condVarClone,
incResult,
BodyDiff.getStmt(),
Expand Down Expand Up @@ -1208,7 +1219,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff ReverseModeVisitor::VisitParenExpr(const ParenExpr* PE) {
StmtDiff subStmtDiff = Visit(PE->getSubExpr(), dfdx());
return StmtDiff(BuildParens(subStmtDiff.getExpr()),
BuildParens(subStmtDiff.getExpr_dx()));
BuildParens(subStmtDiff.getExpr_dx()),
nullptr,
BuildParens(subStmtDiff.getRevSweepExpr()));
}

StmtDiff ReverseModeVisitor::VisitInitListExpr(const InitListExpr* ILE) {
Expand Down Expand Up @@ -1779,6 +1792,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
idx++;
}
Expr* pullback = dfdx();

if ((pullback == nullptr) && FD->getReturnType()->isLValueReferenceType())
pullback = getZeroInit(FD->getReturnType().getNonReferenceType());

Expand Down Expand Up @@ -2052,6 +2066,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
auto opCode = UnOp->getOpcode();
Expr* valueForRevPass = nullptr;
StmtDiff diff{};
Expr* E = UnOp->getSubExpr();
// If it is a post-increment/decrement operator, its result is a reference
Expand All @@ -2070,28 +2085,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
diff = Visit(E, d);
} else if (opCode == UO_PostInc || opCode == UO_PostDec) {
diff = Visit(E, dfdx());
auto EStored = GlobalStoreAndRef(diff.getExpr());
if (EStored.getExpr() != diff.getExpr()) {
auto* assign = BuildOp(BinaryOperatorKind::BO_Assign,
Clone(diff.getExpr()), EStored.getExpr_dx());
if (isInsideLoop)
addToCurrentBlock(EStored.getExpr(), direction::forward);
addToCurrentBlock(assign, direction::reverse);
if (UsefulToStoreGlobal(diff.getRevSweepExpr())) {
auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepExpr())), direction::reverse);
}

ResultRef = diff.getExpr_dx();
valueForRevPass = diff.getRevSweepExpr();
if (m_ExternalSource)
m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff);
} else if (opCode == UO_PreInc || opCode == UO_PreDec) {
diff = Visit(E, dfdx());
auto EStored = GlobalStoreAndRef(diff.getExpr());
if (EStored.getExpr() != diff.getExpr()) {
auto* assign = BuildOp(BinaryOperatorKind::BO_Assign,
Clone(diff.getExpr()), EStored.getExpr_dx());
if (isInsideLoop)
addToCurrentBlock(EStored.getExpr(), direction::forward);
addToCurrentBlock(assign, direction::reverse);
if (UsefulToStoreGlobal(diff.getRevSweepExpr())) {
auto op = opCode == UO_PreInc ? UO_PreDec : UO_PreInc;
addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepExpr())), direction::reverse);
}
auto op = opCode == UO_PreInc ? BinaryOperatorKind::BO_Add : BinaryOperatorKind::BO_Sub;
auto sum = BuildOp(op, diff.getRevSweepExpr(), ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1));
valueForRevPass = utils::BuildParenExpr(m_Sema, sum);
} else if (opCode == UnaryOperatorKind::UO_Real ||
opCode == UnaryOperatorKind::UO_Imag) {
diff = VisitWithExplicitNoDfDx(E);
Expand Down Expand Up @@ -2139,7 +2150,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
diff = StmtDiff(E);
}
Expr* op = BuildOp(opCode, diff.getExpr());
return StmtDiff(op, ResultRef);
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
}

StmtDiff
Expand Down Expand Up @@ -2197,11 +2208,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
if (!RDelayed.isConstant) {
Expr* dr = nullptr;
if (dfdx()) {
StmtDiff LResult;
if (isa<DeclRefExpr>(LStored->IgnoreImpCasts()))
LResult = {LStored, LStored};
else
LResult = GlobalStoreAndRef(LStored, "_t", /*force=*/true);
StmtDiff LResult = GlobalStoreAndRef(LStored);
LStored = LResult.getExpr();
dr = BuildOp(BO_Mul, LResult.getExpr_dx(), dfdx());
dr = StoreAndRef(dr, direction::reverse);
Expand All @@ -2228,27 +2235,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr))
// Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is
// produced instead of 1 / (R * R).
Expr* LStored = Ldiff.getExpr();
if (!RDelayed.isConstant) {
Expr* dr = nullptr;
StmtDiff LResult;
if (dfdx()) {
if (isa<DeclRefExpr>(LStored->IgnoreParenImpCasts()))
LResult = {LStored, LStored};
else
LResult = GlobalStoreAndRef(LStored, "_t", /*force=*/true);
LStored = LResult.getExpr();
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
dr = BuildOp(BO_Mul,
dfdx(),
BuildOp(UO_Minus,
BuildOp(BO_Div, LResult.getExpr_dx(), RxR)));
BuildOp(BO_Div, Ldiff.getRevSweepExpr(), RxR)));
dr = StoreAndRef(dr, direction::reverse);
}
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(Ldiff.getRevSweepExpr(), RResult.getRevSweepExpr());
} else if (BinOp->isAssignmentOp()) {
if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning,
Expand Down Expand Up @@ -2333,15 +2333,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Ldiff.updateStmt(storeE);
}

// if (/*!L->HasSideEffects(m_Context)*/true) {
// Lstored = GlobalStoreAndRef(Ldiff.getExpr(), "_t", /*force*/true);
// auto assign = BuildOp(BO_Assign, Ldiff.getExpr(), Lstored.getExpr_dx());
// if (isInsideLoop) {
// addToCurrentBlock(Lstored.getExpr(), direction::forward);
// }
// addToCurrentBlock(assign, direction::reverse);
// }

Expr* LCloned = Ldiff.getExpr();
// For x, AssignedDiff is _d_x, for x[i] its _d_x[i], for reference exprs
// like (x = y) it propagates recursively, so _d_x is also returned.
Expand All @@ -2364,10 +2355,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If assigned expr is dependent, first update its derivative;
auto Lblock_begin = Lblock->body_rbegin();
auto Lblock_end = Lblock->body_rend();
// if (Lblock->size()) {
// addToCurrentBlock(*Lblock_begin, direction::reverse);
// Lblock_begin = std::next(Lblock_begin);
// }

for (auto S : essentialRevBlock)
addToCurrentBlock(S, direction::reverse);
Expand All @@ -2378,14 +2365,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

for (auto E : return_exprs) {
Lstored = GlobalStoreAndRef(E);
if (Lstored.getExpr() != E) {
auto* assign =
BuildOp(BinaryOperatorKind::BO_Assign, E, Lstored.getExpr_dx());
if (isInsideLoop)
addToCurrentBlock(Lstored.getExpr(), direction::forward);
addToCurrentBlock(assign, direction::reverse);
}
auto pushPop = StoreAndRestore(E);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
}

if (m_ExternalSource)
Expand Down Expand Up @@ -2414,7 +2396,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
AssignedDiff,
BuildOp(BO_Mul, oldValue, RResult.getExpr_dx())),
direction::reverse);
Expr* LRef = LCloned;
if (!RDelayed.isConstant) {
// Create a reference variable to keep the result of LHS, since it
// must be used on 2 places: when storing to a global variable
Expand All @@ -2429,24 +2410,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// double & _ref0 = (x *= y);
// _t0 = _ref0;
// double r = _ref0 *= z;
StmtDiff LResult;
if (LCloned->HasSideEffects(m_Context)) {
auto RefType = getNonConstType(L->getType(), m_Context, m_Sema);
// RefType = m_Context.getLValueReferenceType(RefType);
LRef = StoreAndRef(LCloned, RefType, direction::forward, "_ref",
/*forceDeclCreation=*/true);
LResult = GlobalStoreAndRef(LRef, "_t", /*force=*/true);
} else
LResult = {LRef, LRef};

if (isInsideLoop)
addToCurrentBlock(LResult.getExpr(), direction::forward);
Expr* dr = BuildOp(BO_Mul, LResult.getExpr_dx(), oldValue);
addToCurrentBlock(LCloned, direction::forward);
Expr* dr = BuildOp(BO_Mul, LCloned, oldValue);
dr = StoreAndRef(dr, direction::reverse);
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) = std::make_pair(LRef, RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr());
} else if (opCode == BO_DivAssign) {
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expand All @@ -2455,29 +2427,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
AssignedDiff,
BuildOp(BO_Div, oldValue, RStored)),
direction::reverse);
Expr* LRef = LCloned;
if (!RDelayed.isConstant) {
StmtDiff LResult;
if (LCloned->HasSideEffects(m_Context)) {
QualType RefType = m_Context.getLValueReferenceType(
getNonConstType(L->getType(), m_Context, m_Sema));
LRef = StoreAndRef(LCloned, RefType, direction::forward, "_ref",
/*forceDeclCreation=*/true);
LResult = GlobalStoreAndRef(LRef, "_t", /*force=*/true);
} else
LResult = {LRef, LRef};
if (isInsideLoop)
addToCurrentBlock(LResult.getExpr(), direction::forward);
addToCurrentBlock(LCloned, direction::forward);
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
Expr* dr = BuildOp(
BO_Mul,
oldValue,
BuildOp(UO_Minus, BuildOp(BO_Div, LResult.getExpr_dx(), RxR)));
BuildOp(UO_Minus, BuildOp(BO_Div, LCloned, RxR)));
dr = StoreAndRef(dr, direction::reverse);
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
}
std::tie(Ldiff, Rdiff) = std::make_pair(LRef, RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr());
} else
llvm_unreachable("unknown assignment opCode");
if (m_ExternalSource)
Expand Down Expand Up @@ -2515,7 +2477,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return BuildOp(opCode, LExpr, RExpr);
}
Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr());
return StmtDiff(op, ResultRef);
return StmtDiff(op, ResultRef, nullptr, valueForRevPass);
}

VarDeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
Expand Down Expand Up @@ -2707,6 +2669,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitDeclStmt(const DeclStmt* DS) {
llvm::SmallVector<Stmt*, 16> inits;
llvm::SmallVector<Decl*, 4> decls;
llvm::SmallVector<Decl*, 4> declsDiff;
// Need to put array decls inlined.
Expand Down Expand Up @@ -2740,6 +2703,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// }
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName())
m_DeclReplacements[VD] = VDDiff.getDecl();

/// FIXME: This part in necessary to replace local variables inside loops
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from overwriting
/// local variables on different loop passes.
if (isInsideLoop) {
if (VD->getType()->isBuiltinType() && !VD->getType().isConstQualified()) {
auto *decl = VDDiff.getDecl();
if (decl->getInit()) {
auto *declRef = BuildDeclRef(decl);
auto pushPop = StoreAndRestore(declRef, /*prefix=*/"_t", /*force=*/true);
if (pushPop.getExpr() != declRef) {
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
}
auto *assignment = BuildOp(BO_Assign, declRef, decl->getInit());
inits.push_back(BuildOp(BO_Comma, pushPop.getExpr(), assignment));
}
decl->setInit(getZeroInit(VD->getType()));
}
}

decls.push_back(VDDiff.getDecl());
if (isa<VariableArrayType>(VD->getType()))
localDeclsDiff.push_back(VDDiff.getDecl_dx());
Expand Down Expand Up @@ -2769,16 +2753,30 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
declsDiff.append(localDeclsDiff.begin(), localDeclsDiff.end());
m_ExternalSource->ActBeforeFinalizingVisitDeclStmt(decls, declsDiff);
}

/// FIXME: This part in necessary to replace local variables inside loops
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from overwriting
/// local variables on different loop passes.
if (isInsideLoop) {
if (auto *VD = dyn_cast<VarDecl>(decls[0])) {
if (VD->getType()->isBuiltinType() && !VD->getType().isConstQualified()) {
addToBlock(DSClone, m_Globals);
Stmt *initAssignments = MakeCompoundStmt(inits);
initAssignments = unwrapIfSingleStmt(initAssignments);
return StmtDiff(initAssignments);
}
}
}

return StmtDiff(DSClone);
}

StmtDiff
ReverseModeVisitor::VisitImplicitCastExpr(const ImplicitCastExpr* ICE) {
StmtDiff subExprDiff = Visit(ICE->getSubExpr(), dfdx());
// Casts should be handled automatically when the result is used by
// Sema::ActOn.../Build...
return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx(),
subExprDiff.getForwSweepStmt_dx());
return Visit(ICE->getSubExpr(), dfdx());
}

StmtDiff ReverseModeVisitor::VisitMemberExpr(const MemberExpr* ME) {
Expand Down Expand Up @@ -3027,10 +3025,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// statement.
Sema::ConditionResult condResult;
if (condVarDecl) {
Decl* condVarClone = cast<DeclStmt>(condVarRes.getStmt())
if (condVarRes.getStmt()) {
if (isa<DeclStmt>(condVarRes.getStmt())) {
Decl *condVarClone = cast<DeclStmt>(condVarRes.getStmt())
->getSingleDecl();
condResult = m_Sema.ActOnConditionVariable(condVarClone, noLoc,
condResult = m_Sema.ActOnConditionVariable(condVarClone, noLoc,
Sema::ConditionKind::Boolean);
} else {
condResult = m_Sema.ActOnCondition(getCurrentScope(), noLoc, cast<Expr>(condVarRes.getStmt()),
Sema::ConditionKind::Boolean);
}
}
} else {
condResult = m_Sema.ActOnCondition(getCurrentScope(), noLoc, condClone,
Sema::ConditionKind::Boolean);
Expand Down
4 changes: 0 additions & 4 deletions lib/Differentiator/TBRAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,6 @@ void TBRAnalyzer::VisitUnaryOperator(const clang::UnaryOperator* UnOp) {
/// Mark corresponding SourceLocation as required/not required to be
/// stored for all expressions that could be changed.
markLocation(innerExpr);
/// Set them to not required to store because the values were changed.
/// (if some value was not changed, this could only happen if it was
/// already not required to store).
setIsRequired(innerExpr, /*isReq=*/false);
}
}
}
Expand Down

0 comments on commit 7f5d91f

Please sign in to comment.