diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index b93e11f1c..cf1820a56 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -7,13 +7,13 @@ #ifndef CLAD_REVERSE_MODE_VISITOR_H #define CLAD_REVERSE_MODE_VISITOR_H +#include "clad/Differentiator/Compatibility.h" +#include "clad/Differentiator/VisitorBase.h" +#include "clad/Differentiator/ReverseModeVisitorDirectionKinds.h" +#include "clad/Differentiator/ParseDiffArgsTypes.h" #include "clang/AST/RecursiveASTVisitor.h" #include "clang/AST/StmtVisitor.h" #include "clang/Sema/Sema.h" -#include "clad/Differentiator/Compatibility.h" -#include "clad/Differentiator/ParseDiffArgsTypes.h" -#include "clad/Differentiator/ReverseModeVisitorDirectionKinds.h" -#include "clad/Differentiator/VisitorBase.h" #include #include @@ -21,593 +21,602 @@ #include namespace clad { -class ErrorEstimationHandler; -class ExternalRMVSource; -class MultiplexExternalRMVSource; - -/// A visitor for processing the function code in reverse mode. -/// Used to compute derivatives by clad::gradient. -class ReverseModeVisitor - : public clang::ConstStmtVisitor, - public VisitorBase { -protected: - // FIXME: We should remove friend-dependency of the plugin classes here. - // For this we will need to separate out AST related functions in - // a separate namespace, as well as add getters/setters function of - // several private/protected members of the visitor classes. - friend class ErrorEstimationHandler; - llvm::SmallVector m_IndependentVars; - /// In addition to a sequence of forward-accumulated Stmts (m_Blocks), in - /// the reverse mode we also accumulate Stmts for the reverse pass which - /// will be executed on return. - std::vector m_Reverse; - std::vector m_EssentialReverse; - /// Stack is used to pass the arguments (dfdx) to further nodes - /// in the Visit method. - std::stack m_Stack; - /// A sequence of DeclStmts containing "tape" variable declarations - /// that will be put immediately in the beginning of derivative function - /// block. - Stmts m_Globals; - //// A reference to the output parameter of the gradient function. - clang::Expr* m_Result; - /// Based on To-Be-Recorded analysis performed before differentiation, - /// tells UsefulToStoreGlobal whether a variable with a given - /// SourceLocation has to be stored before being changed or not. - std::set m_ToBeRecorded; - /// A flag indicating if the Stmt we are currently visiting is inside loop. - bool isInsideLoop = false; - /// Output variable of vector-valued function - std::string outputArrayStr; - std::vector m_LoopBlock; - unsigned outputArrayCursor = 0; - unsigned numParams = 0; - bool isVectorValued = false; - bool use_enzyme = false; - // FIXME: Should we make this an object instead of a pointer? - // Downside of making it an object: We will need to include - // 'MultiplexExternalRMVSource.h' file - MultiplexExternalRMVSource* m_ExternalSource = nullptr; - clang::Expr* m_Pullback = nullptr; - const char* funcPostfix() const { - if (isVectorValued) - return "_jac"; - else if (use_enzyme) - return "_grad_enzyme"; - else - return "_grad"; - } - - /// Removes the local const qualifiers from a QualType and returns a new - /// type. - static clang::QualType getNonConstType(clang::QualType T, - clang::ASTContext& C, clang::Sema& S) { - clang::Qualifiers quals(T.getQualifiers()); - quals.removeConst(); - return S.BuildQualifiedType(T.getUnqualifiedType(), noLoc, quals); - } - // Function to Differentiate with Clad as Backend - void DifferentiateWithClad(); - - // Function to Differentiate with Enzyme as Backend - void DifferentiateWithEnzyme(); - -public: - using direction = rmv::direction; - clang::Expr* dfdx() { - if (m_Stack.empty()) - return nullptr; - return m_Stack.top(); - } - StmtDiff Visit(const clang::Stmt* stmt, clang::Expr* dfdS = nullptr) { - // No need to push the same expr multiple times. - bool push = !(!m_Stack.empty() && (dfdS == dfdx())); - if (push) - m_Stack.push(dfdS); - auto result = - clang::ConstStmtVisitor::Visit(stmt); - if (push) - m_Stack.pop(); - return result; - } - - /// This visit method explicitly sets `dfdx` to `nullptr` for this visit. - /// - /// This method is helpful when we need derivative of some expression but we - /// do not want `_d_expression += dfdx` statments to be (automatically) - /// added. - /// - /// FIXME: Think of a better way for handling this situation. Maybe we - /// should improve the overall dfdx design and approach. One other way of - /// designing `VisitWithExplicitNoDfDx` in a more general way is - /// to develop a function that takes an expression E and returns the - /// corresponding derivative without any side effects. The difference - /// between this function and the current `VisitWithExplicitNoDfDx` will be - /// 1) better intent through the function name 2) We will also get - /// derivatives of expressions other than `DeclRefExpr` and `MemberExpr`. - StmtDiff VisitWithExplicitNoDfDx(const clang::Stmt* stmt) { - m_Stack.push(nullptr); - auto result = - clang::ConstStmtVisitor::Visit(stmt); - m_Stack.pop(); - return result; - } - - /// Get the latest block of code (i.e. place for statements output). - Stmts& getCurrentBlock(direction d = direction::forward) { - if (d == direction::forward) - return m_Blocks.back(); - else if (d == direction::reverse) - return m_Reverse.back(); - else - return m_EssentialReverse.back(); - } - /// Create new block. - Stmts& beginBlock(direction d = direction::forward) { - if (d == direction::forward) - m_Blocks.push_back({}); - else if (d == direction::reverse) - m_Reverse.push_back({}); - else - m_EssentialReverse.push_back({}); - return getCurrentBlock(d); - } - /// Remove the block from the stack, wrap it in CompoundStmt and return it. - clang::CompoundStmt* endBlock(direction d = direction::forward) { - if (d == direction::forward) { - auto CS = MakeCompoundStmt(getCurrentBlock(direction::forward)); - m_Blocks.pop_back(); - return CS; - } else if (d == direction::reverse) { - auto CS = MakeCompoundStmt(getCurrentBlock(direction::reverse)); - std::reverse(CS->body_begin(), CS->body_end()); - m_Reverse.pop_back(); - return CS; - } else { - auto CS = MakeCompoundStmt(getCurrentBlock(d)); - m_EssentialReverse.pop_back(); - return CS; - } - } - - Stmts EndBlockWithoutCreatingCS(direction d = direction::forward) { - auto blk = getCurrentBlock(d); - if (d == direction::forward) - m_Blocks.pop_back(); - else if (d == direction::reverse) - m_Reverse.pop_back(); - else - m_EssentialReverse.pop_back(); - return blk; - } - /// Output a statement to the current block. If Stmt is null or is an unused - /// expression, it is not output and false is returned. - bool addToCurrentBlock(clang::Stmt* S, direction d = direction::forward) { - return addToBlock(S, getCurrentBlock(d)); - } - - /// Adds a given statement to the global block. - /// - /// \param[in] S The statement to add to the block. - /// - /// \returns True if the statement was added to the block, false otherwise. - bool AddToGlobalBlock(clang::Stmt* S) { return addToBlock(S, m_Globals); } - - /// Stores the result of an expression in a temporary variable (of the same - /// type as is the result of the expression) and returns a reference to it. - /// If force decl creation is true, this will allways create a temporary - /// variable declaration. Otherwise, temporary variable is created only - /// if E requires evaluation (e.g. there is no point to store literals or - /// direct references in intermediate variables) - clang::Expr* StoreAndRef(clang::Expr* E, direction d = direction::forward, - llvm::StringRef prefix = "_t", - bool forceDeclCreation = false, - clang::VarDecl::InitializationStyle IS = - clang::VarDecl::InitializationStyle::CInit) { - assert(E && "cannot infer type from null expression"); - return StoreAndRef(E, getNonConstType(E->getType(), m_Context, m_Sema), d, - prefix, forceDeclCreation, IS); - } - - /// An overload allowing to specify the type for the variable. - clang::Expr* StoreAndRef(clang::Expr* E, clang::QualType Type, - direction d = direction::forward, - llvm::StringRef prefix = "_t", - bool forceDeclCreation = false, - clang::VarDecl::InitializationStyle IS = - clang::VarDecl::InitializationStyle::CInit) { - // Name reverse temporaries as "_r" instead of "_t". - if ((d == direction::reverse) && (prefix == "_t")) - prefix = "_r"; - Stmts* blk = nullptr; - if (d == direction::essential_reverse) - if (!m_EssentialReverse.empty()) - blk = &getCurrentBlock(direction::essential_reverse); + class ErrorEstimationHandler; + class ExternalRMVSource; + class MultiplexExternalRMVSource; + + /// A visitor for processing the function code in reverse mode. + /// Used to compute derivatives by clad::gradient. + class ReverseModeVisitor + : public clang::ConstStmtVisitor, + public VisitorBase { + protected: + // FIXME: We should remove friend-dependency of the plugin classes here. + // For this we will need to separate out AST related functions in + // a separate namespace, as well as add getters/setters function of + // several private/protected members of the visitor classes. + friend class ErrorEstimationHandler; + llvm::SmallVector m_IndependentVars; + /// In addition to a sequence of forward-accumulated Stmts (m_Blocks), in + /// the reverse mode we also accumulate Stmts for the reverse pass which + /// will be executed on return. + std::vector m_Reverse; + std::vector m_EssentialReverse; + /// Stack is used to pass the arguments (dfdx) to further nodes + /// in the Visit method. + std::stack m_Stack; + /// A sequence of DeclStmts containing "tape" variable declarations + /// that will be put immediately in the beginning of derivative function + /// block. + Stmts m_Globals; + //// A reference to the output parameter of the gradient function. + clang::Expr* m_Result; + /// Based on To-Be-Recorded analysis performed before differentiation, + /// tells UsefulToStoreGlobal whether a variable with a given + /// SourceLocation has to be stored before being changed or not. + std::map m_ToBeRecorded; + /// A flag indicating if the Stmt we are currently visiting is inside loop. + bool isInsideLoop = false; + /// Output variable of vector-valued function + std::string outputArrayStr; + std::vector m_LoopBlock; + unsigned outputArrayCursor = 0; + unsigned numParams = 0; + bool isVectorValued = false; + bool use_enzyme = false; + // FIXME: Should we make this an object instead of a pointer? + // Downside of making it an object: We will need to include + // 'MultiplexExternalRMVSource.h' file + MultiplexExternalRMVSource* m_ExternalSource = nullptr; + clang::Expr* m_Pullback = nullptr; + const char* funcPostfix() const { + if (isVectorValued) + return "_jac"; + else if (use_enzyme) + return "_grad_enzyme"; else - blk = &getCurrentBlock(direction::reverse); - else - blk = &getCurrentBlock(d); - return VisitorBase::StoreAndRef(E, Type, *blk, prefix, forceDeclCreation, - IS); - } - - /// For an expr E, decides if it is useful to store it in a global temporary - /// variable and replace E's further usage by a reference to that variable - /// to avoid recomputiation. - bool UsefulToStoreGlobal(clang::Expr* E); - - /// Builds a variable declaration and stores it in the function - /// global scope. - /// - /// \param[in] Type The type of variable declaration to build. - /// - /// \param[in] prefix The prefix (if any) to the declration name. - /// - /// \param[in] init The variable declaration initializer. - /// - /// \returns A variable declaration that is already added to the - /// global scope. - clang::VarDecl* GlobalStoreImpl(clang::QualType Type, llvm::StringRef prefix, - clang::Expr* init = nullptr); - /// Creates a (global in the function scope) variable declaration, puts - /// it into m_Globals block (to be inserted into the beginning of fn's - /// body). Returns reference R to the created declaration. If E is not null, - /// puts an additional assignment statement (R = E) in the forward block. - /// Alternatively, if isInsideLoop is true, stores E in a stack. Returns - /// StmtDiff, where .getExpr() is intended to be used in forward pass and - /// .getExpr_dx() in the reverse pass. Two expressions can be different in - /// some cases, e.g. clad::push/pop inside loops. - StmtDiff GlobalStoreAndRef(clang::Expr* E, clang::QualType Type, - llvm::StringRef prefix = "_t", bool force = false); - StmtDiff GlobalStoreAndRef(clang::Expr* E, llvm::StringRef prefix = "_t", - bool force = false); - - //// A type returned by DelayedGlobalStoreAndRef - /// .Result is a reference to the created (yet uninitialized) global - /// variable. When the expression is finally visited and rebuilt, .Finalize - /// must be called with new rebuilt expression, to initialize the global - /// variable. Alternatively, expression may be not worth storing in a global - /// varialbe and is easy to clone (e.g. it is a constant literal). Then - /// .Result is cloned E, .isConstant is true and .Finalize does nothing. - struct DelayedStoreResult { - ReverseModeVisitor& V; - StmtDiff Result; - bool isConstant; - bool isInsideLoop; - bool needsUpdate; - DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult, - bool pIsConstant, bool pIsInsideLoop, - bool pNeedsUpdate = false) - : V(pV), Result(pResult), isConstant(pIsConstant), - isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {} - void Finalize(clang::Expr* New); - }; + return "_grad"; + } - /// Sometimes (e.g. when visiting multiplication/division operator), we - /// need to allocate global variable for an expression (e.g. for RHS) before - /// we visit that expression for efficiency reasons, since we may use that - /// global variable for visiting another expression (e.g. LHS) instead of - /// cloning LHS. The global variable will be assigned with the actual - /// expression only later, after the expression is visited and rebuilt. - /// This is what DelayedGlobalStoreAndRef does. E is expected to be the - /// original (uncloned) expression. - DelayedStoreResult DelayedGlobalStoreAndRef(clang::Expr* E, - llvm::StringRef prefix = "_t"); - - struct CladTapeResult { - ReverseModeVisitor& V; - clang::Expr* Push; - clang::Expr* Pop; - clang::Expr* Ref; - /// A request to get expr accessing last element in the tape - /// (clad::back(Ref)). Since it is required only rarely, it is built on - /// demand in the method. - clang::Expr* Last(); - }; + /// Removes the local const qualifiers from a QualType and returns a new + /// type. + static clang::QualType + getNonConstType(clang::QualType T, clang::ASTContext& C, clang::Sema& S) { + clang::Qualifiers quals(T.getQualifiers()); + quals.removeConst(); + return S.BuildQualifiedType(T.getUnqualifiedType(), noLoc, quals); + } + // Function to Differentiate with Clad as Backend + void DifferentiateWithClad(); - /// Make a clad::tape to store variables. - /// If E is supposed to be stored in a tape, will create a global - /// declaration of tape of corresponding type and return a result struct - /// with reference to the tape and constructed calls to push/pop methods. - /// - /// \param[in] E The expression to build the tape for. - /// - /// \param[in] prefix The prefix value for the name of the tape. - /// - /// \returns A struct containg necessary call expressions for the built - /// tape - CladTapeResult MakeCladTapeFor(clang::Expr* E, llvm::StringRef prefix = "_t"); - -public: - ReverseModeVisitor(DerivativeBuilder& builder); - virtual ~ReverseModeVisitor(); - - ///\brief Produces the gradient of a given function. - /// - ///\param[in] FD - the function that will be differentiated. - /// - ///\returns The gradient of the function, potentially created enclosing - /// context and if generated, its overload. - /// - /// We name the gradient of f as 'f_grad'. - /// If the gradient of the same function is requested several times - /// with different parameters, but same parameter types, every such request - /// will create f_grad function with the same signature, which will be - /// ambiguous. E.g. - /// double f(double x, double y, double z) { ... } - /// clad::gradient(f, "x, y"); - /// clad::gradient(f, "x, z"); - /// will create 2 definitions for f_grad with the same signature. - /// - /// Improved naming scheme is required. Hence, we append the indices to of - /// the requested parameters to 'f_grad', i.e. in the previous example "x, - /// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'. - DerivativeAndOverload Derive(const clang::FunctionDecl* FD, - const DiffRequest& request); - DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD, - const DiffRequest& request); - StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); - StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); - StmtDiff VisitCallExpr(const clang::CallExpr* CE); - virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); - StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); - StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); - StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); - virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); - StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); - StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL); - StmtDiff VisitForStmt(const clang::ForStmt* FS); - StmtDiff VisitIfStmt(const clang::IfStmt* If); - StmtDiff VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE); - StmtDiff VisitInitListExpr(const clang::InitListExpr* ILE); - StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL); - StmtDiff VisitMemberExpr(const clang::MemberExpr* ME); - StmtDiff VisitParenExpr(const clang::ParenExpr* PE); - virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); - StmtDiff VisitStmt(const clang::Stmt* S); - StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp); - StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC); - /// Decl is not Stmt, so it cannot be visited directly. - StmtDiff VisitWhileStmt(const clang::WhileStmt* WS); - StmtDiff VisitDoStmt(const clang::DoStmt* DS); - StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS); - StmtDiff VisitBreakStmt(const clang::BreakStmt* BS); - StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE); - StmtDiff VisitCXXConstructExpr(const clang::CXXConstructExpr* CE); - StmtDiff - VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE); - StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE); - VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); - - /// A helper method to differentiate a single Stmt in the reverse mode. - /// Internally, calls Visit(S, expr). Its result is wrapped into a - /// CompoundStmt (if several statements are created) and proper Stmt - /// order is maintained. - /// - /// \param[in] S The statement to differentiate. - /// - /// \param[in] dfdS The expression to propogate to Visit - /// - /// \returns The orignal (cloned) and differentiated forms of S - StmtDiff DifferentiateSingleStmt(const clang::Stmt* S, - clang::Expr* dfdS = nullptr); - /// A helper method used to keep substatements created by Visit(E, expr) in - /// separate forward/reverse blocks instead of putting them into current - /// blocks. First result is a StmtDiff of forward/reverse blocks with - /// additionally created Stmts, second is a direct result of call to Visit. - std::pair - DifferentiateSingleExpr(const clang::Expr* E, clang::Expr* dfdE = nullptr); - /// Shorthand for warning on differentiation of unsupported operators - void unsupportedOpWarn(clang::SourceLocation loc, - llvm::ArrayRef args = {}) { - diag(clang::DiagnosticsEngine::Warning, loc, - "attempt to differentiate unsupported operator, ignored.", args); - } - /// Builds an overload for the gradient function that has derived params for - /// all the arguments of the requested function and it calls the original - /// gradient function internally - clang::FunctionDecl* CreateGradientOverload(); - - /// Returns the type that should be used to represent the derivative of a - /// variable of type `yType` with respect to a parameter variable of type - /// `xType`. - /// - /// FIXME: Parameter derivative type rules are different from the derivative - /// type rules for local variables. We should remove this inconsistency. - /// See the following issue for more details: - /// https://github.com/vgvassilev/clad/issues/385 - clang::QualType GetParameterDerivativeType(clang::QualType yType, - clang::QualType xType); - - /// Allows to easily create and manage a counter for counting the number of - /// executed iterations of a loop. - /// - /// It is required to save the number of executed iterations to use the - /// same number of iterations in the reverse pass. - /// If we are currently inside a loop, then a clad tape object is created - /// to be used as the counter; otherwise, a temporary global variable (in - /// function scope) is created to be used as the counter. - class LoopCounter { - clang::Expr* m_Ref = nullptr; - clang::Expr* m_Pop = nullptr; - clang::Expr* m_Push = nullptr; - ReverseModeVisitor& m_RMV; + // Function to Differentiate with Enzyme as Backend + void DifferentiateWithEnzyme(); public: - LoopCounter(ReverseModeVisitor& RMV); - /// Returns `clad::push(_t, 0UL)` expression if clad tape is used - /// for counter; otherwise, returns nullptr. - clang::Expr* getPush() const { return m_Push; } - - /// Returns `clad::pop(_t)` expression if clad tape is used for - /// for counter; otherwise, returns nullptr. - clang::Expr* getPop() const { return m_Pop; } - - /// Returns reference to the last object of the clad tape if clad tape - /// is used as the counter; otherwise returns reference to the counter - /// variable. - clang::Expr* getRef() const { return m_Ref; } - - /// Returns counter post-increment expression (`counter++`). - clang::Expr* getCounterIncrement() { - return m_RMV.BuildOp(clang::UnaryOperatorKind::UO_PostInc, m_Ref); + using direction = rmv::direction; + clang::Expr* dfdx() { + if (m_Stack.empty()) + return nullptr; + return m_Stack.top(); } - - /// Returns counter post-decrement expression (`counter--`) - clang::Expr* getCounterDecrement() { - return m_RMV.BuildOp(clang::UnaryOperatorKind::UO_PostDec, m_Ref); + StmtDiff Visit(const clang::Stmt* stmt, clang::Expr* dfdS = nullptr) { + // No need to push the same expr multiple times. + bool push = !(!m_Stack.empty() && (dfdS == dfdx())); + if (push) + m_Stack.push(dfdS); + auto result = + clang::ConstStmtVisitor::Visit(stmt); + if (push) + m_Stack.pop(); + return result; } - /// Returns `ConditionResult` object for the counter. - clang::Sema::ConditionResult getCounterConditionResult() { - return m_RMV.m_Sema.ActOnCondition(m_RMV.m_CurScope, noLoc, m_Ref, - clang::Sema::ConditionKind::Boolean); + /// This visit method explicitly sets `dfdx` to `nullptr` for this visit. + /// + /// This method is helpful when we need derivative of some expression but we + /// do not want `_d_expression += dfdx` statments to be (automatically) + /// added. + /// + /// FIXME: Think of a better way for handling this situation. Maybe we + /// should improve the overall dfdx design and approach. One other way of + /// designing `VisitWithExplicitNoDfDx` in a more general way is + /// to develop a function that takes an expression E and returns the + /// corresponding derivative without any side effects. The difference + /// between this function and the current `VisitWithExplicitNoDfDx` will be + /// 1) better intent through the function name 2) We will also get + /// derivatives of expressions other than `DeclRefExpr` and `MemberExpr`. + StmtDiff VisitWithExplicitNoDfDx(const clang::Stmt* stmt) { + m_Stack.push(nullptr); + auto result = + clang::ConstStmtVisitor::Visit(stmt); + m_Stack.pop(); + return result; } - }; - /// Helper function to differentiate a loop body. - /// - ///\param[in] body body of the loop - ///\param[in] loopCounter associated `LoopCounter` object of the loop. - ///\param[in] condVarDiff derived statements of the condition - /// variable, if any. - ///\param[in] forLoopIncDiff derived statements of the `for` loop - /// increment statement, if any. - ///\param[in] isForLoop should be true if we are differentiating a `for` - /// loop body; otherwise false. - ///\returns {forward pass statements, reverse pass statements} for the loop - /// body. - StmtDiff DifferentiateLoopBody(const clang::Stmt* body, - LoopCounter& loopCounter, - clang::Stmt* condVarDifff = nullptr, - clang::Stmt* forLoopIncDiff = nullptr, - bool isForLoop = false); - - /// This class modifies forward and reverse blocks of the loop - /// body so that `break` and `continue` statements are correctly - /// handled. `break` and `continue` statements are handled by - /// enclosing entire reverse block loop body in a switch statement - /// and only executing the statements, with the help of case labels, - /// that were executed in the associated forward iteration. This is - /// determined by keeping track of which `break`/`continue` statement - /// was hit in which iteration and that in turn helps to determine which - /// case label should be selected. - /// - /// Class usage: - /// - /// ```cpp - /// auto activeBreakContStmtHandler = PushBreakContStmtHandler(); - /// activeBreakContHandler->BeginCFSwitchStmtScope(); - /// .... - /// Differentiate loop body, and save results in StmtDiff BodyDiff - /// ... - /// activeBreakContHandler->EndCFSwitchStmtScope(); - /// activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); - /// PopBreakContStmtHandler(); - /// ``` - class BreakContStmtHandler { - /// Keeps track of all the created switch cases. It is required - /// because we need to register all the switch cases later with the - /// switch statement that will be used to manage the control flow in - /// the reverse block. - llvm::SmallVector m_SwitchCases; - - /// `m_ControlFlowTape` tape keeps track of which `break`/`continue` - /// statement was hit in which iteration. - /// \note `m_ControlFlowTape` is only initialized if the body contains - /// `continue` or `break` statement. - std::unique_ptr m_ControlFlowTape; - - /// Each `break` and `continue` statement is assigned a unique number, - /// starting from 1, that is used as the case label corresponding to that - /// `break`/`continue` statement. `m_CaseCounter` stores the value that was - /// used for last `break`/`continue` statement. - std::size_t m_CaseCounter = 0; - - ReverseModeVisitor& m_RMV; - - /// Builds and returns a literal expression of type `std::size_t` with - /// `value` as value. - clang::Expr* CreateSizeTLiteralExpr(std::size_t value); - - /// Initialise the `m_ControlFlowTape`. - /// \note `m_ControlFlowTape` is not initialised in the constructor - /// because it is only initialised if it is required. It is only required - /// if body contains `break` or `continue` statement. - void InitializeCFTape(); - - /// Builds and returns `clad::push(tapeRef, value)` expression. - clang::Expr* CreateCFTapePushExpr(std::size_t value); + /// Get the latest block of code (i.e. place for statements output). + Stmts& getCurrentBlock(direction d = direction::forward) { + if (d == direction::forward) + return m_Blocks.back(); + else if (d == direction::reverse) + return m_Reverse.back(); + else + return m_EssentialReverse.back(); + } + /// Create new block. + Stmts& beginBlock(direction d = direction::forward) { + if (d == direction::forward) + m_Blocks.push_back({}); + else if (d == direction::reverse) + m_Reverse.push_back({}); + else + m_EssentialReverse.push_back({}); + return getCurrentBlock(d); + } + /// Remove the block from the stack, wrap it in CompoundStmt and return it. + clang::CompoundStmt* endBlock(direction d = direction::forward) { + if (d == direction::forward) { + auto CS = MakeCompoundStmt(getCurrentBlock(direction::forward)); + m_Blocks.pop_back(); + return CS; + } else if (d == direction::reverse) { + auto CS = MakeCompoundStmt(getCurrentBlock(direction::reverse)); + std::reverse(CS->body_begin(), CS->body_end()); + m_Reverse.pop_back(); + return CS; + } else { + auto CS = MakeCompoundStmt(getCurrentBlock(d)); + m_EssentialReverse.pop_back(); + return CS; + } + } - public: - BreakContStmtHandler(ReverseModeVisitor& RMV) : m_RMV(RMV) {} + Stmts EndBlockWithoutCreatingCS(direction d = direction::forward) { + auto blk = getCurrentBlock(d); + if (d == direction::forward) + m_Blocks.pop_back(); + else if (d == direction::reverse) + m_Reverse.pop_back(); + else + m_EssentialReverse.pop_back(); + return blk; + } + /// Output a statement to the current block. If Stmt is null or is an unused + /// expression, it is not output and false is returned. + bool addToCurrentBlock(clang::Stmt* S, direction d = direction::forward) { + return addToBlock(S, getCurrentBlock(d)); + } - /// Begins control flow switch statement scope. - /// Control flow switch statement is used to refer to the - /// switch statement that manages the control flow of the reverse - /// block. - void BeginCFSwitchStmtScope() const; + /// Adds a given statement to the global block. + /// + /// \param[in] S The statement to add to the block. + /// + /// \returns True if the statement was added to the block, false otherwise. + bool AddToGlobalBlock(clang::Stmt* S) { return addToBlock(S, m_Globals); } + + /// Stores the result of an expression in a temporary variable (of the same + /// type as is the result of the expression) and returns a reference to it. + /// If force decl creation is true, this will allways create a temporary + /// variable declaration. Otherwise, temporary variable is created only + /// if E requires evaluation (e.g. there is no point to store literals or + /// direct references in intermediate variables) + clang::Expr* StoreAndRef(clang::Expr* E, direction d = direction::forward, + llvm::StringRef prefix = "_t", + bool forceDeclCreation = false, + clang::VarDecl::InitializationStyle IS = + clang::VarDecl::InitializationStyle::CInit) { + assert(E && "cannot infer type from null expression"); + return StoreAndRef(E, getNonConstType(E->getType(), m_Context, m_Sema), d, + prefix, forceDeclCreation, IS); + } - /// Ends control flow switch statement scope. - void EndCFSwitchStmtScope() const; + /// An overload allowing to specify the type for the variable. + clang::Expr* StoreAndRef(clang::Expr* E, clang::QualType Type, + direction d = direction::forward, + llvm::StringRef prefix = "_t", + bool forceDeclCreation = false, + clang::VarDecl::InitializationStyle IS = + clang::VarDecl::InitializationStyle::CInit) { + // Name reverse temporaries as "_r" instead of "_t". + if ((d == direction::reverse) && (prefix == "_t")) + prefix = "_r"; + Stmts* blk = nullptr; + if (d == direction::essential_reverse) { + if (!m_EssentialReverse.empty()) + blk = &getCurrentBlock(direction::essential_reverse); + else + blk = &getCurrentBlock(direction::reverse); + } else + blk = &getCurrentBlock(d); + return VisitorBase::StoreAndRef(E, Type, *blk, prefix, + forceDeclCreation, IS); + } - /// Builds and returns a switch case statement that corresponds - /// to a `break` or `continue` statement and is registered in the - /// control flow switch statement. - clang::CaseStmt* GetNextCFCaseStmt(); + /// For an expr E, decides if it is useful to store it in a global temporary + /// variable and replace E's further usage by a reference to that variable + /// to avoid recomputiation. + bool UsefulToStoreGlobal(clang::Expr* E); + + /// Builds a variable declaration and stores it in the function + /// global scope. + /// + /// \param[in] Type The type of variable declaration to build. + /// + /// \param[in] prefix The prefix (if any) to the declration name. + /// + /// \param[in] init The variable declaration initializer. + /// + /// \returns A variable declaration that is already added to the + /// global scope. + clang::VarDecl* GlobalStoreImpl(clang::QualType Type, + llvm::StringRef prefix, + clang::Expr* init = nullptr); + /// Creates a (global in the function scope) variable declaration, puts + /// it into m_Globals block (to be inserted into the beginning of fn's + /// body). Returns reference R to the created declaration. If E is not null, + /// puts an additional assignment statement (R = E) in the forward block. + /// Alternatively, if isInsideLoop is true, stores E in a stack. Returns + /// StmtDiff, where .getExpr() is intended to be used in forward pass and + /// .getExpr_dx() in the reverse pass. Two expressions can be different in + /// some cases, e.g. clad::push/pop inside loops. + StmtDiff GlobalStoreAndRef(clang::Expr* E, + clang::QualType Type, + llvm::StringRef prefix = "_t", + bool force = false); + StmtDiff GlobalStoreAndRef(clang::Expr* E, + llvm::StringRef prefix = "_t", + bool force = false); + + //// A type returned by DelayedGlobalStoreAndRef + /// .Result is a reference to the created (yet uninitialized) global + /// variable. When the expression is finally visited and rebuilt, .Finalize + /// must be called with new rebuilt expression, to initialize the global + /// variable. Alternatively, expression may be not worth storing in a global + /// varialbe and is easy to clone (e.g. it is a constant literal). Then + /// .Result is cloned E, .isConstant is true and .Finalize does nothing. + struct DelayedStoreResult { + ReverseModeVisitor& V; + StmtDiff Result; + bool isConstant; + bool isInsideLoop; + bool needsUpdate; + DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult, + bool pIsConstant, bool pIsInsideLoop, + bool pNeedsUpdate = false) + : V(pV), Result(pResult), isConstant(pIsConstant), + isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {} + void Finalize(clang::Expr* New); + }; + + /// Sometimes (e.g. when visiting multiplication/division operator), we + /// need to allocate global variable for an expression (e.g. for RHS) before + /// we visit that expression for efficiency reasons, since we may use that + /// global variable for visiting another expression (e.g. LHS) instead of + /// cloning LHS. The global variable will be assigned with the actual + /// expression only later, after the expression is visited and rebuilt. + /// This is what DelayedGlobalStoreAndRef does. E is expected to be the + /// original (uncloned) expression. + DelayedStoreResult DelayedGlobalStoreAndRef(clang::Expr* E, + llvm::StringRef prefix = "_t"); + + struct CladTapeResult { + ReverseModeVisitor& V; + clang::Expr* Push; + clang::Expr* Pop; + clang::Expr* Ref; + /// A request to get expr accessing last element in the tape + /// (clad::back(Ref)). Since it is required only rarely, it is built on + /// demand in the method. + clang::Expr* Last(); + }; + + /// Make a clad::tape to store variables. + /// If E is supposed to be stored in a tape, will create a global + /// declaration of tape of corresponding type and return a result struct + /// with reference to the tape and constructed calls to push/pop methods. + /// + /// \param[in] E The expression to build the tape for. + /// + /// \param[in] prefix The prefix value for the name of the tape. + /// + /// \returns A struct containg necessary call expressions for the built + /// tape + CladTapeResult MakeCladTapeFor(clang::Expr* E, + llvm::StringRef prefix = "_t"); - /// Builds and returns `clad::push(TapeRef, m_CurrentCounter)` - /// expression, where `TapeRef` and `m_CurrentCounter` are replaced - /// by their actual values respectively. - clang::Stmt* CreateCFTapePushExprToCurrentCase(); + public: + ReverseModeVisitor(DerivativeBuilder& builder); + virtual ~ReverseModeVisitor(); + + ///\brief Produces the gradient of a given function. + /// + ///\param[in] FD - the function that will be differentiated. + /// + ///\returns The gradient of the function, potentially created enclosing + /// context and if generated, its overload. + /// + /// We name the gradient of f as 'f_grad'. + /// If the gradient of the same function is requested several times + /// with different parameters, but same parameter types, every such request + /// will create f_grad function with the same signature, which will be + /// ambiguous. E.g. + /// double f(double x, double y, double z) { ... } + /// clad::gradient(f, "x, y"); + /// clad::gradient(f, "x, z"); + /// will create 2 definitions for f_grad with the same signature. + /// + /// Improved naming scheme is required. Hence, we append the indices to of + /// the requested parameters to 'f_grad', i.e. in the previous example "x, + /// y" will give 'f_grad_0_1' and "x, z" will give 'f_grad_0_2'. + DerivativeAndOverload Derive(const clang::FunctionDecl* FD, + const DiffRequest& request); + DerivativeAndOverload DerivePullback(const clang::FunctionDecl* FD, + const DiffRequest& request); + StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); + StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); + StmtDiff VisitCallExpr(const clang::CallExpr* CE); + virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); + StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); + StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); + StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); + virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); + StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); + StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL); + StmtDiff VisitForStmt(const clang::ForStmt* FS); + StmtDiff VisitIfStmt(const clang::IfStmt* If); + StmtDiff VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE); + StmtDiff VisitInitListExpr(const clang::InitListExpr* ILE); + StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL); + StmtDiff VisitMemberExpr(const clang::MemberExpr* ME); + StmtDiff VisitParenExpr(const clang::ParenExpr* PE); + virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); + StmtDiff VisitStmt(const clang::Stmt* S); + StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp); + StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC); + /// Decl is not Stmt, so it cannot be visited directly. + StmtDiff VisitWhileStmt(const clang::WhileStmt* WS); + StmtDiff VisitDoStmt(const clang::DoStmt* DS); + StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS); + StmtDiff VisitBreakStmt(const clang::BreakStmt* BS); + StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE); + StmtDiff VisitCXXConstructExpr(const clang::CXXConstructExpr* CE); + StmtDiff + VisitMaterializeTemporaryExpr(const clang::MaterializeTemporaryExpr* MTE); + StmtDiff VisitCXXStaticCastExpr(const clang::CXXStaticCastExpr* SCE); + VarDeclDiff DifferentiateVarDecl(const clang::VarDecl* VD); + + /// A helper method to differentiate a single Stmt in the reverse mode. + /// Internally, calls Visit(S, expr). Its result is wrapped into a + /// CompoundStmt (if several statements are created) and proper Stmt + /// order is maintained. + /// + /// \param[in] S The statement to differentiate. + /// + /// \param[in] dfdS The expression to propogate to Visit + /// + /// \returns The orignal (cloned) and differentiated forms of S + StmtDiff DifferentiateSingleStmt(const clang::Stmt* S, + clang::Expr* dfdS = nullptr); + /// A helper method used to keep substatements created by Visit(E, expr) in + /// separate forward/reverse blocks instead of putting them into current + /// blocks. First result is a StmtDiff of forward/reverse blocks with + /// additionally created Stmts, second is a direct result of call to Visit. + std::pair + DifferentiateSingleExpr(const clang::Expr* E, clang::Expr* dfdE = nullptr); + /// Shorthand for warning on differentiation of unsupported operators + void unsupportedOpWarn(clang::SourceLocation loc, + llvm::ArrayRef args = {}) { + diag(clang::DiagnosticsEngine::Warning, + loc, + "attempt to differentiate unsupported operator, ignored.", + args); + } + /// Builds an overload for the gradient function that has derived params for + /// all the arguments of the requested function and it calls the original + /// gradient function internally + clang::FunctionDecl* CreateGradientOverload(); + + /// Returns the type that should be used to represent the derivative of a + /// variable of type `yType` with respect to a parameter variable of type + /// `xType`. + /// + /// FIXME: Parameter derivative type rules are different from the derivative + /// type rules for local variables. We should remove this inconsistency. + /// See the following issue for more details: + /// https://github.com/vgvassilev/clad/issues/385 + clang::QualType GetParameterDerivativeType(clang::QualType yType, + clang::QualType xType); + + /// Allows to easily create and manage a counter for counting the number of + /// executed iterations of a loop. + /// + /// It is required to save the number of executed iterations to use the + /// same number of iterations in the reverse pass. + /// If we are currently inside a loop, then a clad tape object is created + /// to be used as the counter; otherwise, a temporary global variable (in + /// function scope) is created to be used as the counter. + class LoopCounter { + clang::Expr *m_Ref = nullptr; + clang::Expr *m_Pop = nullptr; + clang::Expr *m_Push = nullptr; + ReverseModeVisitor& m_RMV; + + public: + LoopCounter(ReverseModeVisitor& RMV); + /// Returns `clad::push(_t, 0UL)` expression if clad tape is used + /// for counter; otherwise, returns nullptr. + clang::Expr* getPush() const { return m_Push; } + + /// Returns `clad::pop(_t)` expression if clad tape is used for + /// for counter; otherwise, returns nullptr. + clang::Expr* getPop() const { return m_Pop; } + + /// Returns reference to the last object of the clad tape if clad tape + /// is used as the counter; otherwise returns reference to the counter + /// variable. + clang::Expr* getRef() const { return m_Ref; } + + /// Returns counter post-increment expression (`counter++`). + clang::Expr* getCounterIncrement() { + return m_RMV.BuildOp(clang::UnaryOperatorKind::UO_PostInc, m_Ref); + } + + /// Returns counter post-decrement expression (`counter--`) + clang::Expr* getCounterDecrement() { + return m_RMV.BuildOp(clang::UnaryOperatorKind::UO_PostDec, m_Ref); + } + + /// Returns `ConditionResult` object for the counter. + clang::Sema::ConditionResult getCounterConditionResult() { + return m_RMV.m_Sema.ActOnCondition(m_RMV.m_CurScope, noLoc, m_Ref, + clang::Sema::ConditionKind::Boolean); + } + }; + + /// Helper function to differentiate a loop body. + /// + ///\param[in] body body of the loop + ///\param[in] loopCounter associated `LoopCounter` object of the loop. + ///\param[in] condVarDiff derived statements of the condition + /// variable, if any. + ///\param[in] forLoopIncDiff derived statements of the `for` loop + /// increment statement, if any. + ///\param[in] isForLoop should be true if we are differentiating a `for` + /// loop body; otherwise false. + ///\returns {forward pass statements, reverse pass statements} for the loop + /// body. + StmtDiff DifferentiateLoopBody(const clang::Stmt* body, + LoopCounter& loopCounter, + clang::Stmt* condVarDifff = nullptr, + clang::Stmt* forLoopIncDiff = nullptr, + bool isForLoop = false); + + /// This class modifies forward and reverse blocks of the loop + /// body so that `break` and `continue` statements are correctly + /// handled. `break` and `continue` statements are handled by + /// enclosing entire reverse block loop body in a switch statement + /// and only executing the statements, with the help of case labels, + /// that were executed in the associated forward iteration. This is + /// determined by keeping track of which `break`/`continue` statement + /// was hit in which iteration and that in turn helps to determine which + /// case label should be selected. + /// + /// Class usage: + /// + /// ```cpp + /// auto activeBreakContStmtHandler = PushBreakContStmtHandler(); + /// activeBreakContHandler->BeginCFSwitchStmtScope(); + /// .... + /// Differentiate loop body, and save results in StmtDiff BodyDiff + /// ... + /// activeBreakContHandler->EndCFSwitchStmtScope(); + /// activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + /// PopBreakContStmtHandler(); + /// ``` + class BreakContStmtHandler { + /// Keeps track of all the created switch cases. It is required + /// because we need to register all the switch cases later with the + /// switch statement that will be used to manage the control flow in + /// the reverse block. + llvm::SmallVector m_SwitchCases; + + /// `m_ControlFlowTape` tape keeps track of which `break`/`continue` + /// statement was hit in which iteration. + /// \note `m_ControlFlowTape` is only initialized if the body contains + /// `continue` or `break` statement. + std::unique_ptr m_ControlFlowTape; + + /// Each `break` and `continue` statement is assigned a unique number, + /// starting from 1, that is used as the case label corresponding to that `break`/`continue` + /// statement. `m_CaseCounter` stores the value that was used for last + /// `break`/`continue` statement. + std::size_t m_CaseCounter = 0; + + ReverseModeVisitor& m_RMV; + + /// Builds and returns a literal expression of type `std::size_t` with + /// `value` as value. + clang::Expr* CreateSizeTLiteralExpr(std::size_t value); + + /// Initialise the `m_ControlFlowTape`. + /// \note `m_ControlFlowTape` is not initialised in the constructor + /// because it is only initialised if it is required. It is only required + /// if body contains `break` or `continue` statement. + void InitializeCFTape(); + + /// Builds and returns `clad::push(tapeRef, value)` expression. + clang::Expr* CreateCFTapePushExpr(std::size_t value); + + public: + BreakContStmtHandler(ReverseModeVisitor& RMV) : m_RMV(RMV) {} + + /// Begins control flow switch statement scope. + /// Control flow switch statement is used to refer to the + /// switch statement that manages the control flow of the reverse + /// block. + void BeginCFSwitchStmtScope() const; + + /// Ends control flow switch statement scope. + void EndCFSwitchStmtScope() const; + + /// Builds and returns a switch case statement that corresponds + /// to a `break` or `continue` statement and is registered in the + /// control flow switch statement. + clang::CaseStmt* GetNextCFCaseStmt(); + + /// Builds and returns `clad::push(TapeRef, m_CurrentCounter)` + /// expression, where `TapeRef` and `m_CurrentCounter` are replaced + /// by their actual values respectively. + clang::Stmt* CreateCFTapePushExprToCurrentCase(); + + /// Does final modifications on forward and reverse blocks + /// so that `break` and `continue` statements are handled + /// accurately. + void UpdateForwAndRevBlocks(StmtDiff& bodyDiff); + }; + // Keeps track of active control flow switch statements. + llvm::SmallVector m_BreakContStmtHandlers; + + BreakContStmtHandler* GetActiveBreakContStmtHandler() { + return &m_BreakContStmtHandlers.back(); + } + BreakContStmtHandler* PushBreakContStmtHandler() { + m_BreakContStmtHandlers.emplace_back(*this); + return &m_BreakContStmtHandlers.back(); + } + void PopBreakContStmtHandler() { + m_BreakContStmtHandlers.pop_back(); + } - /// Does final modifications on forward and reverse blocks - /// so that `break` and `continue` statements are handled - /// accurately. - void UpdateForwAndRevBlocks(StmtDiff& bodyDiff); + /// Registers an external RMV source. + /// + /// Multiple external RMV source can be registered by calling this function + /// multiple times. + ///\paramp[in] source An external RMV source + void AddExternalSource(ExternalRMVSource& source); + + /// Computes and returns the sequence of derived function parameter types. + /// + /// Information about the original function and the differentiation mode + /// are taken from the data member variables. In particular, `m_Function`, + /// `m_Mode` data members should be correctly set before using this + /// function. + llvm::SmallVector ComputeParamTypes(const DiffParams& diffParams); + + /// Builds and returns the sequence of derived function parameters. + /// + /// Information about the original function, derived function, derived + /// function parameter types and the differentiation mode are implicitly + /// taken from the data member variables. In particular, `m_Function`, + /// `m_Mode` and `m_Derivative` should be correctly set before using this + /// function. + llvm::SmallVector + BuildParams(DiffParams& diffParams); + + clang::QualType ComputeAdjointType(clang::QualType T); + clang::QualType ComputeParamType(clang::QualType T); + + std::vector GetInnermostReturnExpr(clang::Expr* E); }; - // Keeps track of active control flow switch statements. - llvm::SmallVector m_BreakContStmtHandlers; - - BreakContStmtHandler* GetActiveBreakContStmtHandler() { - return &m_BreakContStmtHandlers.back(); - } - BreakContStmtHandler* PushBreakContStmtHandler() { - m_BreakContStmtHandlers.emplace_back(*this); - return &m_BreakContStmtHandlers.back(); - } - void PopBreakContStmtHandler() { m_BreakContStmtHandlers.pop_back(); } - - /// Registers an external RMV source. - /// - /// Multiple external RMV source can be registered by calling this function - /// multiple times. - ///\paramp[in] source An external RMV source - void AddExternalSource(ExternalRMVSource& source); - - /// Computes and returns the sequence of derived function parameter types. - /// - /// Information about the original function and the differentiation mode - /// are taken from the data member variables. In particular, `m_Function`, - /// `m_Mode` data members should be correctly set before using this - /// function. - llvm::SmallVector - ComputeParamTypes(const DiffParams& diffParams); - - /// Builds and returns the sequence of derived function parameters. - /// - /// Information about the original function, derived function, derived - /// function parameter types and the differentiation mode are implicitly - /// taken from the data member variables. In particular, `m_Function`, - /// `m_Mode` and `m_Derivative` should be correctly set before using this - /// function. - llvm::SmallVector BuildParams(DiffParams& diffParams); - - clang::QualType ComputeAdjointType(clang::QualType T); - clang::QualType ComputeParamType(clang::QualType T); - - std::vector GetInnermostReturnExpr(clang::Expr* E); -}; } // end namespace clad #endif // CLAD_REVERSE_MODE_VISITOR_H diff --git a/include/clad/Differentiator/TBRAnalyzer.h b/include/clad/Differentiator/TBRAnalyzer.h index c3c223f43..b2677e4e7 100644 --- a/include/clad/Differentiator/TBRAnalyzer.h +++ b/include/clad/Differentiator/TBRAnalyzer.h @@ -223,7 +223,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { enum Mode { markingMode = 1, nonLinearMode = 2 }; /// Tells if the variable at a given location is required to store. Basically, /// is the result of analysis. - std::set TBRLocs; + std::map TBRLocs; /// Stores modes in a stack (used to retrieve the old mode after entering /// a new one). @@ -297,7 +297,7 @@ class TBRAnalyzer : public clang::ConstStmtVisitor { TBRAnalyzer& operator=(const TBRAnalyzer&&) = delete; /// Returns the result of the whole analysis - std::set getResult() { return TBRLocs; } + std::map getResult() { return TBRLocs; } /// Visitors void Analyze(const clang::FunctionDecl* FD); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 63ed52a07..d8b9e36a1 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -48,2169 +48,2396 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return nullptr; } -Expr* ReverseModeVisitor::CladTapeResult::Last() { - LookupResult& Back = V.GetCladTapeBack(); - CXXScopeSpec CSS; - CSS.Extend(V.m_Context, V.GetCladNamespace(), noLoc, noLoc); - Expr* BackDRE = V.m_Sema - .BuildDeclarationNameExpr(CSS, Back, + Expr* ReverseModeVisitor::CladTapeResult::Last() { + LookupResult& Back = V.GetCladTapeBack(); + CXXScopeSpec CSS; + CSS.Extend(V.m_Context, V.GetCladNamespace(), noLoc, noLoc); + Expr* BackDRE = V.m_Sema + .BuildDeclarationNameExpr(CSS, Back, + /*AcceptInvalidDecl=*/false) + .get(); + Expr* Call = + V.m_Sema.ActOnCallExpr(V.getCurrentScope(), BackDRE, noLoc, Ref, noLoc) + .get(); + return Call; + } + + ReverseModeVisitor::CladTapeResult + ReverseModeVisitor::MakeCladTapeFor(Expr* E, llvm::StringRef prefix) { + assert(E && "must be provided"); + if (auto IE = dyn_cast(E)) { + E = IE->getSubExpr()->IgnoreImplicit(); + } + QualType EQt = E->getType(); + if (dyn_cast(EQt)) + EQt = GetCladArrayOfType(utils::GetValueType(EQt)); + QualType TapeType = + GetCladTapeOfType(getNonConstType(EQt, m_Context, m_Sema)); + LookupResult& Push = GetCladTapePush(); + LookupResult& Pop = GetCladTapePop(); + Expr* TapeRef = + BuildDeclRef(GlobalStoreImpl(TapeType, prefix, getZeroInit(TapeType))); + auto VD = cast(cast(TapeRef)->getDecl()); + // Add fake location, since Clang AST does assert(Loc.isValid()) somewhere. + VD->setLocation(m_Function->getLocation()); + CXXScopeSpec CSS; + CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc); + auto PopDRE = m_Sema + .BuildDeclarationNameExpr(CSS, Pop, /*AcceptInvalidDecl=*/false) .get(); - Expr* Call = - V.m_Sema.ActOnCallExpr(V.getCurrentScope(), BackDRE, noLoc, Ref, noLoc) - .get(); - return Call; -} - -ReverseModeVisitor::CladTapeResult -ReverseModeVisitor::MakeCladTapeFor(Expr* E, llvm::StringRef prefix) { - assert(E && "must be provided"); - if (auto IE = dyn_cast(E)) - E = IE->getSubExpr()->IgnoreImplicit(); - QualType EQt = E->getType(); - if (dyn_cast(EQt)) - EQt = GetCladArrayOfType(utils::GetValueType(EQt)); - QualType TapeType = - GetCladTapeOfType(getNonConstType(EQt, m_Context, m_Sema)); - LookupResult& Push = GetCladTapePush(); - LookupResult& Pop = GetCladTapePop(); - Expr* TapeRef = - BuildDeclRef(GlobalStoreImpl(TapeType, prefix, getZeroInit(TapeType))); - auto VD = cast(cast(TapeRef)->getDecl()); - // Add fake location, since Clang AST does assert(Loc.isValid()) somewhere. - VD->setLocation(m_Function->getLocation()); - CXXScopeSpec CSS; - CSS.Extend(m_Context, GetCladNamespace(), noLoc, noLoc); - auto PopDRE = m_Sema - .BuildDeclarationNameExpr(CSS, Pop, - /*AcceptInvalidDecl=*/false) - .get(); - auto PushDRE = m_Sema - .BuildDeclarationNameExpr(CSS, Push, - /*AcceptInvalidDecl=*/false) - .get(); - Expr* PopExpr = - m_Sema.ActOnCallExpr(getCurrentScope(), PopDRE, noLoc, TapeRef, noLoc) - .get(); - Expr* exprToPush = E; - if (auto AT = dyn_cast(E->getType())) { - Expr* init = getArraySizeExpr(AT, m_Context, *this); - exprToPush = BuildOp(BO_Comma, E, init); + auto PushDRE = m_Sema + .BuildDeclarationNameExpr(CSS, Push, + /*AcceptInvalidDecl=*/false) + .get(); + Expr* PopExpr = + m_Sema.ActOnCallExpr(getCurrentScope(), PopDRE, noLoc, TapeRef, noLoc) + .get(); + Expr* exprToPush = E; + if (auto AT = dyn_cast(E->getType())) { + Expr* init = getArraySizeExpr(AT, m_Context, *this); + exprToPush = BuildOp(BO_Comma, E, init); + } + Expr* CallArgs[] = {TapeRef, exprToPush}; + Expr* PushExpr = + m_Sema.ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, CallArgs, noLoc) + .get(); + return CladTapeResult{*this, PushExpr, PopExpr, TapeRef}; } - Expr* CallArgs[] = {TapeRef, exprToPush}; - Expr* PushExpr = - m_Sema.ActOnCallExpr(getCurrentScope(), PushDRE, noLoc, CallArgs, noLoc) - .get(); - return CladTapeResult{*this, PushExpr, PopExpr, TapeRef}; -} -ReverseModeVisitor::ReverseModeVisitor(DerivativeBuilder& builder) - : VisitorBase(builder), m_Result(nullptr) {} - -ReverseModeVisitor::~ReverseModeVisitor() { - if (m_ExternalSource) { - // Inform external sources that `ReverseModeVisitor` object no longer - // exists. - // FIXME: Make this so the lifetime scope of the source matches. - // m_ExternalSource->ForgetRMV(); - // Free the external sources multiplexer since we own this resource. - delete m_ExternalSource; - } -} + ReverseModeVisitor::ReverseModeVisitor(DerivativeBuilder& builder) + : VisitorBase(builder), m_Result(nullptr) {} -FunctionDecl* ReverseModeVisitor::CreateGradientOverload() { - auto gradientParams = m_Derivative->parameters(); - auto gradientNameInfo = m_Derivative->getNameInfo(); - // Calculate the total number of parameters that would be required for - // automatic differentiation in the derived function if all args are - // requested. - // FIXME: Here we are assuming all function parameters are of differentiable - // type. Ideally, we should not make any such assumption. - std::size_t totalDerivedParamsSize = m_Function->getNumParams() * 2; - std::size_t numOfDerivativeParams = m_Function->getNumParams(); - - // Account for the this pointer. - if (isa(m_Function) && !utils::IsStaticMethod(m_Function)) - ++numOfDerivativeParams; - // All output parameters will be of type `clad::array_ref`. These - // parameters will be casted to correct type before the call to the actual - // derived function. - // We require each output parameter to be of same type in the overloaded - // derived function due to limitations of generating the exact derived - // function type at the compile-time (without clad plugin help). - QualType outputParamType = GetCladArrayRefOfType(m_Context.VoidTy); - - llvm::SmallVector paramTypes; - - // Add types for representing original function parameters. - for (auto PVD : m_Function->parameters()) - paramTypes.push_back(PVD->getType()); - // Add types for representing parameter derivatives. - // FIXME: We are assuming all function parameters are differentiable. We - // should not make any such assumptions. - for (std::size_t i = 0; i < numOfDerivativeParams; ++i) - paramTypes.push_back(outputParamType); - - auto gradFuncOverloadEPI = - dyn_cast(m_Function->getType())->getExtProtoInfo(); - QualType gradientFunctionOverloadType = - m_Context.getFunctionType(m_Context.VoidTy, paramTypes, - // Cast to function pointer. - gradFuncOverloadEPI); - - DeclContext* DC = const_cast(m_Function->getDeclContext()); - m_Sema.CurContext = DC; - DeclWithContext gradientOverloadFDWC = - m_Builder.cloneFunction(m_Function, *this, DC, noLoc, gradientNameInfo, - gradientFunctionOverloadType); - FunctionDecl* gradientOverloadFD = gradientOverloadFDWC.first; - - beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | - Scope::DeclScope); - m_Sema.PushFunctionScope(); - m_Sema.PushDeclContext(getCurrentScope(), gradientOverloadFD); - - llvm::SmallVector overloadParams; - llvm::SmallVector callArgs; - - overloadParams.reserve(totalDerivedParamsSize); - callArgs.reserve(gradientParams.size()); - - for (auto PVD : m_Function->parameters()) { - auto VD = utils::BuildParmVarDecl( - m_Sema, gradientOverloadFD, PVD->getIdentifier(), PVD->getType(), - PVD->getStorageClass(), /*defArg=*/nullptr, PVD->getTypeSourceInfo()); - overloadParams.push_back(VD); - callArgs.push_back(BuildDeclRef(VD)); + ReverseModeVisitor::~ReverseModeVisitor() { + if (m_ExternalSource) { + // Inform external sources that `ReverseModeVisitor` object no longer + // exists. + // FIXME: Make this so the lifetime scope of the source matches. + // m_ExternalSource->ForgetRMV(); + // Free the external sources multiplexer since we own this resource. + delete m_ExternalSource; + } } - for (std::size_t i = 0; i < numOfDerivativeParams; ++i) { - IdentifierInfo* II = nullptr; - StorageClass SC = StorageClass::SC_None; - std::size_t effectiveGradientIndex = m_Function->getNumParams() + i; - // `effectiveGradientIndex < gradientParams.size()` implies that this - // parameter represents an actual derivative of one of the function - // original parameters. - if (effectiveGradientIndex < gradientParams.size()) { - auto GVD = gradientParams[effectiveGradientIndex]; - II = CreateUniqueIdentifier("_temp_" + GVD->getNameAsString()); - SC = GVD->getStorageClass(); - } else { - II = CreateUniqueIdentifier("_d_" + std::to_string(i)); + FunctionDecl* ReverseModeVisitor::CreateGradientOverload() { + auto gradientParams = m_Derivative->parameters(); + auto gradientNameInfo = m_Derivative->getNameInfo(); + // Calculate the total number of parameters that would be required for + // automatic differentiation in the derived function if all args are + // requested. + // FIXME: Here we are assuming all function parameters are of differentiable + // type. Ideally, we should not make any such assumption. + std::size_t totalDerivedParamsSize = m_Function->getNumParams() * 2; + std::size_t numOfDerivativeParams = m_Function->getNumParams(); + + // Account for the this pointer. + if (isa(m_Function) && !utils::IsStaticMethod(m_Function)) + ++numOfDerivativeParams; + // All output parameters will be of type `clad::array_ref`. These + // parameters will be casted to correct type before the call to the actual + // derived function. + // We require each output parameter to be of same type in the overloaded + // derived function due to limitations of generating the exact derived + // function type at the compile-time (without clad plugin help). + QualType outputParamType = GetCladArrayRefOfType(m_Context.VoidTy); + + llvm::SmallVector paramTypes; + + // Add types for representing original function parameters. + for (auto PVD : m_Function->parameters()) + paramTypes.push_back(PVD->getType()); + // Add types for representing parameter derivatives. + // FIXME: We are assuming all function parameters are differentiable. We + // should not make any such assumptions. + for (std::size_t i = 0; i < numOfDerivativeParams; ++i) + paramTypes.push_back(outputParamType); + + auto gradFuncOverloadEPI = + dyn_cast(m_Function->getType())->getExtProtoInfo(); + QualType gradientFunctionOverloadType = + m_Context.getFunctionType(m_Context.VoidTy, paramTypes, + // Cast to function pointer. + gradFuncOverloadEPI); + + DeclContext* DC = const_cast(m_Function->getDeclContext()); + m_Sema.CurContext = DC; + DeclWithContext gradientOverloadFDWC = + m_Builder.cloneFunction(m_Function, *this, DC, noLoc, gradientNameInfo, + gradientFunctionOverloadType); + FunctionDecl* gradientOverloadFD = gradientOverloadFDWC.first; + + beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | + Scope::DeclScope); + m_Sema.PushFunctionScope(); + m_Sema.PushDeclContext(getCurrentScope(), gradientOverloadFD); + + llvm::SmallVector overloadParams; + llvm::SmallVector callArgs; + + overloadParams.reserve(totalDerivedParamsSize); + callArgs.reserve(gradientParams.size()); + + for (auto PVD : m_Function->parameters()) { + auto VD = utils::BuildParmVarDecl( + m_Sema, gradientOverloadFD, PVD->getIdentifier(), PVD->getType(), + PVD->getStorageClass(), /*defArg=*/nullptr, PVD->getTypeSourceInfo()); + overloadParams.push_back(VD); + callArgs.push_back(BuildDeclRef(VD)); } - auto PVD = utils::BuildParmVarDecl(m_Sema, gradientOverloadFD, II, - outputParamType, SC); - overloadParams.push_back(PVD); - } - for (auto PVD : overloadParams) - if (PVD->getIdentifier()) - m_Sema.PushOnScopeChains(PVD, getCurrentScope(), - /*AddToContext=*/false); - - gradientOverloadFD->setParams(overloadParams); - gradientOverloadFD->setBody(/*B=*/nullptr); - - beginScope(Scope::FnScope | Scope::DeclScope); - m_DerivativeFnScope = getCurrentScope(); - beginBlock(); - - // Build derivatives to be used in the call to the actual derived function. - // These are initialised by effectively casting the derivative parameters of - // overloaded derived function to the correct type. - for (std::size_t i = m_Function->getNumParams(); i < gradientParams.size(); - ++i) { - auto overloadParam = overloadParams[i]; - auto gradientParam = gradientParams[i]; - - auto gradientVD = - BuildVarDecl(gradientParam->getType(), gradientParam->getName(), - BuildDeclRef(overloadParam)); - callArgs.push_back(BuildDeclRef(gradientVD)); - addToCurrentBlock(BuildDeclStmt(gradientVD)); - } + for (std::size_t i = 0; i < numOfDerivativeParams; ++i) { + IdentifierInfo* II = nullptr; + StorageClass SC = StorageClass::SC_None; + std::size_t effectiveGradientIndex = m_Function->getNumParams() + i; + // `effectiveGradientIndex < gradientParams.size()` implies that this + // parameter represents an actual derivative of one of the function + // original parameters. + if (effectiveGradientIndex < gradientParams.size()) { + auto GVD = gradientParams[effectiveGradientIndex]; + II = CreateUniqueIdentifier("_temp_" + GVD->getNameAsString()); + SC = GVD->getStorageClass(); + } else { + II = CreateUniqueIdentifier("_d_" + std::to_string(i)); + } + auto PVD = utils::BuildParmVarDecl(m_Sema, gradientOverloadFD, II, + outputParamType, SC); + overloadParams.push_back(PVD); + } - Expr* callExpr = BuildCallExprToFunction(m_Derivative, callArgs, - /*UseRefQualifiedThisObj=*/true); - addToCurrentBlock(callExpr); - Stmt* gradientOverloadBody = endBlock(); + for (auto PVD : overloadParams) { + if (PVD->getIdentifier()) + m_Sema.PushOnScopeChains(PVD, getCurrentScope(), + /*AddToContext=*/false); + } - gradientOverloadFD->setBody(gradientOverloadBody); + gradientOverloadFD->setParams(overloadParams); + gradientOverloadFD->setBody(/*B=*/nullptr); + + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); + beginBlock(); + + // Build derivatives to be used in the call to the actual derived function. + // These are initialised by effectively casting the derivative parameters of + // overloaded derived function to the correct type. + for (std::size_t i = m_Function->getNumParams(); i < gradientParams.size(); + ++i) { + auto overloadParam = overloadParams[i]; + auto gradientParam = gradientParams[i]; + + auto gradientVD = + BuildVarDecl(gradientParam->getType(), gradientParam->getName(), + BuildDeclRef(overloadParam)); + callArgs.push_back(BuildDeclRef(gradientVD)); + addToCurrentBlock(BuildDeclStmt(gradientVD)); + } - endScope(); // Function body scope - m_Sema.PopFunctionScopeInfo(); - m_Sema.PopDeclContext(); - endScope(); // Function decl scope + Expr* callExpr = BuildCallExprToFunction(m_Derivative, callArgs, + /*UseRefQualifiedThisObj=*/true); + addToCurrentBlock(callExpr); + Stmt* gradientOverloadBody = endBlock(); - return gradientOverloadFD; -} + gradientOverloadFD->setBody(gradientOverloadBody); -DerivativeAndOverload ReverseModeVisitor::Derive(const FunctionDecl* FD, - const DiffRequest& request) { - if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDerive(); - silenceDiags = !request.VerboseDiags; - m_Function = FD; - - // reverse mode plugins may have request mode other than - // `DiffMode::reverse`, but they still need the `DiffMode::reverse` mode - // specific behaviour, because they are "reverse" mode plugins. - m_Mode = DiffMode::reverse; - if (request.Mode == DiffMode::jacobian) - m_Mode = DiffMode::jacobian; - m_Pullback = ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1); - assert(m_Function && "Must not be null."); - - DiffParams args{}; - DiffInputVarsInfo DVI; - if (request.Args) { - DVI = request.DVI; - for (auto dParam : DVI) - args.push_back(dParam.param); - } else - std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - if (args.empty()) - return {}; + endScope(); // Function body scope + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); // Function decl scope - if (m_ExternalSource) - m_ExternalSource->ActAfterParsingDiffArgs(request, args); - // Save the type of the output parameter(s) that is add by clad to the - // derived function - if (request.Mode == DiffMode::jacobian) { - isVectorValued = true; - unsigned lastArgN = m_Function->getNumParams() - 1; - outputArrayStr = m_Function->getParamDecl(lastArgN)->getNameAsString(); + return gradientOverloadFD; } - // Check if DiffRequest asks for use of enzyme as backend - if (request.use_enzyme) - use_enzyme = true; - - auto derivativeBaseName = request.BaseFunctionName; - std::string gradientName = derivativeBaseName + funcPostfix(); - // To be consistent with older tests, nothing is appended to 'f_grad' if - // we differentiate w.r.t. all the parameters at once. - if (isVectorValued) { - // If Jacobian is asked, the last parameter is the result parameter - // and should be ignored - if (args.size() != FD->getNumParams() - 1) { - for (auto arg : args) { - auto it = std::find(FD->param_begin(), FD->param_end() - 1, arg); - auto idx = std::distance(FD->param_begin(), it); - gradientName += ('_' + std::to_string(idx)); - } + DerivativeAndOverload + ReverseModeVisitor::Derive(const FunctionDecl* FD, + const DiffRequest& request) { + if (m_ExternalSource) + m_ExternalSource->ActOnStartOfDerive(); + silenceDiags = !request.VerboseDiags; + m_Function = FD; + + // reverse mode plugins may have request mode other than + // `DiffMode::reverse`, but they still need the `DiffMode::reverse` mode + // specific behaviour, because they are "reverse" mode plugins. + m_Mode = DiffMode::reverse; + if (request.Mode == DiffMode::jacobian) + m_Mode = DiffMode::jacobian; + m_Pullback = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1); + assert(m_Function && "Must not be null."); + + DiffParams args{}; + DiffInputVarsInfo DVI; + if (request.Args) { + DVI = request.DVI; + for (auto dParam : DVI) + args.push_back(dParam.param); } - } else if (args.size() != FD->getNumParams()) { - for (auto arg : args) { - auto it = std::find(FD->param_begin(), FD->param_end(), arg); - auto idx = std::distance(FD->param_begin(), it); - gradientName += ('_' + std::to_string(idx)); + else + std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); + if (args.empty()) + return {}; + + if (m_ExternalSource) + m_ExternalSource->ActAfterParsingDiffArgs(request, args); + // Save the type of the output parameter(s) that is add by clad to the + // derived function + if (request.Mode == DiffMode::jacobian) { + isVectorValued = true; + unsigned lastArgN = m_Function->getNumParams() - 1; + outputArrayStr = m_Function->getParamDecl(lastArgN)->getNameAsString(); } - } - IdentifierInfo* II = &m_Context.Idents.get(gradientName); - DeclarationNameInfo name(II, noLoc); - - // If we are in error estimation mode, we have an extra `double&` - // parameter that stores the final error - unsigned numExtraParam = 0; - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnParamTypes(numExtraParam); - - auto paramTypes = ComputeParamTypes(args); - - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); - - // If reverse mode differentiates only part of the arguments it needs to - // generate an overload that can take in all the diff variables - bool shouldCreateOverload = false; - // FIXME: Gradient overload doesn't know how to handle additional parameters - // added by the plugins yet. - if (!isVectorValued && numExtraParam == 0) - shouldCreateOverload = true; - - auto originalFnType = dyn_cast(m_Function->getType()); - // For a function f of type R(A1, A2, ..., An), - // the type of the gradient function is void(A1, A2, ..., An, R*, R*, ..., - // R*) . the type of the jacobian function is void(A1, A2, ..., An, R*, R*) - // and for error estimation, the function type is - // void(A1, A2, ..., An, R*, R*, ..., R*, double&) - QualType gradientFunctionType = m_Context.getFunctionType( - m_Context.VoidTy, - llvm::ArrayRef(paramTypes.data(), paramTypes.size()), - // Cast to function pointer. - originalFnType->getExtProtoInfo()); - - // Create the gradient function declaration. - llvm::SaveAndRestore SaveContext(m_Sema.CurContext); - llvm::SaveAndRestore SaveScope(m_CurScope); - DeclContext* DC = const_cast(m_Function->getDeclContext()); - m_Sema.CurContext = DC; - DeclWithContext result = m_Builder.cloneFunction(m_Function, *this, DC, noLoc, - name, gradientFunctionType); - FunctionDecl* gradientFD = result.first; - m_Derivative = gradientFD; - - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnScope(); - - // Function declaration scope - beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | - Scope::DeclScope); - m_Sema.PushFunctionScope(); - m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); - - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnScope(); - - auto params = BuildParams(args); - - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnParams(params); - - llvm::ArrayRef paramsRef = - clad_compat::makeArrayRef(params.data(), params.size()); - gradientFD->setParams(paramsRef); - gradientFD->setBody(nullptr); - - if (isVectorValued) { - // Reference to the output parameter. - m_Result = BuildDeclRef(params.back()); - numParams = args.size(); - - // Creates the ArraySubscriptExprs for the independent variables - size_t idx = 0; - for (auto arg : args) { - // FIXME: fix when adding array inputs, now we are just skipping all - // array/pointer inputs (not treating them as independent variables). - if (utils::isArrayOrPointerType(arg->getType())) { - if (arg->getName() == "p") - m_Variables[arg] = m_Result; + // Check if DiffRequest asks for use of enzyme as backend + if (request.use_enzyme) + use_enzyme = true; + + auto derivativeBaseName = request.BaseFunctionName; + std::string gradientName = derivativeBaseName + funcPostfix(); + // To be consistent with older tests, nothing is appended to 'f_grad' if + // we differentiate w.r.t. all the parameters at once. + if(isVectorValued){ + // If Jacobian is asked, the last parameter is the result parameter + // and should be ignored + if (args.size() != FD->getNumParams()-1){ + for (auto arg : args) { + auto it = std::find(FD->param_begin(), FD->param_end()-1, arg); + auto idx = std::distance(FD->param_begin(), it); + gradientName += ('_' + std::to_string(idx)); + } + } + }else{ + if (args.size() != FD->getNumParams()){ + for (auto arg : args) { + auto it = std::find(FD->param_begin(), FD->param_end(), arg); + auto idx = std::distance(FD->param_begin(), it); + gradientName += ('_' + std::to_string(idx)); + } + } + } + + IdentifierInfo* II = &m_Context.Idents.get(gradientName); + DeclarationNameInfo name(II, noLoc); + + // If we are in error estimation mode, we have an extra `double&` + // parameter that stores the final error + unsigned numExtraParam = 0; + if (m_ExternalSource) + m_ExternalSource->ActBeforeCreatingDerivedFnParamTypes(numExtraParam); + + auto paramTypes = ComputeParamTypes(args); + + if (m_ExternalSource) + m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); + + // If reverse mode differentiates only part of the arguments it needs to + // generate an overload that can take in all the diff variables + bool shouldCreateOverload = false; + // FIXME: Gradient overload doesn't know how to handle additional parameters + // added by the plugins yet. + if (!isVectorValued && numExtraParam == 0) + shouldCreateOverload = true; + + auto originalFnType = dyn_cast(m_Function->getType()); + // For a function f of type R(A1, A2, ..., An), + // the type of the gradient function is void(A1, A2, ..., An, R*, R*, ..., + // R*) . the type of the jacobian function is void(A1, A2, ..., An, R*, R*) + // and for error estimation, the function type is + // void(A1, A2, ..., An, R*, R*, ..., R*, double&) + QualType gradientFunctionType = m_Context.getFunctionType( + m_Context.VoidTy, + llvm::ArrayRef(paramTypes.data(), paramTypes.size()), + // Cast to function pointer. + originalFnType->getExtProtoInfo()); + + // Create the gradient function declaration. + llvm::SaveAndRestore SaveContext(m_Sema.CurContext); + llvm::SaveAndRestore SaveScope(m_CurScope); + DeclContext* DC = const_cast(m_Function->getDeclContext()); + m_Sema.CurContext = DC; + DeclWithContext result = m_Builder.cloneFunction( + m_Function, *this, DC, noLoc, name, gradientFunctionType); + FunctionDecl* gradientFD = result.first; + m_Derivative = gradientFD; + + if (m_ExternalSource) + m_ExternalSource->ActBeforeCreatingDerivedFnScope(); + + // Function declaration scope + beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | + Scope::DeclScope); + m_Sema.PushFunctionScope(); + m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); + + if (m_ExternalSource) + m_ExternalSource->ActAfterCreatingDerivedFnScope(); + + auto params = BuildParams(args); + + if (m_ExternalSource) + m_ExternalSource->ActAfterCreatingDerivedFnParams(params); + + llvm::ArrayRef paramsRef = + clad_compat::makeArrayRef(params.data(), params.size()); + gradientFD->setParams(paramsRef); + gradientFD->setBody(nullptr); + + if (isVectorValued) { + // Reference to the output parameter. + m_Result = BuildDeclRef(params.back()); + numParams = args.size(); + + // Creates the ArraySubscriptExprs for the independent variables + size_t idx = 0; + for (auto arg : args) { + // FIXME: fix when adding array inputs, now we are just skipping all + // array/pointer inputs (not treating them as independent variables). + if (utils::isArrayOrPointerType(arg->getType())) { + if (arg->getName() == "p") + m_Variables[arg] = m_Result; + idx += 1; + continue; + } + auto size_type = m_Context.getSizeType(); + unsigned size_type_bits = m_Context.getIntWidth(size_type); + // Create the idx literal. + auto i = + IntegerLiteral::Create(m_Context, llvm::APInt(size_type_bits, idx), + size_type, noLoc); + // Create the jacobianMatrix[idx] expression. + auto result_at_i = + m_Sema.CreateBuiltinArraySubscriptExpr(m_Result, noLoc, i, noLoc) + .get(); + m_Variables[arg] = result_at_i; idx += 1; - continue; + m_IndependentVars.push_back(arg); } - auto size_type = m_Context.getSizeType(); - unsigned size_type_bits = m_Context.getIntWidth(size_type); - // Create the idx literal. - auto i = IntegerLiteral::Create( - m_Context, llvm::APInt(size_type_bits, idx), size_type, noLoc); - // Create the jacobianMatrix[idx] expression. - auto result_at_i = - m_Sema.CreateBuiltinArraySubscriptExpr(m_Result, noLoc, i, noLoc) - .get(); - m_Variables[arg] = result_at_i; - idx += 1; - m_IndependentVars.push_back(arg); } - } - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope(); + if (m_ExternalSource) + m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope(); + + // Function body scope. + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); + beginBlock(); + if (m_ExternalSource) + m_ExternalSource->ActOnStartOfDerivedFnBody(request); - // Function body scope. - beginScope(Scope::FnScope | Scope::DeclScope); - m_DerivativeFnScope = getCurrentScope(); - beginBlock(); - if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDerivedFnBody(request); + Stmt* gradientBody = nullptr; - Stmt* gradientBody = nullptr; + if (!use_enzyme) + DifferentiateWithClad(); + else + DifferentiateWithEnzyme(); + + gradientBody = endBlock(); + m_Derivative->setBody(gradientBody); + endScope(); // Function body scope + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); // Function decl scope + + FunctionDecl* gradientOverloadFD = nullptr; + if (shouldCreateOverload) { + gradientOverloadFD = + CreateGradientOverload(); + } - if (!use_enzyme) - DifferentiateWithClad(); - else - DifferentiateWithEnzyme(); + return DerivativeAndOverload{result.first, gradientOverloadFD}; + } - gradientBody = endBlock(); - m_Derivative->setBody(gradientBody); - endScope(); // Function body scope - m_Sema.PopFunctionScopeInfo(); - m_Sema.PopDeclContext(); - endScope(); // Function decl scope + DerivativeAndOverload + ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD, + const DiffRequest& request) { + auto* analyzer = new TBRAnalyzer(&m_Context); + analyzer->Analyze(FD); + m_ToBeRecorded = analyzer->getResult(); + delete analyzer; - FunctionDecl* gradientOverloadFD = nullptr; - if (shouldCreateOverload) - gradientOverloadFD = CreateGradientOverload(); + // for (auto pair : m_ToBeRecorded) { + // auto line = + // m_Context.getSourceManager().getPresumedLoc(pair.first).getLine(); auto + // column = + // m_Context.getSourceManager().getPresumedLoc(pair.first).getColumn(); + // llvm::errs() << line << "|" <ActOnStartOfDerive(); + silenceDiags = !request.VerboseDiags; + m_Function = FD; + m_Mode = DiffMode::experimental_pullback; + assert(m_Function && "Must not be null."); -DerivativeAndOverload -ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD, - const DiffRequest& request) { - auto* analyzer = new TBRAnalyzer(&m_Context); - analyzer->Analyze(FD); - m_ToBeRecorded = analyzer->getResult(); - delete analyzer; - - // for (auto pair : m_ToBeRecorded) { - // auto line = - // m_Context.getSourceManager().getPresumedLoc(pair.first).getLine(); auto - // column = - // m_Context.getSourceManager().getPresumedLoc(pair.first).getColumn(); - // llvm::errs() << line << "|" <ActOnStartOfDerive(); - silenceDiags = !request.VerboseDiags; - m_Function = FD; - m_Mode = DiffMode::experimental_pullback; - assert(m_Function && "Must not be null."); - - DiffParams args{}; - std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); + DiffParams args{}; + std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); #ifndef NDEBUG - bool isStaticMethod = utils::IsStaticMethod(FD); - assert((!args.empty() || !isStaticMethod) && - "Cannot generate pullback function of a function " - "with no differentiable arguments"); + bool isStaticMethod = utils::IsStaticMethod(FD); + assert((!args.empty() || !isStaticMethod) && + "Cannot generate pullback function of a function " + "with no differentiable arguments"); #endif - if (m_ExternalSource) - m_ExternalSource->ActAfterParsingDiffArgs(request, args); + if (m_ExternalSource) + m_ExternalSource->ActAfterParsingDiffArgs(request, args); - auto derivativeName = utils::ComputeEffectiveFnName(m_Function) + "_pullback"; - auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); + auto derivativeName = + utils::ComputeEffectiveFnName(m_Function) + "_pullback"; + auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); - auto paramTypes = ComputeParamTypes(args); - auto originalFnType = dyn_cast(m_Function->getType()); + auto paramTypes = ComputeParamTypes(args); + auto originalFnType = dyn_cast(m_Function->getType()); - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); + if (m_ExternalSource) + m_ExternalSource->ActAfterCreatingDerivedFnParamTypes(paramTypes); - QualType pullbackFnType = m_Context.getFunctionType( - m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo()); + QualType pullbackFnType = m_Context.getFunctionType( + m_Context.VoidTy, paramTypes, originalFnType->getExtProtoInfo()); - llvm::SaveAndRestore saveContext(m_Sema.CurContext); - llvm::SaveAndRestore saveScope(m_CurScope); - m_Sema.CurContext = const_cast(m_Function->getDeclContext()); + llvm::SaveAndRestore saveContext(m_Sema.CurContext); + llvm::SaveAndRestore saveScope(m_CurScope); + m_Sema.CurContext = const_cast(m_Function->getDeclContext()); - DeclWithContext fnBuildRes = m_Builder.cloneFunction( - m_Function, *this, m_Sema.CurContext, noLoc, DNI, pullbackFnType); - m_Derivative = fnBuildRes.first; + DeclWithContext fnBuildRes = m_Builder.cloneFunction( + m_Function, *this, m_Sema.CurContext, noLoc, DNI, pullbackFnType); + m_Derivative = fnBuildRes.first; - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnScope(); + if (m_ExternalSource) + m_ExternalSource->ActBeforeCreatingDerivedFnScope(); - beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | - Scope::DeclScope); - m_Sema.PushFunctionScope(); - m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); + beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | + Scope::DeclScope); + m_Sema.PushFunctionScope(); + m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnScope(); + if (m_ExternalSource) + m_ExternalSource->ActAfterCreatingDerivedFnScope(); - auto params = BuildParams(args); - if (m_ExternalSource) - m_ExternalSource->ActAfterCreatingDerivedFnParams(params); + auto params = BuildParams(args); + if (m_ExternalSource) + m_ExternalSource->ActAfterCreatingDerivedFnParams(params); - m_Derivative->setParams(params); - m_Derivative->setBody(nullptr); + m_Derivative->setParams(params); + m_Derivative->setBody(nullptr); - if (m_ExternalSource) - m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope(); + if (m_ExternalSource) + m_ExternalSource->ActBeforeCreatingDerivedFnBodyScope(); - beginScope(Scope::FnScope | Scope::DeclScope); - m_DerivativeFnScope = getCurrentScope(); + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); - beginBlock(); - if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDerivedFnBody(request); + beginBlock(); + if (m_ExternalSource) + m_ExternalSource->ActOnStartOfDerivedFnBody(request); - StmtDiff bodyDiff = Visit(m_Function->getBody()); - Stmt* forward = bodyDiff.getStmt(); - Stmt* reverse = bodyDiff.getStmt_dx(); + StmtDiff bodyDiff = Visit(m_Function->getBody()); + Stmt* forward = bodyDiff.getStmt(); + Stmt* reverse = bodyDiff.getStmt_dx(); - // Create the body of the function. - // Firstly, all "global" Stmts are put into fn's body. - for (Stmt* S : m_Globals) - addToCurrentBlock(S, direction::forward); - // Forward pass. - if (auto CS = dyn_cast(forward)) - for (Stmt* S : CS->body()) + // Create the body of the function. + // Firstly, all "global" Stmts are put into fn's body. + for (Stmt* S : m_Globals) addToCurrentBlock(S, direction::forward); + // Forward pass. + if (auto CS = dyn_cast(forward)) + for (Stmt* S : CS->body()) + addToCurrentBlock(S, direction::forward); - // Reverse pass. - if (auto RCS = dyn_cast(reverse)) - for (Stmt* S : RCS->body()) - addToCurrentBlock(S, direction::forward); + // Reverse pass. + if (auto RCS = dyn_cast(reverse)) + for (Stmt* S : RCS->body()) + addToCurrentBlock(S, direction::forward); - if (m_ExternalSource) - m_ExternalSource->ActOnEndOfDerivedFnBody(); + if (m_ExternalSource) + m_ExternalSource->ActOnEndOfDerivedFnBody(); + + Stmt* fnBody = endBlock(); + m_Derivative->setBody(fnBody); + endScope(); // Function body scope + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); // Function decl scope + + return DerivativeAndOverload{fnBuildRes.first, nullptr}; + } + + void ReverseModeVisitor::DifferentiateWithClad() { + auto* analyzer = new TBRAnalyzer(&m_Context); + analyzer->Analyze(m_Function); + m_ToBeRecorded = analyzer->getResult(); + delete analyzer; + + // for (auto pair : m_ToBeRecorded) { + // auto line = + // m_Context.getSourceManager().getPresumedLoc(pair.first).getLine(); auto + // column = + // m_Context.getSourceManager().getPresumedLoc(pair.first).getColumn(); + // llvm::errs() << line << "|" <setBody(fnBody); - endScope(); // Function body scope - m_Sema.PopFunctionScopeInfo(); - m_Sema.PopDeclContext(); - endScope(); // Function decl scope + llvm::ArrayRef paramsRef = m_Derivative->parameters(); - return DerivativeAndOverload{fnBuildRes.first, nullptr}; -} + // create derived variables for parameters which are not part of + // independent variables (args). + for (std::size_t i = 0; i < m_Function->getNumParams(); ++i) { + ParmVarDecl* param = paramsRef[i]; + // derived variables are already created for independent variables. + if (m_Variables.count(param)) + continue; + // in vector mode last non diff parameter is output parameter. + if (isVectorValued && i == m_Function->getNumParams() - 1) + continue; + auto VDDerivedType = param->getType(); + // We cannot initialize derived variable for pointer types because + // we do not know the correct size. + if (utils::isArrayOrPointerType(VDDerivedType)) + continue; + auto VDDerived = + BuildVarDecl(VDDerivedType, "_d_" + param->getNameAsString(), + getZeroInit(VDDerivedType)); + m_Variables[param] = BuildDeclRef(VDDerived); + addToBlock(BuildDeclStmt(VDDerived), m_Globals); + } + // Start the visitation process which outputs the statements in the + // current block. + StmtDiff BodyDiff = Visit(m_Function->getBody()); + Stmt* Forward = BodyDiff.getStmt(); + Stmt* Reverse = BodyDiff.getStmt_dx(); + // Create the body of the function. + // Firstly, all "global" Stmts are put into fn's body. + for (Stmt* S : m_Globals) + addToCurrentBlock(S, direction::forward); + // Forward pass. + if (auto CS = dyn_cast(Forward)) + for (Stmt* S : CS->body()) + addToCurrentBlock(S, direction::forward); + else + addToCurrentBlock(Forward, direction::forward); + // Reverse pass. + if (auto RCS = dyn_cast(Reverse)) + for (Stmt* S : RCS->body()) + addToCurrentBlock(S, direction::forward); + else + addToCurrentBlock(Reverse, direction::forward); -void ReverseModeVisitor::DifferentiateWithClad() { - auto* analyzer = new TBRAnalyzer(&m_Context); - analyzer->Analyze(m_Function); - m_ToBeRecorded = analyzer->getResult(); - delete analyzer; - - // for (auto pair : m_ToBeRecorded) { - // auto line = - // m_Context.getSourceManager().getPresumedLoc(pair.first).getLine(); auto - // column = - // m_Context.getSourceManager().getPresumedLoc(pair.first).getColumn(); - // llvm::errs() << line << "|" < paramsRef = m_Derivative->parameters(); - - // create derived variables for parameters which are not part of - // independent variables (args). - for (std::size_t i = 0; i < m_Function->getNumParams(); ++i) { - ParmVarDecl* param = paramsRef[i]; - // derived variables are already created for independent variables. - if (m_Variables.count(param)) - continue; - // in vector mode last non diff parameter is output parameter. - if (isVectorValued && i == m_Function->getNumParams() - 1) - continue; - auto VDDerivedType = param->getType(); - // We cannot initialize derived variable for pointer types because - // we do not know the correct size. - if (utils::isArrayOrPointerType(VDDerivedType)) - continue; - auto VDDerived = - BuildVarDecl(VDDerivedType, "_d_" + param->getNameAsString(), - getZeroInit(VDDerivedType)); - m_Variables[param] = BuildDeclRef(VDDerived); - addToBlock(BuildDeclStmt(VDDerived), m_Globals); + if (m_ExternalSource) + m_ExternalSource->ActOnEndOfDerivedFnBody(); } - // Start the visitation process which outputs the statements in the - // current block. - StmtDiff BodyDiff = Visit(m_Function->getBody()); - Stmt* Forward = BodyDiff.getStmt(); - Stmt* Reverse = BodyDiff.getStmt_dx(); - // Create the body of the function. - // Firstly, all "global" Stmts are put into fn's body. - for (Stmt* S : m_Globals) - addToCurrentBlock(S, direction::forward); - // Forward pass. - if (auto CS = dyn_cast(Forward)) - for (Stmt* S : CS->body()) - addToCurrentBlock(S, direction::forward); - else - addToCurrentBlock(Forward, direction::forward); - // Reverse pass. - if (auto RCS = dyn_cast(Reverse)) - for (Stmt* S : RCS->body()) - addToCurrentBlock(S, direction::forward); - else - addToCurrentBlock(Reverse, direction::forward); - if (m_ExternalSource) - m_ExternalSource->ActOnEndOfDerivedFnBody(); -} + void ReverseModeVisitor::DifferentiateWithEnzyme() { + unsigned numParams = m_Function->getNumParams(); + auto origParams = m_Function->parameters(); + llvm::ArrayRef paramsRef = m_Derivative->parameters(); + auto originalFnType = dyn_cast(m_Function->getType()); -void ReverseModeVisitor::DifferentiateWithEnzyme() { - unsigned numParams = m_Function->getNumParams(); - auto origParams = m_Function->parameters(); - llvm::ArrayRef paramsRef = m_Derivative->parameters(); - auto originalFnType = dyn_cast(m_Function->getType()); - - // Extract Pointer from Clad Array Ref - llvm::SmallVector cladRefParams; - for (unsigned i = 0; i < numParams; i++) { - QualType paramType = origParams[i]->getOriginalType(); - if (paramType->isRealType()) { - cladRefParams.push_back(nullptr); - continue; - } - - paramType = m_Context.getPointerType( - QualType(paramType->getPointeeOrArrayElementType(), 0)); - auto arrayRefNameExpr = BuildDeclRef(paramsRef[numParams + i]); - auto getPointerExpr = BuildCallExprToMemFn(arrayRefNameExpr, "ptr", {}); - auto arrayRefToArrayStmt = BuildVarDecl( - paramType, "d_" + paramsRef[i]->getNameAsString(), getPointerExpr); - addToCurrentBlock(BuildDeclStmt(arrayRefToArrayStmt), direction::forward); - cladRefParams.push_back(arrayRefToArrayStmt); - } - // Prepare Arguments and Parameters to enzyme_autodiff - llvm::SmallVector enzymeArgs; - llvm::SmallVector enzymeParams; - llvm::SmallVector enzymeRealParams; - llvm::SmallVector enzymeRealParamsRef; - - // First add the function itself as a parameter/argument - enzymeArgs.push_back(BuildDeclRef(const_cast(m_Function))); - DeclContext* fdDeclContext = - const_cast(m_Function->getDeclContext()); - enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( - fdDeclContext, noLoc, m_Function->getType())); - - // Add rest of the parameters/arguments - for (unsigned i = 0; i < numParams; i++) { - // First Add the original parameter - enzymeArgs.push_back(BuildDeclRef(paramsRef[i])); - enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( - fdDeclContext, noLoc, paramsRef[i]->getType())); - - // If the original parameter is not of array/pointer type, then we don't - // have to extract its pointer from clad array_ref and add it to the - // enzyme parameters, so we can skip the rest of the code - if (!cladRefParams[i]) { - // If original parameter is of a differentiable real type(but not - // array/pointer), then add it to the list of params whose gradient must - // be extracted later from the EnzymeGradient structure - if (paramsRef[i]->getOriginalType()->isRealFloatingType()) { - enzymeRealParams.push_back(paramsRef[i]); - enzymeRealParamsRef.push_back(paramsRef[numParams + i]); + // Extract Pointer from Clad Array Ref + llvm::SmallVector cladRefParams; + for (unsigned i = 0; i < numParams; i++) { + QualType paramType = origParams[i]->getOriginalType(); + if (paramType->isRealType()) { + cladRefParams.push_back(nullptr); + continue; } - continue; + + paramType = m_Context.getPointerType( + QualType(paramType->getPointeeOrArrayElementType(), 0)); + auto arrayRefNameExpr = BuildDeclRef(paramsRef[numParams + i]); + auto getPointerExpr = BuildCallExprToMemFn(arrayRefNameExpr, "ptr", {}); + auto arrayRefToArrayStmt = BuildVarDecl( + paramType, "d_" + paramsRef[i]->getNameAsString(), getPointerExpr); + addToCurrentBlock(BuildDeclStmt(arrayRefToArrayStmt), direction::forward); + cladRefParams.push_back(arrayRefToArrayStmt); } - // Then add the corresponding clad array ref pointer variable - enzymeArgs.push_back(BuildDeclRef(cladRefParams[i])); + // Prepare Arguments and Parameters to enzyme_autodiff + llvm::SmallVector enzymeArgs; + llvm::SmallVector enzymeParams; + llvm::SmallVector enzymeRealParams; + llvm::SmallVector enzymeRealParamsRef; + + // First add the function itself as a parameter/argument + enzymeArgs.push_back(BuildDeclRef(const_cast(m_Function))); + DeclContext* fdDeclContext = + const_cast(m_Function->getDeclContext()); enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( - fdDeclContext, noLoc, cladRefParams[i]->getType())); - } + fdDeclContext, noLoc, m_Function->getType())); + + // Add rest of the parameters/arguments + for (unsigned i = 0; i < numParams; i++) { + // First Add the original parameter + enzymeArgs.push_back(BuildDeclRef(paramsRef[i])); + enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( + fdDeclContext, noLoc, paramsRef[i]->getType())); + + // If the original parameter is not of array/pointer type, then we don't + // have to extract its pointer from clad array_ref and add it to the + // enzyme parameters, so we can skip the rest of the code + if (!cladRefParams[i]) { + // If original parameter is of a differentiable real type(but not + // array/pointer), then add it to the list of params whose gradient must + // be extracted later from the EnzymeGradient structure + if (paramsRef[i]->getOriginalType()->isRealFloatingType()) { + enzymeRealParams.push_back(paramsRef[i]); + enzymeRealParamsRef.push_back(paramsRef[numParams + i]); + } + continue; + } + // Then add the corresponding clad array ref pointer variable + enzymeArgs.push_back(BuildDeclRef(cladRefParams[i])); + enzymeParams.push_back(m_Sema.BuildParmVarDeclForTypedef( + fdDeclContext, noLoc, cladRefParams[i]->getType())); + } - llvm::SmallVector enzymeParamsType; - for (auto i : enzymeParams) - enzymeParamsType.push_back(i->getType()); + llvm::SmallVector enzymeParamsType; + for (auto i : enzymeParams) + enzymeParamsType.push_back(i->getType()); - QualType QT; - if (enzymeRealParams.size()) { - // Find the EnzymeGradient datastructure - auto gradDecl = LookupTemplateDeclInCladNamespace("EnzymeGradient"); + QualType QT; + if (enzymeRealParams.size()) { + // Find the EnzymeGradient datastructure + auto gradDecl = LookupTemplateDeclInCladNamespace("EnzymeGradient"); - TemplateArgumentListInfo TLI{}; - llvm::APSInt argValue(std::to_string(enzymeRealParams.size())); - TemplateArgument TA(m_Context, argValue, m_Context.UnsignedIntTy); - TLI.addArgument(TemplateArgumentLoc(TA, TemplateArgumentLocInfo())); + TemplateArgumentListInfo TLI{}; + llvm::APSInt argValue(std::to_string(enzymeRealParams.size())); + TemplateArgument TA(m_Context, argValue, m_Context.UnsignedIntTy); + TLI.addArgument(TemplateArgumentLoc(TA, TemplateArgumentLocInfo())); - QT = InstantiateTemplate(gradDecl, TLI); - } else { - QT = m_Context.VoidTy; - } + QT = InstantiateTemplate(gradDecl, TLI); + } else { + QT = m_Context.VoidTy; + } - // Prepare Function call - std::string enzymeCallName = - "__enzyme_autodiff_" + m_Function->getNameAsString(); - IdentifierInfo* IIEnzyme = &m_Context.Idents.get(enzymeCallName); - DeclarationName nameEnzyme(IIEnzyme); - QualType enzymeFunctionType = - m_Sema.BuildFunctionType(QT, enzymeParamsType, noLoc, nameEnzyme, - originalFnType->getExtProtoInfo()); - FunctionDecl* enzymeCallFD = FunctionDecl::Create( - m_Context, fdDeclContext, noLoc, noLoc, nameEnzyme, enzymeFunctionType, - m_Function->getTypeSourceInfo(), SC_Extern); - enzymeCallFD->setParams(enzymeParams); - Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs); - - // Prepare the statements that assign the gradients to - // non array/pointer type parameters of the original function - if (enzymeRealParams.size() != 0) { - auto gradDeclStmt = BuildVarDecl(QT, "grad", enzymeCall, true); - addToCurrentBlock(BuildDeclStmt(gradDeclStmt), direction::forward); - - for (unsigned i = 0; i < enzymeRealParams.size(); i++) { - auto LHSExpr = BuildOp(UO_Deref, BuildDeclRef(enzymeRealParamsRef[i])); - - auto ME = utils::BuildMemberExpr(m_Sema, getCurrentScope(), - BuildDeclRef(gradDeclStmt), "d_arr"); - - Expr* gradIndex = dyn_cast( - IntegerLiteral::Create(m_Context, llvm::APSInt(std::to_string(i)), - m_Context.UnsignedIntTy, noLoc)); - Expr* RHSExpr = - m_Sema.CreateBuiltinArraySubscriptExpr(ME, noLoc, gradIndex, noLoc) - .get(); + // Prepare Function call + std::string enzymeCallName = + "__enzyme_autodiff_" + m_Function->getNameAsString(); + IdentifierInfo* IIEnzyme = &m_Context.Idents.get(enzymeCallName); + DeclarationName nameEnzyme(IIEnzyme); + QualType enzymeFunctionType = + m_Sema.BuildFunctionType(QT, enzymeParamsType, noLoc, nameEnzyme, + originalFnType->getExtProtoInfo()); + FunctionDecl* enzymeCallFD = FunctionDecl::Create( + m_Context, fdDeclContext, noLoc, noLoc, nameEnzyme, enzymeFunctionType, + m_Function->getTypeSourceInfo(), SC_Extern); + enzymeCallFD->setParams(enzymeParams); + Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs); + + // Prepare the statements that assign the gradients to + // non array/pointer type parameters of the original function + if (enzymeRealParams.size() != 0) { + auto gradDeclStmt = BuildVarDecl(QT, "grad", enzymeCall, true); + addToCurrentBlock(BuildDeclStmt(gradDeclStmt), direction::forward); + + for (unsigned i = 0; i < enzymeRealParams.size(); i++) { + auto LHSExpr = BuildOp(UO_Deref, BuildDeclRef(enzymeRealParamsRef[i])); + + auto ME = utils::BuildMemberExpr(m_Sema, getCurrentScope(), + BuildDeclRef(gradDeclStmt), "d_arr"); + + Expr* gradIndex = dyn_cast( + IntegerLiteral::Create(m_Context, llvm::APSInt(std::to_string(i)), + m_Context.UnsignedIntTy, noLoc)); + Expr* RHSExpr = + m_Sema.CreateBuiltinArraySubscriptExpr(ME, noLoc, gradIndex, noLoc) + .get(); - auto assignExpr = BuildOp(BO_Assign, LHSExpr, RHSExpr); - addToCurrentBlock(assignExpr, direction::forward); + auto assignExpr = BuildOp(BO_Assign, LHSExpr, RHSExpr); + addToCurrentBlock(assignExpr, direction::forward); + } + } else { + // Add Function call to block + Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs); + addToCurrentBlock(enzymeCall); } - } else { - // Add Function call to block - Expr* enzymeCall = BuildCallExprToFunction(enzymeCallFD, enzymeArgs); - addToCurrentBlock(enzymeCall); } -} -StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { - diag(DiagnosticsEngine::Warning, S->getBeginLoc(), - "attempted to differentiate unsupported statement, no changes applied"); - // Unknown stmt, just clone it. - return StmtDiff(Clone(S)); -} + StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { + diag( + DiagnosticsEngine::Warning, + S->getBeginLoc(), + "attempted to differentiate unsupported statement, no changes applied"); + // Unknown stmt, just clone it. + return StmtDiff(Clone(S)); + } -StmtDiff ReverseModeVisitor::VisitCompoundStmt(const CompoundStmt* CS) { - beginScope(Scope::DeclScope); - beginBlock(direction::forward); - beginBlock(direction::reverse); - for (Stmt* S : CS->body()) { - if (m_ExternalSource) - m_ExternalSource->ActBeforeDifferentiatingStmtInVisitCompoundStmt(); - StmtDiff SDiff = DifferentiateSingleStmt(S); - addToCurrentBlock(SDiff.getStmt(), direction::forward); - addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse); + StmtDiff ReverseModeVisitor::VisitCompoundStmt(const CompoundStmt* CS) { + beginScope(Scope::DeclScope); + beginBlock(direction::forward); + beginBlock(direction::reverse); + for (Stmt* S : CS->body()) { + if (m_ExternalSource) + m_ExternalSource->ActBeforeDifferentiatingStmtInVisitCompoundStmt(); + StmtDiff SDiff = DifferentiateSingleStmt(S); + addToCurrentBlock(SDiff.getStmt(), direction::forward); + addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse); - if (m_ExternalSource) - m_ExternalSource->ActAfterProcessingStmtInVisitCompoundStmt(); + if (m_ExternalSource) + m_ExternalSource->ActAfterProcessingStmtInVisitCompoundStmt(); + } + CompoundStmt* Forward = endBlock(direction::forward); + CompoundStmt* Reverse = endBlock(direction::reverse); + endScope(); + return StmtDiff(Forward, Reverse); } - CompoundStmt* Forward = endBlock(direction::forward); - CompoundStmt* Reverse = endBlock(direction::reverse); - endScope(); - return StmtDiff(Forward, Reverse); -} -static Stmt* unwrapIfSingleStmt(Stmt* S) { - if (!S) - return nullptr; - if (!isa(S)) - return S; - auto CS = cast(S); - if (CS->size() == 0) - return nullptr; - else if (CS->size() == 1) - return CS->body_front(); - else - return CS; -} + static Stmt* unwrapIfSingleStmt(Stmt* S) { + if (!S) + return nullptr; + if (!isa(S)) + return S; + auto CS = cast(S); + if (CS->size() == 0) + return nullptr; + else if (CS->size() == 1) + return CS->body_front(); + else + return CS; + } + + StmtDiff ReverseModeVisitor::VisitIfStmt(const clang::IfStmt* If) { + // Control scope of the IfStmt. E.g., in if (double x = ...) {...}, x goes + // to this scope. + beginScope(Scope::DeclScope | Scope::ControlScope); + + StmtDiff cond = Clone(If->getCond()); + // Condition has to be stored as a "global" variable, to take the correct + // branch in the reverse pass. + // If we are inside loop, the condition has to be stored in a stack after + // the if statement. + Expr* PushCond = nullptr; + Expr* PopCond = nullptr; + auto condExpr = Visit(cond.getExpr()); + if (isInsideLoop) { + // If we are inside for loop, cond will be stored in the following way: + // forward: + // _t = cond; + // if (_t) { ... } + // clad::push(..., _t); + // reverse: + // if (clad::pop(...)) { ... } + // Simply doing + // if (clad::push(..., _t) { ... } + // is incorrect when if contains return statement inside: return will + // skip corresponding push. + cond = StoreAndRef(condExpr.getExpr(), direction::forward, "_t", + /*forceDeclCreation=*/true); + StmtDiff condPushPop = GlobalStoreAndRef(cond.getExpr(), "_cond", + /*force=*/true); + PushCond = condPushPop.getExpr(); + PopCond = condPushPop.getExpr_dx(); + } else + cond = GlobalStoreAndRef(condExpr.getExpr(), "_cond"); + // Convert cond to boolean condition. We are modifying each Stmt in + // StmtDiff. + for (Stmt*& S : cond.getBothStmts()) + if (S) + S = m_Sema + .ActOnCondition(m_CurScope, + noLoc, + cast(S), + Sema::ConditionKind::Boolean) + .get() + .second; + + // Create a block "around" if statement, e.g: + // { + // ... + // if (...) {...} + // } + beginBlock(direction::forward); + beginBlock(direction::reverse); + const Stmt* init = If->getInit(); + StmtDiff initResult = init ? Visit(init) : StmtDiff{}; + // If there is Init, it's derivative will be output in the block before if: + // E.g., for: + // if (int x = 1; ...) {...} + // result will be: + // { + // int _d_x = 0; + // if (int x = 1; ...) {...} + // } + // This is done to avoid variable names clashes. + addToCurrentBlock(initResult.getStmt_dx()); + + VarDecl* condVarClone = nullptr; + if (const VarDecl* condVarDecl = If->getConditionVariable()) { + VarDeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl); + condVarClone = condVarDeclDiff.getDecl(); + if (condVarDeclDiff.getDecl_dx()) + addToBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()), m_Globals); + } + + // Condition is just cloned as it is, not derived. + // FIXME: if condition changes one of the variables, it may be reasonable + // to derive it, e.g. + // if (x += x) {...} + // should result in: + // { + // _d_y += _d_x + // if (y += x) {...} + // } + + auto VisitBranch = [&](const Stmt* Branch) -> StmtDiff { + if (!Branch) + return {}; + if (isa(Branch)) { + StmtDiff BranchDiff = Visit(Branch); + return BranchDiff; + } else { + beginBlock(direction::forward); + if (m_ExternalSource) + m_ExternalSource->ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt(); + StmtDiff BranchDiff = DifferentiateSingleStmt(Branch, /*dfdS=*/nullptr); + addToCurrentBlock(BranchDiff.getStmt(), direction::forward); + + if (m_ExternalSource) + m_ExternalSource->ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt(); + + Stmt* Forward = unwrapIfSingleStmt(endBlock(direction::forward)); + Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); + return StmtDiff(Forward, Reverse); + } + }; + + StmtDiff thenDiff = VisitBranch(If->getThen()); + StmtDiff elseDiff = VisitBranch(If->getElse()); + + // It is problematic to specify both condVarDecl and cond thorugh + // Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor. + Stmt* Forward = clad_compat::IfStmt_Create(m_Context, + noLoc, + If->isConstexpr(), + initResult.getStmt(), + condVarClone, + cond.getExpr(), + noLoc, + noLoc, + thenDiff.getStmt(), + noLoc, + elseDiff.getStmt()); + addToCurrentBlock(Forward, direction::forward); -StmtDiff ReverseModeVisitor::VisitIfStmt(const clang::IfStmt* If) { - // Control scope of the IfStmt. E.g., in if (double x = ...) {...}, x goes - // to this scope. - beginScope(Scope::DeclScope | Scope::ControlScope); - - StmtDiff cond = Clone(If->getCond()); - // Condition has to be stored as a "global" variable, to take the correct - // branch in the reverse pass. - // If we are inside loop, the condition has to be stored in a stack after - // the if statement. - Expr* PushCond = nullptr; - Expr* PopCond = nullptr; - auto condExpr = Visit(cond.getExpr()); - if (isInsideLoop) { - // If we are inside for loop, cond will be stored in the following way: - // forward: - // _t = cond; - // if (_t) { ... } - // clad::push(..., _t); - // reverse: - // if (clad::pop(...)) { ... } - // Simply doing - // if (clad::push(..., _t) { ... } - // is incorrect when if contains return statement inside: return will - // skip corresponding push. - cond = StoreAndRef(condExpr.getExpr(), direction::forward, "_t", - /*forceDeclCreation=*/true); - StmtDiff condPushPop = GlobalStoreAndRef(cond.getExpr(), "_cond", - /*force=*/true); - PushCond = condPushPop.getExpr(); - PopCond = condPushPop.getExpr_dx(); - } else - cond = GlobalStoreAndRef(condExpr.getExpr(), "_cond"); - // Convert cond to boolean condition. We are modifying each Stmt in - // StmtDiff. - for (Stmt*& S : cond.getBothStmts()) - if (S) + Expr* reverseCond = cond.getExpr_dx(); + if (isInsideLoop) { + addToCurrentBlock(PushCond, direction::forward); + reverseCond = PopCond; + } + Stmt* Reverse = clad_compat::IfStmt_Create(m_Context, + noLoc, + If->isConstexpr(), + initResult.getStmt_dx(), + condVarClone, + reverseCond, + noLoc, + noLoc, + thenDiff.getStmt_dx(), + noLoc, + elseDiff.getStmt_dx()); + addToCurrentBlock(Reverse, direction::reverse); + CompoundStmt* ForwardBlock = endBlock(direction::forward); + CompoundStmt* ReverseBlock = endBlock(direction::reverse); + endScope(); + return StmtDiff(unwrapIfSingleStmt(ForwardBlock), + unwrapIfSingleStmt(ReverseBlock)); + } + + StmtDiff ReverseModeVisitor::VisitConditionalOperator( + const clang::ConditionalOperator* CO) { + StmtDiff cond = Clone(CO->getCond()); + // Condition has to be stored as a "global" variable, to take the correct + // branch in the reverse pass. + cond = GlobalStoreAndRef(Visit(cond.getExpr()).getExpr(), "_cond"); + // Convert cond to boolean condition. We are modifying each Stmt in + // StmtDiff. + for (Stmt*& S : cond.getBothStmts()) S = m_Sema - .ActOnCondition(m_CurScope, noLoc, cast(S), + .ActOnCondition(m_CurScope, + noLoc, + cast(S), Sema::ConditionKind::Boolean) .get() .second; - // Create a block "around" if statement, e.g: - // { - // ... - // if (...) {...} - // } - beginBlock(direction::forward); - beginBlock(direction::reverse); - const Stmt* init = If->getInit(); - StmtDiff initResult = init ? Visit(init) : StmtDiff{}; - // If there is Init, it's derivative will be output in the block before if: - // E.g., for: - // if (int x = 1; ...) {...} - // result will be: - // { - // int _d_x = 0; - // if (int x = 1; ...) {...} - // } - // This is done to avoid variable names clashes. - addToCurrentBlock(initResult.getStmt_dx()); - - VarDecl* condVarClone = nullptr; - if (const VarDecl* condVarDecl = If->getConditionVariable()) { - VarDeclDiff condVarDeclDiff = DifferentiateVarDecl(condVarDecl); - condVarClone = condVarDeclDiff.getDecl(); - if (condVarDeclDiff.getDecl_dx()) - addToBlock(BuildDeclStmt(condVarDeclDiff.getDecl_dx()), m_Globals); - } + auto ifTrue = CO->getTrueExpr(); + auto ifFalse = CO->getFalseExpr(); - // Condition is just cloned as it is, not derived. - // FIXME: if condition changes one of the variables, it may be reasonable - // to derive it, e.g. - // if (x += x) {...} - // should result in: - // { - // _d_y += _d_x - // if (y += x) {...} - // } - - auto VisitBranch = [&](const Stmt* Branch) -> StmtDiff { - if (!Branch) - return {}; - if (isa(Branch)) { - StmtDiff BranchDiff = Visit(Branch); - return BranchDiff; - } else { - beginBlock(direction::forward); - if (m_ExternalSource) - m_ExternalSource - ->ActBeforeDifferentiatingSingleStmtBranchInVisitIfStmt(); - StmtDiff BranchDiff = DifferentiateSingleStmt(Branch, /*dfdS=*/nullptr); - addToCurrentBlock(BranchDiff.getStmt(), direction::forward); + auto VisitBranch = [&](const Expr* Branch, + Expr* dfdx) -> std::pair { + auto Result = DifferentiateSingleExpr(Branch, dfdx); + StmtDiff BranchDiff = Result.first; + StmtDiff ExprDiff = Result.second; + Stmt* Forward = unwrapIfSingleStmt(BranchDiff.getStmt()); + Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); + return {StmtDiff(Forward, Reverse), ExprDiff}; + }; - if (m_ExternalSource) - m_ExternalSource - ->ActBeforeFinalisingVisitBranchSingleStmtInIfVisitStmt(); + StmtDiff ifTrueDiff; + StmtDiff ifTrueExprDiff; + StmtDiff ifFalseDiff; + StmtDiff ifFalseExprDiff; + + std::tie(ifTrueDiff, ifTrueExprDiff) = VisitBranch(ifTrue, dfdx()); + std::tie(ifFalseDiff, ifFalseExprDiff) = VisitBranch(ifFalse, dfdx()); + + auto BuildIf = [&](Expr* Cond, Stmt* Then, Stmt* Else) -> Stmt* { + if (!Then && !Else) + return nullptr; + if (!Then) + Then = m_Sema.ActOnNullStmt(noLoc).get(); + return clad_compat::IfStmt_Create(m_Context, + noLoc, + false, + nullptr, + nullptr, + Cond, + noLoc, + noLoc, + Then, + noLoc, + Else); + }; - Stmt* Forward = unwrapIfSingleStmt(endBlock(direction::forward)); - Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); - return StmtDiff(Forward, Reverse); + Stmt* Forward = + BuildIf(cond.getExpr(), ifTrueDiff.getStmt(), ifFalseDiff.getStmt()); + Stmt* Reverse = BuildIf(cond.getExpr_dx(), + ifTrueDiff.getStmt_dx(), + ifFalseDiff.getStmt_dx()); + if (Forward) + addToCurrentBlock(Forward, direction::forward); + if (Reverse) + addToCurrentBlock(Reverse, direction::reverse); + + Expr* condExpr = m_Sema + .ActOnConditionalOp(noLoc, + noLoc, + cond.getExpr(), + ifTrueExprDiff.getExpr(), + ifFalseExprDiff.getExpr()) + .get(); + // If result is a glvalue, we should keep it as it can potentially be + // assigned as in (c ? a : b) = x; + if ((CO->isModifiableLvalue(m_Context) == Expr::MLV_Valid) && + ifTrueExprDiff.getExpr_dx() && ifFalseExprDiff.getExpr_dx()) { + Expr* ResultRef = m_Sema + .ActOnConditionalOp(noLoc, + noLoc, + cond.getExpr_dx(), + ifTrueExprDiff.getExpr_dx(), + ifFalseExprDiff.getExpr_dx()) + .get(); + if (ResultRef->isModifiableLvalue(m_Context) != Expr::MLV_Valid) + ResultRef = nullptr; + return StmtDiff(condExpr, ResultRef); } - }; + return StmtDiff(condExpr); + } - StmtDiff thenDiff = VisitBranch(If->getThen()); - StmtDiff elseDiff = VisitBranch(If->getElse()); + StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) { + beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | + Scope::ContinueScope); - // It is problematic to specify both condVarDecl and cond thorugh - // Sema::ActOnIfStmt, therefore we directly use the IfStmt constructor. - Stmt* Forward = clad_compat::IfStmt_Create( - m_Context, noLoc, If->isConstexpr(), initResult.getStmt(), condVarClone, - cond.getExpr(), noLoc, noLoc, thenDiff.getStmt(), noLoc, - elseDiff.getStmt()); - addToCurrentBlock(Forward, direction::forward); + LoopCounter loopCounter(*this); + if (loopCounter.getPush()) + addToCurrentBlock(loopCounter.getPush()); + beginBlock(direction::forward); + beginBlock(direction::reverse); + const Stmt* init = FS->getInit(); + if (m_ExternalSource) + m_ExternalSource->ActBeforeDifferentiatingLoopInitStmt(); + StmtDiff initResult = init ? DifferentiateSingleStmt(init) : StmtDiff{}; + + // Save the isInsideLoop value (we may be inside another loop). + llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); + isInsideLoop = true; + + StmtDiff condVarRes; + VarDecl* condVarClone = nullptr; + if (FS->getConditionVariable()) { + condVarRes = DifferentiateSingleStmt(FS->getConditionVariableDeclStmt()); + Decl* decl = cast(condVarRes.getStmt())->getSingleDecl(); + condVarClone = cast(decl); + } - Expr* reverseCond = cond.getExpr_dx(); - if (isInsideLoop) { - addToCurrentBlock(PushCond, direction::forward); - reverseCond = PopCond; - } - Stmt* Reverse = clad_compat::IfStmt_Create( - m_Context, noLoc, If->isConstexpr(), initResult.getStmt_dx(), - condVarClone, reverseCond, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, - elseDiff.getStmt_dx()); - addToCurrentBlock(Reverse, direction::reverse); - CompoundStmt* ForwardBlock = endBlock(direction::forward); - CompoundStmt* ReverseBlock = endBlock(direction::reverse); - endScope(); - return StmtDiff(unwrapIfSingleStmt(ForwardBlock), - unwrapIfSingleStmt(ReverseBlock)); -} + // FIXME: for now we assume that cond has no differentiable effects, + // but it is not generally true, e.g. for (...; (x = y); ...)... + StmtDiff cond; + if (FS->getCond()) + cond = Visit(FS->getCond()); + auto IDRE = dyn_cast(FS->getInc()); + const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc(); + + // Differentiate the increment expression of the for loop + // incExprDiff.getExpr() is the reconstructed expression, incDiff.getStmt() + // a block with all the intermediate statements used to reconstruct it on + // the forward pass, incDiff.getStmt_dx() is the reverse pass block. + StmtDiff incDiff; + StmtDiff incExprDiff; + if (inc) + std::tie(incDiff, incExprDiff) = DifferentiateSingleExpr(inc); + Expr* incResult = nullptr; + // If any additional statements were created, enclose them into lambda. + CompoundStmt* Additional = cast(incDiff.getStmt()); + bool anyNonExpr = std::any_of(Additional->body_begin(), + Additional->body_end(), + [](Stmt* S) { return !isa(S); }); + if (anyNonExpr) { + incResult = wrapInLambda(*this, m_Sema, inc, [&] { + std::tie(incDiff, incExprDiff) = DifferentiateSingleExpr(inc); + for (Stmt* S : cast(incDiff.getStmt())->body()) + addToCurrentBlock(S); + addToCurrentBlock(incDiff.getExpr()); + }); + } + // Otherwise, join all exprs by comma operator. + else if (incExprDiff.getExpr()) { + auto CommaJoin = [this](Expr* Acc, Stmt* S) { + Expr* E = cast(S); + return BuildOp(BO_Comma, E, BuildParens(Acc)); + }; + incResult = std::accumulate(Additional->body_rbegin(), + Additional->body_rend(), + incExprDiff.getExpr(), + CommaJoin); + } -StmtDiff ReverseModeVisitor::VisitConditionalOperator( - const clang::ConditionalOperator* CO) { - StmtDiff cond = Clone(CO->getCond()); - // Condition has to be stored as a "global" variable, to take the correct - // branch in the reverse pass. - cond = GlobalStoreAndRef(Visit(cond.getExpr()).getExpr(), "_cond"); - // Convert cond to boolean condition. We are modifying each Stmt in - // StmtDiff. - for (Stmt*& S : cond.getBothStmts()) - S = m_Sema - .ActOnCondition(m_CurScope, noLoc, cast(S), - Sema::ConditionKind::Boolean) - .get() - .second; - - auto ifTrue = CO->getTrueExpr(); - auto ifFalse = CO->getFalseExpr(); - - auto VisitBranch = [&](const Expr* Branch, - Expr* dfdx) -> std::pair { - auto Result = DifferentiateSingleExpr(Branch, dfdx); - StmtDiff BranchDiff = Result.first; - StmtDiff ExprDiff = Result.second; - Stmt* Forward = unwrapIfSingleStmt(BranchDiff.getStmt()); - Stmt* Reverse = unwrapIfSingleStmt(BranchDiff.getStmt_dx()); - return {StmtDiff(Forward, Reverse), ExprDiff}; - }; - - StmtDiff ifTrueDiff; - StmtDiff ifTrueExprDiff; - StmtDiff ifFalseDiff; - StmtDiff ifFalseExprDiff; - - std::tie(ifTrueDiff, ifTrueExprDiff) = VisitBranch(ifTrue, dfdx()); - std::tie(ifFalseDiff, ifFalseExprDiff) = VisitBranch(ifFalse, dfdx()); - - auto BuildIf = [&](Expr* Cond, Stmt* Then, Stmt* Else) -> Stmt* { - if (!Then && !Else) - return nullptr; - if (!Then) - Then = m_Sema.ActOnNullStmt(noLoc).get(); - return clad_compat::IfStmt_Create(m_Context, noLoc, false, nullptr, nullptr, - Cond, noLoc, noLoc, Then, noLoc, Else); - }; - - Stmt* Forward = - BuildIf(cond.getExpr(), ifTrueDiff.getStmt(), ifFalseDiff.getStmt()); - Stmt* Reverse = BuildIf(cond.getExpr_dx(), ifTrueDiff.getStmt_dx(), - ifFalseDiff.getStmt_dx()); - if (Forward) + const Stmt* body = FS->getBody(); + StmtDiff BodyDiff = DifferentiateLoopBody(body, loopCounter, + condVarRes.getStmt_dx(), + incDiff.getStmt_dx(), + /*isForLoop=*/true); + + Stmt* Forward = new (m_Context) ForStmt(m_Context, + initResult.getStmt(), + cond.getExpr(), + condVarClone, + incResult, + BodyDiff.getStmt(), + noLoc, + noLoc, + noLoc); + + // Create a condition testing counter for being zero, and its decrement. + // To match the number of iterations in the forward pass, the reverse loop + // will look like: for(; Counter; Counter--) ... + Expr* + CounterCondition = loopCounter.getCounterConditionResult().get().second; + Expr* CounterDecrement = loopCounter.getCounterDecrement(); + + Stmt* ReverseResult = BodyDiff.getStmt_dx(); + if (!ReverseResult) + ReverseResult = new (m_Context) NullStmt(noLoc); + Stmt* Reverse = new (m_Context) ForStmt(m_Context, + nullptr, + CounterCondition, + nullptr, + CounterDecrement, + ReverseResult, + noLoc, + noLoc, + noLoc); addToCurrentBlock(Forward, direction::forward); - if (Reverse) + Forward = endBlock(direction::forward); + addToCurrentBlock(loopCounter.getPop(), direction::reverse); + addToCurrentBlock(initResult.getStmt_dx(), direction::reverse); addToCurrentBlock(Reverse, direction::reverse); + Reverse = endBlock(direction::reverse); + endScope(); - Expr* condExpr = m_Sema - .ActOnConditionalOp(noLoc, noLoc, cond.getExpr(), - ifTrueExprDiff.getExpr(), - ifFalseExprDiff.getExpr()) - .get(); - // If result is a glvalue, we should keep it as it can potentially be - // assigned as in (c ? a : b) = x; - if ((CO->isModifiableLvalue(m_Context) == Expr::MLV_Valid) && - ifTrueExprDiff.getExpr_dx() && ifFalseExprDiff.getExpr_dx()) { - Expr* ResultRef = m_Sema - .ActOnConditionalOp(noLoc, noLoc, cond.getExpr_dx(), - ifTrueExprDiff.getExpr_dx(), - ifFalseExprDiff.getExpr_dx()) - .get(); - if (ResultRef->isModifiableLvalue(m_Context) != Expr::MLV_Valid) - ResultRef = nullptr; - return StmtDiff(condExpr, ResultRef); + return {unwrapIfSingleStmt(Forward), unwrapIfSingleStmt(Reverse)}; } - return StmtDiff(condExpr); -} -StmtDiff ReverseModeVisitor::VisitForStmt(const ForStmt* FS) { - beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | - Scope::ContinueScope); - - LoopCounter loopCounter(*this); - if (loopCounter.getPush()) - addToCurrentBlock(loopCounter.getPush()); - beginBlock(direction::forward); - beginBlock(direction::reverse); - const Stmt* init = FS->getInit(); - if (m_ExternalSource) - m_ExternalSource->ActBeforeDifferentiatingLoopInitStmt(); - StmtDiff initResult = init ? DifferentiateSingleStmt(init) : StmtDiff{}; - - // Save the isInsideLoop value (we may be inside another loop). - llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); - isInsideLoop = true; - - StmtDiff condVarRes; - VarDecl* condVarClone = nullptr; - if (FS->getConditionVariable()) { - condVarRes = DifferentiateSingleStmt(FS->getConditionVariableDeclStmt()); - Decl* decl = cast(condVarRes.getStmt())->getSingleDecl(); - condVarClone = cast(decl); + StmtDiff + ReverseModeVisitor::VisitCXXDefaultArgExpr(const CXXDefaultArgExpr* DE) { + return Visit(DE->getExpr(), dfdx()); } - // FIXME: for now we assume that cond has no differentiable effects, - // but it is not generally true, e.g. for (...; (x = y); ...)... - StmtDiff cond; - if (FS->getCond()) - cond = Visit(FS->getCond()); - auto IDRE = dyn_cast(FS->getInc()); - const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc(); - - // Differentiate the increment expression of the for loop - // incExprDiff.getExpr() is the reconstructed expression, incDiff.getStmt() - // a block with all the intermediate statements used to reconstruct it on - // the forward pass, incDiff.getStmt_dx() is the reverse pass block. - StmtDiff incDiff; - StmtDiff incExprDiff; - if (inc) - std::tie(incDiff, incExprDiff) = DifferentiateSingleExpr(inc); - Expr* incResult = nullptr; - // If any additional statements were created, enclose them into lambda. - CompoundStmt* Additional = cast(incDiff.getStmt()); - bool anyNonExpr = - std::any_of(Additional->body_begin(), Additional->body_end(), - [](Stmt* S) { return !isa(S); }); - if (anyNonExpr) { - incResult = wrapInLambda(*this, m_Sema, inc, [&] { - std::tie(incDiff, incExprDiff) = DifferentiateSingleExpr(inc); - for (Stmt* S : cast(incDiff.getStmt())->body()) - addToCurrentBlock(S); - addToCurrentBlock(incDiff.getExpr()); - }); - } - // Otherwise, join all exprs by comma operator. - else if (incExprDiff.getExpr()) { - auto CommaJoin = [this](Expr* Acc, Stmt* S) { - Expr* E = cast(S); - return BuildOp(BO_Comma, E, BuildParens(Acc)); - }; - incResult = - std::accumulate(Additional->body_rbegin(), Additional->body_rend(), - incExprDiff.getExpr(), CommaJoin); + StmtDiff + ReverseModeVisitor::VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr* BL) { + return Clone(BL); } - const Stmt* body = FS->getBody(); - StmtDiff BodyDiff = DifferentiateLoopBody( - body, loopCounter, condVarRes.getStmt_dx(), incDiff.getStmt_dx(), - /*isForLoop=*/true); - - Stmt* Forward = new (m_Context) - ForStmt(m_Context, initResult.getStmt(), cond.getExpr(), condVarClone, - incResult, BodyDiff.getStmt(), noLoc, noLoc, noLoc); - - // Create a condition testing counter for being zero, and its decrement. - // To match the number of iterations in the forward pass, the reverse loop - // will look like: for(; Counter; Counter--) ... - Expr* CounterCondition = loopCounter.getCounterConditionResult().get().second; - Expr* CounterDecrement = loopCounter.getCounterDecrement(); - - Stmt* ReverseResult = BodyDiff.getStmt_dx(); - if (!ReverseResult) - ReverseResult = new (m_Context) NullStmt(noLoc); - Stmt* Reverse = new (m_Context) - ForStmt(m_Context, nullptr, CounterCondition, nullptr, CounterDecrement, - ReverseResult, noLoc, noLoc, noLoc); - addToCurrentBlock(Forward, direction::forward); - Forward = endBlock(direction::forward); - addToCurrentBlock(loopCounter.getPop(), direction::reverse); - addToCurrentBlock(initResult.getStmt_dx(), direction::reverse); - addToCurrentBlock(Reverse, direction::reverse); - Reverse = endBlock(direction::reverse); - endScope(); - - return {unwrapIfSingleStmt(Forward), unwrapIfSingleStmt(Reverse)}; -} - -StmtDiff -ReverseModeVisitor::VisitCXXDefaultArgExpr(const CXXDefaultArgExpr* DE) { - return Visit(DE->getExpr(), dfdx()); -} - -StmtDiff -ReverseModeVisitor::VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr* BL) { - return Clone(BL); -} - -StmtDiff ReverseModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { - // Initially, df/df = 1. - const Expr* value = RS->getRetValue(); - QualType type = value->getType(); - auto dfdf = m_Pullback; - if (isa(dfdf) || isa(dfdf)) { - ExprResult tmp = dfdf; - dfdf = m_Sema - .ImpCastExprToType(tmp.get(), type, - m_Sema.PrepareScalarCast(tmp, type)) - .get(); - } - auto ReturnResult = DifferentiateSingleExpr(value, dfdf); - StmtDiff ReturnDiff = ReturnResult.first; - StmtDiff ExprDiff = ReturnResult.second; - Stmt* Reverse = ReturnDiff.getStmt_dx(); - // If the original function returns at this point, some part of the reverse - // pass (corresponding to other branches that do not return here) must be - // skipped. We create a label in the reverse pass and jump to it via goto. - LabelDecl* LD = LabelDecl::Create(m_Context, m_Sema.CurContext, noLoc, - CreateUniqueIdentifier("_label")); - m_Sema.PushOnScopeChains(LD, m_DerivativeFnScope, true); - // Attach label to the last Stmt in the corresponding Reverse Stmt. - if (!Reverse) - Reverse = m_Sema.ActOnNullStmt(noLoc).get(); - Stmt* LS = m_Sema.ActOnLabelStmt(noLoc, LD, noLoc, Reverse).get(); - addToCurrentBlock(LS, direction::reverse); - for (Stmt* S : cast(ReturnDiff.getStmt())->body()) - addToCurrentBlock(S, direction::forward); - - // FIXME: When the return type of a function is a class, ExprDiff.getExpr() - // returns nullptr, which is a bug. For the time being, the only use case of - // a return type being class is in pushforwards. Hence a special case has - // been made to to not do the StoreAndRef operation when return type is - // ValueAndPushforward. - if (!isCladValueAndPushforwardType(type)) { - if (m_ExternalSource) - m_ExternalSource->ActBeforeFinalisingVisitReturnStmt(ExprDiff); - } + StmtDiff ReverseModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { + // Initially, df/df = 1. + const Expr* value = RS->getRetValue(); + QualType type = value->getType(); + auto dfdf = m_Pullback; + if (isa(dfdf) || isa(dfdf)) { + ExprResult tmp = dfdf; + dfdf = m_Sema + .ImpCastExprToType(tmp.get(), type, + m_Sema.PrepareScalarCast(tmp, type)) + .get(); + } + auto ReturnResult = DifferentiateSingleExpr(value, dfdf); + StmtDiff ReturnDiff = ReturnResult.first; + StmtDiff ExprDiff = ReturnResult.second; + Stmt* Reverse = ReturnDiff.getStmt_dx(); + // If the original function returns at this point, some part of the reverse + // pass (corresponding to other branches that do not return here) must be + // skipped. We create a label in the reverse pass and jump to it via goto. + LabelDecl* LD = LabelDecl::Create( + m_Context, m_Sema.CurContext, noLoc, CreateUniqueIdentifier("_label")); + m_Sema.PushOnScopeChains(LD, m_DerivativeFnScope, true); + // Attach label to the last Stmt in the corresponding Reverse Stmt. + if (!Reverse) + Reverse = m_Sema.ActOnNullStmt(noLoc).get(); + Stmt* LS = m_Sema.ActOnLabelStmt(noLoc, LD, noLoc, Reverse).get(); + addToCurrentBlock(LS, direction::reverse); + for (Stmt* S : cast(ReturnDiff.getStmt())->body()) + addToCurrentBlock(S, direction::forward); - // Create goto to the label. - return m_Sema.ActOnGotoStmt(noLoc, noLoc, LD).get(); -} + // FIXME: When the return type of a function is a class, ExprDiff.getExpr() + // returns nullptr, which is a bug. For the time being, the only use case of + // a return type being class is in pushforwards. Hence a special case has + // been made to to not do the StoreAndRef operation when return type is + // ValueAndPushforward. + if (!isCladValueAndPushforwardType(type)) { + if (m_ExternalSource) + m_ExternalSource->ActBeforeFinalisingVisitReturnStmt(ExprDiff); + } -StmtDiff ReverseModeVisitor::VisitParenExpr(const ParenExpr* PE) { - StmtDiff subStmtDiff = Visit(PE->getSubExpr(), dfdx()); - return StmtDiff(BuildParens(subStmtDiff.getExpr()), - BuildParens(subStmtDiff.getExpr_dx())); -} + // Create goto to the label. + return m_Sema.ActOnGotoStmt(noLoc, noLoc, LD).get(); + } + + StmtDiff ReverseModeVisitor::VisitParenExpr(const ParenExpr* PE) { + StmtDiff subStmtDiff = Visit(PE->getSubExpr(), dfdx()); + return StmtDiff(BuildParens(subStmtDiff.getExpr()), + BuildParens(subStmtDiff.getExpr_dx())); + } + + StmtDiff ReverseModeVisitor::VisitInitListExpr(const InitListExpr* ILE) { + QualType ILEType = ILE->getType(); + llvm::SmallVector clonedExprs(ILE->getNumInits()); + if (isArrayOrPointerType(ILEType)) { + for (unsigned i = 0, e = ILE->getNumInits(); i < e; i++) { + Expr* I = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, i); + Expr* array_at_i = m_Sema + .ActOnArraySubscriptExpr(getCurrentScope(), + dfdx(), noLoc, I, noLoc) + .get(); + Expr* clonedEI = Visit(ILE->getInit(i), array_at_i).getExpr(); + clonedExprs[i] = clonedEI; + } -StmtDiff ReverseModeVisitor::VisitInitListExpr(const InitListExpr* ILE) { - QualType ILEType = ILE->getType(); - llvm::SmallVector clonedExprs(ILE->getNumInits()); - if (isArrayOrPointerType(ILEType)) { - for (unsigned i = 0, e = ILE->getNumInits(); i < e; i++) { - Expr* I = - ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, i); - Expr* array_at_i = m_Sema - .ActOnArraySubscriptExpr(getCurrentScope(), dfdx(), - noLoc, I, noLoc) - .get(); - Expr* clonedEI = Visit(ILE->getInit(i), array_at_i).getExpr(); - clonedExprs[i] = clonedEI; - } - - Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get(); - return StmtDiff(clonedILE); - } else { - // FIXME: This is a makeshift arrangement to differentiate an InitListExpr - // that represents a ValueAndPushforward type. Ideally this must be - // differentiated at VisitCXXConstructExpr + Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get(); + return StmtDiff(clonedILE); + } else { + // FIXME: This is a makeshift arrangement to differentiate an InitListExpr + // that represents a ValueAndPushforward type. Ideally this must be + // differentiated at VisitCXXConstructExpr #ifndef NDEBUG - bool isValueAndPushforward = isCladValueAndPushforwardType(ILEType); - assert(isValueAndPushforward && - "Only InitListExpr that represents arrays or ValueAndPushforward " - "Object initialization is supported"); + bool isValueAndPushforward = isCladValueAndPushforwardType(ILEType); + assert(isValueAndPushforward && + "Only InitListExpr that represents arrays or ValueAndPushforward " + "Object initialization is supported"); #endif - // Here we assume that the adjoint expression of the first element in - // InitList is dfdx().value and the adjoint for the second element is - // dfdx().pushforward. At this point the top of the Tape must contain a - // ValueAndPushforward object that represents derivative of the - // ValueAndPushforward object returned by the function whose derivative is - // requested. - Expr* dValueExpr = - utils::BuildMemberExpr(m_Sema, getCurrentScope(), dfdx(), "value"); - StmtDiff clonedValueEI = Visit(ILE->getInit(0), dValueExpr).getExpr(); - clonedExprs[0] = clonedValueEI.getExpr(); - - Expr* dPushforwardExpr = utils::BuildMemberExpr(m_Sema, getCurrentScope(), - dfdx(), "pushforward"); - Expr* clonedPushforwardEI = - Visit(ILE->getInit(1), dPushforwardExpr).getExpr(); - clonedExprs[1] = clonedPushforwardEI; - - Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get(); - return StmtDiff(clonedILE); + // Here we assume that the adjoint expression of the first element in + // InitList is dfdx().value and the adjoint for the second element is + // dfdx().pushforward. At this point the top of the Tape must contain a + // ValueAndPushforward object that represents derivative of the + // ValueAndPushforward object returned by the function whose derivative is + // requested. + Expr* dValueExpr = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), dfdx(), "value"); + StmtDiff clonedValueEI = Visit(ILE->getInit(0), dValueExpr).getExpr(); + clonedExprs[0] = clonedValueEI.getExpr(); + + Expr* dPushforwardExpr = utils::BuildMemberExpr(m_Sema, getCurrentScope(), + dfdx(), "pushforward"); + Expr* clonedPushforwardEI = + Visit(ILE->getInit(1), dPushforwardExpr).getExpr(); + clonedExprs[1] = clonedPushforwardEI; + + Expr* clonedILE = m_Sema.ActOnInitList(noLoc, clonedExprs, noLoc).get(); + return StmtDiff(clonedILE); + } } -} -StmtDiff -ReverseModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) { - auto ASI = SplitArraySubscript(ASE); - const Expr* Base = ASI.first; - const auto& Indices = ASI.second; - StmtDiff BaseDiff = Visit(Base); - llvm::SmallVector clonedIndices(Indices.size()); - llvm::SmallVector reverseIndices(Indices.size()); - llvm::SmallVector forwSweepDerivativeIndices(Indices.size()); - for (std::size_t i = 0; i < Indices.size(); i++) { - /// FIXME: Remove redundant indices vectors. - StmtDiff IdxDiff = Visit(Indices[i]); - clonedIndices[i] = Clone(IdxDiff.getExpr()); - reverseIndices[i] = Clone(IdxDiff.getExpr()); - // reverseIndices[i] = Clone(IdxDiff.getExpr()); - forwSweepDerivativeIndices[i] = IdxDiff.getExpr(); - } - auto cloned = BuildArraySubscript(BaseDiff.getExpr(), clonedIndices); - auto valueForRevSweep = - BuildArraySubscript(BaseDiff.getExpr(), reverseIndices); - Expr* target = BaseDiff.getExpr_dx(); - if (!target) - return cloned; - Expr* result = nullptr; - Expr* forwSweepDerivative = nullptr; - if (utils::isArrayOrPointerType(target->getType())) { - // Create the target[idx] expression. - result = BuildArraySubscript(target, reverseIndices); - forwSweepDerivative = - BuildArraySubscript(target, forwSweepDerivativeIndices); - } else if (isCladArrayType(target->getType())) { - result = m_Sema - .ActOnArraySubscriptExpr(getCurrentScope(), target, - ASE->getExprLoc(), - reverseIndices.back(), noLoc) - .get(); - forwSweepDerivative = m_Sema - .ActOnArraySubscriptExpr( - getCurrentScope(), target, ASE->getExprLoc(), - forwSweepDerivativeIndices.back(), noLoc) - .get(); - } else - result = target; - // Create the (target += dfdx) statement. - if (dfdx()) { - auto add_assign = BuildOp(BO_AddAssign, result, dfdx()); - // Add it to the body statements. - addToCurrentBlock(add_assign, direction::reverse); + StmtDiff + ReverseModeVisitor::VisitArraySubscriptExpr(const ArraySubscriptExpr* ASE) { + auto ASI = SplitArraySubscript(ASE); + const Expr* Base = ASI.first; + const auto& Indices = ASI.second; + StmtDiff BaseDiff = Visit(Base); + llvm::SmallVector clonedIndices(Indices.size()); + llvm::SmallVector reverseIndices(Indices.size()); + llvm::SmallVector forwSweepDerivativeIndices(Indices.size()); + for (std::size_t i = 0; i < Indices.size(); i++) { + /// FIXME: Remove redundant indices vectors. + StmtDiff IdxDiff = Visit(Indices[i]); + clonedIndices[i] = Clone(IdxDiff.getExpr()); + reverseIndices[i] = Clone(IdxDiff.getExpr()); + // reverseIndices[i] = Clone(IdxDiff.getExpr()); + forwSweepDerivativeIndices[i] = IdxDiff.getExpr(); + } + auto cloned = BuildArraySubscript(BaseDiff.getExpr(), clonedIndices); + auto valueForRevSweep = BuildArraySubscript(BaseDiff.getExpr(), reverseIndices); + Expr* target = BaseDiff.getExpr_dx(); + if (!target) + return cloned; + Expr* result = nullptr; + Expr* forwSweepDerivative = nullptr; + if (utils::isArrayOrPointerType(target->getType())) { + // Create the target[idx] expression. + result = BuildArraySubscript(target, reverseIndices); + forwSweepDerivative = + BuildArraySubscript(target, forwSweepDerivativeIndices); + } + else if (isCladArrayType(target->getType())) { + result = m_Sema + .ActOnArraySubscriptExpr(getCurrentScope(), target, + ASE->getExprLoc(), + reverseIndices.back(), noLoc) + .get(); + forwSweepDerivative = + m_Sema + .ActOnArraySubscriptExpr(getCurrentScope(), target, + ASE->getExprLoc(), + forwSweepDerivativeIndices.back(), noLoc) + .get(); + } else + result = target; + // Create the (target += dfdx) statement. + if (dfdx()) { + auto add_assign = BuildOp(BO_AddAssign, result, dfdx()); + // Add it to the body statements. + addToCurrentBlock(add_assign, direction::reverse); + } + return StmtDiff(cloned, result, forwSweepDerivative, valueForRevSweep); } - return StmtDiff(cloned, result, forwSweepDerivative, valueForRevSweep); -} -StmtDiff ReverseModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { - DeclRefExpr* clonedDRE = nullptr; - // Check if referenced Decl was "replaced" with another identifier inside - // the derivative - if (auto VD = dyn_cast(DRE->getDecl())) { - auto it = m_DeclReplacements.find(VD); - if (it != std::end(m_DeclReplacements)) - clonedDRE = BuildDeclRef(it->second); - else + StmtDiff ReverseModeVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { + DeclRefExpr* clonedDRE = nullptr; + // Check if referenced Decl was "replaced" with another identifier inside + // the derivative + if (auto VD = dyn_cast(DRE->getDecl())) { + auto it = m_DeclReplacements.find(VD); + if (it != std::end(m_DeclReplacements)) + clonedDRE = BuildDeclRef(it->second); + else + clonedDRE = cast(Clone(DRE)); + // If current context is different than the context of the original + // declaration (e.g. we are inside lambda), rebuild the DeclRefExpr + // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. + // Sema::BuildDeclRefExpr is responsible for adding captured fields + // to the underlying struct of a lambda. + if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { + auto referencedDecl = cast(clonedDRE->getDecl()); + clonedDRE = cast(BuildDeclRef(referencedDecl)); + } + } else clonedDRE = cast(Clone(DRE)); - // If current context is different than the context of the original - // declaration (e.g. we are inside lambda), rebuild the DeclRefExpr - // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. - // Sema::BuildDeclRefExpr is responsible for adding captured fields - // to the underlying struct of a lambda. - if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { - auto referencedDecl = cast(clonedDRE->getDecl()); - clonedDRE = cast(BuildDeclRef(referencedDecl)); - } - } else - clonedDRE = cast(Clone(DRE)); - - if (auto decl = dyn_cast(clonedDRE->getDecl())) { - if (isVectorValued) { - if (m_VectorOutput.size() <= outputArrayCursor) - return StmtDiff(clonedDRE); - auto it = m_VectorOutput[outputArrayCursor].find(decl); - if (it == std::end(m_VectorOutput[outputArrayCursor])) { - // Is not an independent variable, ignored. - return StmtDiff(clonedDRE); - } - // Create the (jacobianMatrix[idx] += dfdx) statement. - if (dfdx()) { - auto add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); - // Add it to the body statements. - addToCurrentBlock(add_assign, direction::reverse); - } - } else { - // Check DeclRefExpr is a reference to an independent variable. - auto it = m_Variables.find(decl); - if (it == std::end(m_Variables)) { - // Is not an independent variable, ignored. - return StmtDiff(clonedDRE); - } - // Create the (_d_param[idx] += dfdx) statement. - if (dfdx()) { - Expr* add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); - // Add it to the body statements. - addToCurrentBlock(add_assign, direction::reverse); + if (auto decl = dyn_cast(clonedDRE->getDecl())) { + if (isVectorValued) { + if (m_VectorOutput.size() <= outputArrayCursor) + return StmtDiff(clonedDRE); + + auto it = m_VectorOutput[outputArrayCursor].find(decl); + if (it == std::end(m_VectorOutput[outputArrayCursor])) { + // Is not an independent variable, ignored. + return StmtDiff(clonedDRE); + } + // Create the (jacobianMatrix[idx] += dfdx) statement. + if (dfdx()) { + auto add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); + // Add it to the body statements. + addToCurrentBlock(add_assign, direction::reverse); + } + } else { + // Check DeclRefExpr is a reference to an independent variable. + auto it = m_Variables.find(decl); + if (it == std::end(m_Variables)) { + // Is not an independent variable, ignored. + return StmtDiff(clonedDRE); + } + // Create the (_d_param[idx] += dfdx) statement. + if (dfdx()) { + Expr* add_assign = BuildOp(BO_AddAssign, it->second, dfdx()); + // Add it to the body statements. + addToCurrentBlock(add_assign, direction::reverse); + } + return StmtDiff(clonedDRE, it->second, it->second); } - return StmtDiff(clonedDRE, it->second, it->second); } + + return StmtDiff(clonedDRE); } - return StmtDiff(clonedDRE); -} + StmtDiff ReverseModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) { + return StmtDiff(Clone(IL)); + } -StmtDiff ReverseModeVisitor::VisitIntegerLiteral(const IntegerLiteral* IL) { - return StmtDiff(Clone(IL)); -} + StmtDiff ReverseModeVisitor::VisitFloatingLiteral(const FloatingLiteral* FL) { + return StmtDiff(Clone(FL)); + } -StmtDiff ReverseModeVisitor::VisitFloatingLiteral(const FloatingLiteral* FL) { - return StmtDiff(Clone(FL)); -} + StmtDiff ReverseModeVisitor::VisitCallExpr(const CallExpr* CE) { + const FunctionDecl* FD = CE->getDirectCallee(); + if (!FD) { + diag(DiagnosticsEngine::Warning, + CE->getEndLoc(), + "Differentiation of only direct calls is supported. Ignored"); + return StmtDiff(Clone(CE)); + } -StmtDiff ReverseModeVisitor::VisitCallExpr(const CallExpr* CE) { - const FunctionDecl* FD = CE->getDirectCallee(); - if (!FD) { - diag(DiagnosticsEngine::Warning, CE->getEndLoc(), - "Differentiation of only direct calls is supported. Ignored"); - return StmtDiff(Clone(CE)); - } + auto NArgs = FD->getNumParams(); + // If the function has no args and is not a member function call then we + // assume that it is not related to independent variables and does not + // contribute to gradient. + if (!NArgs && !isa(CE)) + return StmtDiff(Clone(CE)); + + // Stores the call arguments for the function to be derived + llvm::SmallVector CallArgs{}; + // Stores the dx of the call arguments for the function to be derived + llvm::SmallVector CallArgDx{}; + // Stores the call arguments for the derived function + llvm::SmallVector DerivedCallArgs{}; + // Stores tape decl and pushes for multiarg numerically differentiated + // calls. + llvm::SmallVector NumericalDiffMultiArg{}; + // If the result does not depend on the result of the call, just clone + // the call and visit arguments (since they may contain side-effects like + // f(x = y)) + // If the callee function takes arguments by reference then it can affect + // derivatives even if there is no `dfdx()` and thus we should call the + // derived function. In the case of member functions, `implicit` + // this object is always passed by reference. + if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) && + !isa(CE)) { + for (const Expr* Arg : CE->arguments()) { + StmtDiff ArgDiff = Visit(Arg, dfdx()); + CallArgs.push_back(ArgDiff.getExpr()); + } + Expr* call = m_Sema + .ActOnCallExpr(getCurrentScope(), + Clone(CE->getCallee()), + noLoc, + llvm::MutableArrayRef(CallArgs), + noLoc) + .get(); + return call; + } - auto NArgs = FD->getNumParams(); - // If the function has no args and is not a member function call then we - // assume that it is not related to independent variables and does not - // contribute to gradient. - if (!NArgs && !isa(CE)) - return StmtDiff(Clone(CE)); - - // Stores the call arguments for the function to be derived - llvm::SmallVector CallArgs{}; - // Stores the dx of the call arguments for the function to be derived - llvm::SmallVector CallArgDx{}; - // Stores the call arguments for the derived function - llvm::SmallVector DerivedCallArgs{}; - // Stores tape decl and pushes for multiarg numerically differentiated - // calls. - llvm::SmallVector NumericalDiffMultiArg{}; - // If the result does not depend on the result of the call, just clone - // the call and visit arguments (since they may contain side-effects like - // f(x = y)) - // If the callee function takes arguments by reference then it can affect - // derivatives even if there is no `dfdx()` and thus we should call the - // derived function. In the case of member functions, `implicit` - // this object is always passed by reference. - if (!dfdx() && !utils::HasAnyReferenceOrPointerArgument(FD) && - !isa(CE)) { - for (const Expr* Arg : CE->arguments()) { - StmtDiff ArgDiff = Visit(Arg, dfdx()); - CallArgs.push_back(ArgDiff.getExpr()); - } - Expr* call = - m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, - llvm::MutableArrayRef(CallArgs), noLoc) - .get(); - return call; - } + llvm::SmallVector ArgResultDecls{}; + llvm::SmallVector ArgDeclStmts{}; + // Save current index in the current block, to potentially put some + // statements there later. + std::size_t insertionPoint = getCurrentBlock(direction::reverse).size(); - llvm::SmallVector ArgResultDecls{}; - llvm::SmallVector ArgDeclStmts{}; - // Save current index in the current block, to potentially put some - // statements there later. - std::size_t insertionPoint = getCurrentBlock(direction::reverse).size(); - - // `CXXOperatorCallExpr` have the `base` expression as the first argument. - size_t skipFirstArg = 0; - - // Here we do not need to check if FD is an instance method or a static - // method because C++ forbids creating operator overloads as static methods. - if (isa(CE) && isa(FD)) - skipFirstArg = 1; - - // FIXME: We should add instructions for handling non-differentiable - // arguments. Currently we are implicitly assuming function call only - // contains differentiable arguments. - for (std::size_t i = skipFirstArg, e = CE->getNumArgs(); i != e; ++i) { - const Expr* arg = CE->getArg(i); - const auto* PVD = FD->getParamDecl(i - skipFirstArg); - StmtDiff argDiff{}; - bool passByRef = utils::IsReferenceOrPointerType(PVD->getType()); - // We do not need to create result arg for arguments passed by reference - // because the derivatives of arguments passed by reference are directly - // modified by the derived callee function. - if (passByRef) { - argDiff = Visit(arg); - QualType argResultValueType = - utils::GetValueType(argDiff.getExpr()->getType()) - .getNonReferenceType(); - // Create ArgResult variable for each reference argument because it is - // required by error estimator. For automatic differentiation, we do not - // need to create ArgResult variable for arguments passed by reference. - // ``` - // _r0 = _d_a; - // ``` - Expr* dArg = nullptr; - if (utils::isArrayOrPointerType(argDiff.getExpr()->getType())) { - Expr* init = argDiff.getExpr_dx(); - if (isa(argDiff.getExpr_dx()->getType())) - init = utils::BuildCladArrayInitByConstArray(m_Sema, - argDiff.getExpr_dx()); - - dArg = StoreAndRef(init, GetCladArrayOfType(argResultValueType), - direction::reverse, "_r", - /*forceDeclCreation=*/true, - VarDecl::InitializationStyle::CallInit); + // `CXXOperatorCallExpr` have the `base` expression as the first argument. + size_t skipFirstArg = 0; + + // Here we do not need to check if FD is an instance method or a static + // method because C++ forbids creating operator overloads as static methods. + if (isa(CE) && isa(FD)) + skipFirstArg = 1; + + // FIXME: We should add instructions for handling non-differentiable + // arguments. Currently we are implicitly assuming function call only + // contains differentiable arguments. + for (std::size_t i = skipFirstArg, e = CE->getNumArgs(); i != e; ++i) { + const Expr* arg = CE->getArg(i); + const auto* PVD = FD->getParamDecl(i - skipFirstArg); + StmtDiff argDiff{}; + bool passByRef = utils::IsReferenceOrPointerType(PVD->getType()); + // We do not need to create result arg for arguments passed by reference + // because the derivatives of arguments passed by reference are directly + // modified by the derived callee function. + if (passByRef) { + argDiff = Visit(arg); + QualType argResultValueType = + utils::GetValueType(argDiff.getExpr()->getType()) + .getNonReferenceType(); + // Create ArgResult variable for each reference argument because it is + // required by error estimator. For automatic differentiation, we do not need + // to create ArgResult variable for arguments passed by reference. + // ``` + // _r0 = _d_a; + // ``` + Expr* dArg = nullptr; + if (utils::isArrayOrPointerType(argDiff.getExpr()->getType())) { + Expr* init = argDiff.getExpr_dx(); + if (isa(argDiff.getExpr_dx()->getType())) + init = utils::BuildCladArrayInitByConstArray(m_Sema, + argDiff.getExpr_dx()); + + dArg = StoreAndRef(init, GetCladArrayOfType(argResultValueType), + direction::reverse, "_r", + /*forceDeclCreation=*/true, + VarDecl::InitializationStyle::CallInit); + } else { + dArg = StoreAndRef(argDiff.getExpr_dx(), argResultValueType, + direction::reverse, "_r", + /*forceDeclCreation=*/true); + } + ArgResultDecls.push_back( + cast(cast(dArg)->getDecl())); } else { - dArg = StoreAndRef(argDiff.getExpr_dx(), argResultValueType, - direction::reverse, "_r", + assert(!utils::isArrayOrPointerType(arg->getType()) && + "Arguments passed by pointers should be covered in pass by " + "reference calls"); + // Create temporary variables corresponding to derivative of each + // argument, so that they can be referred to when arguments is visited. + // Variables will be initialized later after arguments is visited. This + // is done to reduce cloning complexity and only clone once. The type is + // same as the call expression as it is the type used to declare the + // _gradX array + Expr* dArg; + dArg = StoreAndRef(/*E=*/nullptr, arg->getType(), direction::reverse, "_r", /*forceDeclCreation=*/true); + ArgResultDecls.push_back( + cast(cast(dArg)->getDecl())); + // Visit using uninitialized reference. + argDiff = Visit(arg, dArg); } - ArgResultDecls.push_back( - cast(cast(dArg)->getDecl())); - } else { - assert(!utils::isArrayOrPointerType(arg->getType()) && - "Arguments passed by pointers should be covered in pass by " - "reference calls"); - // Create temporary variables corresponding to derivative of each - // argument, so that they can be referred to when arguments is visited. - // Variables will be initialized later after arguments is visited. This - // is done to reduce cloning complexity and only clone once. The type is - // same as the call expression as it is the type used to declare the - // _gradX array - Expr* dArg; - dArg = - StoreAndRef(/*E=*/nullptr, arg->getType(), direction::reverse, "_r", - /*forceDeclCreation=*/true); - ArgResultDecls.push_back( - cast(cast(dArg)->getDecl())); - // Visit using uninitialized reference. - argDiff = Visit(arg, dArg); - } - - // FIXME: We may use same argDiff.getExpr_dx at two places. This can - // lead to inconsistent pushes and pops. If `isInsideLoop` is true and - // actual argument is something like "a[i]", then argDiff.getExpr() and - // argDiff.getExpr_dx() will respectively be: - // ``` - // a[clad::push(_t0, i)]; - // a[clad::pop(_t0)]; - // ``` - // The expression `a[clad::pop(_t0)]` might already be used in the AST if - // visit was called with a dfdx() present. - // And thus using this expression in the AST explicitly may lead to size - // assertion failed. - // - // We should modify the design so that the behaviour of obtained StmtDiff - // expression is consistent both inside and outside loops. - CallArgDx.push_back(argDiff.getExpr_dx()); - // Save cloned arg in a "global" variable, so that it is accessible from - // the reverse pass. - // FIXME: At this point, we assume all the variables passed by reference - // may be changed since we have no way to determine otherwise. - // FIXME: We cannot use GlobalStoreAndRef to store a whole array so now - // arrays are not stored. - StmtDiff argDiffStore; - if (passByRef && !argDiff.getExpr()->getType()->isArrayType()) - argDiffStore = GlobalStoreAndRef(argDiff.getExpr(), "_t", /*force=*/true); - else - argDiffStore = {argDiff.getExpr(), argDiff.getExpr()}; - - // We need to pass the actual argument in the cloned call expression, - // instead of a temporary, for arguments passed by reference. This is - // because, callee function may modify the argument passed as reference - // and if we use a temporary variable then the effect of the modification - // will be lost. - // For example: - // ``` - // // original statements - // modify(a); // a is passed by reference - // modify(a); // a is passed by reference - // - // // forward pass - // _t0 = a; - // modify(_t0); // _t0 is modified instead of a - // _t1 = a; // stale value of a is being used here - // modify(_t1); - // - // // correct forward pass - // _t0 = a; - // modify(a); - // _t1 = a; - // modify(a); - // ``` - // FIXME: We cannot use GlobalStoreAndRef to store a whole array so now - // arrays are not stored. - if (passByRef && !argDiff.getExpr()->getType()->isArrayType()) { - if (isInsideLoop) { - // Add tape push expression. We need to explicitly add it here because - // we cannot add it as call expression argument -- we need to pass the - // actual argument there. - addToCurrentBlock(argDiffStore.getExpr()); - // For reference arguments, we cannot pass `clad::pop(_t0)` to the - // derived function. Because it will throw "lvalue reference cannot - // bind to rvalue error". Thus we are proceeding as follows: - // ``` - // double _r0 = clad::pop(_t0); - // derivedCalleeFunction(_r0, ...) - // ``` - VarDecl* argDiffLocalVD = BuildVarDecl( - argDiffStore.getExpr_dx()->getType(), CreateUniqueIdentifier("_r"), - argDiffStore.getExpr_dx(), - /*DirectInit=*/false, /*TSI=*/nullptr, - VarDecl::InitializationStyle::CInit); - auto& block = getCurrentBlock(direction::reverse); - block.insert(block.begin() + insertionPoint, - BuildDeclStmt(argDiffLocalVD)); - // Restore agrs - auto op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getExpr(), - BuildDeclRef(argDiffLocalVD)); - block.insert(block.begin() + insertionPoint + 1, op); - - Expr* argDiffLocalE = BuildDeclRef(argDiffLocalVD); - - // We added local variable to store result of `clad::pop(...)` and - // restoration of the original arg. Thus we need to correspondingly - // adjust the insertion point. - insertionPoint += 2; - // We cannot use the already existing `argDiff.getExpr()` here because - // it will cause inconsistent pushes and pops to the clad tape. - // FIXME: Modify `GlobalStoreAndRef` such that its functioning is - // consistent with `StoreAndRef`. This way we will not need to handle - // inside loop and outside loop cases separately. - Expr* newArgE = Visit(arg).getExpr(); - argDiffStore = {newArgE, argDiffLocalE}; + + // FIXME: We may use same argDiff.getExpr_dx at two places. This can + // lead to inconsistent pushes and pops. If `isInsideLoop` is true and + // actual argument is something like "a[i]", then argDiff.getExpr() and + // argDiff.getExpr_dx() will respectively be: + // ``` + // a[clad::push(_t0, i)]; + // a[clad::pop(_t0)]; + // ``` + // The expression `a[clad::pop(_t0)]` might already be used in the AST if + // visit was called with a dfdx() present. + // And thus using this expression in the AST explicitly may lead to size + // assertion failed. + // + // We should modify the design so that the behaviour of obtained StmtDiff + // expression is consistent both inside and outside loops. + CallArgDx.push_back(argDiff.getExpr_dx()); + // Save cloned arg in a "global" variable, so that it is accessible from + // the reverse pass. + // FIXME: At this point, we assume all the variables passed by reference + // may be changed since we have no way to determine otherwise. + // FIXME: We cannot use GlobalStoreAndRef to store a whole array so now + // arrays are not stored. + StmtDiff argDiffStore; + if (passByRef && !argDiff.getExpr()->getType()->isArrayType()) { + argDiffStore = + GlobalStoreAndRef(argDiff.getExpr(), "_t", /*force=*/true); } else { - // Restore args - auto& block = getCurrentBlock(direction::reverse); - auto op = BuildOp(BinaryOperatorKind::BO_Assign, argDiff.getExpr(), - argDiffStore.getExpr()); - block.insert(block.begin() + insertionPoint, op); - // We added restoration of the original arg. Thus we need to - // correspondingly adjust the insertion point. - insertionPoint += 1; - - argDiffStore = {argDiff.getExpr(), argDiffStore.getExpr_dx()}; + argDiffStore = {argDiff.getExpr(), argDiff.getExpr()}; } - } - CallArgs.push_back(argDiffStore.getExpr()); - DerivedCallArgs.push_back(argDiffStore.getExpr_dx()); - } - VarDecl* gradVarDecl = nullptr; - Expr* gradVarExpr = nullptr; - Expr* gradArgExpr = nullptr; - IdentifierInfo* gradVarII = nullptr; - Expr* OverloadedDerivedFn = nullptr; - // If the function has a single arg and does not returns a reference or take - // arg by reference, we look for a derivative w.r.t. to this arg using the - // forward mode(it is unlikely that we need gradient of a one-dimensional' - // function). - bool asGrad = true; - - if (NArgs == 1 && !utils::HasAnyReferenceOrPointerArgument(FD) && - !isa(FD)) { - std::string customPushforward = FD->getNameAsString() + "_pushforward"; - auto pushforwardCallArgs = DerivedCallArgs; - pushforwardCallArgs.push_back(ConstantFolder::synthesizeLiteral( - DerivedCallArgs.front()->getType(), m_Context, 1)); - OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - customPushforward, pushforwardCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext())); - if (OverloadedDerivedFn) - asGrad = false; - } - // Store all the derived call output args (if any) - llvm::SmallVector DerivedCallOutputArgs{}; - // It is required because call to numerical diff and reverse mode diff - // requires (slightly) different arguments. - llvm::SmallVector pullbackCallArgs{}; - - // Stores a list of arg result variable declaration (_r0) with the - // corresponding grad variable expression (_grad0). - llvm::SmallVector, 4> argResultsAndGrads; - - // Stores differentiation result of implicit `this` object, if any. - StmtDiff baseDiff; - // If it has more args or f_darg0 was not found, we look for its pullback - // function. - if (!OverloadedDerivedFn) { - size_t idx = 0; - - /// Add base derivative expression in the derived call output args list if - /// `CE` is a call to an instance member function. - if (const auto* MD = dyn_cast(FD)) { - if (MD->isInstance()) { - const Expr* baseOriginalE = nullptr; - if (const auto* MCE = dyn_cast(CE)) - baseOriginalE = MCE->getImplicitObjectArgument(); - else if (const auto* OCE = dyn_cast(CE)) - baseOriginalE = OCE->getArg(0); - - baseDiff = Visit(baseOriginalE); - StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); + // We need to pass the actual argument in the cloned call expression, + // instead of a temporary, for arguments passed by reference. This is + // because, callee function may modify the argument passed as reference + // and if we use a temporary variable then the effect of the modification + // will be lost. + // For example: + // ``` + // // original statements + // modify(a); // a is passed by reference + // modify(a); // a is passed by reference + // + // // forward pass + // _t0 = a; + // modify(_t0); // _t0 is modified instead of a + // _t1 = a; // stale value of a is being used here + // modify(_t1); + // + // // correct forward pass + // _t0 = a; + // modify(a); + // _t1 = a; + // modify(a); + // ``` + // FIXME: We cannot use GlobalStoreAndRef to store a whole array so now + // arrays are not stored. + if (passByRef && !argDiff.getExpr()->getType()->isArrayType()) { if (isInsideLoop) { - addToCurrentBlock(baseDiffStore.getExpr()); - VarDecl* baseLocalVD = BuildVarDecl( - baseDiffStore.getExpr_dx()->getType(), - CreateUniqueIdentifier("_r"), baseDiffStore.getExpr_dx(), + // Add tape push expression. We need to explicitly add it here because + // we cannot add it as call expression argument -- we need to pass the + // actual argument there. + addToCurrentBlock(argDiffStore.getExpr()); + // For reference arguments, we cannot pass `clad::pop(_t0)` to the + // derived function. Because it will throw "lvalue reference cannot + // bind to rvalue error". Thus we are proceeding as follows: + // ``` + // double _r0 = clad::pop(_t0); + // derivedCalleeFunction(_r0, ...) + // ``` + VarDecl* argDiffLocalVD = BuildVarDecl( + argDiffStore.getExpr_dx()->getType(), + CreateUniqueIdentifier("_r"), argDiffStore.getExpr_dx(), /*DirectInit=*/false, /*TSI=*/nullptr, VarDecl::InitializationStyle::CInit); auto& block = getCurrentBlock(direction::reverse); block.insert(block.begin() + insertionPoint, - BuildDeclStmt(baseLocalVD)); + BuildDeclStmt(argDiffLocalVD)); + // Restore agrs + auto op = BuildOp(BinaryOperatorKind::BO_Assign, + argDiff.getExpr(), BuildDeclRef(argDiffLocalVD)); + block.insert(block.begin() + insertionPoint + 1, op); + + Expr* argDiffLocalE = BuildDeclRef(argDiffLocalVD); + + // We added local variable to store result of `clad::pop(...)` and + // restoration of the original arg. Thus we need to correspondingly + // adjust the insertion point. + insertionPoint += 2; + // We cannot use the already existing `argDiff.getExpr()` here because + // it will cause inconsistent pushes and pops to the clad tape. + // FIXME: Modify `GlobalStoreAndRef` such that its functioning is + // consistent with `StoreAndRef`. This way we will not need to handle + // inside loop and outside loop cases separately. + Expr* newArgE = Visit(arg).getExpr(); + argDiffStore = {newArgE, argDiffLocalE}; + } else { + // Restore args + auto& block = getCurrentBlock(direction::reverse); + auto op = BuildOp(BinaryOperatorKind::BO_Assign, + argDiff.getExpr(), argDiffStore.getExpr()); + block.insert(block.begin() + insertionPoint, op); + // We added restoration of the original arg. Thus we need to + // correspondingly adjust the insertion point. insertionPoint += 1; - Expr* baseLocalE = BuildDeclRef(baseLocalVD); - baseDiffStore = {baseDiffStore.getExpr(), baseLocalE}; + + argDiffStore = {argDiff.getExpr(), argDiffStore.getExpr_dx()}; } - baseDiff = {baseDiffStore.getExpr_dx(), baseDiff.getExpr_dx()}; - Expr* baseDerivative = baseDiff.getExpr_dx(); - if (!baseDerivative->getType()->isPointerType()) - baseDerivative = - BuildOp(UnaryOperatorKind::UO_AddrOf, baseDerivative); - DerivedCallOutputArgs.push_back(baseDerivative); } + CallArgs.push_back(argDiffStore.getExpr()); + DerivedCallArgs.push_back(argDiffStore.getExpr_dx()); } - for (auto argDerivative : CallArgDx) { - gradVarDecl = nullptr; - gradVarExpr = nullptr; - gradArgExpr = nullptr; - gradVarII = CreateUniqueIdentifier(funcPostfix()); + VarDecl* gradVarDecl = nullptr; + Expr* gradVarExpr = nullptr; + Expr* gradArgExpr = nullptr; + IdentifierInfo* gradVarII = nullptr; + Expr* OverloadedDerivedFn = nullptr; + // If the function has a single arg and does not returns a reference or take + // arg by reference, we look for a derivative w.r.t. to this arg using the + // forward mode(it is unlikely that we need gradient of a one-dimensional' + // function). + bool asGrad = true; + + if (NArgs == 1 && !utils::HasAnyReferenceOrPointerArgument(FD) && + !isa(FD)) { + std::string customPushforward = FD->getNameAsString() + "_pushforward"; + auto pushforwardCallArgs = DerivedCallArgs; + pushforwardCallArgs.push_back(ConstantFolder::synthesizeLiteral( + DerivedCallArgs.front()->getType(), m_Context, 1)); + OverloadedDerivedFn = + m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( + customPushforward, pushforwardCallArgs, getCurrentScope(), + const_cast(FD->getDeclContext())); + if (OverloadedDerivedFn) + asGrad = false; + } + // Store all the derived call output args (if any) + llvm::SmallVector DerivedCallOutputArgs{}; + // It is required because call to numerical diff and reverse mode diff + // requires (slightly) different arguments. + llvm::SmallVector pullbackCallArgs{}; + + // Stores a list of arg result variable declaration (_r0) with the + // corresponding grad variable expression (_grad0). + llvm::SmallVector, 4> argResultsAndGrads; + + // Stores differentiation result of implicit `this` object, if any. + StmtDiff baseDiff; + // If it has more args or f_darg0 was not found, we look for its pullback + // function. + if (!OverloadedDerivedFn) { + size_t idx = 0; + + /// Add base derivative expression in the derived call output args list if + /// `CE` is a call to an instance member function. + if (const auto* MD = dyn_cast(FD)) { + if (MD->isInstance()) { + const Expr* baseOriginalE = nullptr; + if (const auto* MCE = dyn_cast(CE)) + baseOriginalE = MCE->getImplicitObjectArgument(); + else if (const auto* OCE = dyn_cast(CE)) + baseOriginalE = OCE->getArg(0); + + baseDiff = Visit(baseOriginalE); + StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); + if (isInsideLoop) { + addToCurrentBlock(baseDiffStore.getExpr()); + VarDecl* baseLocalVD = BuildVarDecl( + baseDiffStore.getExpr_dx()->getType(), + CreateUniqueIdentifier("_r"), baseDiffStore.getExpr_dx(), + /*DirectInit=*/false, /*TSI=*/nullptr, + VarDecl::InitializationStyle::CInit); + auto& block = getCurrentBlock(direction::reverse); + block.insert(block.begin() + insertionPoint, + BuildDeclStmt(baseLocalVD)); + insertionPoint += 1; + Expr* baseLocalE = BuildDeclRef(baseLocalVD); + baseDiffStore = {baseDiffStore.getExpr(), baseLocalE}; + } + baseDiff = {baseDiffStore.getExpr_dx(), baseDiff.getExpr_dx()}; + Expr* baseDerivative = baseDiff.getExpr_dx(); + if (!baseDerivative->getType()->isPointerType()) + baseDerivative = + BuildOp(UnaryOperatorKind::UO_AddrOf, baseDerivative); + DerivedCallOutputArgs.push_back(baseDerivative); + } + } - auto PVD = FD->getParamDecl(idx); - bool passByRef = utils::IsReferenceOrPointerType(PVD->getType()); - if (passByRef) { - // If derivative type is constant array type instead of - // `clad::array_ref` or `clad::array` type, then create an - // `clad::array_ref` variable that references this constant array. It - // is required because the pullback function expects `clad::array_ref` - // type for representing array derivatives. Currently, only constant - // array data members have derivatives of constant array types. - if (isa(argDerivative->getType())) { - Expr* init = - utils::BuildCladArrayInitByConstArray(m_Sema, argDerivative); - auto derivativeArrayRefVD = BuildVarDecl( - GetCladArrayRefOfType(argDerivative->getType() - ->getPointeeOrArrayElementType() - ->getCanonicalTypeInternal()), - "_t", init); - ArgDeclStmts.push_back(BuildDeclStmt(derivativeArrayRefVD)); - argDerivative = BuildDeclRef(derivativeArrayRefVD); + for (auto argDerivative : CallArgDx) { + gradVarDecl = nullptr; + gradVarExpr = nullptr; + gradArgExpr = nullptr; + gradVarII = CreateUniqueIdentifier(funcPostfix()); + + auto PVD = FD->getParamDecl(idx); + bool passByRef = utils::IsReferenceOrPointerType(PVD->getType()); + if (passByRef) { + // If derivative type is constant array type instead of + // `clad::array_ref` or `clad::array` type, then create an + // `clad::array_ref` variable that references this constant array. It + // is required because the pullback function expects `clad::array_ref` + // type for representing array derivatives. Currently, only constant + // array data members have derivatives of constant array types. + if (isa(argDerivative->getType())) { + Expr* init = + utils::BuildCladArrayInitByConstArray(m_Sema, argDerivative); + auto derivativeArrayRefVD = BuildVarDecl( + GetCladArrayRefOfType(argDerivative->getType() + ->getPointeeOrArrayElementType() + ->getCanonicalTypeInternal()), + "_t", init); + ArgDeclStmts.push_back(BuildDeclStmt(derivativeArrayRefVD)); + argDerivative = BuildDeclRef(derivativeArrayRefVD); + } + if (isCladArrayType(argDerivative->getType())) { + gradArgExpr = argDerivative; + } else { + gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative); + } + } else { + // Declare: diffArgType _grad = 0; + gradVarDecl = BuildVarDecl( + PVD->getType(), gradVarII, + ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0)); + // Pass the address of the declared variable + gradVarExpr = BuildDeclRef(gradVarDecl); + gradArgExpr = + BuildOp(UO_AddrOf, gradVarExpr, m_Function->getLocation()); + argResultsAndGrads.push_back({ArgResultDecls[idx], gradVarExpr}); } - if (isCladArrayType(argDerivative->getType())) - gradArgExpr = argDerivative; - else - gradArgExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, argDerivative); - } else { - // Declare: diffArgType _grad = 0; - gradVarDecl = BuildVarDecl( - PVD->getType(), gradVarII, - ConstantFolder::synthesizeLiteral(PVD->getType(), m_Context, 0)); - // Pass the address of the declared variable - gradVarExpr = BuildDeclRef(gradVarDecl); - gradArgExpr = - BuildOp(UO_AddrOf, gradVarExpr, m_Function->getLocation()); - argResultsAndGrads.push_back({ArgResultDecls[idx], gradVarExpr}); + DerivedCallOutputArgs.push_back(gradArgExpr); + if (gradVarDecl) + ArgDeclStmts.push_back(BuildDeclStmt(gradVarDecl)); + idx++; + } + Expr* pullback = dfdx(); + if ((pullback == nullptr) && FD->getReturnType()->isLValueReferenceType()) + pullback = getZeroInit(FD->getReturnType().getNonReferenceType()); + + // FIXME: Remove this restriction. + if (!FD->getReturnType()->isVoidType()) { + assert((pullback && !FD->getReturnType()->isVoidType()) && + "Call to function returning non-void type with no dfdx() is not " + "supported!"); } - DerivedCallOutputArgs.push_back(gradArgExpr); - if (gradVarDecl) - ArgDeclStmts.push_back(BuildDeclStmt(gradVarDecl)); - idx++; - } - Expr* pullback = dfdx(); - if ((pullback == nullptr) && FD->getReturnType()->isLValueReferenceType()) - pullback = getZeroInit(FD->getReturnType().getNonReferenceType()); - - // FIXME: Remove this restriction. - if (!FD->getReturnType()->isVoidType()) { - assert((pullback && !FD->getReturnType()->isVoidType()) && - "Call to function returning non-void type with no dfdx() is not " - "supported!"); - } - - if (FD->getReturnType()->isVoidType()) { - assert(pullback == nullptr && FD->getReturnType()->isVoidType() && - "Call to function returning void type should not have any " - "corresponding dfdx()."); - } - - DerivedCallArgs.insert(DerivedCallArgs.end(), DerivedCallOutputArgs.begin(), - DerivedCallOutputArgs.end()); - pullbackCallArgs = DerivedCallArgs; - - if (pullback) - pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() - - static_cast(skipFirstArg), - pullback); - - // Try to find it in builtin derivatives - std::string customPullback = FD->getNameAsString() + "_pullback"; - OverloadedDerivedFn = m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( - customPullback, pullbackCallArgs, getCurrentScope(), - const_cast(FD->getDeclContext())); - } - // should be true if we are using numerical differentiation to differentiate - // the callee function. - bool usingNumericalDiff = false; - // Derivative was not found, check if it is a recursive call - if (!OverloadedDerivedFn) { - if (FD == m_Function && m_Mode == DiffMode::experimental_pullback) { - // Recursive call. - auto selfRef = m_Sema - .BuildDeclarationNameExpr(CXXScopeSpec(), - m_Derivative->getNameInfo(), - m_Derivative) - .get(); + if (FD->getReturnType()->isVoidType()) { + assert(pullback == nullptr && FD->getReturnType()->isVoidType() && + "Call to function returning void type should not have any " + "corresponding dfdx()."); + } + + DerivedCallArgs.insert(DerivedCallArgs.end(), + DerivedCallOutputArgs.begin(), + DerivedCallOutputArgs.end()); + pullbackCallArgs = DerivedCallArgs; + + if (pullback) + pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() - + static_cast(skipFirstArg), + pullback); + // Try to find it in builtin derivatives + std::string customPullback = FD->getNameAsString() + "_pullback"; OverloadedDerivedFn = - m_Sema - .ActOnCallExpr(getCurrentScope(), selfRef, noLoc, - llvm::MutableArrayRef(DerivedCallArgs), - noLoc) - .get(); - } else { - if (m_ExternalSource) - m_ExternalSource->ActBeforeDifferentiatingCallExpr( - pullbackCallArgs, ArgDeclStmts, dfdx()); - // Overloaded derivative was not found, request the CladPlugin to - // derive the called function. - DiffRequest pullbackRequest{}; - pullbackRequest.Function = FD; - pullbackRequest.BaseFunctionName = FD->getNameAsString(); - pullbackRequest.Mode = DiffMode::experimental_pullback; - // Silence diag outputs in nested derivation process. - pullbackRequest.VerboseDiags = false; - FunctionDecl* pullbackFD = - plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); - // Clad failed to derive it. - // FIXME: Add support for reference arguments to the numerical diff. If - // it already correctly support reference arguments then confirm the - // support and add tests for the same. - if (!pullbackFD && !utils::HasAnyReferenceOrPointerArgument(FD) && - !isa(FD)) { - // Try numerically deriving it. - // Build a clone call expression so that we can correctly - // scope the function to be differentiated. - Expr* call = + m_Builder.BuildCallToCustomDerivativeOrNumericalDiff( + customPullback, pullbackCallArgs, getCurrentScope(), + const_cast(FD->getDeclContext())); + } + + // should be true if we are using numerical differentiation to differentiate + // the callee function. + bool usingNumericalDiff = false; + // Derivative was not found, check if it is a recursive call + if (!OverloadedDerivedFn) { + if (FD == m_Function && m_Mode == DiffMode::experimental_pullback) { + // Recursive call. + auto selfRef = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, - llvm::MutableArrayRef(CallArgs), noLoc) + .BuildDeclarationNameExpr(CXXScopeSpec(), + m_Derivative->getNameInfo(), + m_Derivative) .get(); - Expr* fnCallee = cast(call)->getCallee(); - if (NArgs == 1) { - OverloadedDerivedFn = - GetSingleArgCentralDiffCall(fnCallee, DerivedCallArgs[0], - /*targetPos=*/0, - /*numArgs=*/1, DerivedCallArgs); - asGrad = !OverloadedDerivedFn; - } else { - auto CEType = getNonConstType(CE->getType(), m_Context, m_Sema); - OverloadedDerivedFn = GetMultiArgCentralDiffCall( - fnCallee, CEType.getCanonicalType(), CE->getNumArgs(), - NumericalDiffMultiArg, DerivedCallArgs, DerivedCallOutputArgs); - } - CallExprDiffDiagnostics(FD->getNameAsString(), CE->getBeginLoc(), - OverloadedDerivedFn); - if (!OverloadedDerivedFn) { - auto& block = getCurrentBlock(direction::reverse); - block.insert(block.begin(), ArgDeclStmts.begin(), ArgDeclStmts.end()); - return StmtDiff(Clone(CE)); - } else { - usingNumericalDiff = true; - } - } else if (pullbackFD) { - if (baseDiff.getExpr()) { - Expr* baseE = baseDiff.getExpr(); - OverloadedDerivedFn = BuildCallExprToMemFn( - baseE, pullbackFD->getName(), pullbackCallArgs, pullbackFD); - } else { - OverloadedDerivedFn = - m_Sema - .ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD), - noLoc, pullbackCallArgs, noLoc) - .get(); - } - } - } - } - if (OverloadedDerivedFn) { - // Derivative was found. - FunctionDecl* fnDecl = - dyn_cast(OverloadedDerivedFn)->getDirectCallee(); - if (!asGrad) { - if (utils::IsCladValueAndPushforwardType(fnDecl->getReturnType())) - OverloadedDerivedFn = utils::BuildMemberExpr( - m_Sema, getCurrentScope(), OverloadedDerivedFn, "pushforward"); - // If the derivative is called through _darg0 instead of _grad. - Expr* d = BuildOp(BO_Mul, dfdx(), OverloadedDerivedFn); - - PerformImplicitConversionAndAssign(ArgResultDecls[0], d); - } else { - // Put Result array declaration in the function body. - // Call the gradient, passing Result as the last Arg. - auto& block = getCurrentBlock(direction::reverse); - auto it = std::begin(block) + insertionPoint; - - // Insert the _gradX declaration statements - it = block.insert(it, ArgDeclStmts.begin(), ArgDeclStmts.end()); - it += ArgDeclStmts.size(); - it = block.insert(it, NumericalDiffMultiArg.begin(), - NumericalDiffMultiArg.end()); - it += NumericalDiffMultiArg.size(); - // Insert the CallExpr to the derived function - block.insert(it, OverloadedDerivedFn); - - if (usingNumericalDiff) { - for (auto resAndGrad : argResultsAndGrads) { - VarDecl* argRes = resAndGrad.first; - Expr* grad = resAndGrad.second; - if (isCladArrayType(grad->getType())) { - Expr* E = BuildOp(BO_MulAssign, grad, dfdx()); - // Visit each arg with df/dargi = df/dxi * Result. - PerformImplicitConversionAndAssign(argRes, E); + OverloadedDerivedFn = + m_Sema + .ActOnCallExpr(getCurrentScope(), selfRef, noLoc, + llvm::MutableArrayRef(DerivedCallArgs), + noLoc) + .get(); + } else { + if (m_ExternalSource) + m_ExternalSource->ActBeforeDifferentiatingCallExpr( + pullbackCallArgs, ArgDeclStmts, dfdx()); + // Overloaded derivative was not found, request the CladPlugin to + // derive the called function. + DiffRequest pullbackRequest{}; + pullbackRequest.Function = FD; + pullbackRequest.BaseFunctionName = FD->getNameAsString(); + pullbackRequest.Mode = DiffMode::experimental_pullback; + // Silence diag outputs in nested derivation process. + pullbackRequest.VerboseDiags = false; + FunctionDecl* pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); + // Clad failed to derive it. + // FIXME: Add support for reference arguments to the numerical diff. If + // it already correctly support reference arguments then confirm the + // support and add tests for the same. + if (!pullbackFD && !utils::HasAnyReferenceOrPointerArgument(FD) && + !isa(FD)) { + // Try numerically deriving it. + // Build a clone call expression so that we can correctly + // scope the function to be differentiated. + Expr* call = m_Sema + .ActOnCallExpr(getCurrentScope(), + Clone(CE->getCallee()), + noLoc, + llvm::MutableArrayRef(CallArgs), + noLoc) + .get(); + Expr* fnCallee = cast(call)->getCallee(); + if (NArgs == 1) { + OverloadedDerivedFn = GetSingleArgCentralDiffCall(fnCallee, + DerivedCallArgs + [0], + /*targetPos=*/0, + /*numArgs=*/1, + DerivedCallArgs); + asGrad = !OverloadedDerivedFn; + } else { + auto CEType = getNonConstType(CE->getType(), m_Context, m_Sema); + OverloadedDerivedFn = GetMultiArgCentralDiffCall( + fnCallee, CEType.getCanonicalType(), CE->getNumArgs(), + NumericalDiffMultiArg, DerivedCallArgs, DerivedCallOutputArgs); + } + CallExprDiffDiagnostics(FD->getNameAsString(), CE->getBeginLoc(), + OverloadedDerivedFn); + if (!OverloadedDerivedFn) { + auto& block = getCurrentBlock(direction::reverse); + block.insert(block.begin(), ArgDeclStmts.begin(), + ArgDeclStmts.end()); + return StmtDiff(Clone(CE)); } else { - // Visit each arg with df/dargi = df/dxi * Result. - PerformImplicitConversionAndAssign(argRes, - BuildOp(BO_Mul, dfdx(), grad)); + usingNumericalDiff = true; + } + } else if (pullbackFD) { + if (baseDiff.getExpr()) { + Expr* baseE = baseDiff.getExpr(); + OverloadedDerivedFn = BuildCallExprToMemFn( + baseE, pullbackFD->getName(), pullbackCallArgs, pullbackFD); + } else { + OverloadedDerivedFn = + m_Sema + .ActOnCallExpr(getCurrentScope(), BuildDeclRef(pullbackFD), + noLoc, pullbackCallArgs, noLoc) + .get(); } } + } + } + + if (OverloadedDerivedFn) { + // Derivative was found. + FunctionDecl* fnDecl = dyn_cast(OverloadedDerivedFn) + ->getDirectCallee(); + if (!asGrad) { + if (utils::IsCladValueAndPushforwardType(fnDecl->getReturnType())) + OverloadedDerivedFn = utils::BuildMemberExpr( + m_Sema, getCurrentScope(), OverloadedDerivedFn, "pushforward"); + // If the derivative is called through _darg0 instead of _grad. + Expr* d = BuildOp(BO_Mul, dfdx(), OverloadedDerivedFn); + + PerformImplicitConversionAndAssign(ArgResultDecls[0], d); } else { - for (auto resAndGrad : argResultsAndGrads) { - VarDecl* argRes = resAndGrad.first; - Expr* grad = resAndGrad.second; - PerformImplicitConversionAndAssign(argRes, grad); + // Put Result array declaration in the function body. + // Call the gradient, passing Result as the last Arg. + auto& block = getCurrentBlock(direction::reverse); + auto it = std::begin(block) + insertionPoint; + + // Insert the _gradX declaration statements + it = block.insert(it, ArgDeclStmts.begin(), ArgDeclStmts.end()); + it += ArgDeclStmts.size(); + it = block.insert(it, NumericalDiffMultiArg.begin(), + NumericalDiffMultiArg.end()); + it += NumericalDiffMultiArg.size(); + // Insert the CallExpr to the derived function + block.insert(it, OverloadedDerivedFn); + + if (usingNumericalDiff) { + for (auto resAndGrad : argResultsAndGrads) { + VarDecl* argRes = resAndGrad.first; + Expr* grad = resAndGrad.second; + if (isCladArrayType(grad->getType())) { + Expr* E = BuildOp(BO_MulAssign, grad, dfdx()); + // Visit each arg with df/dargi = df/dxi * Result. + PerformImplicitConversionAndAssign(argRes, E); + } else { + // Visit each arg with df/dargi = df/dxi * Result. + PerformImplicitConversionAndAssign(argRes, + BuildOp(BO_Mul, dfdx(), grad)); + } + } + } else { + for (auto resAndGrad : argResultsAndGrads) { + VarDecl* argRes = resAndGrad.first; + Expr* grad = resAndGrad.second; + PerformImplicitConversionAndAssign(argRes, grad); + } } } } - } - if (m_ExternalSource) - m_ExternalSource->ActBeforeFinalizingVisitCallExpr( + if (m_ExternalSource) + m_ExternalSource->ActBeforeFinalizingVisitCallExpr( CE, OverloadedDerivedFn, DerivedCallArgs, ArgResultDecls, asGrad); - Expr* call = nullptr; - - if (FD->getReturnType()->isReferenceType()) { - DiffRequest calleeFnForwPassReq; - calleeFnForwPassReq.Function = FD; - calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; - calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString(); - calleeFnForwPassReq.VerboseDiags = true; - FunctionDecl* calleeFnForwPassFD = - plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); - - assert(calleeFnForwPassFD && - "Clad failed to generate callee function forward pass function"); - - // FIXME: We are using the derivatives in forward pass here - // If `expr_dx()` is only meant to be used in reverse pass, - // (for example, `clad::pop(...)` expression and a corresponding - // `clad::push(...)` in the forward pass), then this can result in - // incorrect derivative or crash at runtime. Ideally, we should have - // a separate routine to use derivative in the forward pass. - - // We cannot reuse the derivatives previously computed because - // they might contain 'clad::pop(..)` expression. - if (baseDiff.getExpr_dx()) { - Expr* derivedBase = baseDiff.getExpr_dx(); - // FIXME: We may need this if-block once we support pointers, and - // passing pointers-by-reference if - // (isCladArrayType(derivedBase->getType())) - // CallArgs.push_back(derivedBase); - // else - CallArgs.push_back( - BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc)); - } - - for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { - const Expr* arg = CE->getArg(i); - const ParmVarDecl* PVD = FD->getParamDecl(i); - StmtDiff argDiff = Visit(arg); - if ((argDiff.getExpr_dx() != nullptr) && - PVD->getType()->isReferenceType()) { - Expr* derivedArg = argDiff.getExpr_dx(); + Expr* call = nullptr; + + if (FD->getReturnType()->isReferenceType()) { + DiffRequest calleeFnForwPassReq; + calleeFnForwPassReq.Function = FD; + calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; + calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString(); + calleeFnForwPassReq.VerboseDiags = true; + FunctionDecl* calleeFnForwPassFD = + plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); + + assert(calleeFnForwPassFD && + "Clad failed to generate callee function forward pass function"); + + // FIXME: We are using the derivatives in forward pass here + // If `expr_dx()` is only meant to be used in reverse pass, + // (for example, `clad::pop(...)` expression and a corresponding + // `clad::push(...)` in the forward pass), then this can result in + // incorrect derivative or crash at runtime. Ideally, we should have + // a separate routine to use derivative in the forward pass. + + // We cannot reuse the derivatives previously computed because + // they might contain 'clad::pop(..)` expression. + if (baseDiff.getExpr_dx()) { + Expr* derivedBase = baseDiff.getExpr_dx(); // FIXME: We may need this if-block once we support pointers, and // passing pointers-by-reference if - // (isCladArrayType(derivedArg->getType())) - // CallArgs.push_back(derivedArg); + // (isCladArrayType(derivedBase->getType())) + // CallArgs.push_back(derivedBase); // else CallArgs.push_back( - BuildOp(UnaryOperatorKind::UO_AddrOf, derivedArg, noLoc)); - } else - CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get()); - } - if (baseDiff.getExpr()) { - Expr* baseE = baseDiff.getExpr(); - call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), - CallArgs, calleeFnForwPassFD); - } else { - call = m_Sema - .ActOnCallExpr(getCurrentScope(), - BuildDeclRef(calleeFnForwPassFD), noLoc, - CallArgs, noLoc) - .get(); - } - auto* callRes = StoreAndRef(call); - auto* resValue = - utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); - auto* resAdjoint = - utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); - return StmtDiff(resValue, nullptr, resAdjoint); - } // Recreate the original call expression. - call = m_Sema - .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, - CallArgs, noLoc) - .get(); - return StmtDiff(call); - - return {}; -} + BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc)); + } -StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { - auto opCode = UnOp->getOpcode(); - StmtDiff diff{}; - Expr* E = UnOp->getSubExpr(); - // If it is a post-increment/decrement operator, its result is a reference - // and we should return it. - Expr* ResultRef = nullptr; - if (opCode == UO_Plus) - // xi = +xj - // dxi/dxj = +1.0 - // df/dxj += df/dxi * dxi/dxj = df/dxi - diff = Visit(E, dfdx()); - else if (opCode == UO_Minus) { - // xi = -xj - // dxi/dxj = -1.0 - // df/dxj += df/dxi * dxi/dxj = -df/dxi - auto d = BuildOp(UO_Minus, dfdx()); - diff = Visit(E, d); - } else if (opCode == UO_PostInc || opCode == UO_PostDec) { - diff = Visit(E, dfdx()); - auto EStored = GlobalStoreAndRef(diff.getExpr()); - if (EStored.getExpr() != diff.getExpr()) { - auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, - Clone(diff.getExpr()), EStored.getExpr_dx()); - if (isInsideLoop) - addToCurrentBlock(EStored.getExpr(), direction::forward); - addToCurrentBlock(assign, direction::reverse); - } - - ResultRef = diff.getExpr_dx(); - if (m_ExternalSource) - m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff); - } else if (opCode == UO_PreInc || opCode == UO_PreDec) { - diff = Visit(E, dfdx()); - auto EStored = GlobalStoreAndRef(diff.getExpr()); - if (EStored.getExpr() != diff.getExpr()) { - auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, - Clone(diff.getExpr()), EStored.getExpr_dx()); - if (isInsideLoop) - addToCurrentBlock(EStored.getExpr(), direction::forward); - addToCurrentBlock(assign, direction::reverse); - } - } else if (opCode == UnaryOperatorKind::UO_Real || - opCode == UnaryOperatorKind::UO_Imag) { - diff = VisitWithExplicitNoDfDx(E); - ResultRef = BuildOp(opCode, diff.getExpr_dx()); - /// Create and add `__real r += dfdx()` expression. - if (dfdx()) { - Expr* add_assign = BuildOp(BO_AddAssign, ResultRef, dfdx()); - // Add it to the body statements. - addToCurrentBlock(add_assign, direction::reverse); - } - } else { - // FIXME: This is not adding 'address-of' operator support. - // This is just making this special case differentiable that is required - // for computing hessian: - // ``` - // Class _d_this_obj; - // Class* _d_this = &_d_this_obj; - // ``` - // This code snippet should be removed once reverse mode officially - // supports pointers. - if (opCode == UnaryOperatorKind::UO_AddrOf) { - if (auto MD = dyn_cast(m_Function)) { - if (MD->isInstance()) { - auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); - if (utils::SameCanonicalType(thisType, UnOp->getType())) { - diff = Visit(E); - Expr* cloneE = - BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr()); - Expr* derivedE = - BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr_dx()); - return {cloneE, derivedE}; - } - } + for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { + const Expr* arg = CE->getArg(i); + const ParmVarDecl* PVD = FD->getParamDecl(i); + StmtDiff argDiff = Visit(arg); + if ((argDiff.getExpr_dx() != nullptr) && + PVD->getType()->isReferenceType()) { + Expr* derivedArg = argDiff.getExpr_dx(); + // FIXME: We may need this if-block once we support pointers, and + // passing pointers-by-reference if + // (isCladArrayType(derivedArg->getType())) + // CallArgs.push_back(derivedArg); + // else + CallArgs.push_back( + BuildOp(UnaryOperatorKind::UO_AddrOf, derivedArg, noLoc)); + } else + CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get()); } - } - // We should not output any warning on visiting boolean conditions - // FIXME: We should support boolean differentiation or ignore it - // completely - if (opCode != UO_LNot) - unsupportedOpWarn(UnOp->getEndLoc()); + if (baseDiff.getExpr()) { + Expr* baseE = baseDiff.getExpr(); + call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), + CallArgs, calleeFnForwPassFD); + } else { + call = m_Sema + .ActOnCallExpr(getCurrentScope(), + BuildDeclRef(calleeFnForwPassFD), noLoc, + CallArgs, noLoc) + .get(); + } + auto* callRes = StoreAndRef(call); + auto* resValue = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); + auto* resAdjoint = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); + return StmtDiff(resValue, nullptr, resAdjoint); + } // Recreate the original call expression. + call = m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, + CallArgs, noLoc) + .get(); + return StmtDiff(call); - if (isa(E)) - diff = Visit(E); - else - diff = StmtDiff(E); + return {}; } - Expr* op = BuildOp(opCode, diff.getExpr()); - return StmtDiff(op, ResultRef); -} -StmtDiff ReverseModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { - auto opCode = BinOp->getOpcode(); - StmtDiff Ldiff{}; - StmtDiff Rdiff{}; - StmtDiff Lstored{}; - auto L = BinOp->getLHS(); - auto R = BinOp->getRHS(); - // If it is an assignment operator, its result is a reference to LHS and - // we should return it. - Expr* ResultRef = nullptr; - - if (opCode == BO_Add) { - // xi = xl + xr - // dxi/xl = 1.0 - // df/dxl += df/dxi * dxi/xl = df/dxi - Ldiff = Visit(L, dfdx()); - // dxi/xr = 1.0 - // df/dxr += df/dxi * dxi/xr = df/dxi - Rdiff = Visit(R, dfdx()); - } else if (opCode == BO_Sub) { - // xi = xl - xr - // dxi/xl = 1.0 - // df/dxl += df/dxi * dxi/xl = df/dxi - Ldiff = Visit(L, dfdx()); - // dxi/xr = -1.0 - // df/dxl += df/dxi * dxi/xr = -df/dxi - auto dr = BuildOp(UO_Minus, dfdx()); - Rdiff = Visit(R, dr); - } else if (opCode == BO_Mul) { - // xi = xl * xr - // dxi/xl = xr - // df/dxl += df/dxi * dxi/xl = df/dxi * xr - // Create uninitialized "global" variable for the right multiplier. - // It will be assigned later after R is visited and cloned. This allows - // to reduce cloning complexity and only clones once. Storing it in a - // global variable allows to save current result and make it accessible - // in the reverse pass. - auto RDelayed = DelayedGlobalStoreAndRef(R); - StmtDiff RResult = RDelayed.Result; - Expr* dl = nullptr; - if (dfdx()) { - dl = BuildOp(BO_Mul, dfdx(), RResult.getExpr_dx()); - dl = StoreAndRef(dl, direction::reverse); - } - Ldiff = Visit(L, dl); - // dxi/xr = xl - // df/dxr += df/dxi * dxi/xr = df/dxi * xl - // Store left multiplier and assign it with L. - Expr* LStored = Ldiff.getExpr(); - // RDelayed.isConstant == true implies that R is a constant expression, - // therefore we can skip visiting it. - if (!RDelayed.isConstant) { - Expr* dr = nullptr; + StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { + auto opCode = UnOp->getOpcode(); + StmtDiff diff{}; + Expr* E = UnOp->getSubExpr(); + // If it is a post-increment/decrement operator, its result is a reference + // and we should return it. + Expr* ResultRef = nullptr; + if (opCode == UO_Plus) + // xi = +xj + // dxi/dxj = +1.0 + // df/dxj += df/dxi * dxi/dxj = df/dxi + diff = Visit(E, dfdx()); + else if (opCode == UO_Minus) { + // xi = -xj + // dxi/dxj = -1.0 + // df/dxj += df/dxi * dxi/dxj = -df/dxi + auto d = BuildOp(UO_Minus, dfdx()); + diff = Visit(E, d); + } else if (opCode == UO_PostInc || opCode == UO_PostDec) { + diff = Visit(E, dfdx()); + auto EStored = GlobalStoreAndRef(diff.getExpr()); + if (EStored.getExpr() != diff.getExpr()) { + auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, + Clone(diff.getExpr()), EStored.getExpr_dx()); + if (isInsideLoop) + addToCurrentBlock(EStored.getExpr(), direction::forward); + addToCurrentBlock(assign, direction::reverse); + } + + ResultRef = diff.getExpr_dx(); + if (m_ExternalSource) + m_ExternalSource->ActBeforeFinalisingPostIncDecOp(diff); + } else if (opCode == UO_PreInc || opCode == UO_PreDec) { + diff = Visit(E, dfdx()); + auto EStored = GlobalStoreAndRef(diff.getExpr()); + if (EStored.getExpr() != diff.getExpr()) { + auto* assign = BuildOp(BinaryOperatorKind::BO_Assign, + Clone(diff.getExpr()), EStored.getExpr_dx()); + if (isInsideLoop) + addToCurrentBlock(EStored.getExpr(), direction::forward); + addToCurrentBlock(assign, direction::reverse); + } + } else if (opCode == UnaryOperatorKind::UO_Real || + opCode == UnaryOperatorKind::UO_Imag) { + diff = VisitWithExplicitNoDfDx(E); + ResultRef = BuildOp(opCode, diff.getExpr_dx()); + /// Create and add `__real r += dfdx()` expression. if (dfdx()) { - StmtDiff LResult; - if (isa(LStored->IgnoreImpCasts())) - LResult = {LStored, LStored}; - else - LResult = GlobalStoreAndRef(LStored, "_t", /*force=*/true); - LStored = LResult.getExpr(); - dr = BuildOp(BO_Mul, LResult.getExpr_dx(), dfdx()); - dr = StoreAndRef(dr, direction::reverse); + Expr* add_assign = BuildOp(BO_AddAssign, ResultRef, dfdx()); + // Add it to the body statements. + addToCurrentBlock(add_assign, direction::reverse); + } + } else { + // FIXME: This is not adding 'address-of' operator support. + // This is just making this special case differentiable that is required + // for computing hessian: + // ``` + // Class _d_this_obj; + // Class* _d_this = &_d_this_obj; + // ``` + // This code snippet should be removed once reverse mode officially + // supports pointers. + if (opCode == UnaryOperatorKind::UO_AddrOf) { + if (auto MD = dyn_cast(m_Function)) { + if (MD->isInstance()) { + auto thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); + if (utils::SameCanonicalType(thisType, UnOp->getType())) { + diff = Visit(E); + Expr* cloneE = + BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr()); + Expr* derivedE = + BuildOp(UnaryOperatorKind::UO_AddrOf, diff.getExpr_dx()); + return {cloneE, derivedE}; + } + } + } } + // We should not output any warning on visiting boolean conditions + // FIXME: We should support boolean differentiation or ignore it + // completely + if (opCode != UO_LNot) + unsupportedOpWarn(UnOp->getEndLoc()); + + if (isa(E)) + diff = Visit(E); + else + diff = StmtDiff(E); + } + Expr* op = BuildOp(opCode, diff.getExpr()); + return StmtDiff(op, ResultRef); + } + + StmtDiff + ReverseModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { + auto opCode = BinOp->getOpcode(); + StmtDiff Ldiff{}; + StmtDiff Rdiff{}; + StmtDiff Lstored{}; + auto L = BinOp->getLHS(); + auto R = BinOp->getRHS(); + // If it is an assignment operator, its result is a reference to LHS and + // we should return it. + Expr* ResultRef = nullptr; + + if (opCode == BO_Add) { + // xi = xl + xr + // dxi/xl = 1.0 + // df/dxl += df/dxi * dxi/xl = df/dxi + Ldiff = Visit(L, dfdx()); + // dxi/xr = 1.0 + // df/dxr += df/dxi * dxi/xr = df/dxi + Rdiff = Visit(R, dfdx()); + } else if (opCode == BO_Sub) { + // xi = xl - xr + // dxi/xl = 1.0 + // df/dxl += df/dxi * dxi/xl = df/dxi + Ldiff = Visit(L, dfdx()); + // dxi/xr = -1.0 + // df/dxl += df/dxi * dxi/xr = -df/dxi + auto dr = BuildOp(UO_Minus, dfdx()); Rdiff = Visit(R, dr); - // Assign right multiplier's variable with R. - RDelayed.Finalize(Rdiff.getExpr()); - } - std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult.getExpr()); - } else if (opCode == BO_Div) { - // xi = xl / xr - // dxi/xl = 1 / xr - // df/dxl += df/dxi * dxi/xl = df/dxi * (1/xr) - auto RDelayed = DelayedGlobalStoreAndRef(R); - StmtDiff RResult = RDelayed.Result; - Expr* RStored = StoreAndRef(RResult.getExpr_dx(), direction::reverse); - Expr* dl = nullptr; - if (dfdx()) { - dl = BuildOp(BO_Div, dfdx(), RStored); - dl = StoreAndRef(dl, direction::reverse); - } - Ldiff = Visit(L, dl); - // dxi/xr = -xl / (xr * xr) - // df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr)) - // Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is - // produced instead of 1 / (R * R). - Expr* LStored = Ldiff.getExpr(); - if (!RDelayed.isConstant) { - Expr* dr = nullptr; - StmtDiff LResult; + } else if (opCode == BO_Mul) { + // xi = xl * xr + // dxi/xl = xr + // df/dxl += df/dxi * dxi/xl = df/dxi * xr + // Create uninitialized "global" variable for the right multiplier. + // It will be assigned later after R is visited and cloned. This allows + // to reduce cloning complexity and only clones once. Storing it in a + // global variable allows to save current result and make it accessible + // in the reverse pass. + auto RDelayed = DelayedGlobalStoreAndRef(R); + StmtDiff RResult = RDelayed.Result; + Expr* dl = nullptr; if (dfdx()) { - if (isa(LStored->IgnoreParenImpCasts())) - LResult = {LStored, LStored}; - else - LResult = GlobalStoreAndRef(LStored, "_t", /*force=*/true); - LStored = LResult.getExpr(); - Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored)); - dr = BuildOp( - BO_Mul, dfdx(), - BuildOp(UO_Minus, BuildOp(BO_Div, LResult.getExpr_dx(), RxR))); - dr = StoreAndRef(dr, direction::reverse); + dl = BuildOp(BO_Mul, dfdx(), RResult.getExpr_dx()); + dl = StoreAndRef(dl, direction::reverse); + } + Ldiff = Visit(L, dl); + // dxi/xr = xl + // df/dxr += df/dxi * dxi/xr = df/dxi * xl + // Store left multiplier and assign it with L. + Expr* LStored = Ldiff.getExpr(); + // RDelayed.isConstant == true implies that R is a constant expression, + // therefore we can skip visiting it. + if (!RDelayed.isConstant) { + Expr* dr = nullptr; + if (dfdx()) { + StmtDiff LResult; + if (isa(LStored->IgnoreImpCasts())) + LResult = {LStored, LStored}; + else + LResult = GlobalStoreAndRef(LStored, "_t", /*force=*/true); + LStored = LResult.getExpr(); + dr = BuildOp(BO_Mul, LResult.getExpr_dx(), dfdx()); + dr = StoreAndRef(dr, direction::reverse); + } + Rdiff = Visit(R, dr); + // Assign right multiplier's variable with R. + RDelayed.Finalize(Rdiff.getExpr()); + } + std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult.getExpr()); + } else if (opCode == BO_Div) { + // xi = xl / xr + // dxi/xl = 1 / xr + // df/dxl += df/dxi * dxi/xl = df/dxi * (1/xr) + auto RDelayed = DelayedGlobalStoreAndRef(R); + StmtDiff RResult = RDelayed.Result; + Expr* RStored = StoreAndRef(RResult.getExpr_dx(), direction::reverse); + Expr* dl = nullptr; + if (dfdx()) { + dl = BuildOp(BO_Div, dfdx(), RStored); + dl = StoreAndRef(dl, direction::reverse); + } + Ldiff = Visit(L, dl); + // dxi/xr = -xl / (xr * xr) + // df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr)) + // Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is + // produced instead of 1 / (R * R). + Expr* LStored = Ldiff.getExpr(); + if (!RDelayed.isConstant) { + Expr* dr = nullptr; + StmtDiff LResult; + if (dfdx()) { + if (isa(LStored->IgnoreParenImpCasts())) + LResult = {LStored, LStored}; + else + LResult = GlobalStoreAndRef(LStored, "_t", /*force=*/true); + LStored = LResult.getExpr(); + Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored)); + dr = BuildOp(BO_Mul, + dfdx(), + BuildOp(UO_Minus, + BuildOp(BO_Div, LResult.getExpr_dx(), RxR))); + dr = StoreAndRef(dr, direction::reverse); + } + Rdiff = Visit(R, dr); + RDelayed.Finalize(Rdiff.getExpr()); + } + std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult.getExpr()); + } else if (BinOp->isAssignmentOp()) { + if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) { + diag(DiagnosticsEngine::Warning, + BinOp->getEndLoc(), + "derivative of an assignment attempts to assign to unassignable " + "expr, assignment ignored"); + auto LDRE = dyn_cast(L); + auto RDRE = dyn_cast(R); + + if (!LDRE && !RDRE) + return Clone(BinOp); + Expr* LExpr = LDRE ? Visit(L).getExpr() : L; + Expr* RExpr = RDRE ? Visit(R).getExpr() : R; + + return BuildOp(opCode, LExpr, RExpr); } - Rdiff = Visit(R, dr); - RDelayed.Finalize(Rdiff.getExpr()); - } - std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult.getExpr()); - } else if (BinOp->isAssignmentOp()) { - if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) { - diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(), - "derivative of an assignment attempts to assign to unassignable " - "expr, assignment ignored"); - auto LDRE = dyn_cast(L); - auto RDRE = dyn_cast(R); - - if (!LDRE && !RDRE) - return Clone(BinOp); - Expr* LExpr = LDRE ? Visit(L).getExpr() : L; - Expr* RExpr = RDRE ? Visit(R).getExpr() : R; - - return BuildOp(opCode, LExpr, RExpr); - } - // FIXME: Put this code into a separate subroutine and break out early - // using return if the diff mode is not jacobian and we are not dealing - // with the `outputArray`. - if (auto ASE = dyn_cast(L)) { - if (auto DRE = dyn_cast(ASE->getBase()->IgnoreImplicit())) { - auto type = QualType(DRE->getType()->getPointeeOrArrayElementType(), - /*Quals=*/0); - std::string DRE_str = DRE->getDecl()->getNameAsString(); - - llvm::APSInt intIdx; - auto isIdxValid = - clad_compat::Expr_EvaluateAsInt(ASE->getIdx(), intIdx, m_Context); - - if (DRE_str == outputArrayStr && isIdxValid) { - if (isVectorValued) { - outputArrayCursor = intIdx.getExtValue(); - - std::unordered_map - temp_m_Variables; - for (unsigned i = 0; i < numParams; i++) { - auto size_type = m_Context.getSizeType(); - unsigned size_type_bits = m_Context.getIntWidth(size_type); - llvm::APInt idxValue(size_type_bits, - i + (outputArrayCursor * numParams)); - auto idx = - IntegerLiteral::Create(m_Context, idxValue, size_type, noLoc); - // Create the jacobianMatrix[idx] expression. - auto result_at_i = m_Sema - .CreateBuiltinArraySubscriptExpr( - m_Result, noLoc, idx, noLoc) - .get(); - temp_m_Variables[m_IndependentVars[i]] = result_at_i; + // FIXME: Put this code into a separate subroutine and break out early + // using return if the diff mode is not jacobian and we are not dealing + // with the `outputArray`. + if (auto ASE = dyn_cast(L)) { + if (auto DRE = dyn_cast(ASE->getBase()->IgnoreImplicit())) { + auto type = QualType(DRE->getType()->getPointeeOrArrayElementType(), + /*Quals=*/0); + std::string DRE_str = DRE->getDecl()->getNameAsString(); + + llvm::APSInt intIdx; + auto isIdxValid = + clad_compat::Expr_EvaluateAsInt(ASE->getIdx(), intIdx, m_Context); + + if (DRE_str == outputArrayStr && isIdxValid) { + if (isVectorValued) { + outputArrayCursor = intIdx.getExtValue(); + + std::unordered_map + temp_m_Variables; + for (unsigned i = 0; i < numParams; i++) { + auto size_type = m_Context.getSizeType(); + unsigned size_type_bits = m_Context.getIntWidth(size_type); + llvm::APInt idxValue(size_type_bits, + i + (outputArrayCursor * numParams)); + auto idx = IntegerLiteral::Create(m_Context, idxValue, + size_type, noLoc); + // Create the jacobianMatrix[idx] expression. + auto result_at_i = m_Sema + .CreateBuiltinArraySubscriptExpr( + m_Result, noLoc, idx, noLoc) + .get(); + temp_m_Variables[m_IndependentVars[i]] = result_at_i; + } + m_VectorOutput.push_back(temp_m_Variables); } - m_VectorOutput.push_back(temp_m_Variables); + + auto dfdf = ConstantFolder::synthesizeLiteral(m_Context.IntTy, + m_Context, 1); + ExprResult tmp = dfdf; + dfdf = m_Sema + .ImpCastExprToType(tmp.get(), type, + m_Sema.PrepareScalarCast(tmp, type)) + .get(); + auto ReturnResult = DifferentiateSingleExpr(R, dfdf); + StmtDiff ReturnDiff = ReturnResult.first; + Stmt* Reverse = ReturnDiff.getStmt_dx(); + addToCurrentBlock(Reverse, direction::reverse); + for (Stmt* S : cast(ReturnDiff.getStmt())->body()) + addToCurrentBlock(S, direction::forward); } + } + } - auto dfdf = - ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 1); - ExprResult tmp = dfdf; - dfdf = m_Sema - .ImpCastExprToType(tmp.get(), type, - m_Sema.PrepareScalarCast(tmp, type)) - .get(); - auto ReturnResult = DifferentiateSingleExpr(R, dfdf); - StmtDiff ReturnDiff = ReturnResult.first; - Stmt* Reverse = ReturnDiff.getStmt_dx(); - addToCurrentBlock(Reverse, direction::reverse); - for (Stmt* S : cast(ReturnDiff.getStmt())->body()) - addToCurrentBlock(S, direction::forward); + // Visit LHS, but delay emission of its derivative statements, save them + // in Lblock + beginBlock(direction::reverse); + beginBlock(direction::essential_reverse); + Ldiff = Visit(L, dfdx()); + Stmts essentialRevBlock = EndBlockWithoutCreatingCS(direction::essential_reverse); + auto Lblock = endBlock(direction::reverse); + auto return_exprs = utils::GetInnermostReturnExpr(Ldiff.getExpr()); + if (L->HasSideEffects(m_Context)) { + Expr* E = Ldiff.getExpr(); + auto storeE = StoreAndRef(E, m_Context.getLValueReferenceType(E->getType())); + Ldiff.updateStmt(storeE); + } + + // if (/*!L->HasSideEffects(m_Context)*/true) { + // Lstored = GlobalStoreAndRef(Ldiff.getExpr(), "_t", /*force*/true); + // auto assign = BuildOp(BO_Assign, Ldiff.getExpr(), Lstored.getExpr_dx()); + // if (isInsideLoop) { + // addToCurrentBlock(Lstored.getExpr(), direction::forward); + // } + // addToCurrentBlock(assign, direction::reverse); + // } + + Expr* LCloned = Ldiff.getExpr(); + // For x, AssignedDiff is _d_x, for x[i] its _d_x[i], for reference exprs + // like (x = y) it propagates recursively, so _d_x is also returned. + Expr* AssignedDiff = Ldiff.getExpr_dx(); + if (!AssignedDiff) { + // If either LHS or RHS is a declaration reference, visit it to avoid + // naming collision + auto LDRE = dyn_cast(L); + auto RDRE = dyn_cast(R); + + if (!LDRE && !RDRE) + return Clone(BinOp); + + Expr* LExpr = LDRE ? Visit(L).getExpr() : L; + Expr* RExpr = RDRE ? Visit(R).getExpr() : R; + + return BuildOp(opCode, LExpr, RExpr); + } + ResultRef = AssignedDiff; + // If assigned expr is dependent, first update its derivative; + auto Lblock_begin = Lblock->body_rbegin(); + auto Lblock_end = Lblock->body_rend(); + // if (Lblock->size()) { + // addToCurrentBlock(*Lblock_begin, direction::reverse); + // Lblock_begin = std::next(Lblock_begin); + // } + + for (auto S : essentialRevBlock) + addToCurrentBlock(S, direction::reverse); + + if (dfdx() && Lblock_begin != Lblock_end) { + addToCurrentBlock(*Lblock_begin, direction::reverse); + Lblock_begin = std::next(Lblock_begin); + } + + for (auto E : return_exprs) { + Lstored = GlobalStoreAndRef(E); + if (Lstored.getExpr() != E) { + auto* assign = + BuildOp(BinaryOperatorKind::BO_Assign, E, Lstored.getExpr_dx()); + if (isInsideLoop) + addToCurrentBlock(Lstored.getExpr(), direction::forward); + addToCurrentBlock(assign, direction::reverse); } } - } - // Visit LHS, but delay emission of its derivative statements, save them - // in Lblock - beginBlock(direction::reverse); - beginBlock(direction::essential_reverse); - Ldiff = Visit(L, dfdx()); - Stmts essentialRevBlock = - EndBlockWithoutCreatingCS(direction::essential_reverse); - auto Lblock = endBlock(direction::reverse); - auto return_exprs = utils::GetInnermostReturnExpr(Ldiff.getExpr()); - if (L->HasSideEffects(m_Context)) { - Expr* E = Ldiff.getExpr(); - auto storeE = - StoreAndRef(E, m_Context.getLValueReferenceType(E->getType())); - Ldiff.updateStmt(storeE); - } - - // if (/*!L->HasSideEffects(m_Context)*/true) { - // Lstored = GlobalStoreAndRef(Ldiff.getExpr(), "_t", /*force*/true); - // auto assign = BuildOp(BO_Assign, Ldiff.getExpr(), - // Lstored.getExpr_dx()); if (isInsideLoop) { - // addToCurrentBlock(Lstored.getExpr(), direction::forward); - // } - // addToCurrentBlock(assign, direction::reverse); - // } + if (m_ExternalSource) + m_ExternalSource->ActAfterCloningLHSOfAssignOp(LCloned, R, opCode); + + // Save old value for the derivative of LHS, to avoid problems with cases + // like x = x. + auto oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d", + /*forceDeclCreation=*/true); + + if (opCode == BO_Assign) { + Rdiff = Visit(R, oldValue); + } else if (opCode == BO_AddAssign) { + addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), + direction::reverse); + Rdiff = Visit(R, oldValue); + } else if (opCode == BO_SubAssign) { + addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), + direction::reverse); + Rdiff = Visit(R, BuildOp(UO_Minus, oldValue)); + } else if (opCode == BO_MulAssign) { + auto RDelayed = DelayedGlobalStoreAndRef(R); + StmtDiff RResult = RDelayed.Result; + addToCurrentBlock( + BuildOp(BO_AddAssign, + AssignedDiff, + BuildOp(BO_Mul, oldValue, RResult.getExpr_dx())), + direction::reverse); + Expr* LRef = LCloned; + if (!RDelayed.isConstant) { + // Create a reference variable to keep the result of LHS, since it + // must be used on 2 places: when storing to a global variable + // accessible from the reverse pass, and when rebuilding the original + // expression for the forward pass. This allows to avoid executing + // same expression with side effects twice. E.g., on + // double r = (x *= y) *= z; + // instead of: + // _t0 = (x *= y); + // double r = (x *= y) *= z; + // which modifies x twice, we get: + // double & _ref0 = (x *= y); + // _t0 = _ref0; + // double r = _ref0 *= z; + StmtDiff LResult; + if (LCloned->HasSideEffects(m_Context)) { + auto RefType = getNonConstType(L->getType(), m_Context, m_Sema); + // RefType = m_Context.getLValueReferenceType(RefType); + LRef = StoreAndRef(LCloned, RefType, direction::forward, "_ref", + /*forceDeclCreation=*/true); + LResult = GlobalStoreAndRef(LRef, "_t", /*force=*/true); + } else + LResult = {LRef, LRef}; + + if (isInsideLoop) + addToCurrentBlock(LResult.getExpr(), direction::forward); + Expr* dr = BuildOp(BO_Mul, LResult.getExpr_dx(), oldValue); + dr = StoreAndRef(dr, direction::reverse); + Rdiff = Visit(R, dr); + RDelayed.Finalize(Rdiff.getExpr()); + } + std::tie(Ldiff, Rdiff) = std::make_pair(LRef, RResult.getExpr()); + } else if (opCode == BO_DivAssign) { + auto RDelayed = DelayedGlobalStoreAndRef(R); + StmtDiff RResult = RDelayed.Result; + Expr* RStored = StoreAndRef(RResult.getExpr_dx(), direction::reverse); + addToCurrentBlock(BuildOp(BO_AddAssign, + AssignedDiff, + BuildOp(BO_Div, oldValue, RStored)), + direction::reverse); + Expr* LRef = LCloned; + if (!RDelayed.isConstant) { + StmtDiff LResult; + if (LCloned->HasSideEffects(m_Context)) { + QualType RefType = m_Context.getLValueReferenceType( + getNonConstType(L->getType(), m_Context, m_Sema)); + LRef = StoreAndRef(LCloned, RefType, direction::forward, "_ref", + /*forceDeclCreation=*/true); + LResult = GlobalStoreAndRef(LRef, "_t", /*force=*/true); + } else + LResult = {LRef, LRef}; + if (isInsideLoop) + addToCurrentBlock(LResult.getExpr(), direction::forward); + Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored)); + Expr* dr = BuildOp( + BO_Mul, + oldValue, + BuildOp(UO_Minus, BuildOp(BO_Div, LResult.getExpr_dx(), RxR))); + dr = StoreAndRef(dr, direction::reverse); + Rdiff = Visit(R, dr); + RDelayed.Finalize(Rdiff.getExpr()); + } + std::tie(Ldiff, Rdiff) = std::make_pair(LRef, RResult.getExpr()); + } else + llvm_unreachable("unknown assignment opCode"); + if (m_ExternalSource) + m_ExternalSource->ActBeforeFinalisingAssignOp(LCloned, oldValue); + + // Update the derivative. + addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), direction::reverse); + // Output statements from Visit(L). + for (auto it = Lblock_begin; it != Lblock_end; ++it) + addToCurrentBlock(*it, direction::reverse); + } else if (opCode == BO_Comma) { + auto zero = + ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); + Ldiff = Visit(L, zero); + Rdiff = Visit(R, dfdx()); + ResultRef = Ldiff.getExpr(); + } else { + // We should not output any warning on visiting boolean conditions + // FIXME: We should support boolean differentiation or ignore it + // completely + if (!BinOp->isComparisonOp() && !BinOp->isLogicalOp()) + unsupportedOpWarn(BinOp->getEndLoc()); - Expr* LCloned = Ldiff.getExpr(); - // For x, AssignedDiff is _d_x, for x[i] its _d_x[i], for reference exprs - // like (x = y) it propagates recursively, so _d_x is also returned. - Expr* AssignedDiff = Ldiff.getExpr_dx(); - if (!AssignedDiff) { // If either LHS or RHS is a declaration reference, visit it to avoid // naming collision auto LDRE = dyn_cast(L); @@ -2224,1189 +2451,1073 @@ StmtDiff ReverseModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { return BuildOp(opCode, LExpr, RExpr); } - ResultRef = AssignedDiff; - // If assigned expr is dependent, first update its derivative; - auto Lblock_begin = Lblock->body_rbegin(); - auto Lblock_end = Lblock->body_rend(); - // if (Lblock->size()) { - // addToCurrentBlock(*Lblock_begin, direction::reverse); - // Lblock_begin = std::next(Lblock_begin); - // } - - for (auto S : essentialRevBlock) - addToCurrentBlock(S, direction::reverse); + Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr()); + return StmtDiff(op, ResultRef); + } + + VarDeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { + StmtDiff initDiff; + Expr* VDDerivedInit = nullptr; + auto VDDerivedType = ComputeAdjointType(VD->getType()); + bool isDerivativeOfRefType = VD->getType()->isReferenceType(); + VarDecl* VDDerived = nullptr; + + // VDDerivedInit now serves two purposes -- as the initial derivative value + // or the size of the derivative array -- depending on the primal type. + if (auto AT = dyn_cast(VD->getType())) { + Expr* init = getArraySizeExpr(AT, m_Context, *this); + VDDerivedInit = init; + VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), + VDDerivedInit, false, nullptr, + clang::VarDecl::InitializationStyle::CallInit); + } else { + // If VD is a reference to a local variable, then the initial value is set + // to the derived variable of the corresponding local variable. + // If VD is a reference to a non-local variable (global variable, struct + // member etc), then no derived variable is available, thus `VDDerived` + // does not need to reference any variable, consequentially the + // `VDDerivedType` is the corresponding non-reference type and the initial + // value is set to 0. + // Otherwise, for non-reference types, the initial value is set to 0. + VDDerivedInit = getZeroInit(VD->getType()); + + // `specialThisDiffCase` is only required for correctly differentiating + // the following code: + // ``` + // Class _d_this_obj; + // Class* _d_this = &_d_this_obj; + // ``` + // Computation of hessian requires this code to be correctly + // differentiated. + bool specialThisDiffCase = false; + if (auto MD = dyn_cast(m_Function)) { + if (VDDerivedType->isPointerType() && MD->isInstance()) { + specialThisDiffCase = true; + } + } - if (dfdx() && Lblock_begin != Lblock_end) { - addToCurrentBlock(*Lblock_begin, direction::reverse); - Lblock_begin = std::next(Lblock_begin); - } + if (isDerivativeOfRefType) { + initDiff = Visit(VD->getInit()); + if (!initDiff.getForwSweepExpr_dx()) { + VDDerivedType = + ComputeAdjointType(VD->getType().getNonReferenceType()); + isDerivativeOfRefType = false; + } + VDDerivedInit = getZeroInit(VDDerivedType); + } - for (auto E : return_exprs) { - Lstored = GlobalStoreAndRef(E); - if (Lstored.getExpr() != E) { - auto* assign = - BuildOp(BinaryOperatorKind::BO_Assign, E, Lstored.getExpr_dx()); - if (isInsideLoop) - addToCurrentBlock(Lstored.getExpr(), direction::forward); - addToCurrentBlock(assign, direction::reverse); + // FIXME: Remove the special cases introduced by `specialThisDiffCase` + // once reverse mode supports pointers. `specialThisDiffCase` is only + // required for correctly differentiating the following code: + // ``` + // Class _d_this_obj; + // Class* _d_this = &_d_this_obj; + // ``` + // Computation of hessian requires this code to be correctly + // differentiated. + if (specialThisDiffCase && VD->getNameAsString() == "_d_this") { + VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); + initDiff = Visit(VD->getInit()); + if (initDiff.getExpr_dx()) + VDDerivedInit = initDiff.getExpr_dx(); } + // Here separate behaviour for record and non-record types is only + // necessary to preserve the old tests. + if (VDDerivedType->isRecordType()) + VDDerived = + BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), + VDDerivedInit, VD->isDirectInit(), + m_Context.getTrivialTypeSourceInfo(VDDerivedType), + VD->getInitStyle()); + else + VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), + VDDerivedInit); } - if (m_ExternalSource) - m_ExternalSource->ActAfterCloningLHSOfAssignOp(LCloned, R, opCode); - - // Save old value for the derivative of LHS, to avoid problems with cases - // like x = x. - auto oldValue = StoreAndRef(AssignedDiff, direction::reverse, "_r_d", - /*forceDeclCreation=*/true); - - if (opCode == BO_Assign) { - Rdiff = Visit(R, oldValue); - } else if (opCode == BO_AddAssign) { - addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), - direction::reverse); - Rdiff = Visit(R, oldValue); - } else if (opCode == BO_SubAssign) { - addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, oldValue), - direction::reverse); - Rdiff = Visit(R, BuildOp(UO_Minus, oldValue)); - } else if (opCode == BO_MulAssign) { - auto RDelayed = DelayedGlobalStoreAndRef(R); - StmtDiff RResult = RDelayed.Result; - addToCurrentBlock( - BuildOp(BO_AddAssign, AssignedDiff, - BuildOp(BO_Mul, oldValue, RResult.getExpr_dx())), - direction::reverse); - Expr* LRef = LCloned; - if (!RDelayed.isConstant) { - // Create a reference variable to keep the result of LHS, since it - // must be used on 2 places: when storing to a global variable - // accessible from the reverse pass, and when rebuilding the original - // expression for the forward pass. This allows to avoid executing - // same expression with side effects twice. E.g., on - // double r = (x *= y) *= z; - // instead of: - // _t0 = (x *= y); - // double r = (x *= y) *= z; - // which modifies x twice, we get: - // double & _ref0 = (x *= y); - // _t0 = _ref0; - // double r = _ref0 *= z; - StmtDiff LResult; - if (LCloned->HasSideEffects(m_Context)) { - auto RefType = getNonConstType(L->getType(), m_Context, m_Sema); - // RefType = m_Context.getLValueReferenceType(RefType); - LRef = StoreAndRef(LCloned, RefType, direction::forward, "_ref", - /*forceDeclCreation=*/true); - LResult = GlobalStoreAndRef(LRef, "_t", /*force=*/true); - } else - LResult = {LRef, LRef}; - - if (isInsideLoop) - addToCurrentBlock(LResult.getExpr(), direction::forward); - Expr* dr = BuildOp(BO_Mul, LResult.getExpr_dx(), oldValue); - dr = StoreAndRef(dr, direction::reverse); - Rdiff = Visit(R, dr); - RDelayed.Finalize(Rdiff.getExpr()); - } - std::tie(Ldiff, Rdiff) = std::make_pair(LRef, RResult.getExpr()); - } else if (opCode == BO_DivAssign) { - auto RDelayed = DelayedGlobalStoreAndRef(R); - StmtDiff RResult = RDelayed.Result; - Expr* RStored = StoreAndRef(RResult.getExpr_dx(), direction::reverse); - addToCurrentBlock(BuildOp(BO_AddAssign, AssignedDiff, - BuildOp(BO_Div, oldValue, RStored)), - direction::reverse); - Expr* LRef = LCloned; - if (!RDelayed.isConstant) { - StmtDiff LResult; - if (LCloned->HasSideEffects(m_Context)) { - QualType RefType = m_Context.getLValueReferenceType( - getNonConstType(L->getType(), m_Context, m_Sema)); - LRef = StoreAndRef(LCloned, RefType, direction::forward, "_ref", - /*forceDeclCreation=*/true); - LResult = GlobalStoreAndRef(LRef, "_t", /*force=*/true); - } else - LResult = {LRef, LRef}; - if (isInsideLoop) - addToCurrentBlock(LResult.getExpr(), direction::forward); - Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored)); - Expr* dr = BuildOp( - BO_Mul, oldValue, - BuildOp(UO_Minus, BuildOp(BO_Div, LResult.getExpr_dx(), RxR))); - dr = StoreAndRef(dr, direction::reverse); - Rdiff = Visit(R, dr); - RDelayed.Finalize(Rdiff.getExpr()); + // If `VD` is a reference to a local variable, then it is already + // differentiated and should not be differentiated again. + // If `VD` is a reference to a non-local variable then also there's no + // need to call `Visit` since non-local variables are not differentiated. + if (!isDerivativeOfRefType) { + Expr* derivedE = BuildDeclRef(VDDerived); + initDiff = StmtDiff{}; + if (VD->getInit()) { + if (isa(VD->getInit())) + initDiff = Visit(VD->getInit()); + else + initDiff = Visit(VD->getInit(), derivedE); } - std::tie(Ldiff, Rdiff) = std::make_pair(LRef, RResult.getExpr()); - } else - llvm_unreachable("unknown assignment opCode"); - if (m_ExternalSource) - m_ExternalSource->ActBeforeFinalisingAssignOp(LCloned, oldValue); - - // Update the derivative. - addToCurrentBlock(BuildOp(BO_SubAssign, AssignedDiff, oldValue), - direction::reverse); - // Output statements from Visit(L). - for (auto it = Lblock_begin; it != Lblock_end; ++it) - addToCurrentBlock(*it, direction::reverse); - } else if (opCode == BO_Comma) { - auto zero = - ConstantFolder::synthesizeLiteral(m_Context.IntTy, m_Context, 0); - Ldiff = Visit(L, zero); - Rdiff = Visit(R, dfdx()); - ResultRef = Ldiff.getExpr(); - } else { - // We should not output any warning on visiting boolean conditions - // FIXME: We should support boolean differentiation or ignore it - // completely - if (!BinOp->isComparisonOp() && !BinOp->isLogicalOp()) - unsupportedOpWarn(BinOp->getEndLoc()); - - // If either LHS or RHS is a declaration reference, visit it to avoid - // naming collision - auto LDRE = dyn_cast(L); - auto RDRE = dyn_cast(R); - - if (!LDRE && !RDRE) - return Clone(BinOp); - - Expr* LExpr = LDRE ? Visit(L).getExpr() : L; - Expr* RExpr = RDRE ? Visit(R).getExpr() : R; - - return BuildOp(opCode, LExpr, RExpr); - } - Expr* op = BuildOp(opCode, Ldiff.getExpr(), Rdiff.getExpr()); - return StmtDiff(op, ResultRef); -} - -VarDeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { - StmtDiff initDiff; - Expr* VDDerivedInit = nullptr; - auto VDDerivedType = ComputeAdjointType(VD->getType()); - bool isDerivativeOfRefType = VD->getType()->isReferenceType(); - VarDecl* VDDerived = nullptr; - - // VDDerivedInit now serves two purposes -- as the initial derivative value - // or the size of the derivative array -- depending on the primal type. - if (auto AT = dyn_cast(VD->getType())) { - Expr* init = getArraySizeExpr(AT, m_Context, *this); - VDDerivedInit = init; - VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), - VDDerivedInit, false, nullptr, - clang::VarDecl::InitializationStyle::CallInit); - } else { - // If VD is a reference to a local variable, then the initial value is set - // to the derived variable of the corresponding local variable. - // If VD is a reference to a non-local variable (global variable, struct - // member etc), then no derived variable is available, thus `VDDerived` - // does not need to reference any variable, consequentially the - // `VDDerivedType` is the corresponding non-reference type and the initial - // value is set to 0. - // Otherwise, for non-reference types, the initial value is set to 0. - VDDerivedInit = getZeroInit(VD->getType()); - - // `specialThisDiffCase` is only required for correctly differentiating - // the following code: - // ``` - // Class _d_this_obj; - // Class* _d_this = &_d_this_obj; - // ``` - // Computation of hessian requires this code to be correctly - // differentiated. - bool specialThisDiffCase = false; - if (auto MD = dyn_cast(m_Function)) { - if (VDDerivedType->isPointerType() && MD->isInstance()) - specialThisDiffCase = true; - } - if (isDerivativeOfRefType) { - initDiff = Visit(VD->getInit()); - if (!initDiff.getForwSweepExpr_dx()) { - VDDerivedType = ComputeAdjointType(VD->getType().getNonReferenceType()); - isDerivativeOfRefType = false; + // If we are differentiating `VarDecl` corresponding to a local variable + // inside a loop, then we need to reset it to 0 at each iteration. + // + // for example, if defined inside a loop, + // ``` + // double localVar = i; + // ``` + // this statement should get differentiated to, + // ``` + // { + // *_d_i += _d_localVar; + // _d_localVar = 0; + // } + if (isInsideLoop) { + Stmt* assignToZero = BuildOp(BinaryOperatorKind::BO_Assign, + BuildDeclRef(VDDerived), + getZeroInit(VDDerivedType)); + addToCurrentBlock(assignToZero, direction::reverse); } - VDDerivedInit = getZeroInit(VDDerivedType); - } - - // FIXME: Remove the special cases introduced by `specialThisDiffCase` - // once reverse mode supports pointers. `specialThisDiffCase` is only - // required for correctly differentiating the following code: - // ``` - // Class _d_this_obj; - // Class* _d_this = &_d_this_obj; - // ``` - // Computation of hessian requires this code to be correctly - // differentiated. - if (specialThisDiffCase && VD->getNameAsString() == "_d_this") { - VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); - initDiff = Visit(VD->getInit()); - if (initDiff.getExpr_dx()) - VDDerivedInit = initDiff.getExpr_dx(); } + VarDecl* VDClone = nullptr; // Here separate behaviour for record and non-record types is only // necessary to preserve the old tests. - if (VDDerivedType->isRecordType()) - VDDerived = BuildVarDecl( - VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, - VD->isDirectInit(), m_Context.getTrivialTypeSourceInfo(VDDerivedType), - VD->getInitStyle()); - else - VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), - VDDerivedInit); - } - - // If `VD` is a reference to a local variable, then it is already - // differentiated and should not be differentiated again. - // If `VD` is a reference to a non-local variable then also there's no - // need to call `Visit` since non-local variables are not differentiated. - if (!isDerivativeOfRefType) { - Expr* derivedE = BuildDeclRef(VDDerived); - initDiff = StmtDiff{}; - if (VD->getInit()) { - if (isa(VD->getInit())) - initDiff = Visit(VD->getInit()); - else - initDiff = Visit(VD->getInit(), derivedE); + if (VD->getType()->isRecordType()) { + VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(), + initDiff.getExpr(), VD->isDirectInit(), + VD->getTypeSourceInfo(), VD->getInitStyle()); + } else { + VDClone = BuildVarDecl(CloneType(VD->getType()), VD->getNameAsString(), + initDiff.getExpr(), VD->isDirectInit()); } + Expr* derivedVDE = BuildDeclRef(VDDerived); - // If we are differentiating `VarDecl` corresponding to a local variable - // inside a loop, then we need to reset it to 0 at each iteration. - // - // for example, if defined inside a loop, - // ``` - // double localVar = i; - // ``` - // this statement should get differentiated to, - // ``` - // { - // *_d_i += _d_localVar; - // _d_localVar = 0; - // } - if (isInsideLoop) { - Stmt* assignToZero = - BuildOp(BinaryOperatorKind::BO_Assign, BuildDeclRef(VDDerived), - getZeroInit(VDDerivedType)); - addToCurrentBlock(assignToZero, direction::reverse); - } - } - VarDecl* VDClone = nullptr; - // Here separate behaviour for record and non-record types is only - // necessary to preserve the old tests. - if (VD->getType()->isRecordType()) { - VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(), - initDiff.getExpr(), VD->isDirectInit(), - VD->getTypeSourceInfo(), VD->getInitStyle()); - } else { - VDClone = BuildVarDecl(CloneType(VD->getType()), VD->getNameAsString(), - initDiff.getExpr(), VD->isDirectInit()); - } - Expr* derivedVDE = BuildDeclRef(VDDerived); - - // FIXME: Add extra parantheses if derived variable pointer is pointing to a - // class type object. - if (isDerivativeOfRefType) { - Expr* assignDerivativeE = BuildOp( - BinaryOperatorKind::BO_Assign, derivedVDE, - BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getForwSweepExpr_dx())); - addToCurrentBlock(assignDerivativeE); - if (isInsideLoop) { - auto tape = MakeCladTapeFor(derivedVDE); - addToCurrentBlock(tape.Push); - auto reverseSweepDerivativePointerE = - BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop); - m_LoopBlock.back().push_back( - BuildDeclStmt(reverseSweepDerivativePointerE)); - auto revSweepDerPointerRef = BuildDeclRef(reverseSweepDerivativePointerE); - derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, revSweepDerPointerRef); - } else { - derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE); + // FIXME: Add extra parantheses if derived variable pointer is pointing to a + // class type object. + if (isDerivativeOfRefType) { + Expr* assignDerivativeE = + BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, + BuildOp(UnaryOperatorKind::UO_AddrOf, + initDiff.getForwSweepExpr_dx())); + addToCurrentBlock(assignDerivativeE); + if (isInsideLoop) { + auto tape = MakeCladTapeFor(derivedVDE); + addToCurrentBlock(tape.Push); + auto reverseSweepDerivativePointerE = + BuildVarDecl(derivedVDE->getType(), "_t", tape.Pop); + m_LoopBlock.back().push_back( + BuildDeclStmt(reverseSweepDerivativePointerE)); + auto revSweepDerPointerRef = + BuildDeclRef(reverseSweepDerivativePointerE); + derivedVDE = + BuildOp(UnaryOperatorKind::UO_Deref, revSweepDerPointerRef); + } else { + derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE); + } } - } - m_Variables.emplace(VDClone, derivedVDE); + m_Variables.emplace(VDClone, derivedVDE); - return VarDeclDiff(VDClone, VDDerived); -} + return VarDeclDiff(VDClone, VDDerived); + } -// TODO: 'shouldEmit' parameter should be removed after converting -// Error estimation framework to callback style. Some more research -// need to be done to -StmtDiff ReverseModeVisitor::DifferentiateSingleStmt(const Stmt* S, - Expr* dfdS) { - if (m_ExternalSource) - m_ExternalSource->ActOnStartOfDifferentiateSingleStmt(); - beginBlock(direction::reverse); - StmtDiff SDiff = Visit(S, dfdS); - - if (m_ExternalSource) - m_ExternalSource->ActBeforeFinalizingDifferentiateSingleStmt( - direction::reverse); - - addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse); - CompoundStmt* RCS = endBlock(direction::reverse); - std::reverse(RCS->body_begin(), RCS->body_end()); - Stmt* ReverseResult = unwrapIfSingleStmt(RCS); - return StmtDiff(SDiff.getStmt(), ReverseResult); -} + // TODO: 'shouldEmit' parameter should be removed after converting + // Error estimation framework to callback style. Some more research + // need to be done to + StmtDiff + ReverseModeVisitor::DifferentiateSingleStmt(const Stmt* S, Expr* dfdS) { + if (m_ExternalSource) + m_ExternalSource->ActOnStartOfDifferentiateSingleStmt(); + beginBlock(direction::reverse); + StmtDiff SDiff = Visit(S, dfdS); -std::pair -ReverseModeVisitor::DifferentiateSingleExpr(const Expr* E, Expr* dfdE) { - beginBlock(direction::forward); - beginBlock(direction::reverse); - StmtDiff EDiff = Visit(E, dfdE); - if (m_ExternalSource) - m_ExternalSource->ActBeforeFinalizingDifferentiateSingleExpr( - direction::reverse); - CompoundStmt* RCS = endBlock(direction::reverse); - Stmt* ForwardResult = endBlock(direction::forward); - std::reverse(RCS->body_begin(), RCS->body_end()); - Stmt* ReverseResult = unwrapIfSingleStmt(RCS); - return {StmtDiff(ForwardResult, ReverseResult), EDiff}; -} + if (m_ExternalSource) + m_ExternalSource->ActBeforeFinalizingDifferentiateSingleStmt(direction::reverse); -StmtDiff ReverseModeVisitor::VisitDeclStmt(const DeclStmt* DS) { - llvm::SmallVector decls; - llvm::SmallVector declsDiff; - // Need to put array decls inlined. - llvm::SmallVector localDeclsDiff; - // For each variable declaration v, create another declaration _d_v to - // store derivatives for potential reassignments. E.g. - // double y = x; - // -> - // double _d_y = _d_x; double y = x; - for (auto D : DS->decls()) { - if (auto VD = dyn_cast(D)) { - VarDeclDiff VDDiff = DifferentiateVarDecl(VD); - - // Check if decl's name is the same as before. The name may be changed - // if decl name collides with something in the derivative body. - // This can happen in rare cases, e.g. when the original function - // has both y and _d_y (here _d_y collides with the name produced by - // the derivation process), e.g. - // double f(double x) { - // double y = x; - // double _d_y = x; - // } - // -> - // double f_darg0(double x) { - // double _d_x = 1; - // double _d_y = _d_x; // produced as a derivative for y - // double y = x; - // double _d__d_y = _d_x; - // double _d_y = x; // copied from original funcion, collides with - // _d_y - // } - if (VDDiff.getDecl()->getDeclName() != VD->getDeclName()) - m_DeclReplacements[VD] = VDDiff.getDecl(); - decls.push_back(VDDiff.getDecl()); - if (isa(VD->getType())) - localDeclsDiff.push_back(VDDiff.getDecl_dx()); - else - declsDiff.push_back(VDDiff.getDecl_dx()); - } else { - diag(DiagnosticsEngine::Warning, D->getEndLoc(), - "Unsupported declaration"); - } + addToCurrentBlock(SDiff.getStmt_dx(), direction::reverse); + CompoundStmt* RCS = endBlock(direction::reverse); + std::reverse(RCS->body_begin(), RCS->body_end()); + Stmt* ReverseResult = unwrapIfSingleStmt(RCS); + return StmtDiff(SDiff.getStmt(), ReverseResult); } - Stmt* DSClone = BuildDeclStmt(decls); - - if (!localDeclsDiff.empty()) { - Stmt* localDSDIff = BuildDeclStmt(localDeclsDiff); - addToCurrentBlock( - localDSDIff, - clad::rmv::forward); // Doesnt work for arrays decl'd in loops. - } - if (!declsDiff.empty()) { - Stmt* DSDiff = BuildDeclStmt(declsDiff); - addToBlock(DSDiff, m_Globals); - } + std::pair + ReverseModeVisitor::DifferentiateSingleExpr(const Expr* E, Expr* dfdE) { + beginBlock(direction::forward); + beginBlock(direction::reverse); + StmtDiff EDiff = Visit(E, dfdE); + if (m_ExternalSource) + m_ExternalSource->ActBeforeFinalizingDifferentiateSingleExpr(direction::reverse); + CompoundStmt* RCS = endBlock(direction::reverse); + Stmt* ForwardResult = endBlock(direction::forward); + std::reverse(RCS->body_begin(), RCS->body_end()); + Stmt* ReverseResult = unwrapIfSingleStmt(RCS); + return {StmtDiff(ForwardResult, ReverseResult), EDiff}; + } + + StmtDiff ReverseModeVisitor::VisitDeclStmt(const DeclStmt* DS) { + llvm::SmallVector decls; + llvm::SmallVector declsDiff; + // Need to put array decls inlined. + llvm::SmallVector localDeclsDiff; + // For each variable declaration v, create another declaration _d_v to + // store derivatives for potential reassignments. E.g. + // double y = x; + // -> + // double _d_y = _d_x; double y = x; + for (auto D : DS->decls()) { + if (auto VD = dyn_cast(D)) { + VarDeclDiff VDDiff = DifferentiateVarDecl(VD); + + // Check if decl's name is the same as before. The name may be changed + // if decl name collides with something in the derivative body. + // This can happen in rare cases, e.g. when the original function + // has both y and _d_y (here _d_y collides with the name produced by + // the derivation process), e.g. + // double f(double x) { + // double y = x; + // double _d_y = x; + // } + // -> + // double f_darg0(double x) { + // double _d_x = 1; + // double _d_y = _d_x; // produced as a derivative for y + // double y = x; + // double _d__d_y = _d_x; + // double _d_y = x; // copied from original funcion, collides with + // _d_y + // } + if (VDDiff.getDecl()->getDeclName() != VD->getDeclName()) + m_DeclReplacements[VD] = VDDiff.getDecl(); + decls.push_back(VDDiff.getDecl()); + if (isa(VD->getType())) + localDeclsDiff.push_back(VDDiff.getDecl_dx()); + else + declsDiff.push_back(VDDiff.getDecl_dx()); + } else { + diag(DiagnosticsEngine::Warning, + D->getEndLoc(), + "Unsupported declaration"); + } + } - if (m_ExternalSource) { - declsDiff.append(localDeclsDiff.begin(), localDeclsDiff.end()); - m_ExternalSource->ActBeforeFinalizingVisitDeclStmt(decls, declsDiff); - } - return StmtDiff(DSClone); -} + Stmt* DSClone = BuildDeclStmt(decls); -StmtDiff -ReverseModeVisitor::VisitImplicitCastExpr(const ImplicitCastExpr* ICE) { - StmtDiff subExprDiff = Visit(ICE->getSubExpr(), dfdx()); - // Casts should be handled automatically when the result is used by - // Sema::ActOn.../Build... - return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx(), - subExprDiff.getForwSweepStmt_dx()); -} + if (!localDeclsDiff.empty()) { + Stmt* localDSDIff = BuildDeclStmt(localDeclsDiff); + addToCurrentBlock( + localDSDIff, + clad::rmv::forward); // Doesnt work for arrays decl'd in loops. + } + if (!declsDiff.empty()) { + Stmt* DSDiff = BuildDeclStmt(declsDiff); + addToBlock(DSDiff, m_Globals); + } -StmtDiff ReverseModeVisitor::VisitMemberExpr(const MemberExpr* ME) { - auto baseDiff = VisitWithExplicitNoDfDx(ME->getBase()); - auto field = ME->getMemberDecl(); - assert(!isa(field) && - "CXXMethodDecl nodes not supported yet!"); - MemberExpr* clonedME = utils::BuildMemberExpr( - m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); - if (!baseDiff.getExpr_dx()) - return {clonedME, nullptr}; - MemberExpr* derivedME = utils::BuildMemberExpr( - m_Sema, getCurrentScope(), baseDiff.getExpr_dx(), field->getName()); - if (dfdx()) { - Expr* addAssign = - BuildOp(BinaryOperatorKind::BO_AddAssign, derivedME, dfdx()); - addToCurrentBlock(addAssign, direction::reverse); - } - return {clonedME, derivedME, derivedME}; -} + if (m_ExternalSource) { + declsDiff.append(localDeclsDiff.begin(), localDeclsDiff.end()); + m_ExternalSource->ActBeforeFinalizingVisitDeclStmt(decls, declsDiff); + } + return StmtDiff(DSClone); + } + + StmtDiff + ReverseModeVisitor::VisitImplicitCastExpr(const ImplicitCastExpr* ICE) { + StmtDiff subExprDiff = Visit(ICE->getSubExpr(), dfdx()); + // Casts should be handled automatically when the result is used by + // Sema::ActOn.../Build... + return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx(), + subExprDiff.getForwSweepStmt_dx()); + } + + StmtDiff ReverseModeVisitor::VisitMemberExpr(const MemberExpr* ME) { + auto baseDiff = VisitWithExplicitNoDfDx(ME->getBase()); + auto field = ME->getMemberDecl(); + assert(!isa(field) && + "CXXMethodDecl nodes not supported yet!"); + MemberExpr* clonedME = utils::BuildMemberExpr( + m_Sema, getCurrentScope(), baseDiff.getExpr(), field->getName()); + if (!baseDiff.getExpr_dx()) + return {clonedME, nullptr}; + MemberExpr* derivedME = utils::BuildMemberExpr( + m_Sema, getCurrentScope(), baseDiff.getExpr_dx(), field->getName()); + if (dfdx()) { + Expr* addAssign = + BuildOp(BinaryOperatorKind::BO_AddAssign, derivedME, dfdx()); + addToCurrentBlock(addAssign, direction::reverse); + } + return {clonedME, derivedME, derivedME}; + } + + StmtDiff + ReverseModeVisitor::VisitExprWithCleanups(const ExprWithCleanups* EWC) { + StmtDiff subExprDiff = Visit(EWC->getSubExpr(), dfdx()); + // FIXME: We are unable to create cleanup objects currently, this can be + // potentially problematic + return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx()); + } + + bool ReverseModeVisitor::UsefulToStoreGlobal(Expr* E) { + if (!E) + return false; + // Use stricter policy when inside loops: IsEvaluatable is also true for + // arithmetical expressions consisting of constants, e.g. (1 + 2)*3. This + // chech is more expensive, but it doesn't make sense to push such constants + // into stack. + if (isInsideLoop && E->isEvaluatable(m_Context, Expr::SE_NoSideEffects)) + return false; + Expr* B = E->IgnoreParenImpCasts(); + // FIXME: find a more general way to determine that or add more options. + if (isa(B) || isa(B)) + return false; + if (isa(B)) { + auto UO = cast(B); + auto OpKind = UO->getOpcode(); + if (OpKind == UO_Plus || OpKind == UO_Minus) + return UsefulToStoreGlobal(UO->getSubExpr()); + return true; + } + // We lack context to decide if this is useful to store or not. In the + // current system that should have been decided by the parent expression. + // FIXME: Here will be the entry point of the advanced activity analysis. + if (isa(B) || isa(B)) { + // auto line = + // m_Context.getSourceManager().getPresumedLoc(B->getBeginLoc()).getLine(); + // auto column = + // m_Context.getSourceManager().getPresumedLoc(B->getBeginLoc()).getColumn(); + // llvm::errs() << line << "|" <getBeginLoc()); + if (it == m_ToBeRecorded.end()) { + return true; + } + return it->second; + } -StmtDiff -ReverseModeVisitor::VisitExprWithCleanups(const ExprWithCleanups* EWC) { - StmtDiff subExprDiff = Visit(EWC->getSubExpr(), dfdx()); - // FIXME: We are unable to create cleanup objects currently, this can be - // potentially problematic - return StmtDiff(subExprDiff.getExpr(), subExprDiff.getExpr_dx()); -} + // FIXME: Attach checkpointing. + if (isa(B)) + return false; -bool ReverseModeVisitor::UsefulToStoreGlobal(Expr* E) { - if (!E) - return false; - // Use stricter policy when inside loops: IsEvaluatable is also true for - // arithmetical expressions consisting of constants, e.g. (1 + 2)*3. This - // chech is more expensive, but it doesn't make sense to push such constants - // into stack. - if (isInsideLoop && E->isEvaluatable(m_Context, Expr::SE_NoSideEffects)) - return false; - Expr* B = E->IgnoreParenImpCasts(); - // FIXME: find a more general way to determine that or add more options. - if (isa(B) || isa(B)) - return false; - if (isa(B)) { - auto UO = cast(B); - auto OpKind = UO->getOpcode(); - if (OpKind == UO_Plus || OpKind == UO_Minus) - return UsefulToStoreGlobal(UO->getSubExpr()); return true; } - // We lack context to decide if this is useful to store or not. In the - // current system that should have been decided by the parent expression. - // FIXME: Here will be the entry point of the advanced activity analysis. - if (isa(B) || isa(B)) { - auto found = m_ToBeRecorded.find(B->getBeginLoc()); - return found != m_ToBeRecorded.end(); - } - - // FIXME: Attach checkpointing. - if (isa(B)) - return false; - return true; -} + VarDecl* ReverseModeVisitor::GlobalStoreImpl(QualType Type, + llvm::StringRef prefix, + Expr* init) { + // Create identifier before going to topmost scope + // to let Sema::LookupName see the whole scope. + auto identifier = CreateUniqueIdentifier(prefix); + // Save current scope and temporarily go to topmost function scope. + llvm::SaveAndRestore SaveScope(m_CurScope); + assert(m_DerivativeFnScope && "must be set"); + m_CurScope = m_DerivativeFnScope; + + VarDecl* Var = nullptr; + if (isa(Type)) { + Type = GetCladArrayOfType(m_Context.getBaseElementType(Type)); + Var = BuildVarDecl(Type, identifier, init, false, nullptr, + clang::VarDecl::InitializationStyle::CallInit); + } else { + Var = BuildVarDecl(Type, identifier, init, false, nullptr, + VarDecl::InitializationStyle::CInit); + } -VarDecl* ReverseModeVisitor::GlobalStoreImpl(QualType Type, - llvm::StringRef prefix, - Expr* init) { - // Create identifier before going to topmost scope - // to let Sema::LookupName see the whole scope. - auto identifier = CreateUniqueIdentifier(prefix); - // Save current scope and temporarily go to topmost function scope. - llvm::SaveAndRestore SaveScope(m_CurScope); - assert(m_DerivativeFnScope && "must be set"); - m_CurScope = m_DerivativeFnScope; - - VarDecl* Var = nullptr; - if (isa(Type)) { - Type = GetCladArrayOfType(m_Context.getBaseElementType(Type)); - Var = BuildVarDecl(Type, identifier, init, false, nullptr, - clang::VarDecl::InitializationStyle::CallInit); - } else { - Var = BuildVarDecl(Type, identifier, init, false, nullptr, - VarDecl::InitializationStyle::CInit); + // Add the declaration to the body of the gradient function. + addToBlock(BuildDeclStmt(Var), m_Globals); + return Var; } - // Add the declaration to the body of the gradient function. - addToBlock(BuildDeclStmt(Var), m_Globals); - return Var; -} + StmtDiff ReverseModeVisitor::GlobalStoreAndRef(Expr* E, + QualType Type, + llvm::StringRef prefix, + bool force) { + assert(E && "must be provided, otherwise use DelayedGlobalStoreAndRef"); + if (!force && !UsefulToStoreGlobal(E)) + return {E, E}; -StmtDiff ReverseModeVisitor::GlobalStoreAndRef(Expr* E, QualType Type, - llvm::StringRef prefix, - bool force) { - assert(E && "must be provided, otherwise use DelayedGlobalStoreAndRef"); - if (!force && !UsefulToStoreGlobal(E)) - return {E, E}; - - if (isInsideLoop) { - auto CladTape = MakeCladTapeFor(E); - Expr* Push = CladTape.Push; - Expr* Pop = CladTape.Pop; - return {Push, Pop}; - } + if (isInsideLoop) { + auto CladTape = MakeCladTapeFor(E); + Expr* Push = CladTape.Push; + Expr* Pop = CladTape.Pop; + return {Push, Pop}; + } - Expr* init = nullptr; - if (auto AT = dyn_cast(Type)) - init = getArraySizeExpr(AT, m_Context, *this); + Expr* init = nullptr; + if (auto AT = dyn_cast(Type)) { + init = getArraySizeExpr(AT, m_Context, *this); + } - Expr* Ref = BuildDeclRef(GlobalStoreImpl(Type, prefix, init)); - if (E) { - Expr* Set = BuildOp(BO_Assign, Ref, E); - addToCurrentBlock(Set, direction::forward); + Expr* Ref = BuildDeclRef(GlobalStoreImpl(Type, prefix, init)); + if (E) { + Expr* Set = BuildOp(BO_Assign, Ref, E); + addToCurrentBlock(Set, direction::forward); + } + return {Ref, Ref}; } - return {Ref, Ref}; -} - -StmtDiff ReverseModeVisitor::GlobalStoreAndRef(Expr* E, llvm::StringRef prefix, - bool force) { - assert(E && "cannot infer type"); - return GlobalStoreAndRef(E, getNonConstType(E->getType(), m_Context, m_Sema), - prefix, force); -} -void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) { - if (isConstant || !needsUpdate) - return; - if (isInsideLoop) { - auto Push = cast(Result.getExpr()); - unsigned lastArg = Push->getNumArgs() - 1; - Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New).get()); - } else { - V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New), - direction::forward); + StmtDiff ReverseModeVisitor::GlobalStoreAndRef(Expr* E, + llvm::StringRef prefix, + bool force) { + assert(E && "cannot infer type"); + return GlobalStoreAndRef( + E, getNonConstType(E->getType(), m_Context, m_Sema), prefix, force); } -} -ReverseModeVisitor::DelayedStoreResult -ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E, llvm::StringRef prefix) { - assert(E && "must be provided"); - if (isa(E) /*!UsefulToStoreGlobal(E)*/) { - Expr* Cloned = Clone(E); - Expr::EvalResult evalRes; - bool isConst = - clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context); - return DelayedStoreResult{*this, StmtDiff{Cloned, Cloned}, - /*isConstant*/ isConst, - /*isInsideLoop*/ false, - /*needsUpdate=*/false}; - } - if (isInsideLoop) { - Expr* dummy = E; - auto CladTape = MakeCladTapeFor(dummy); - Expr* Push = CladTape.Push; - Expr* Pop = CladTape.Pop; - return DelayedStoreResult{*this, StmtDiff{Push, Pop}, - /*isConstant*/ false, - /*isInsideLoop*/ true, /*needsUpdate=*/true}; - } else { - Expr* Ref = BuildDeclRef(GlobalStoreImpl( - getNonConstType(E->getType(), m_Context, m_Sema), prefix)); - // Return reference to the declaration instead of original expression. - return DelayedStoreResult{*this, StmtDiff{Ref, Ref}, - /*isConstant*/ false, - /*isInsideLoop*/ false, /*needsUpdate=*/true}; + void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) { + if (isConstant || !needsUpdate) + return; + if (isInsideLoop) { + auto Push = cast(Result.getExpr()); + unsigned lastArg = Push->getNumArgs() - 1; + Push->setArg(lastArg, V.m_Sema.DefaultLvalueConversion(New).get()); + } else { + V.addToCurrentBlock(V.BuildOp(BO_Assign, Result.getExpr(), New), + direction::forward); + } } -} -ReverseModeVisitor::LoopCounter::LoopCounter(ReverseModeVisitor& RMV) - : m_RMV(RMV) { - ASTContext& C = m_RMV.m_Context; - if (RMV.isInsideLoop) { - auto zero = ConstantFolder::synthesizeLiteral(C.getSizeType(), C, - /*val=*/0); - auto counterTape = m_RMV.MakeCladTapeFor(zero); - m_Ref = counterTape.Last(); - m_Pop = counterTape.Pop; - m_Push = counterTape.Push; - } else { - m_Ref = m_RMV - .GlobalStoreAndRef(m_RMV.getZeroInit(C.IntTy), C.getSizeType(), - "_t", - /*force=*/true) - .getExpr(); + ReverseModeVisitor::DelayedStoreResult + ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E, + llvm::StringRef prefix) { + assert(E && "must be provided"); + if (isa(E) /*!UsefulToStoreGlobal(E)*/) { + Expr* Cloned = Clone(E); + Expr::EvalResult evalRes; + bool isConst = + clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context); + return DelayedStoreResult{*this, StmtDiff{Cloned, Cloned}, + /*isConstant*/ isConst, + /*isInsideLoop*/ false, + /*needsUpdate=*/ false}; + } + if (isInsideLoop) { + Expr* dummy = E; + auto CladTape = MakeCladTapeFor(dummy); + Expr* Push = CladTape.Push; + Expr* Pop = CladTape.Pop; + return DelayedStoreResult{*this, + StmtDiff{Push, Pop}, + /*isConstant*/ false, + /*isInsideLoop*/ true, /*needsUpdate=*/ true}; + } else { + Expr* Ref = BuildDeclRef(GlobalStoreImpl( + getNonConstType(E->getType(), m_Context, m_Sema), prefix)); + // Return reference to the declaration instead of original expression. + return DelayedStoreResult{*this, + StmtDiff{Ref, Ref}, + /*isConstant*/ false, + /*isInsideLoop*/ false, /*needsUpdate=*/ true}; + } } -} -StmtDiff ReverseModeVisitor::VisitWhileStmt(const WhileStmt* WS) { - LoopCounter loopCounter(*this); - if (loopCounter.getPush()) - addToCurrentBlock(loopCounter.getPush()); - - // begin scope for while statement - beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | - Scope::ContinueScope); - - llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); - isInsideLoop = true; - - Expr* condClone = (WS->getCond() ? Clone(WS->getCond()) : nullptr); - const VarDecl* condVarDecl = WS->getConditionVariable(); - StmtDiff condVarRes; - if (condVarDecl) - condVarRes = DifferentiateSingleStmt(WS->getConditionVariableDeclStmt()); - - // compute condition result object for the forward pass `while` - // statement. - Sema::ConditionResult condResult; - if (condVarDecl) { - Decl* condVarClone = cast(condVarRes.getStmt())->getSingleDecl(); - condResult = m_Sema.ActOnConditionVariable(condVarClone, noLoc, - Sema::ConditionKind::Boolean); - } else { - condResult = m_Sema.ActOnCondition(getCurrentScope(), noLoc, condClone, - Sema::ConditionKind::Boolean); + ReverseModeVisitor::LoopCounter::LoopCounter(ReverseModeVisitor& RMV) + : m_RMV(RMV) { + ASTContext& C = m_RMV.m_Context; + if (RMV.isInsideLoop) { + auto zero = ConstantFolder::synthesizeLiteral(C.getSizeType(), C, + /*val=*/0); + auto counterTape = m_RMV.MakeCladTapeFor(zero); + m_Ref = counterTape.Last(); + m_Pop = counterTape.Pop; + m_Push = counterTape.Push; + } else { + m_Ref = m_RMV + .GlobalStoreAndRef(m_RMV.getZeroInit(C.IntTy), + C.getSizeType(), "_t", + /*force=*/true) + .getExpr(); + } } - const Stmt* body = WS->getBody(); - StmtDiff bodyDiff = - DifferentiateLoopBody(body, loopCounter, condVarRes.getStmt_dx()); - // Create forward-pass `while` loop. - Stmt* forwardWS = - clad_compat::Sema_ActOnWhileStmt(m_Sema, condResult, bodyDiff.getStmt()) - .get(); - - // Create reverse-pass `while` loop. - Sema::ConditionResult CounterCondition = - loopCounter.getCounterConditionResult(); - Stmt* reverseWS = clad_compat::Sema_ActOnWhileStmt(m_Sema, CounterCondition, - bodyDiff.getStmt_dx()) - .get(); - // for while statement - endScope(); - Stmt* reverseBlock = reverseWS; - // If loop counter have to be popped then create a compound statement - // enclosing the reverse pass while statement and loop counter pop - // expression. - // - // Therefore, reverse pass code will look like this: - // { - // while (_t) { - // - // } - // clad::pop(_t); - // } - if (loopCounter.getPop()) { - beginBlock(direction::reverse); - addToCurrentBlock(loopCounter.getPop(), direction::reverse); - addToCurrentBlock(reverseWS, direction::reverse); - reverseBlock = endBlock(direction::reverse); - } - return {forwardWS, reverseBlock}; -} + StmtDiff ReverseModeVisitor::VisitWhileStmt(const WhileStmt* WS) { + LoopCounter loopCounter(*this); + if (loopCounter.getPush()) + addToCurrentBlock(loopCounter.getPush()); + + // begin scope for while statement + beginScope(Scope::DeclScope | Scope::ControlScope | Scope::BreakScope | + Scope::ContinueScope); + + llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); + isInsideLoop = true; + + Expr* condClone = (WS->getCond() ? Clone(WS->getCond()) : nullptr); + const VarDecl* condVarDecl = WS->getConditionVariable(); + StmtDiff condVarRes; + if (condVarDecl) + condVarRes = DifferentiateSingleStmt(WS->getConditionVariableDeclStmt()); + + // compute condition result object for the forward pass `while` + // statement. + Sema::ConditionResult condResult; + if (condVarDecl) { + Decl* condVarClone = cast(condVarRes.getStmt()) + ->getSingleDecl(); + condResult = m_Sema.ActOnConditionVariable(condVarClone, noLoc, + Sema::ConditionKind::Boolean); + } else { + condResult = m_Sema.ActOnCondition(getCurrentScope(), noLoc, condClone, + Sema::ConditionKind::Boolean); + } -StmtDiff ReverseModeVisitor::VisitDoStmt(const DoStmt* DS) { - LoopCounter loopCounter(*this); - if (loopCounter.getPush()) - addToCurrentBlock(loopCounter.getPush()); + const Stmt* body = WS->getBody(); + StmtDiff bodyDiff = DifferentiateLoopBody(body, loopCounter, + condVarRes.getStmt_dx()); + // Create forward-pass `while` loop. + Stmt* forwardWS = clad_compat::Sema_ActOnWhileStmt(m_Sema, condResult, + bodyDiff.getStmt()) + .get(); - // begin scope for do statement - beginScope(Scope::ContinueScope | Scope::BreakScope); + // Create reverse-pass `while` loop. + Sema::ConditionResult CounterCondition = loopCounter + .getCounterConditionResult(); + Stmt* reverseWS = clad_compat::Sema_ActOnWhileStmt(m_Sema, CounterCondition, + bodyDiff.getStmt_dx()) + .get(); + // for while statement + endScope(); + Stmt* reverseBlock = reverseWS; + // If loop counter have to be popped then create a compound statement + // enclosing the reverse pass while statement and loop counter pop + // expression. + // + // Therefore, reverse pass code will look like this: + // { + // while (_t) { + // + // } + // clad::pop(_t); + // } + if (loopCounter.getPop()) { + beginBlock(direction::reverse); + addToCurrentBlock(loopCounter.getPop(), direction::reverse); + addToCurrentBlock(reverseWS, direction::reverse); + reverseBlock = endBlock(direction::reverse); + } + return {forwardWS, reverseBlock}; + } - llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); - isInsideLoop = true; + StmtDiff ReverseModeVisitor::VisitDoStmt(const DoStmt* DS) { + LoopCounter loopCounter(*this); + if (loopCounter.getPush()) + addToCurrentBlock(loopCounter.getPush()); - Expr* clonedCond = (DS->getCond() ? Clone(DS->getCond()) : nullptr); + // begin scope for do statement + beginScope(Scope::ContinueScope | Scope::BreakScope); - const Stmt* body = DS->getBody(); - StmtDiff bodyDiff = DifferentiateLoopBody(body, loopCounter); + llvm::SaveAndRestore SaveIsInsideLoop(isInsideLoop); + isInsideLoop = true; - // Create forward-pass `do-while` statement. - Stmt* forwardDS = m_Sema - .ActOnDoStmt(/*DoLoc=*/noLoc, bodyDiff.getStmt(), - /*WhileLoc=*/noLoc, - /*CondLParen=*/noLoc, clonedCond, - /*CondRParen=*/noLoc) - .get(); + Expr* clonedCond = (DS->getCond() ? Clone(DS->getCond()) : nullptr); - // create reverse-pass `do-while` statement. - Expr* counterCondition = loopCounter.getCounterConditionResult().get().second; - Stmt* reverseDS = m_Sema - .ActOnDoStmt(/*DoLoc=*/noLoc, bodyDiff.getStmt_dx(), - /*WhileLoc=*/noLoc, - /*CondLParen=*/noLoc, counterCondition, - /*RCondRParen=*/noLoc) - .get(); - // for do-while statement - endScope(); - Stmt* reverseBlock = reverseDS; - // If loop counter have to be popped then create a compound statement - // enclosing the reverse pass while statement and loop counter pop - // expression. - // - // Therefore, reverse pass code will look like this: - // { - // do { - // - // } while (_t); - // clad::pop(_t); - // } - if (loopCounter.getPop()) { - beginBlock(direction::reverse); - addToCurrentBlock(loopCounter.getPop(), direction::reverse); - addToCurrentBlock(reverseDS, direction::reverse); - reverseBlock = endBlock(direction::reverse); - } - return {forwardDS, reverseBlock}; -} + const Stmt* body = DS->getBody(); + StmtDiff bodyDiff = DifferentiateLoopBody(body, loopCounter); -StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body, - LoopCounter& loopCounter, - Stmt* condVarDiff, - Stmt* forLoopIncDiff, - bool isForLoop) { - Expr* counterIncrement = loopCounter.getCounterIncrement(); - auto activeBreakContHandler = PushBreakContStmtHandler(); - activeBreakContHandler->BeginCFSwitchStmtScope(); - m_LoopBlock.push_back({}); - // differentiate loop body and add loop increment expression - // in the forward block. - StmtDiff bodyDiff = nullptr; - if (isa(body)) { - bodyDiff = Visit(body); - beginBlock(direction::forward); - addToCurrentBlock(counterIncrement); - for (Stmt* S : cast(bodyDiff.getStmt())->body()) - addToCurrentBlock(S); - bodyDiff = {endBlock(direction::forward), bodyDiff.getStmt_dx()}; - } else { - // for forward-pass loop statement body - beginScope(Scope::DeclScope); - beginBlock(direction::forward); - addToCurrentBlock(counterIncrement); - if (m_ExternalSource) - m_ExternalSource->ActBeforeDifferentiatingSingleStmtLoopBody(); - bodyDiff = DifferentiateSingleStmt(body, /*dfdS=*/nullptr); - addToCurrentBlock(bodyDiff.getStmt()); - if (m_ExternalSource) - m_ExternalSource->ActAfterProcessingSingleStmtBodyInVisitForLoop(); + // Create forward-pass `do-while` statement. + Stmt* forwardDS = m_Sema + .ActOnDoStmt(/*DoLoc=*/noLoc, bodyDiff.getStmt(), + /*WhileLoc=*/noLoc, + /*CondLParen=*/noLoc, clonedCond, + /*CondRParen=*/noLoc) + .get(); - Stmt* reverseBlock = unwrapIfSingleStmt(bodyDiff.getStmt_dx()); - bodyDiff = {endBlock(direction::forward), reverseBlock}; - // for forward-pass loop statement body + // create reverse-pass `do-while` statement. + Expr* + counterCondition = loopCounter.getCounterConditionResult().get().second; + Stmt* reverseDS = m_Sema + .ActOnDoStmt(/*DoLoc=*/noLoc, bodyDiff.getStmt_dx(), + /*WhileLoc=*/noLoc, + /*CondLParen=*/noLoc, counterCondition, + /*RCondRParen=*/noLoc) + .get(); + // for do-while statement endScope(); - } - Stmts revLoopBlock = m_LoopBlock.back(); - utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx()); - if (!revLoopBlock.empty()) - bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock)); - m_LoopBlock.pop_back(); - - /// Increment statement in the for-loop is only executed if the iteration - /// did not end with a break/continue statement. Therefore, forLoopIncDiff - /// should be inside the last switch case in the reverse pass. - if (forLoopIncDiff) { - if (bodyDiff.getStmt_dx()) { - bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( - m_Context, bodyDiff.getStmt_dx(), forLoopIncDiff)); + Stmt* reverseBlock = reverseDS; + // If loop counter have to be popped then create a compound statement + // enclosing the reverse pass while statement and loop counter pop + // expression. + // + // Therefore, reverse pass code will look like this: + // { + // do { + // + // } while (_t); + // clad::pop(_t); + // } + if (loopCounter.getPop()) { + beginBlock(direction::reverse); + addToCurrentBlock(loopCounter.getPop(), direction::reverse); + addToCurrentBlock(reverseDS, direction::reverse); + reverseBlock = endBlock(direction::reverse); + } + return {forwardDS, reverseBlock}; + } + + StmtDiff ReverseModeVisitor::DifferentiateLoopBody(const Stmt* body, + LoopCounter& loopCounter, + Stmt* condVarDiff, + Stmt* forLoopIncDiff, + bool isForLoop) { + Expr* counterIncrement = loopCounter.getCounterIncrement(); + auto activeBreakContHandler = PushBreakContStmtHandler(); + activeBreakContHandler->BeginCFSwitchStmtScope(); + m_LoopBlock.push_back({}); + // differentiate loop body and add loop increment expression + // in the forward block. + StmtDiff bodyDiff = nullptr; + if (isa(body)) { + bodyDiff = Visit(body); + beginBlock(direction::forward); + addToCurrentBlock(counterIncrement); + for (Stmt* S : cast(bodyDiff.getStmt())->body()) + addToCurrentBlock(S); + bodyDiff = {endBlock(direction::forward), bodyDiff.getStmt_dx()}; } else { - bodyDiff.updateStmtDx(forLoopIncDiff); + // for forward-pass loop statement body + beginScope(Scope::DeclScope); + beginBlock(direction::forward); + addToCurrentBlock(counterIncrement); + if (m_ExternalSource) + m_ExternalSource->ActBeforeDifferentiatingSingleStmtLoopBody(); + bodyDiff = DifferentiateSingleStmt(body, /*dfdS=*/nullptr); + addToCurrentBlock(bodyDiff.getStmt()); + if (m_ExternalSource) + m_ExternalSource->ActAfterProcessingSingleStmtBodyInVisitForLoop(); + + Stmt* reverseBlock = unwrapIfSingleStmt(bodyDiff.getStmt_dx()); + bodyDiff = {endBlock(direction::forward), reverseBlock}; + // for forward-pass loop statement body + endScope(); + } + Stmts revLoopBlock = m_LoopBlock.back(); + utils::AppendIndividualStmts(revLoopBlock, bodyDiff.getStmt_dx()); + if (!revLoopBlock.empty()) + bodyDiff.updateStmtDx(MakeCompoundStmt(revLoopBlock)); + m_LoopBlock.pop_back(); + + /// Increment statement in the for-loop is only executed if the iteration + /// did not end with a break/continue statement. Therefore, forLoopIncDiff + /// should be inside the last switch case in the reverse pass. + if (forLoopIncDiff) { + if (bodyDiff.getStmt_dx()) { + bodyDiff.updateStmtDx(utils::PrependAndCreateCompoundStmt( + m_Context, bodyDiff.getStmt_dx(), forLoopIncDiff)); + } else { + bodyDiff.updateStmtDx(forLoopIncDiff); + } } + + activeBreakContHandler->EndCFSwitchStmtScope(); + activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); + PopBreakContStmtHandler(); + + Expr* counterDecrement = loopCounter.getCounterDecrement(); + + // Create reverse pass loop body statements by arranging various + // differentiated statements in the correct order. + // Order used: + // + // 1) `for` loop increment differentiation statements + // 2) loop body differentiation statements + // 3) condition variable differentiation statements + // 4) counter decrement expression + beginBlock(direction::reverse); + // `for` loops have counter decrement expression in the + // loop iteration-expression. + if (!isForLoop) + addToCurrentBlock(counterDecrement, direction::reverse); + addToCurrentBlock(condVarDiff, direction::reverse); + addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse); + bodyDiff = {bodyDiff.getStmt(), + unwrapIfSingleStmt(endBlock(direction::reverse))}; + return bodyDiff; + } + + StmtDiff ReverseModeVisitor::VisitContinueStmt(const ContinueStmt* CS) { + beginBlock(direction::forward); + Stmt* newCS = m_Sema.ActOnContinueStmt(noLoc, getCurrentScope()).get(); + auto activeBreakContHandler = GetActiveBreakContStmtHandler(); + Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt(); + Stmt* pushExprToCurrentCase = activeBreakContHandler + ->CreateCFTapePushExprToCurrentCase(); + addToCurrentBlock(pushExprToCurrentCase); + addToCurrentBlock(newCS); + return {endBlock(direction::forward), CFCaseStmt}; } - activeBreakContHandler->EndCFSwitchStmtScope(); - activeBreakContHandler->UpdateForwAndRevBlocks(bodyDiff); - PopBreakContStmtHandler(); - - Expr* counterDecrement = loopCounter.getCounterDecrement(); - - // Create reverse pass loop body statements by arranging various - // differentiated statements in the correct order. - // Order used: - // - // 1) `for` loop increment differentiation statements - // 2) loop body differentiation statements - // 3) condition variable differentiation statements - // 4) counter decrement expression - beginBlock(direction::reverse); - // `for` loops have counter decrement expression in the - // loop iteration-expression. - if (!isForLoop) - addToCurrentBlock(counterDecrement, direction::reverse); - addToCurrentBlock(condVarDiff, direction::reverse); - addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse); - bodyDiff = {bodyDiff.getStmt(), - unwrapIfSingleStmt(endBlock(direction::reverse))}; - return bodyDiff; -} + StmtDiff ReverseModeVisitor::VisitBreakStmt(const BreakStmt* BS) { + beginBlock(direction::forward); + Stmt* newBS = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get(); + auto activeBreakContHandler = GetActiveBreakContStmtHandler(); + Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt(); + Stmt* pushExprToCurrentCase = activeBreakContHandler + ->CreateCFTapePushExprToCurrentCase(); + addToCurrentBlock(pushExprToCurrentCase); + addToCurrentBlock(newBS); + return {endBlock(direction::forward), CFCaseStmt}; + } + + Expr* ReverseModeVisitor::BreakContStmtHandler::CreateSizeTLiteralExpr( + std::size_t value) { + ASTContext& C = m_RMV.m_Context; + auto literalExpr = ConstantFolder::synthesizeLiteral(C.getSizeType(), C, + value); + return literalExpr; + } + + void ReverseModeVisitor::BreakContStmtHandler::InitializeCFTape() { + assert(!m_ControlFlowTape && "InitializeCFTape() should not be called if " + "m_ControlFlowTape is already initialized"); + + auto zeroLiteral = CreateSizeTLiteralExpr(0); + m_ControlFlowTape.reset( + new CladTapeResult(m_RMV.MakeCladTapeFor(zeroLiteral))); + } + + Expr* ReverseModeVisitor::BreakContStmtHandler::CreateCFTapePushExpr( + std::size_t value) { + Expr* pushDRE = m_RMV.GetCladTapePushDRE(); + Expr* callArgs[] = {m_ControlFlowTape->Ref, CreateSizeTLiteralExpr(value)}; + Expr* pushExpr = m_RMV.m_Sema + .ActOnCallExpr(m_RMV.getCurrentScope(), pushDRE, noLoc, + callArgs, noLoc) + .get(); + return pushExpr; + } -StmtDiff ReverseModeVisitor::VisitContinueStmt(const ContinueStmt* CS) { - beginBlock(direction::forward); - Stmt* newCS = m_Sema.ActOnContinueStmt(noLoc, getCurrentScope()).get(); - auto activeBreakContHandler = GetActiveBreakContStmtHandler(); - Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt(); - Stmt* pushExprToCurrentCase = - activeBreakContHandler->CreateCFTapePushExprToCurrentCase(); - addToCurrentBlock(pushExprToCurrentCase); - addToCurrentBlock(newCS); - return {endBlock(direction::forward), CFCaseStmt}; -} + void + ReverseModeVisitor::BreakContStmtHandler::BeginCFSwitchStmtScope() const { + m_RMV.beginScope(Scope::SwitchScope | Scope::ControlScope | + Scope::BreakScope | Scope::DeclScope); + } -StmtDiff ReverseModeVisitor::VisitBreakStmt(const BreakStmt* BS) { - beginBlock(direction::forward); - Stmt* newBS = m_Sema.ActOnBreakStmt(noLoc, getCurrentScope()).get(); - auto activeBreakContHandler = GetActiveBreakContStmtHandler(); - Stmt* CFCaseStmt = activeBreakContHandler->GetNextCFCaseStmt(); - Stmt* pushExprToCurrentCase = - activeBreakContHandler->CreateCFTapePushExprToCurrentCase(); - addToCurrentBlock(pushExprToCurrentCase); - addToCurrentBlock(newBS); - return {endBlock(direction::forward), CFCaseStmt}; -} + void ReverseModeVisitor::BreakContStmtHandler::EndCFSwitchStmtScope() const { + m_RMV.endScope(); + } -Expr* ReverseModeVisitor::BreakContStmtHandler::CreateSizeTLiteralExpr( - std::size_t value) { - ASTContext& C = m_RMV.m_Context; - auto literalExpr = - ConstantFolder::synthesizeLiteral(C.getSizeType(), C, value); - return literalExpr; -} + CaseStmt* ReverseModeVisitor::BreakContStmtHandler::GetNextCFCaseStmt() { + // End scope for currenly active case statement, if any. + if (!m_SwitchCases.empty()) + m_RMV.endScope(); -void ReverseModeVisitor::BreakContStmtHandler::InitializeCFTape() { - assert(!m_ControlFlowTape && "InitializeCFTape() should not be called if " - "m_ControlFlowTape is already initialized"); + ++m_CaseCounter; + auto counterLiteral = CreateSizeTLiteralExpr(m_CaseCounter); + CaseStmt* CS = clad_compat::CaseStmt_Create(m_RMV.m_Context, counterLiteral, + nullptr, noLoc, noLoc, noLoc); - auto zeroLiteral = CreateSizeTLiteralExpr(0); - m_ControlFlowTape.reset( - new CladTapeResult(m_RMV.MakeCladTapeFor(zeroLiteral))); -} + // Initialise switch case statements with null statement because it is + // necessary for switch case statements to have a substatement but it + // is possible that there are no statements after the corresponding + // break/continue statement. It's also easier to just set null statement + // as substatement instead of keeping track of switch cases and + // corresponding next statements. + CS->setSubStmt(m_RMV.m_Sema.ActOnNullStmt(noLoc).get()); -Expr* ReverseModeVisitor::BreakContStmtHandler::CreateCFTapePushExpr( - std::size_t value) { - Expr* pushDRE = m_RMV.GetCladTapePushDRE(); - Expr* callArgs[] = {m_ControlFlowTape->Ref, CreateSizeTLiteralExpr(value)}; - Expr* pushExpr = m_RMV.m_Sema - .ActOnCallExpr(m_RMV.getCurrentScope(), pushDRE, noLoc, - callArgs, noLoc) - .get(); - return pushExpr; -} + // begin scope for the new active switch case statement. + m_RMV.beginScope(Scope::DeclScope); + m_SwitchCases.push_back(CS); + return CS; + } -void ReverseModeVisitor::BreakContStmtHandler::BeginCFSwitchStmtScope() const { - m_RMV.beginScope(Scope::SwitchScope | Scope::ControlScope | - Scope::BreakScope | Scope::DeclScope); -} + Stmt* ReverseModeVisitor::BreakContStmtHandler:: + CreateCFTapePushExprToCurrentCase() { + if (!m_ControlFlowTape) + InitializeCFTape(); + return CreateCFTapePushExpr(m_CaseCounter); + } -void ReverseModeVisitor::BreakContStmtHandler::EndCFSwitchStmtScope() const { - m_RMV.endScope(); -} + void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks( + StmtDiff& bodyDiff) { + if (m_SwitchCases.empty()) + return; -CaseStmt* ReverseModeVisitor::BreakContStmtHandler::GetNextCFCaseStmt() { - // End scope for currenly active case statement, if any. - if (!m_SwitchCases.empty()) + // end scope for last switch case. m_RMV.endScope(); - ++m_CaseCounter; - auto counterLiteral = CreateSizeTLiteralExpr(m_CaseCounter); - CaseStmt* CS = clad_compat::CaseStmt_Create(m_RMV.m_Context, counterLiteral, - nullptr, noLoc, noLoc, noLoc); - - // Initialise switch case statements with null statement because it is - // necessary for switch case statements to have a substatement but it - // is possible that there are no statements after the corresponding - // break/continue statement. It's also easier to just set null statement - // as substatement instead of keeping track of switch cases and - // corresponding next statements. - CS->setSubStmt(m_RMV.m_Sema.ActOnNullStmt(noLoc).get()); - - // begin scope for the new active switch case statement. - m_RMV.beginScope(Scope::DeclScope); - m_SwitchCases.push_back(CS); - return CS; -} + // Add case statement in the beginning of the reverse block + // and corresponding push expression for this case statement + // at the end of the forward block to cover the case when no + // `break`/`continue` statements are hit. + auto lastSC = GetNextCFCaseStmt(); + auto pushExprToCurrentCase = CreateCFTapePushExprToCurrentCase(); + + Stmt *forwBlock, *revBlock; + + forwBlock = utils::AppendAndCreateCompoundStmt(m_RMV.m_Context, + bodyDiff.getStmt(), + pushExprToCurrentCase); + revBlock = utils::PrependAndCreateCompoundStmt(m_RMV.m_Context, + bodyDiff.getStmt_dx(), + lastSC); + + bodyDiff = {forwBlock, revBlock}; + + auto condResult = m_RMV.m_Sema.ActOnCondition(m_RMV.getCurrentScope(), + noLoc, m_ControlFlowTape->Pop, + Sema::ConditionKind::Switch); + SwitchStmt* CFSS = clad_compat::Sema_ActOnStartOfSwitchStmt(m_RMV.m_Sema, + nullptr, + condResult) + .getAs(); + // Registers all the switch cases + for (auto SC : m_SwitchCases) { + CFSS->addSwitchCase(SC); + } + m_RMV.m_Sema.ActOnFinishSwitchStmt(noLoc, CFSS, bodyDiff.getStmt_dx()); -Stmt* ReverseModeVisitor::BreakContStmtHandler:: - CreateCFTapePushExprToCurrentCase() { - if (!m_ControlFlowTape) - InitializeCFTape(); - return CreateCFTapePushExpr(m_CaseCounter); -} + bodyDiff = {bodyDiff.getStmt(), CFSS}; + } -void ReverseModeVisitor::BreakContStmtHandler::UpdateForwAndRevBlocks( - StmtDiff& bodyDiff) { - if (m_SwitchCases.empty()) - return; - - // end scope for last switch case. - m_RMV.endScope(); - - // Add case statement in the beginning of the reverse block - // and corresponding push expression for this case statement - // at the end of the forward block to cover the case when no - // `break`/`continue` statements are hit. - auto lastSC = GetNextCFCaseStmt(); - auto pushExprToCurrentCase = CreateCFTapePushExprToCurrentCase(); - - Stmt *forwBlock, *revBlock; - - forwBlock = utils::AppendAndCreateCompoundStmt( - m_RMV.m_Context, bodyDiff.getStmt(), pushExprToCurrentCase); - revBlock = utils::PrependAndCreateCompoundStmt(m_RMV.m_Context, - bodyDiff.getStmt_dx(), lastSC); - - bodyDiff = {forwBlock, revBlock}; - - auto condResult = m_RMV.m_Sema.ActOnCondition(m_RMV.getCurrentScope(), noLoc, - m_ControlFlowTape->Pop, - Sema::ConditionKind::Switch); - SwitchStmt* CFSS = clad_compat::Sema_ActOnStartOfSwitchStmt( - m_RMV.m_Sema, nullptr, condResult) - .getAs(); - // Registers all the switch cases - for (auto SC : m_SwitchCases) - CFSS->addSwitchCase(SC); - m_RMV.m_Sema.ActOnFinishSwitchStmt(noLoc, CFSS, bodyDiff.getStmt_dx()); - - bodyDiff = {bodyDiff.getStmt(), CFSS}; -} + void ReverseModeVisitor::AddExternalSource(ExternalRMVSource& source) { + if (!m_ExternalSource) + m_ExternalSource = new MultiplexExternalRMVSource(); + source.InitialiseRMV(*this); + m_ExternalSource->AddSource(source); + } -void ReverseModeVisitor::AddExternalSource(ExternalRMVSource& source) { - if (!m_ExternalSource) - m_ExternalSource = new MultiplexExternalRMVSource(); - source.InitialiseRMV(*this); - m_ExternalSource->AddSource(source); -} + StmtDiff ReverseModeVisitor::VisitCXXThisExpr(const CXXThisExpr* CTE) { + Expr* clonedCTE = Clone(CTE); + return {clonedCTE, m_ThisExprDerivative}; + } -StmtDiff ReverseModeVisitor::VisitCXXThisExpr(const CXXThisExpr* CTE) { - Expr* clonedCTE = Clone(CTE); - return {clonedCTE, m_ThisExprDerivative}; -} + // FIXME: Add support for differentiating calls to constructors. + // We currently assume that constructor arguments are non-differentiable. + StmtDiff + ReverseModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) { + llvm::SmallVector clonedArgs; + for (auto arg : CE->arguments()) { + auto argDiff = Visit(arg, dfdx()); + clonedArgs.push_back(argDiff.getExpr()); + } + Expr* clonedArgsE = nullptr; -// FIXME: Add support for differentiating calls to constructors. -// We currently assume that constructor arguments are non-differentiable. -StmtDiff ReverseModeVisitor::VisitCXXConstructExpr(const CXXConstructExpr* CE) { - llvm::SmallVector clonedArgs; - for (auto arg : CE->arguments()) { - auto argDiff = Visit(arg, dfdx()); - clonedArgs.push_back(argDiff.getExpr()); - } - Expr* clonedArgsE = nullptr; - - if (CE->getNumArgs() != 1) { - if (CE->isListInitialization()) { - clonedArgsE = m_Sema.ActOnInitList(noLoc, clonedArgs, noLoc).get(); - } else if (CE->getNumArgs() == 0) { - // ParenList is empty -- default initialisation. - // Passing empty parenList here will silently cause 'most vexing - // parse' issue. - return StmtDiff(); + if (CE->getNumArgs() != 1) { + if (CE->isListInitialization()) { + clonedArgsE = m_Sema.ActOnInitList(noLoc, clonedArgs, noLoc).get(); + } else { + if (CE->getNumArgs() == 0) { + // ParenList is empty -- default initialisation. + // Passing empty parenList here will silently cause 'most vexing + // parse' issue. + return StmtDiff(); + } else { + clonedArgsE = + m_Sema.ActOnParenListExpr(noLoc, noLoc, clonedArgs).get(); + } + } } else { - clonedArgsE = m_Sema.ActOnParenListExpr(noLoc, noLoc, clonedArgs).get(); + clonedArgsE = clonedArgs[0]; + } + // `CXXConstructExpr` node will be created automatically by passing these + // initialiser to higher level `ActOn`/`Build` Sema functions. + return {clonedArgsE}; + } + + StmtDiff ReverseModeVisitor::VisitMaterializeTemporaryExpr( + const clang::MaterializeTemporaryExpr* MTE) { + // `MaterializeTemporaryExpr` node will be created automatically if it is + // required by `ActOn`/`Build` Sema functions. + StmtDiff MTEDiff = Visit(clad_compat::GetSubExpr(MTE), dfdx()); + return MTEDiff; + } + + QualType ReverseModeVisitor::GetParameterDerivativeType(QualType yType, + QualType xType) { + + if (m_Mode == DiffMode::reverse) + assert(yType->isRealType() && + "yType should be a non-reference builtin-numerical scalar type!!"); + else if (m_Mode == DiffMode::experimental_pullback) + assert(yType.getNonReferenceType()->isRealType() && + "yType should be a builtin-numerical scalar type!!"); + QualType xValueType = utils::GetValueType(xType); + // derivative variables should always be of non-const type. + xValueType.removeLocalConst(); + QualType nonRefXValueType = xValueType.getNonReferenceType(); + return GetCladArrayRefOfType(nonRefXValueType); + } + + StmtDiff ReverseModeVisitor::VisitCXXStaticCastExpr( + const clang::CXXStaticCastExpr* SCE) { + StmtDiff subExprDiff = Visit(SCE->getSubExpr(), dfdx()); + return subExprDiff; + } + + clang::QualType ReverseModeVisitor::ComputeAdjointType(clang::QualType T) { + if (T->isReferenceType()) { + QualType TValueType = utils::GetValueType(T); + TValueType.removeLocalConst(); + return m_Context.getPointerType(TValueType); + } + if (isa(T) && !isa(T)) { + QualType adjointType = + GetCladArrayOfType(m_Context.getBaseElementType(T)); + return adjointType; } - } else { - clonedArgsE = clonedArgs[0]; + T.removeLocalConst(); + return T; } - // `CXXConstructExpr` node will be created automatically by passing these - // initialiser to higher level `ActOn`/`Build` Sema functions. - return {clonedArgsE}; -} -StmtDiff ReverseModeVisitor::VisitMaterializeTemporaryExpr( - const clang::MaterializeTemporaryExpr* MTE) { - // `MaterializeTemporaryExpr` node will be created automatically if it is - // required by `ActOn`/`Build` Sema functions. - StmtDiff MTEDiff = Visit(clad_compat::GetSubExpr(MTE), dfdx()); - return MTEDiff; -} + clang::QualType ReverseModeVisitor::ComputeParamType(clang::QualType T) { + QualType TValueType = utils::GetValueType(T); + TValueType.removeLocalConst(); + return GetCladArrayRefOfType(TValueType); + } -QualType ReverseModeVisitor::GetParameterDerivativeType(QualType yType, - QualType xType) { - - if (m_Mode == DiffMode::reverse) - assert(yType->isRealType() && - "yType should be a non-reference builtin-numerical scalar type!!"); - else if (m_Mode == DiffMode::experimental_pullback) - assert(yType.getNonReferenceType()->isRealType() && - "yType should be a builtin-numerical scalar type!!"); - QualType xValueType = utils::GetValueType(xType); - // derivative variables should always be of non-const type. - xValueType.removeLocalConst(); - QualType nonRefXValueType = xValueType.getNonReferenceType(); - return GetCladArrayRefOfType(nonRefXValueType); -} + llvm::SmallVector + ReverseModeVisitor::ComputeParamTypes(const DiffParams& diffParams) { + llvm::SmallVector paramTypes; + paramTypes.reserve(m_Function->getNumParams() * 2); + for (auto PVD : m_Function->parameters()) { + paramTypes.push_back(PVD->getType()); + } + // TODO: Add DiffMode::experimental_pullback support here as well. + if (m_Mode == DiffMode::reverse || + m_Mode == DiffMode::experimental_pullback) { + QualType effectiveReturnType = + m_Function->getReturnType().getNonReferenceType(); + if (m_Mode == DiffMode::experimental_pullback) { + // FIXME: Generally, we use the function's return type as the argument's + // derivative type. We cannot follow this strategy for `void` function + // return type. Thus, temporarily use `double` type as the placeholder + // type for argument derivatives. We should think of a more uniform and + // consistent solution to this problem. One effective strategy that may + // hold well: If we are differentiating a variable of type Y with + // respect to variable of type X, then the derivative should be of type + // X. Check this related issue for more details: + // https://github.com/vgvassilev/clad/issues/385 + if (effectiveReturnType->isVoidType()) + effectiveReturnType = m_Context.DoubleTy; + else + paramTypes.push_back(effectiveReturnType); + } -StmtDiff ReverseModeVisitor::VisitCXXStaticCastExpr( - const clang::CXXStaticCastExpr* SCE) { - StmtDiff subExprDiff = Visit(SCE->getSubExpr(), dfdx()); - return subExprDiff; -} + if (auto MD = dyn_cast(m_Function)) { + const CXXRecordDecl* RD = MD->getParent(); + if (MD->isInstance() && !RD->isLambda()) { + QualType thisType = + clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); + paramTypes.push_back( + GetParameterDerivativeType(effectiveReturnType, thisType)); + } + } -clang::QualType ReverseModeVisitor::ComputeAdjointType(clang::QualType T) { - if (T->isReferenceType()) { - QualType TValueType = utils::GetValueType(T); - TValueType.removeLocalConst(); - return m_Context.getPointerType(TValueType); - } - if (isa(T) && !isa(T)) { - QualType adjointType = GetCladArrayOfType(m_Context.getBaseElementType(T)); - return adjointType; + for (auto PVD : m_Function->parameters()) { + auto it = std::find(std::begin(diffParams), std::end(diffParams), PVD); + if (it != std::end(diffParams)) + paramTypes.push_back(ComputeParamType(PVD->getType())); + } + } else if (m_Mode == DiffMode::jacobian) { + std::size_t lastArgIdx = m_Function->getNumParams() - 1; + QualType derivativeParamType = + m_Function->getParamDecl(lastArgIdx)->getType(); + paramTypes.push_back(derivativeParamType); + } + return paramTypes; } - T.removeLocalConst(); - return T; -} -clang::QualType ReverseModeVisitor::ComputeParamType(clang::QualType T) { - QualType TValueType = utils::GetValueType(T); - TValueType.removeLocalConst(); - return GetCladArrayRefOfType(TValueType); -} + llvm::SmallVector + ReverseModeVisitor::BuildParams(DiffParams& diffParams) { + llvm::SmallVector params, paramDerivatives; + params.reserve(m_Function->getNumParams() + diffParams.size()); + auto derivativeFnType = cast(m_Derivative->getType()); + std::size_t dParamTypesIdx = m_Function->getNumParams(); -llvm::SmallVector -ReverseModeVisitor::ComputeParamTypes(const DiffParams& diffParams) { - llvm::SmallVector paramTypes; - paramTypes.reserve(m_Function->getNumParams() * 2); - for (auto PVD : m_Function->parameters()) - paramTypes.push_back(PVD->getType()); - // TODO: Add DiffMode::experimental_pullback support here as well. - if (m_Mode == DiffMode::reverse || - m_Mode == DiffMode::experimental_pullback) { - QualType effectiveReturnType = - m_Function->getReturnType().getNonReferenceType(); - if (m_Mode == DiffMode::experimental_pullback) { - // FIXME: Generally, we use the function's return type as the argument's - // derivative type. We cannot follow this strategy for `void` function - // return type. Thus, temporarily use `double` type as the placeholder - // type for argument derivatives. We should think of a more uniform and - // consistent solution to this problem. One effective strategy that may - // hold well: If we are differentiating a variable of type Y with - // respect to variable of type X, then the derivative should be of type - // X. Check this related issue for more details: - // https://github.com/vgvassilev/clad/issues/385 - if (effectiveReturnType->isVoidType()) - effectiveReturnType = m_Context.DoubleTy; - else - paramTypes.push_back(effectiveReturnType); + if (m_Mode == DiffMode::experimental_pullback && + !m_Function->getReturnType()->isVoidType()) { + ++dParamTypesIdx; } if (auto MD = dyn_cast(m_Function)) { const CXXRecordDecl* RD = MD->getParent(); - if (MD->isInstance() && !RD->isLambda()) { - QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); - paramTypes.push_back( - GetParameterDerivativeType(effectiveReturnType, thisType)); + if (!isVectorValued && MD->isInstance() && !RD->isLambda()) { + auto thisDerivativePVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, CreateUniqueIdentifier("_d_this"), + derivativeFnType->getParamType(dParamTypesIdx)); + paramDerivatives.push_back(thisDerivativePVD); + + if (thisDerivativePVD->getIdentifier()) + m_Sema.PushOnScopeChains(thisDerivativePVD, getCurrentScope(), + /*AddToContext=*/false); + + // This can instantiate an array_ref and needs a fake source location. + SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema); + Expr* deref = BuildOp(UnaryOperatorKind::UO_Deref, + BuildDeclRef(thisDerivativePVD), fakeLoc); + m_ThisExprDerivative = utils::BuildParenExpr(m_Sema, deref); + ++dParamTypesIdx; } } for (auto PVD : m_Function->parameters()) { - auto it = std::find(std::begin(diffParams), std::end(diffParams), PVD); - if (it != std::end(diffParams)) - paramTypes.push_back(ComputeParamType(PVD->getType())); - } - } else if (m_Mode == DiffMode::jacobian) { - std::size_t lastArgIdx = m_Function->getNumParams() - 1; - QualType derivativeParamType = - m_Function->getParamDecl(lastArgIdx)->getType(); - paramTypes.push_back(derivativeParamType); - } - return paramTypes; -} - -llvm::SmallVector -ReverseModeVisitor::BuildParams(DiffParams& diffParams) { - llvm::SmallVector params, paramDerivatives; - params.reserve(m_Function->getNumParams() + diffParams.size()); - auto derivativeFnType = cast(m_Derivative->getType()); - std::size_t dParamTypesIdx = m_Function->getNumParams(); + auto newPVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), + PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); + params.push_back(newPVD); - if (m_Mode == DiffMode::experimental_pullback && - !m_Function->getReturnType()->isVoidType()) { - ++dParamTypesIdx; - } + if (newPVD->getIdentifier()) + m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), + /*AddToContext=*/false); - if (auto MD = dyn_cast(m_Function)) { - const CXXRecordDecl* RD = MD->getParent(); - if (!isVectorValued && MD->isInstance() && !RD->isLambda()) { - auto thisDerivativePVD = utils::BuildParmVarDecl( - m_Sema, m_Derivative, CreateUniqueIdentifier("_d_this"), - derivativeFnType->getParamType(dParamTypesIdx)); - paramDerivatives.push_back(thisDerivativePVD); + auto it = std::find(std::begin(diffParams), std::end(diffParams), PVD); + if (it != std::end(diffParams)) { + *it = newPVD; + if (m_Mode == DiffMode::reverse || + m_Mode == DiffMode::experimental_pullback) { + QualType dType = derivativeFnType->getParamType(dParamTypesIdx); + IdentifierInfo* dII = + CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + auto dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, + PVD->getStorageClass()); + paramDerivatives.push_back(dPVD); + ++dParamTypesIdx; + + if (dPVD->getIdentifier()) + m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), + /*AddToContext=*/false); + + if (utils::isArrayOrPointerType(PVD->getType())) { + m_Variables[*it] = (Expr*)BuildDeclRef(dPVD); + } else { + QualType valueType = DetermineCladArrayValueType(dPVD->getType()); + m_Variables[*it] = BuildOp(UO_Deref, BuildDeclRef(dPVD), + m_Function->getLocation()); + // Add additional paranthesis if derivative is of record type + // because `*derivative.someField` will be incorrectly evaluated if + // the derived function is compiled standalone. + if (valueType->isRecordType()) + m_Variables[*it] = + utils::BuildParenExpr(m_Sema, m_Variables[*it]); + } + } + } + } - if (thisDerivativePVD->getIdentifier()) - m_Sema.PushOnScopeChains(thisDerivativePVD, getCurrentScope(), + if (m_Mode == DiffMode::experimental_pullback && + !m_Function->getReturnType()->isVoidType()) { + IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y"); + QualType pullbackType = + derivativeFnType->getParamType(m_Function->getNumParams()); + ParmVarDecl* pullbackPVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, pullbackParamII, pullbackType); + paramDerivatives.insert(paramDerivatives.begin(), pullbackPVD); + + if (pullbackPVD->getIdentifier()) + m_Sema.PushOnScopeChains(pullbackPVD, getCurrentScope(), /*AddToContext=*/false); - // This can instantiate an array_ref and needs a fake source location. - SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema); - Expr* deref = BuildOp(UnaryOperatorKind::UO_Deref, - BuildDeclRef(thisDerivativePVD), fakeLoc); - m_ThisExprDerivative = utils::BuildParenExpr(m_Sema, deref); + m_Pullback = BuildDeclRef(pullbackPVD); ++dParamTypesIdx; } - } - - for (auto PVD : m_Function->parameters()) { - auto newPVD = utils::BuildParmVarDecl( - m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), - PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); - params.push_back(newPVD); - - if (newPVD->getIdentifier()) - m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), - /*AddToContext=*/false); - - auto it = std::find(std::begin(diffParams), std::end(diffParams), PVD); - if (it != std::end(diffParams)) { - *it = newPVD; - if (m_Mode == DiffMode::reverse || - m_Mode == DiffMode::experimental_pullback) { - QualType dType = derivativeFnType->getParamType(dParamTypesIdx); - IdentifierInfo* dII = - CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); - auto dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, - PVD->getStorageClass()); - paramDerivatives.push_back(dPVD); - ++dParamTypesIdx; - if (dPVD->getIdentifier()) - m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), - /*AddToContext=*/false); - - if (utils::isArrayOrPointerType(PVD->getType())) { - m_Variables[*it] = (Expr*)BuildDeclRef(dPVD); - } else { - QualType valueType = DetermineCladArrayValueType(dPVD->getType()); - m_Variables[*it] = - BuildOp(UO_Deref, BuildDeclRef(dPVD), m_Function->getLocation()); - // Add additional paranthesis if derivative is of record type - // because `*derivative.someField` will be incorrectly evaluated if - // the derived function is compiled standalone. - if (valueType->isRecordType()) - m_Variables[*it] = utils::BuildParenExpr(m_Sema, m_Variables[*it]); - } - } + if (m_Mode == DiffMode::jacobian) { + IdentifierInfo* II = CreateUniqueIdentifier("jacobianMatrix"); + // FIXME: Why are we taking storageClass of `params.front()`? + auto dPVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, II, + derivativeFnType->getParamType(dParamTypesIdx), + params.front()->getStorageClass()); + paramDerivatives.push_back(dPVD); + if (dPVD->getIdentifier()) + m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), + /*AddToContext=*/false); } + params.insert(params.end(), paramDerivatives.begin(), + paramDerivatives.end()); + // FIXME: If we do not consider diffParams as an independent argument for + // jacobian mode, then we should keep diffParams list empty for jacobian + // mode and thus remove the if condition. + if (m_Mode == DiffMode::reverse || + m_Mode == DiffMode::experimental_pullback) + m_IndependentVars.insert(m_IndependentVars.end(), diffParams.begin(), + diffParams.end()); + return params; } - - if (m_Mode == DiffMode::experimental_pullback && - !m_Function->getReturnType()->isVoidType()) { - IdentifierInfo* pullbackParamII = CreateUniqueIdentifier("_d_y"); - QualType pullbackType = - derivativeFnType->getParamType(m_Function->getNumParams()); - ParmVarDecl* pullbackPVD = utils::BuildParmVarDecl( - m_Sema, m_Derivative, pullbackParamII, pullbackType); - paramDerivatives.insert(paramDerivatives.begin(), pullbackPVD); - - if (pullbackPVD->getIdentifier()) - m_Sema.PushOnScopeChains(pullbackPVD, getCurrentScope(), - /*AddToContext=*/false); - - m_Pullback = BuildDeclRef(pullbackPVD); - ++dParamTypesIdx; - } - - if (m_Mode == DiffMode::jacobian) { - IdentifierInfo* II = CreateUniqueIdentifier("jacobianMatrix"); - // FIXME: Why are we taking storageClass of `params.front()`? - auto dPVD = - utils::BuildParmVarDecl(m_Sema, m_Derivative, II, - derivativeFnType->getParamType(dParamTypesIdx), - params.front()->getStorageClass()); - paramDerivatives.push_back(dPVD); - if (dPVD->getIdentifier()) - m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), - /*AddToContext=*/false); - } - params.insert(params.end(), paramDerivatives.begin(), paramDerivatives.end()); - // FIXME: If we do not consider diffParams as an independent argument for - // jacobian mode, then we should keep diffParams list empty for jacobian - // mode and thus remove the if condition. - if (m_Mode == DiffMode::reverse || m_Mode == DiffMode::experimental_pullback) - m_IndependentVars.insert(m_IndependentVars.end(), diffParams.begin(), - diffParams.end()); - return params; -} } // end namespace clad diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index b5f6af7e6..62f826781 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -295,16 +295,20 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { void TBRAnalyzer::markLocation(const clang::Expr* E) { VarData* data = getExprVarData(E); - if (!data || findReq(data)) { + if (data) { /// FIXME: If any of the data's child nodes are required to store then data /// itself is stored. We might add an option to store separate fields. + bool& ToBeRec = TBRLocs[E->getBeginLoc()]; /// FIXME: Sometimes one location might correspond to multiple stores. /// For example, in ``(x*=y)=u`` x's location will first be marked as /// required to be stored (when passing *= operator) but then marked as not /// required to be stored (when passing = operator). Current method of /// marking locations does not allow to differentiate between these two. - TBRLocs.insert(E->getBeginLoc()); - } + ToBeRec = ToBeRec || findReq(data); + } else + /// If the current branch is going to be deleted then there is not point in + /// storing anything in it. + TBRLocs[E->getBeginLoc()] = true; } void TBRAnalyzer::setIsRequired(const clang::Expr* E, bool isReq) { @@ -719,11 +723,15 @@ void TBRAnalyzer::VisitCallExpr(const clang::CallExpr* CE) { // FIXME: this supports only DeclRefExpr const auto innerExpr = utils::GetInnermostReturnExpr(arg); if (passByRef) { - /// Mark SourceLocation as required to store for ref-type arguments. + /// Mark SourceLocation as required for ref-type arguments. if (isa(B) || isa(B)) { - TBRLocs.insert(arg->getBeginLoc()); + TBRLocs[arg->getBeginLoc()] = true; setIsRequired(arg, /*isReq=*/false); } + } else { + /// Mark SourceLocation as not required for non-ref-type arguments. + if (isa(B) || isa(B)) + TBRLocs[arg->getBeginLoc()] = false; } } resetMode(); @@ -747,9 +755,13 @@ void TBRAnalyzer::VisitCXXConstructExpr(const clang::CXXConstructExpr* CE) { if (passByRef) { /// Mark SourceLocation as required for ref-type arguments. if (isa(B) || isa(B)) { - TBRLocs.insert(arg->getBeginLoc()); + TBRLocs[arg->getBeginLoc()] = true; setIsRequired(arg, /*isReq=*/false); } + } else { + /// Mark SourceLocation as not required for non-ref-type arguments. + if (isa(B) || isa(B)) + TBRLocs[arg->getBeginLoc()] = false; } } resetMode();