Skip to content

Commit

Permalink
Remove cloning from VisitBinaryOperator by introducing placeholders i…
Browse files Browse the repository at this point in the history
…n DelayedStoreAndRef.
  • Loading branch information
PetroZarytskyi committed Dec 29, 2023
1 parent b9c1ace commit e623842
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 40 deletions.
28 changes: 21 additions & 7 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ namespace clad {
std::vector<Stmts> m_Reverse;
/// Accumulates local variables for all visited blocks.
std::vector<std::map<clang::VarDecl*, clang::VarDecl*>> m_Locals;
/// Stores all expressions used as placeholders which have to be
/// reset later.
std::set<const clang::Expr*> m_Placeholders;
/// Stack is used to pass the arguments (dfdx) to further nodes
/// in the Visit method.
std::stack<clang::Expr*> m_Stack;
Expand Down Expand Up @@ -180,6 +183,16 @@ namespace clad {
return blk;
}

clang::Expr* Clone(const clang::Expr* E) {
if (m_Placeholders.find(E)!=m_Placeholders.end())
return const_cast<clang::Expr*>(E);
return VisitorBase::Clone(E);
}

clang::Stmt* Clone(const clang::Stmt* E) {
return VisitorBase::Clone(E);
}

/// Add more comments.
void EmitRevSweepDecls();

Expand Down Expand Up @@ -282,15 +295,15 @@ namespace clad {
struct DelayedStoreResult {
ReverseModeVisitor& V;
StmtDiff Result;
bool isConstant;
bool isInsideLoop;
bool needsUpdate;
clang::Expr* Placeholder;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
bool pIsConstant, bool pIsInsideLoop,
bool pNeedsUpdate = false)
: V(pV), Result(pResult), isConstant(pIsConstant),
isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {}
void Finalize(clang::Expr* New);
bool pIsInsideLoop, bool pNeedsUpdate = false,
clang::Expr* pPlaceholder = nullptr)
: V(pV), Result(pResult), isInsideLoop(pIsInsideLoop),
needsUpdate(pNeedsUpdate), Placeholder(pPlaceholder) {}
void Finalize(StmtDiff New);
};

/// Sometimes (e.g. when visiting multiplication/division operator), we
Expand All @@ -302,7 +315,8 @@ namespace clad {
/// This is what DelayedGlobalStoreAndRef does. E is expected to be the
/// original (uncloned) expression.
DelayedStoreResult DelayedGlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t");
llvm::StringRef prefix = "_t",
bool forceNoRecompute = false);

struct CladTapeResult {
ReverseModeVisitor& V;
Expand Down
95 changes: 62 additions & 33 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2204,17 +2204,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to reduce cloning complexity and only clones once. Storing it in a
// global variable allows to save current result and make it accessible
// in the reverse pass.
std::unique_ptr<DelayedStoreResult> RDelayed;
StmtDiff RResult;
// If R has no side effects, it can be just cloned
// (no need to store it).
if (!ShouldRecompute(R)) {
RDelayed = std::unique_ptr<DelayedStoreResult>(
new DelayedStoreResult(DelayedGlobalStoreAndRef(R)));
RResult = RDelayed->Result;
} else {
RResult = StmtDiff(Clone(R));
}
DelayedStoreResult RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff& RResult = RDelayed.Result;

Expr* dl = nullptr;
if (dfdx())
Expand All @@ -2228,24 +2219,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
LStored = GlobalStoreAndRef(LStored.getExpr(), /*prefix=*/"_t",
/*force=*/true);
Expr::EvalResult dummy;
if (RDelayed ||
!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
Expr* dr = nullptr;
if (dfdx())
dr = BuildOp(BO_Mul, LStored.getRevSweepAsExpr(), dfdx());
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
if (RDelayed)
RDelayed->Finalize(Rdiff.getExpr());
RDelayed.Finalize(Rdiff);
}
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
} else if (opCode == BO_Div) {
// xi = xl / xr
// dxi/xl = 1 / xr
// df/dxl += df/dxi * dxi/xl = df/dxi * (1/xr)
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t", /*forceNoRecompute=*/false);
StmtDiff& RResult = RDelayed.Result;
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
Expr* dl = nullptr;
Expand All @@ -2260,7 +2249,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr))
// Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is
// produced instead of 1 / (R * R).
if (!RDelayed.isConstant) {
Expr::EvalResult dummy;
if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
Expr* dr = nullptr;
if (dfdx()) {
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
Expand All @@ -2271,7 +2261,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
dr = StoreAndRef(dr, direction::reverse);
}
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
RDelayed.Finalize(Rdiff);
}
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
Expand Down Expand Up @@ -2445,22 +2435,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Ldiff.getRevSweepAsExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, Rdiff.getExpr());
} else if (opCode == BO_DivAssign) {
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t", /*forceNoRecompute=*/true);
StmtDiff& RResult = RDelayed.Result;
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff,
BuildOp(BO_Div, oldValue, RStored)),
direction::reverse);
if (!RDelayed.isConstant) {
Expr::EvalResult dummy;
if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
if (isInsideLoop)
addToCurrentBlock(LCloned, direction::forward);
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
Expr* dr = BuildOp(BO_Mul, oldValue,
BuildOp(UO_Minus, BuildOp(BO_Div, LCloned, RxR)));
dr = StoreAndRef(dr, direction::reverse);
Rdiff = Visit(R, dr);
RDelayed.Finalize(Rdiff.getExpr());
RDelayed.Finalize(Rdiff);
}
valueForRevPass = BuildOp(BO_Div, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
Expand Down Expand Up @@ -2967,47 +2958,85 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return {};
}

