Skip to content

Commit

Permalink
Replace getExpr with getRevSweepAsExpr where necessary.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 29, 2023
1 parent 93d4773 commit b9c1ace
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,15 +1295,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVector<Expr*, 4> reverseIndices(Indices.size());
llvm::SmallVector<Expr*, 4> forwSweepDerivativeIndices(Indices.size());
for (std::size_t i = 0; i < Indices.size(); i++) {
// FIXME: Remove redundant indices vectors.
StmtDiff IdxDiff = Visit(Indices[i]);
clonedIndices[i] = Clone(IdxDiff.getExpr());
reverseIndices[i] = Clone(IdxDiff.getExpr());
forwSweepDerivativeIndices[i] = IdxDiff.getExpr();
clonedIndices[i] = IdxDiff.getExpr();
reverseIndices[i] = IdxDiff.getRevSweepAsExpr();
forwSweepDerivativeIndices[i] = Clone(IdxDiff.getExpr());
}
auto* cloned = BuildArraySubscript(BaseDiff.getExpr(), clonedIndices);
auto* valueForRevSweep =
BuildArraySubscript(BaseDiff.getExpr(), reverseIndices);
BuildArraySubscript(BaseDiff.getRevSweepAsExpr(), reverseIndices);
Expr* target = BaseDiff.getExpr_dx();
if (!target)
return cloned;
Expand Down Expand Up @@ -1352,9 +1351,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

/// If the variable's declaration is not function global, find
/// its reverse sweep declaration and rebuild clonedRevDRE.
auto foundRevDecl = m_Locals.back().find(VD);
if (foundRevDecl!=m_Locals.back().end())
clonedRevDRE = cast<DeclRefExpr>(BuildDeclRef((*foundRevDecl).second));
for (auto locals = m_Locals.rbegin(), e = m_Locals.rend(); locals!=e; ++locals) {
auto foundRevDecl = locals->find(VD);
if (foundRevDecl != locals->end()) {
clonedRevDRE =
cast<DeclRefExpr>(BuildDeclRef((*foundRevDecl).second));
break;
}
}

if (isVectorValued) {
if (m_VectorOutput.size() <= outputArrayCursor)
Expand Down Expand Up @@ -1575,7 +1579,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
argDiffStore =
GlobalStoreAndRef(argDiff.getExpr(), "_t", /*force=*/true);
else
argDiffStore = {argDiff.getExpr(), argDiff.getExpr()};
argDiffStore = argDiff;

// We need to pass the actual argument in the cloned call expression,
// instead of a temporary, for arguments passed by reference. This is
Expand Down Expand Up @@ -1617,14 +1621,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// ```
VarDecl* argDiffLocalVD = BuildVarDecl(
argDiffStore.getExpr_dx()->getType(),
CreateUniqueIdentifier("_r"), argDiffStore.getExpr_dx(),
CreateUniqueIdentifier("_r"), argDiffStore.getRevSweepAsExpr(),
/*DirectInit=*/false, /*TSI=*/nullptr,
VarDecl::InitializationStyle::CInit);
auto& block = getCurrentBlock(direction::reverse);
block.insert(block.begin() + insertionPoint,
BuildDeclStmt(argDiffLocalVD));
// Restore agrs
auto* op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getExpr(),
auto* op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getRevSweepAsExpr(),
BuildDeclRef(argDiffLocalVD));
block.insert(block.begin() + insertionPoint + 1, op);

Expand All @@ -1644,18 +1648,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
} else {
// Restore args
auto& block = getCurrentBlock(direction::reverse);
auto* op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getExpr(),
auto* op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getRevSweepAsExpr(),
argDiffStore.getExpr());
block.insert(block.begin() + insertionPoint, op);
// We added restoration of the original arg. Thus we need to
// correspondingly adjust the insertion point.
insertionPoint += 1;

argDiffStore = {argDiff.getExpr(), argDiffStore.getExpr_dx()};
argDiffStore.updateStmt(argDiff.getExpr());
}
}
CallArgs.push_back(argDiffStore.getExpr());
DerivedCallArgs.push_back(argDiffStore.getExpr_dx());
DerivedCallArgs.push_back(argDiffStore.getRevSweepAsExpr());
}

VarDecl* gradVarDecl = nullptr;
Expand Down Expand Up @@ -2911,14 +2915,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto CladTape = MakeCladTapeFor(Clone(E), Type);
Expr* Push = CladTape.Push;
Expr* Pop = CladTape.Pop;
return {Push, Pop};
return {Push, Pop, nullptr, Pop};
}
Expr* init = nullptr;
if (const auto* const AT = dyn_cast<ArrayType>(Type))
init = getArraySizeExpr(AT, m_Context, *this);

Expr* Ref = BuildDeclRef(GlobalStoreImpl(Type, prefix, init));
return {Ref, Ref};
return {Ref, Ref, nullptr, Ref};
}

StmtDiff ReverseModeVisitor::StoreAndRestore(clang::Expr* E,
Expand Down

0 comments on commit b9c1ace

Please sign in to comment.