diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 5b78db25b..7b33c9ee5 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1295,15 +1295,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVector reverseIndices(Indices.size()); llvm::SmallVector forwSweepDerivativeIndices(Indices.size()); for (std::size_t i = 0; i < Indices.size(); i++) { - // FIXME: Remove redundant indices vectors. StmtDiff IdxDiff = Visit(Indices[i]); - clonedIndices[i] = Clone(IdxDiff.getExpr()); - reverseIndices[i] = Clone(IdxDiff.getExpr()); - forwSweepDerivativeIndices[i] = IdxDiff.getExpr(); + clonedIndices[i] = IdxDiff.getExpr(); + reverseIndices[i] = IdxDiff.getRevSweepAsExpr(); + forwSweepDerivativeIndices[i] = Clone(IdxDiff.getExpr()); } auto* cloned = BuildArraySubscript(BaseDiff.getExpr(), clonedIndices); auto* valueForRevSweep = - BuildArraySubscript(BaseDiff.getExpr(), reverseIndices); + BuildArraySubscript(BaseDiff.getRevSweepAsExpr(), reverseIndices); Expr* target = BaseDiff.getExpr_dx(); if (!target) return cloned; @@ -1352,9 +1351,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, /// If the variable's declaration is not function global, find /// its reverse sweep declaration and rebuild clonedRevDRE. - auto foundRevDecl = m_Locals.back().find(VD); - if (foundRevDecl!=m_Locals.back().end()) - clonedRevDRE = cast(BuildDeclRef((*foundRevDecl).second)); + for (auto locals = m_Locals.rbegin(), e = m_Locals.rend(); locals!=e; ++locals) { + auto foundRevDecl = locals->find(VD); + if (foundRevDecl != locals->end()) { + clonedRevDRE = + cast(BuildDeclRef((*foundRevDecl).second)); + break; + } + } if (isVectorValued) { if (m_VectorOutput.size() <= outputArrayCursor) @@ -1575,7 +1579,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, argDiffStore = GlobalStoreAndRef(argDiff.getExpr(), "_t", /*force=*/true); else - argDiffStore = {argDiff.getExpr(), argDiff.getExpr()}; + argDiffStore = argDiff; // We need to pass the actual argument in the cloned call expression, // instead of a temporary, for arguments passed by reference. This is @@ -1617,14 +1621,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // ``` VarDecl* argDiffLocalVD = BuildVarDecl( argDiffStore.getExpr_dx()->getType(), - CreateUniqueIdentifier("_r"), argDiffStore.getExpr_dx(), + CreateUniqueIdentifier("_r"), argDiffStore.getRevSweepAsExpr(), /*DirectInit=*/false, /*TSI=*/nullptr, VarDecl::InitializationStyle::CInit); auto& block = getCurrentBlock(direction::reverse); block.insert(block.begin() + insertionPoint, BuildDeclStmt(argDiffLocalVD)); // Restore agrs - auto* op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getExpr(), + auto* op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getRevSweepAsExpr(), BuildDeclRef(argDiffLocalVD)); block.insert(block.begin() + insertionPoint + 1, op); @@ -1644,18 +1648,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else { // Restore args auto& block = getCurrentBlock(direction::reverse); - auto* op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getExpr(), + auto* op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getRevSweepAsExpr(), argDiffStore.getExpr()); block.insert(block.begin() + insertionPoint, op); // We added restoration of the original arg. Thus we need to // correspondingly adjust the insertion point. insertionPoint += 1; - argDiffStore = {argDiff.getExpr(), argDiffStore.getExpr_dx()}; + argDiffStore.updateStmt(argDiff.getExpr()); } } CallArgs.push_back(argDiffStore.getExpr()); - DerivedCallArgs.push_back(argDiffStore.getExpr_dx()); + DerivedCallArgs.push_back(argDiffStore.getRevSweepAsExpr()); } VarDecl* gradVarDecl = nullptr; @@ -2911,14 +2915,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto CladTape = MakeCladTapeFor(Clone(E), Type); Expr* Push = CladTape.Push; Expr* Pop = CladTape.Pop; - return {Push, Pop}; + return {Push, Pop, nullptr, Pop}; } Expr* init = nullptr; if (const auto* const AT = dyn_cast(Type)) init = getArraySizeExpr(AT, m_Context, *this); Expr* Ref = BuildDeclRef(GlobalStoreImpl(Type, prefix, init)); - return {Ref, Ref}; + return {Ref, Ref, nullptr, Ref}; } StmtDiff ReverseModeVisitor::StoreAndRestore(clang::Expr* E,