void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) {
if (isConstant || !needsUpdate)
void ReverseModeVisitor::DelayedStoreResult::Finalize(StmtDiff New) {
class PlaceholderReplacer : public RecursiveASTVisitor<PlaceholderReplacer> {
public:
const Expr* placeholder;
Expr* newExpr;
PlaceholderReplacer(const Expr* Placeholder)
: placeholder(Placeholder), newExpr(nullptr) {}

bool VisitExpr(Expr *E) {
for (auto iter = E->child_begin(), e = E->child_end();
iter!=e; ++iter) {
if (*iter == placeholder)
*iter = newExpr;
else
TraverseStmt(*iter);
}
return true;
}
};

if (!needsUpdate)
return;

if (Placeholder!=nullptr) {
PlaceholderReplacer repl(Placeholder);
repl.newExpr = New.getExpr();
for (Stmt* S : V.getCurrentBlock(direction::forward))
repl.TraverseStmt(S);
repl.newExpr = New.getRevSweepAsExpr();
for (Stmt* S : V.getCurrentBlock(direction::reverse))
repl.TraverseStmt(S);
Result = New;
V.m_Placeholders.erase(Placeholder);
return;
}

if (isInsideLoop) {
auto* Push = cast<CallExpr>(Result.getExpr());
unsigned lastArg = Push->getNumArgs() - 1;
Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New).get());
Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New.getExpr()).get());
} else {
V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New),
V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New.getExpr()),
direction::forward);
}
}

ReverseModeVisitor::DelayedStoreResult
ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E,
llvm::StringRef prefix) {
llvm::StringRef prefix,
bool forceNoRecompute) {
assert(E && "must be provided");
if (!UsefulToStore(E)) {
StmtDiff Ediff = Visit(E);
Expr::EvalResult evalRes;
bool isConst =
clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context);
return DelayedStoreResult{*this, Ediff,
/*isConstant*/ isConst,
/*isInsideLoop*/ false,
/*pNeedsUpdate=*/false};
}
if (!forceNoRecompute && ShouldRecompute(E)) {
Expr* PH
= ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0);
m_Placeholders.insert(PH);
return DelayedStoreResult{*this, StmtDiff{PH, nullptr, nullptr, PH},
/*isInsideLoop*/ false,
/*pNeedsUpdate=*/true,
/*pPlaceholder=*/PH};
}
if (isInsideLoop) {
Expr* dummy = E;
auto CladTape = MakeCladTapeFor(dummy);
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
return DelayedStoreResult{*this, StmtDiff{Push, nullptr, nullptr, Pop},
/*isConstant*/ false,
/*isInsideLoop*/ true, /*pNeedsUpdate=*/true};
}
Expr* Ref = BuildDeclRef(GlobalStoreImpl(
getNonConstType(E->getType(), m_Context, m_Sema), prefix));
// Return reference to the declaration instead of original expression.
return DelayedStoreResult{*this, StmtDiff{Ref, nullptr, nullptr, Ref},
/*isConstant*/ false,
/*isInsideLoop*/ false, /*pNeedsUpdate=*/true};
}

Expand Down

0 comments on commit e623842

Please sign in to comment.