diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9e0e7b78e..6357d78fe 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1059,8 +1059,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VarDecl* condVarClone = nullptr; if (FS->getConditionVariable()) { condVarRes = DifferentiateSingleStmt(FS->getConditionVariableDeclStmt()); - Decl* decl = cast(condVarRes.getStmt())->getSingleDecl(); - condVarClone = cast(decl); + if (isa(condVarRes.getStmt())) { + Decl *decl = cast(condVarRes.getStmt())->getSingleDecl(); + condVarClone = cast(decl); + } } // FIXME: for now we assume that cond has no differentiable effects, @@ -1111,9 +1113,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, incDiff.getStmt_dx(), /*isForLoop=*/true); + /// FIXME: This part in necessary to replace local variables inside loops + /// with function globals and replace initializations with assignments. + /// This is a temporary measure to avoid the bug that arises from overwriting + /// local variables on different loop passes. + Expr* forwardCond = cond.getExpr(); + if (condVarRes.getExpr() && isa(condVarRes.getExpr())) { + forwardCond = BuildOp(BO_Comma, cond.getExpr(), cast(condVarRes.getExpr())); + } + Stmt* Forward = new (m_Context) ForStmt(m_Context, initResult.getStmt(), - cond.getExpr(), + forwardCond, condVarClone, incResult, BodyDiff.getStmt(), @@ -1208,7 +1219,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitParenExpr(const ParenExpr* PE) { StmtDiff subStmtDiff = Visit(PE->getSubExpr(), dfdx()); return StmtDiff(BuildParens(subStmtDiff.getExpr()), - BuildParens(subStmtDiff.getExpr_dx())); + BuildParens(subStmtDiff.getExpr_dx()), + nullptr, + BuildParens(subStmtDiff.getRevSweepExpr())); } StmtDiff ReverseModeVisitor::VisitInitListExpr(const InitListExpr* ILE) { @@ -1779,6 +1792,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, idx++; } Expr* pullback = dfdx(); + if ((pullback == nullptr) && FD->getReturnType()->isLValueReferenceType()) pullback = getZeroInit(FD->getReturnType().getNonReferenceType()); @@ -2052,6 +2066,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { auto opCode = UnOp->getOpcode(); + Expr* valueForRevPass = nullptr; StmtDiff diff{}; Expr* E = UnOp->getSubExpr(); // If it is a post-increment/decrement operator, its result is a reference @@ -2070,28 +2085,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, diff = Visit(E, d); } else if (opCode == UO_PostInc || opCode == UO_PostDec) { diff = Visit(E, dfdx()); - auto EStored = GlobalStoreAndRef(diff.getExpr()); - if (EStored.getExpr() != diff.getExpr()) { - auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, - Clone(diff.getExpr()), EStored.getExpr_dx()); - if (isInsideLoop) - addToCurrentBlock(EStored.getExpr(), direction::forward); - addToCurrentBlock(assign, direction::reverse); + if (UsefulToStoreGlobal(diff.getRevSweepExpr())) { + auto op = opCode == UO_PostInc ? UO_PostDec : UO_PostInc; + addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepExpr())), direction::reverse); } ResultRef = diff.getExpr_dx(); + valueForRevPass = diff.getRevSweepExpr(); if (m_ExternalSource) m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff); } else if (opCode == UO_PreInc || opCode == UO_PreDec) { diff = Visit(E, dfdx()); - auto EStored = GlobalStoreAndRef(diff.getExpr()); - if (EStored.getExpr() != diff.getExpr()) { - auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, - Clone(diff.getExpr()), EStored.getExpr_dx()); - if (isInsideLoop) - addToCurrentBlock(EStored.getExpr(), direction::forward); - addToCurrentBlock(assign, direction::reverse); + if (UsefulToStoreGlobal(diff.getRevSweepExpr())) { + auto op = opCode == UO_PreInc ? UO_PreDec : UO_PreInc; + addToCurrentBlock(BuildOp(op, Clone(diff.getRevSweepExpr())), direction::reverse); } + auto op = opCode == UO_PreInc ? BinaryOperatorKind::BO_Add : BinaryOperatorKind::BO_Sub; + auto sum = BuildOp(op, diff.getRevSweepExpr(), ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1)); + valueForRevPass = utils::BuildParenExpr(m_Sema, sum); } else if (opCode == UnaryOperatorKind::UO_Real || opCode == UnaryOperatorKind::UO_Imag) { diff = VisitWithExplicitNoDfDx(E); @@ -2139,7 +2150,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, diff = StmtDiff(E); } Expr* op = BuildOp(opCode, diff.getExpr()); - return StmtDiff(op, ResultRef); + return StmtDiff(op, ResultRef, nullptr, valueForRevPass); } StmtDiff @@ -2197,11 +2208,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (!RDelayed.isConstant) { Expr* dr = nullptr; if (dfdx()) { - StmtDiff LResult; - if (isa(LStored->IgnoreImpCasts())) - LResult = {LStored, LStored}; - else - LResult = GlobalStoreAndRef(LStored, "_t", /*force=*/true); + StmtDiff LResult = GlobalStoreAndRef(LStored); LStored = LResult.getExpr(); dr = BuildOp(BO_Mul, LResult.getExpr_dx(), dfdx()); dr = StoreAndRef(dr, direction::reverse); @@ -2228,27 +2235,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr)) // Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is // produced instead of 1 / (R * R). - Expr* LStored = Ldiff.getExpr(); if (!RDelayed.isConstant) { Expr* dr = nullptr; - StmtDiff LResult; if (dfdx()) { - if (isa(LStored->IgnoreParenImpCasts())) - LResult = {LStored, LStored}; - else - LResult = GlobalStoreAndRef(LStored, "_t", /*force=*/true); - LStored = LResult.getExpr(); Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored)); dr = BuildOp(BO_Mul, dfdx(), BuildOp(UO_Minus, - BuildOp(BO_Div, LResult.getExpr_dx(), RxR))); + BuildOp(BO_Div, Ldiff.getRevSweepExpr(), RxR))); dr = StoreAndRef(dr, direction::reverse); } Rdiff = Visit(R, dr); RDelayed.Finalize(Rdiff.getExpr()); } - std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult.getExpr()); + std::tie(Ldiff, Rdiff) = std::make_pair(Ldiff.getRevSweepExpr(), RResult.getRevSweepExpr()); } else if (BinOp->isAssignmentOp()) { if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) { diag(DiagnosticsEngine::Warning, @@ -2333,15 +2333,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Ldiff.updateStmt(storeE); } - // if (/*!L->HasSideEffects(m_Context)*/true) { - // Lstored = GlobalStoreAndRef(Ldiff.getExpr(), "_t", /*force*/true); - // auto assign = BuildOp(BO_Assign, Ldiff.getExpr(), Lstored.getExpr_dx()); - // if (isInsideLoop) { - // addToCurrentBlock(Lstored.getExpr(), direction::forward); - // } - // addToCurrentBlock(assign, direction::reverse); - // } - Expr* LCloned = Ldiff.getExpr(); // For x, AssignedDiff is _d_x, for x[i] its _d_x[i], for reference exprs // like (x = y) it propagates recursively, so _d_x is also returned. @@ -2364,10 +2355,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // If assigned expr is dependent, first update its derivative; auto Lblock_begin = Lblock->body_rbegin(); auto Lblock_end = Lblock->body_rend(); - // if (Lblock->size()) { - // addToCurrentBlock(*Lblock_begin, direction::reverse); - // Lblock_begin = std::next(Lblock_begin); - // } for (auto S : essentialRevBlock) addToCurrentBlock(S, direction::reverse); @@ -2378,14 +2365,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } for (auto E : return_exprs) { - Lstored = GlobalStoreAndRef(E); - if (Lstored.getExpr() != E) { - auto* assign = - BuildOp(BinaryOperatorKind::BO_Assign, E, Lstored.getExpr_dx()); - if (isInsideLoop) - addToCurrentBlock(Lstored.getExpr(), direction::forward); - addToCurrentBlock(assign, direction::reverse); - } + auto pushPop = StoreAndRestore(E); + addToCurrentBlock(pushPop.getExpr(), direction::forward); + addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse); } if (m_ExternalSource) @@ -2414,7 +2396,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, AssignedDiff, BuildOp(BO_Mul, oldValue, RResult.getExpr_dx())), direction::reverse); - Expr* LRef = LCloned; if (!RDelayed.isConstant) { // Create a reference variable to keep the result of LHS, since it // must be used on 2 places: when storing to a global variable @@ -2429,24 +2410,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // double & _ref0 = (x *= y); // _t0 = _ref0; // double r = _ref0 *= z; - StmtDiff LResult; - if (LCloned->HasSideEffects(m_Context)) { - auto RefType = getNonConstType(L->getType(), m_Context, m_Sema); - // RefType = m_Context.getLValueReferenceType(RefType); - LRef = StoreAndRef(LCloned, RefType, direction::forward, "_ref", - /*forceDeclCreation=*/true); - LResult = GlobalStoreAndRef(LRef, "_t", /*force=*/true); - } else - LResult = {LRef, LRef}; if (isInsideLoop) - addToCurrentBlock(LResult.getExpr(), direction::forward); - Expr* dr = BuildOp(BO_Mul, LResult.getExpr_dx(), oldValue); + addToCurrentBlock(LCloned, direction::forward); + Expr* dr = BuildOp(BO_Mul, LCloned, oldValue); dr = StoreAndRef(dr, direction::reverse); Rdiff = Visit(R, dr); RDelayed.Finalize(Rdiff.getExpr()); } - std::tie(Ldiff, Rdiff) = std::make_pair(LRef, RResult.getExpr()); + std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr()); } else if (opCode == BO_DivAssign) { auto RDelayed = DelayedGlobalStoreAndRef(R); StmtDiff RResult = RDelayed.Result; @@ -2455,29 +2427,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, AssignedDiff, BuildOp(BO_Div, oldValue, RStored)), direction::reverse); - Expr* LRef = LCloned; if (!RDelayed.isConstant) { - StmtDiff LResult; - if (LCloned->HasSideEffects(m_Context)) { - QualType RefType = m_Context.getLValueReferenceType( - getNonConstType(L->getType(), m_Context, m_Sema)); - LRef = StoreAndRef(LCloned, RefType, direction::forward, "_ref", - /*forceDeclCreation=*/true); - LResult = GlobalStoreAndRef(LRef, "_t", /*force=*/true); - } else - LResult = {LRef, LRef}; if (isInsideLoop) - addToCurrentBlock(LResult.getExpr(), direction::forward); + addToCurrentBlock(LCloned, direction::forward); Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored)); Expr* dr = BuildOp( BO_Mul, oldValue, - BuildOp(UO_Minus, BuildOp(BO_Div, LResult.getExpr_dx(), RxR))); + BuildOp(UO_Minus, BuildOp(BO_Div, LCloned, RxR))); dr = StoreAndRef(dr, direction::reverse); Rdiff = Visit(R, dr); RDelayed.Finalize(Rdiff.getExpr()); } - std::tie(Ldiff, Rdiff) = std::make_pair(LRef, RResult.getExpr()); + std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr()); } else llvm_unreachable("unknown assignment opCode"); if (m_ExternalSource) @@ -2515,7 +2477,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return BuildOp(opCode, LExpr, RExpr); } Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr()); - return StmtDiff(op, ResultRef); + return StmtDiff(op, ResultRef, nullptr, valueForRevPass); } VarDeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { @@ -2707,6 +2669,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitDeclStmt(const DeclStmt* DS) { + llvm::SmallVector inits; llvm::SmallVector decls; llvm::SmallVector declsDiff; // Need to put array decls inlined. @@ -2740,6 +2703,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // } if (VDDiff.getDecl()->getDeclName() != VD->getDeclName()) m_DeclReplacements[VD] = VDDiff.getDecl(); + + /// FIXME: This part in necessary to replace local variables inside loops + /// with function globals and replace initializations with assignments. + /// This is a temporary measure to avoid the bug that arises from overwriting + /// local variables on different loop passes. + if (isInsideLoop) { + if (VD->getType()->isBuiltinType() && !VD->getType().isConstQualified()) { + auto *decl = VDDiff.getDecl(); + if (decl->getInit()) { + auto *declRef = BuildDeclRef(decl); + auto pushPop = StoreAndRestore(declRef, /*prefix=*/"_t", /*force=*/true); + if (pushPop.getExpr() != declRef) { + addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse); + } + auto *assignment = BuildOp(BO_Assign, declRef, decl->getInit()); + inits.push_back(BuildOp(BO_Comma, pushPop.getExpr(), assignment)); + } + decl->setInit(getZeroInit(VD->getType())); + } + } + decls.push_back(VDDiff.getDecl()); if (isa(VD->getType())) localDeclsDiff.push_back(VDDiff.getDecl_dx()); @@ -2769,16 +2753,30 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, declsDiff.append(localDeclsDiff.begin(), localDeclsDiff.end()); m_ExternalSource->ActBeforeFinalizingVisitDeclStmt(decls, declsDiff); } + + /// FIXME: This part in necessary to replace local variables inside loops + /// with function globals and replace initializations with assignments. + /// This is a temporary measure to avoid the bug that arises from overwriting + /// local variables on different loop passes. + if (isInsideLoop) { + if (auto *VD = dyn_cast(decls[0])) { + if (VD->getType()->isBuiltinType() && !VD->getType().isConstQualified()) { + addToBlock(DSClone, m_Globals); + Stmt *initAssignments = MakeCompoundStmt(inits); + initAssignments = unwrapIfSingleStmt(initAssignments); + return StmtDiff(initAssignments); + } + } + } + return StmtDiff(DSClone); } StmtDiff ReverseModeVisitor::VisitImplicitCastExpr(const ImplicitCastExpr* ICE) { - StmtDiff subExprDiff = Visit(ICE->getSubExpr(), dfdx()); // Casts should be handled automatically when the result is used by // Sema::ActOn.../Build... - return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx(), - subExprDiff.getForwSweepStmt_dx()); + return Visit(ICE->getSubExpr(), dfdx()); } StmtDiff ReverseModeVisitor::VisitMemberExpr(const MemberExpr* ME) { @@ -3027,10 +3025,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // statement. Sema::ConditionResult condResult; if (condVarDecl) { - Decl* condVarClone = cast(condVarRes.getStmt()) + if (condVarRes.getStmt()) { + if (isa(condVarRes.getStmt())) { + Decl *condVarClone = cast(condVarRes.getStmt()) ->getSingleDecl(); - condResult = m_Sema.ActOnConditionVariable(condVarClone, noLoc, + condResult = m_Sema.ActOnConditionVariable(condVarClone, noLoc, + Sema::ConditionKind::Boolean); + } else { + condResult = m_Sema.ActOnCondition(getCurrentScope(), noLoc, cast(condVarRes.getStmt()), Sema::ConditionKind::Boolean); + } + } } else { condResult = m_Sema.ActOnCondition(getCurrentScope(), noLoc, condClone, Sema::ConditionKind::Boolean); diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index 3fa47c780..252b33350 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -739,10 +739,6 @@ void TBRAnalyzer::VisitUnaryOperator(const clang::UnaryOperator* UnOp) { /// Mark corresponding SourceLocation as required/not required to be /// stored for all expressions that could be changed. markLocation(innerExpr); - /// Set them to not required to store because the values were changed. - /// (if some value was not changed, this could only happen if it was - /// already not required to store). - setIsRequired(innerExpr, /*isReq=*/false); } } }