diff --git a/include/clad/Differentiator/CladUtils.h b/include/clad/Differentiator/CladUtils.h index 9fbaf9717..dbf860daa 100644 --- a/include/clad/Differentiator/CladUtils.h +++ b/include/clad/Differentiator/CladUtils.h @@ -316,8 +316,12 @@ namespace clad { bool hasNonDifferentiableAttribute(const clang::Decl* D); bool hasNonDifferentiableAttribute(const clang::Expr* E); - /// FIXME: add documentation - std::vector GetInnermostReturnExpr(const clang::Expr* E); + + /// Collects every DeclRefExpr, MemberExpr, ArraySubscriptExpr in an + /// assignment operator or a ternary if operator. This is useful to when we + /// need to decide what needs to be stored on tape in reverse mode. + void GetInnermostReturnExpr(const clang::Expr* E, + llvm::SmallVectorImpl& Exprs); } // namespace utils } diff --git a/lib/Differentiator/CladUtils.cpp b/lib/Differentiator/CladUtils.cpp index 3b3bae886..11721ef62 100644 --- a/lib/Differentiator/CladUtils.cpp +++ b/lib/Differentiator/CladUtils.cpp @@ -544,17 +544,18 @@ namespace clad { return false; } - std::vector GetInnermostReturnExpr(const clang::Expr* E) { - struct Finder : public ConstStmtVisitor { - std::vector m_return_exprs; + void GetInnermostReturnExpr(const clang::Expr* E, + llvm::SmallVectorImpl& Exprs) { + struct Finder : public StmtVisitor { + llvm::SmallVectorImpl& m_Exprs; public: - std::vector Find(const clang::Expr* E) { + Finder(clang::Expr* E, llvm::SmallVectorImpl& Exprs) + : m_Exprs(Exprs) { Visit(E); - return m_return_exprs; } - void VisitBinaryOperator(const clang::BinaryOperator* BO) { + void VisitBinaryOperator(clang::BinaryOperator* BO) { if (BO->isAssignmentOp() || BO->isCompoundAssignmentOp()) { Visit(BO->getLHS()); } else if (BO->getOpcode() == clang::BO_Comma) { @@ -564,41 +565,37 @@ namespace clad { } } - void VisitConditionalOperator(const clang::ConditionalOperator* CO) { + void VisitConditionalOperator(clang::ConditionalOperator* CO) { // FIXME: in cases like (cond ? x : y) = 2; both x and y will be // stored. Visit(CO->getTrueExpr()); Visit(CO->getFalseExpr()); } - void VisitUnaryOperator(const clang::UnaryOperator* UnOp) { + void VisitUnaryOperator(clang::UnaryOperator* UnOp) { auto opCode = UnOp->getOpcode(); if (opCode == clang::UO_PreInc || opCode == clang::UO_PreDec) Visit(UnOp->getSubExpr()); } - void VisitDeclRefExpr(const clang::DeclRefExpr* DRE) { - m_return_exprs.push_back(const_cast(DRE)); + void VisitDeclRefExpr(clang::DeclRefExpr* DRE) { + m_Exprs.push_back(DRE); } - void VisitParenExpr(const clang::ParenExpr* PE) { - Visit(PE->getSubExpr()); - } + void VisitParenExpr(clang::ParenExpr* PE) { Visit(PE->getSubExpr()); } - void VisitMemberExpr(const clang::MemberExpr* ME) { - m_return_exprs.push_back(const_cast(ME)); - } + void VisitMemberExpr(clang::MemberExpr* ME) { m_Exprs.push_back(ME); } - void VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE) { - m_return_exprs.push_back(const_cast(ASE)); + void VisitArraySubscriptExpr(clang::ArraySubscriptExpr* ASE) { + m_Exprs.push_back(ASE); } - void VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE) { + void VisitImplicitCastExpr(clang::ImplicitCastExpr* ICE) { Visit(ICE->getSubExpr()); } }; - Finder finder; - return finder.Find(E); + // FIXME: Fix the constness on the callers of this function. + Finder finder(const_cast(E), Exprs); } bool IsAutoOrAutoPtrType(QualType T) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 218218158..914fdf5c3 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -2330,7 +2330,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Ldiff = Visit(L, dfdx()); Stmts essentialRevBlock = EndBlockWithoutCreatingCS(direction::essential_reverse); auto* Lblock = endBlock(direction::reverse); - auto return_exprs = utils::GetInnermostReturnExpr(Ldiff.getExpr()); + llvm::SmallVector ExprsToStore; + utils::GetInnermostReturnExpr(Ldiff.getExpr(), ExprsToStore); if (L->HasSideEffects(m_Context)) { Expr* E = Ldiff.getExpr(); auto* storeE = StoreAndRef(E, m_Context.getLValueReferenceType(E->getType())); @@ -2368,7 +2369,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Lblock_begin = std::next(Lblock_begin); } - for (auto& E : return_exprs) { + for (auto& E : ExprsToStore) { auto pushPop = StoreAndRestore(E); addToCurrentBlock(pushPop.getExpr(), direction::forward); addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse); diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 833a8a450..af21d1061 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -599,10 +599,11 @@ void TBRAnalyzer::VisitDeclStmt(const DeclStmt* DS) { auto& VDExpr = getCurBlockVarsData()[VD]; /// if the declared variable is ref type attach its VarData to the /// VarData of the RHS variable. - auto returnExprs = utils::GetInnermostReturnExpr(init); - if (VDExpr.type == VarData::REF_TYPE && !returnExprs.empty()) + llvm::SmallVector ExprsToStore; + utils::GetInnermostReturnExpr(init, ExprsToStore); + if (VDExpr.type == VarData::REF_TYPE && !ExprsToStore.empty()) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access) - VDExpr.val.m_RefData = returnExprs[0]; + VDExpr.val.m_RefData = ExprsToStore[0]; } } } @@ -695,8 +696,9 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) { Visit(R); resetMode(); } - const auto returnExprs = utils::GetInnermostReturnExpr(L); - for (const auto* innerExpr : returnExprs) { + llvm::SmallVector ExprsToStore; + utils::GetInnermostReturnExpr(L, ExprsToStore); + for (const auto* innerExpr : ExprsToStore) { /// Mark corresponding SourceLocation as required/not required to be /// stored for all expressions that could be used changed. markLocation(innerExpr); @@ -726,8 +728,9 @@ void TBRAnalyzer::VisitUnaryOperator(const clang::UnaryOperator* UnOp) { // FIXME: this doesn't support all the possible references /// Mark corresponding SourceLocation as required/not required to be /// stored for all expressions that could be used in this operation. - const auto innerExprs = utils::GetInnermostReturnExpr(E); - for (const auto* innerExpr : innerExprs) { + llvm::SmallVector ExprsToStore; + utils::GetInnermostReturnExpr(E, ExprsToStore); + for (const auto* innerExpr : ExprsToStore) { /// Mark corresponding SourceLocation as required/not required to be /// stored for all expressions that could be changed. markLocation(innerExpr); @@ -754,7 +757,8 @@ void TBRAnalyzer::VisitCallExpr(const clang::CallExpr* CE) { resetMode(); const auto* B = arg->IgnoreParenImpCasts(); // FIXME: this supports only DeclRefExpr - const auto innerExpr = utils::GetInnermostReturnExpr(arg); + llvm::SmallVector ExprsToStore; + utils::GetInnermostReturnExpr(arg, ExprsToStore); if (passByRef) { /// Mark SourceLocation as required to store for ref-type arguments. if (isa(B) || isa(B)) {