diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index efe369500..4b43c5798 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -38,9 +38,6 @@ namespace clad { class CladPlugin; clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P, DiffRequest& request); - // FIXME: This function should be removed and the entire plans array - // should be somehow made accessible to all the visitors. - void AddRequestToSchedule(CladPlugin& P, const DiffRequest& request); } // namespace plugin } // namespace clad @@ -87,6 +84,7 @@ namespace clad { plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; const DerivedFnCollector& m_DFC; + clad::DynamicGraph& m_DiffRequestGraph; std::unique_ptr m_NodeCloner; clang::NamespaceDecl* m_BuiltinDerivativesNSD; /// A reference to the model to use for error estimation (if any). @@ -137,7 +135,8 @@ namespace clad { public: DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC); + const DerivedFnCollector& DFC, + clad::DynamicGraph& DRG); ~DerivativeBuilder(); /// Reset the model use for error estimation (if any). /// \param[in] estModel The error estimation model, can be either @@ -172,6 +171,16 @@ namespace clad { /// /// \returns The derived function if found, nullptr otherwise. clang::FunctionDecl* FindDerivedFunction(const DiffRequest& request); + /// Add edge from current request to the given request in the DiffRequest + /// graph. + /// + /// \param[in] request The request to add the edge to. + void AddEdgeToGraph(const DiffRequest& request); + /// Add edge between two requests in the DiffRequest graph. + /// + /// \param[in] from The source request. + /// \param[in] to The destination request. + void AddEdgeToGraph(const DiffRequest& from, const DiffRequest& to); }; } // end namespace clad diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index 919d22c80..c7418eb6f 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -15,6 +15,42 @@ enum class DiffMode { reverse_mode_forward_pass, error_estimation }; + +/// Convert enum value to string. +inline const char* DiffModeToString(DiffMode mode) { + switch (mode) { + case DiffMode::forward: + return "forward"; + case DiffMode::vector_forward_mode: + return "vector_forward_mode"; + case DiffMode::experimental_pushforward: + return "pushforward"; + case DiffMode::experimental_pullback: + return "pullback"; + case DiffMode::experimental_vector_pushforward: + return "vector_pushforward"; + case DiffMode::reverse: + return "reverse"; + case DiffMode::hessian: + return "hessian"; + case DiffMode::jacobian: + return "jacobian"; + case DiffMode::reverse_mode_forward_pass: + return "reverse_mode_forward_pass"; + case DiffMode::error_estimation: + return "error_estimation"; + default: + return "unknown"; + } +} + +/// Returns true if the given mode is a pullback/pushforward mode. +inline bool IsPullbackOrPushforwardMode(DiffMode mode) { + return mode == DiffMode::experimental_pushforward || + mode == DiffMode::experimental_pullback || + mode == DiffMode::experimental_vector_pushforward || + mode == DiffMode::reverse_mode_forward_pass; +} } #endif diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 187810364..1a70c90f2 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -1,10 +1,11 @@ #ifndef CLAD_DIFF_PLANNER_H #define CLAD_DIFF_PLANNER_H -#include "clad/Differentiator/DiffMode.h" -#include "clad/Differentiator/ParseDiffArgsTypes.h" #include "clang/AST/RecursiveASTVisitor.h" #include "llvm/ADT/SmallSet.h" +#include "clad/Differentiator/DiffMode.h" +#include "clad/Differentiator/DynamicGraph.h" +#include "clad/Differentiator/ParseDiffArgsTypes.h" namespace clang { class ASTContext; @@ -90,9 +91,33 @@ struct DiffRequest { /// 3) If no argument is provided, a default argument is used. The /// function will be differentiated w.r.t. to its every parameter. void UpdateDiffParamsInfo(clang::Sema& semaRef); + + /// Define the == operator for DiffRequest. + bool operator==(const DiffRequest& other) const { + // either function match or previous declaration match + return (Function == other.Function || + Function->getPreviousDecl() == other.Function || + Function == other.Function->getPreviousDecl()) && + BaseFunctionName == other.BaseFunctionName && + CurrentDerivativeOrder == other.CurrentDerivativeOrder && + RequestedDerivativeOrder == other.RequestedDerivativeOrder && + CallContext == other.CallContext && Args == other.Args && + Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && + DVI == other.DVI && use_enzyme == other.use_enzyme && + DeclarationOnly == other.DeclarationOnly; + } + + // String operator for printing the node. + operator std::string() const { + std::string res = BaseFunctionName + "__order_" + + std::to_string(CurrentDerivativeOrder) + "__mode_" + + DiffModeToString(Mode); + if (EnableTBRAnalysis) + res += "__TBR"; + return res; + } }; - using DiffSchedule = llvm::SmallVector; using DiffInterval = std::vector; struct RequestOptions { @@ -106,9 +131,9 @@ struct DiffRequest { /// DiffInterval& m_Interval; - /// The diff step-by-step plan for differentiation. + /// Graph to store the dependencies between different requests. /// - DiffSchedule& m_DiffPlans; + clad::DynamicGraph& m_DiffRequestGraph; /// If set it means that we need to find the called functions and /// add them for implicit diff. @@ -120,7 +145,8 @@ struct DiffRequest { public: DiffCollector(clang::DeclGroupRef DGR, DiffInterval& Interval, - DiffSchedule& plans, clang::Sema& S, RequestOptions& opts); + clad::DynamicGraph& requestGraph, clang::Sema& S, + RequestOptions& opts); bool VisitCallExpr(clang::CallExpr* E); private: @@ -128,4 +154,15 @@ struct DiffRequest { }; } +// Define the hash function for DiffRequest. +template <> struct std::hash { + std::size_t operator()(const clad::DiffRequest& DR) const { + // Use the function pointer as the hash of the DiffRequest, it + // is sufficient to break a reasonable number of collisions. + if (DR.Function->getPreviousDecl()) + return std::hash{}(DR.Function->getPreviousDecl()); + return std::hash{}(DR.Function); + } +}; + #endif diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 65211f24c..4a089c095 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -11,6 +11,7 @@ #include "ArrayRef.h" #include "BuiltinDerivatives.h" #include "CladConfig.h" +#include "DynamicGraph.h" #include "FunctionTraits.h" #include "Matrix.h" #include "NumericalDiff.h" diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h new file mode 100644 index 000000000..0e528d671 --- /dev/null +++ b/include/clad/Differentiator/DynamicGraph.h @@ -0,0 +1,137 @@ +#ifndef CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H +#define CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace clad { +template class DynamicGraph { +private: + /// Storing nodes in the graph. The index of the node in the vector is used as + /// a unique identifier for the node in the adjacency list. + std::vector m_nodes; + + /// Store the nodes in the graph as an unordered map from the node to a + /// boolean indicating whether the node is processed or not. The second + /// element in the pair is the id of the node in the nodes vector. + std::unordered_map> m_nodeMap; + + /// Store the adjacency list for the graph. The adjacency list is a map from + /// a node to the set of nodes that it has an edge to. We use integers inside + /// the set to avoid copying the nodes. + std::unordered_map> m_adjList; + + /// Set of source nodes in the graph. + std::set m_sources; + + /// Store the id of the node being processed right now. + int m_currentId = -1; // -1 means no node is being processed. + + /// Maintain a queue of nodes to be processed next. + std::queue m_toProcessQueue; + +public: + DynamicGraph() = default; + + /// Add an edge from the source node to the destination node. + /// \param src + /// \param dest + void addEdge(const T& src, const T& dest) { + std::pair srcInfo = addNode(src); + std::pair destInfo = addNode(dest); + size_t srcId = srcInfo.second; + size_t destId = destInfo.second; + m_adjList[srcId].insert(destId); + } + + /// Add a node to the graph. If the node is already present, return the + /// id of the node in the graph. If the node is a source node, add it to the + /// queue of nodes to be processed. + /// \param node + /// \param isSource + /// \returns A pair of a boolean indicating whether the node is already + /// processed and the id of the node in the graph. + std::pair addNode(const T& node, bool isSource = false) { + if (m_nodeMap.find(node) == m_nodeMap.end()) { + size_t id = m_nodes.size(); + m_nodes.push_back(node); + m_nodeMap[node] = {false, id}; // node is not processed yet. + m_adjList[id] = {}; + if (isSource) { + m_sources.insert(id); + m_toProcessQueue.push(id); + } + } + return m_nodeMap[node]; + } + + /// Add an edge from the current node being processed to the + /// destination node. + /// \param dest + void addEdgeToCurrentNode(const T& dest) { + if (m_currentId == -1) + return; + addEdge(m_nodes[m_currentId], dest); + } + + /// Set the current node being processed. + /// \param node + void setCurrentProcessingNode(const T& node) { + if (m_nodeMap.find(node) != m_nodeMap.end()) + m_currentId = m_nodeMap[node].second; + } + + /// Mark the current node being processed as processed and add the + /// destination nodes to the queue of nodes to be processed. + void markCurrentNodeProcessed() { + if (m_currentId != -1) { + m_nodeMap[m_nodes[m_currentId]].first = true; + for (size_t destId : m_adjList[m_currentId]) + if (!m_nodeMap[m_nodes[destId]].first) + m_toProcessQueue.push(destId); + } + m_currentId = -1; + } + + /// Get the nodes in the graph. + std::vector getNodes() { return m_nodes; } + + /// Print the nodes and edges in the graph. + void print() { + // First print the nodes with their insertion order. + for (const T& node : m_nodes) { + std::pair nodeInfo = m_nodeMap[node]; + std::cout << (std::string)node << ": #" << nodeInfo.second; + if (m_sources.find(nodeInfo.second) != m_sources.end()) + std::cout << " (source)"; + if (nodeInfo.first) + std::cout << ", (done)\n"; + else + std::cout << ", (unprocessed)\n"; + } + // Then print the edges. + for (int i = 0; i < m_nodes.size(); i++) + for (size_t dest : m_adjList[i]) + std::cout << i << " -> " << dest << "\n"; + } + + /// Get the next node to be processed from the queue of nodes to be + /// processed. + /// \returns The next node to be processed. + T getNextToProcessNode() { + if (m_toProcessQueue.empty()) + return T(); + size_t nextId = m_toProcessQueue.front(); + m_toProcessQueue.pop(); + return m_nodes[nextId]; + } +}; +} // end namespace clad + +#endif // CLAD_DIFFERENTIATOR_DYNAMICGRAPH_H \ No newline at end of file diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index acad4b3a1..5beab0cb3 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1168,8 +1168,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // into the queue. pushforwardFnRequest.DeclarationOnly = false; pushforwardFnRequest.DerivedFDPrototype = pushforwardFD; - plugin::AddRequestToSchedule(m_CladPlugin, pushforwardFnRequest); } + m_Builder.AddEdgeToGraph(pushforwardFnRequest); if (pushforwardFD) { if (baseDiff.getExpr()) { diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 3a6c57ec0..35a252c60 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -37,8 +37,10 @@ using namespace clang; namespace clad { DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC) + const DerivedFnCollector& DFC, + clad::DynamicGraph& G) : m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC), + m_DiffRequestGraph(G), m_NodeCloner(new utils::StmtClone(m_Sema, m_Context)), m_BuiltinDerivativesNSD(nullptr), m_NumericalDiffNSD(nullptr) {} @@ -292,15 +294,25 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { assert(FD && "Must not be null."); // If FD is only a declaration, try to find its definition. if (!FD->getDefinition()) { - if (request.VerboseDiags) - diag(DiagnosticsEngine::Error, - request.CallContext ? request.CallContext->getBeginLoc() : noLoc, - "attempted differentiation of function '%0', which does not have a " - "definition", { FD->getNameAsString() }); - return {}; + // If only declaration is requested, allow this for non + // pullback/pushforward modes. For ex, this is required for Hessian - + // where we have forward mode followed by reverse mode, but we only need + // the declaration of the forward mode initially. + if (!request.DeclarationOnly || + IsPullbackOrPushforwardMode(request.Mode)) { + if (request.VerboseDiags) + diag(DiagnosticsEngine::Error, + request.CallContext ? request.CallContext->getBeginLoc() : noLoc, + "attempted differentiation of function '%0', which does not " + "have a " + "definition", + {FD->getNameAsString()}); + return {}; + } } - FD = FD->getDefinition(); + if (!request.DeclarationOnly) + FD = FD->getDefinition(); // check if the function is non-differentiable. if (clad::utils::hasNonDifferentiableAttribute(FD)) { @@ -388,4 +400,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { return DFI.DerivedFn(); return nullptr; } + + void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& request) { + m_DiffRequestGraph.addEdgeToCurrentNode(request); + } + + void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& from, + const DiffRequest& to) { + m_DiffRequestGraph.addEdge(from, to); + } }// end namespace clad diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index e4d76d0d4..07a7993c1 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -232,10 +232,10 @@ namespace clad { } DiffCollector::DiffCollector(DeclGroupRef DGR, DiffInterval& Interval, - DiffSchedule& plans, clang::Sema& S, - RequestOptions& opts) - : m_Interval(Interval), m_DiffPlans(plans), m_TopMostFD(nullptr), - m_Sema(S), m_Options(opts) { + clad::DynamicGraph& requestGraph, + clang::Sema& S, RequestOptions& opts) + : m_Interval(Interval), m_DiffRequestGraph(requestGraph), m_Sema(S), + m_Options(opts) { if (Interval.empty()) return; @@ -300,7 +300,8 @@ namespace clad { auto& C = semaRef.getASTContext(); const Expr* diffArgs = Args; const FunctionDecl* FD = Function; - FD = FD->getDefinition(); + if (!DeclarationOnly) + FD = FD->getDefinition(); if (!diffArgs || !FD) { return; } @@ -684,7 +685,7 @@ namespace clad { llvm::SaveAndRestore saveTopMost = m_TopMostFD; m_TopMostFD = FD; TraverseDecl(derivedFD); - m_DiffPlans.push_back(std::move(request)); + m_DiffRequestGraph.addNode(request, /*isSource=*/true); } /*else if (m_TopMostFD) { // If another function is called inside differentiated function, diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 7ddc3f730..6fe5ecbfd 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -49,8 +49,9 @@ namespace clad { /// Derives the function w.r.t both forward and reverse mode and returns the /// FunctionDecl obtained from reverse mode differentiation - static FunctionDecl* DeriveUsingForwardAndReverseMode(Sema& SemaRef, - clad::plugin::CladPlugin& CP, DiffRequest IndependentArgRequest, + static FunctionDecl* DeriveUsingForwardAndReverseMode( + Sema& SemaRef, clad::plugin::CladPlugin& CP, + clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest, const Expr* ForwardModeArgs, const Expr* ReverseModeArgs) { // Derives function once in forward mode w.r.t to ForwardModeArgs IndependentArgRequest.Args = ForwardModeArgs; @@ -59,25 +60,40 @@ namespace clad { IndependentArgRequest.UpdateDiffParamsInfo(SemaRef); // FIXME: Find a way to do this without accessing plugin namespace functions FunctionDecl* firstDerivative = - plugin::ProcessDiffRequest(CP, IndependentArgRequest); + Builder.FindDerivedFunction(IndependentArgRequest); + if (!firstDerivative) { + // Derive declaration of the the forward mode derivative. + IndependentArgRequest.DeclarationOnly = true; + firstDerivative = plugin::ProcessDiffRequest(CP, IndependentArgRequest); + + // Add the request to derive the definition of the forward mode derivative + // to the schedule. + IndependentArgRequest.DeclarationOnly = false; + IndependentArgRequest.DerivedFDPrototype = firstDerivative; + } + Builder.AddEdgeToGraph(IndependentArgRequest); // Further derives function w.r.t to ReverseModeArgs - IndependentArgRequest.Mode = DiffMode::reverse; - IndependentArgRequest.Function = firstDerivative; - IndependentArgRequest.Args = ReverseModeArgs; - IndependentArgRequest.BaseFunctionName = firstDerivative->getNameAsString(); - IndependentArgRequest.UpdateDiffParamsInfo(SemaRef); + DiffRequest ReverseModeRequest{}; + ReverseModeRequest.Mode = DiffMode::reverse; + ReverseModeRequest.Function = firstDerivative; + ReverseModeRequest.Args = ReverseModeArgs; + ReverseModeRequest.BaseFunctionName = firstDerivative->getNameAsString(); + ReverseModeRequest.UpdateDiffParamsInfo(SemaRef); - // Derive declaration of the the forward mode derivative. - IndependentArgRequest.DeclarationOnly = true; FunctionDecl* secondDerivative = - plugin::ProcessDiffRequest(CP, IndependentArgRequest); - - // Add the request to derive the definition of the forward mode derivative - // to the schedule. - IndependentArgRequest.DeclarationOnly = false; - IndependentArgRequest.DerivedFDPrototype = secondDerivative; - plugin::AddRequestToSchedule(CP, IndependentArgRequest); + Builder.FindDerivedFunction(ReverseModeRequest); + if (!secondDerivative) { + // Derive declaration of the the reverse mode derivative. + ReverseModeRequest.DeclarationOnly = true; + secondDerivative = plugin::ProcessDiffRequest(CP, ReverseModeRequest); + + // Add the request to derive the definition of the reverse mode derivative + // to the schedule. + ReverseModeRequest.DeclarationOnly = false; + ReverseModeRequest.DerivedFDPrototype = secondDerivative; + } + Builder.AddEdgeToGraph(ReverseModeRequest); return secondDerivative; } @@ -177,9 +193,9 @@ namespace clad { PVD->getNameAsString() + "[" + std::to_string(i) + "]"; auto ForwardModeIASL = CreateStringLiteral(m_Context, independentArgString); - auto DFD = - DeriveUsingForwardAndReverseMode(m_Sema, m_CladPlugin, request, - ForwardModeIASL, request.Args); + auto* DFD = DeriveUsingForwardAndReverseMode( + m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, + request.Args); secondDerivativeColumns.push_back(DFD); } @@ -190,9 +206,9 @@ namespace clad { // then in reverse mode w.r.t to all requested args auto ForwardModeIASL = CreateStringLiteral(m_Context, PVD->getNameAsString()); - auto DFD = DeriveUsingForwardAndReverseMode(m_Sema, m_CladPlugin, - request, ForwardModeIASL, - request.Args); + auto* DFD = DeriveUsingForwardAndReverseMode( + m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, + request.Args); secondDerivativeColumns.push_back(DFD); } } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d1c69c64e..32febed08 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1784,6 +1784,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActBeforeDifferentiatingCallExpr( pullbackCallArgs, PreCallStmts, dfdx()); + // Overloaded derivative was not found, request the CladPlugin to // derive the called function. DiffRequest pullbackRequest{}; @@ -1812,7 +1813,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // function. pullbackRequest.DeclarationOnly = false; pullbackRequest.DerivedFDPrototype = pullbackFD; - plugin::AddRequestToSchedule(m_CladPlugin, pullbackRequest); } else { // FIXME: Error estimation currently uses singleton objects - // m_ErrorEstHandler and m_EstModel, which is cleared after each @@ -1822,6 +1822,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); } } + m_Builder.AddEdgeToGraph(pullbackRequest); // Clad failed to derive it. // FIXME: Add support for reference arguments to the numerical diff. If @@ -1923,8 +1924,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // function. calleeFnForwPassReq.DeclarationOnly = false; calleeFnForwPassReq.DerivedFDPrototype = calleeFnForwPassFD; - plugin::AddRequestToSchedule(m_CladPlugin, calleeFnForwPassReq); } + m_Builder.AddEdgeToGraph(calleeFnForwPassReq); assert(calleeFnForwPassFD && "Clad failed to generate callee function forward pass function"); diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index a99c61ce3..e2c8fb1d2 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -122,11 +122,13 @@ namespace clad { Sema& S = m_CI.getSema(); if (!m_DerivativeBuilder) - m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this, m_DFC)); + m_DerivativeBuilder = std::make_unique( + S, *this, m_DFC, m_DiffRequestGraph); RequestOptions opts{}; SetRequestOptions(opts); - DiffCollector collector(DGR, CladEnabledRange, m_DiffSchedule, S, opts); + DiffCollector collector(DGR, CladEnabledRange, m_DiffRequestGraph, S, + opts); } FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { @@ -400,13 +402,12 @@ namespace clad { Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled); Sema::LocalEagerInstantiationScope LocalInstantiations(S); - // Use index based loop to avoid iterator invalidation as - // ProcessDiffRequest might add more requests to m_DiffSchedule. - for (size_t i = 0; i < m_DiffSchedule.size(); ++i) { - // make a copy of the request to avoid invalidating the reference - // when ProcessDiffRequest adds more requests to m_DiffSchedule. - DiffRequest request = m_DiffSchedule[i]; + DiffRequest request = m_DiffRequestGraph.getNextToProcessNode(); + while (request.Function != nullptr) { + m_DiffRequestGraph.setCurrentProcessingNode(request); ProcessDiffRequest(request); + m_DiffRequestGraph.markCurrentNodeProcessed(); + request = m_DiffRequestGraph.getNextToProcessNode(); } // Put the TUScope in a consistent state after clad is done. S.TUScope = nullptr; diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 4d569c1cd..5543b1b2d 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -102,7 +102,7 @@ class CladTimerGroup { bool m_HasRuntime = false; CladTimerGroup m_CTG; DerivedFnCollector m_DFC; - DiffSchedule m_DiffSchedule; + DynamicGraph m_DiffRequestGraph; enum class CallKind { HandleCXXStaticMemberVarInstantiation, HandleTopLevelDecl, @@ -237,10 +237,6 @@ class CladTimerGroup { m_Multiplexer->ForgetSema(); } - void AddRequestToSchedule(const DiffRequest& request) { - m_DiffSchedule.push_back(request); - } - // FIXME: We should hide ProcessDiffRequest when we implement proper // handling of the differentiation plans. clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); @@ -271,10 +267,6 @@ class CladTimerGroup { return P.ProcessDiffRequest(request); } - void AddRequestToSchedule(CladPlugin& P, const DiffRequest& request) { - P.AddRequestToSchedule(request); - } - template class Action : public clang::PluginASTAction { private: diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index 6100f7346..a8c5a946f 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -40,3 +40,5 @@ if (Kokkos_FOUND) set(CMAKE_CXX_STANDARD_REQUIRED TRUE) add_subdirectory(Kokkos) endif(Kokkos_FOUND) + +add_subdirectory(Misc) diff --git a/unittests/Misc/CMakeLists.txt b/unittests/Misc/CMakeLists.txt new file mode 100644 index 000000000..dc13b2417 --- /dev/null +++ b/unittests/Misc/CMakeLists.txt @@ -0,0 +1,4 @@ +add_clad_unittest(MiscTests + main.cpp + DynamicGraph.cpp +) diff --git a/unittests/Misc/DynamicGraph.cpp b/unittests/Misc/DynamicGraph.cpp new file mode 100644 index 000000000..6954a6698 --- /dev/null +++ b/unittests/Misc/DynamicGraph.cpp @@ -0,0 +1,67 @@ +#include "clad/Differentiator/Differentiator.h" + +#include +#include + +#include "gtest/gtest.h" + +struct Node { + std::string name; + int id; + + Node(std::string name, int id) : name(name), id(id) {} + + bool operator==(const Node& other) const { + return name == other.name && id == other.id; + } + + // String operator for printing the node. + operator std::string() const { return name + std::to_string(id); } +}; + +// Specialize std::hash for the Node type. +template <> struct std::hash { + std::size_t operator()(const Node& n) const { + return std::hash()(n.name) ^ std::hash()(n.id); + } +}; + +TEST(DynamicGraphTest, Printing) { + clad::DynamicGraph G; + for (int i = 0; i < 6; i++) { + Node n("node", i); + if (i == 0) + G.addNode(n, /*isSource=*/true); + Node m("node", i + 1); + G.addEdge(n, m); + } + std::vector nodes = G.getNodes(); + EXPECT_EQ(nodes.size(), 7); + + // Edge from node 0 to node 3 and node 4 to node 0. + G.addEdge(nodes[0], nodes[3]); + G.addEdge(nodes[4], nodes[0]); + + // Check the printed output. + std::stringstream ss; + std::streambuf* coutbuf = std::cout.rdbuf(); + std::cout.rdbuf(ss.rdbuf()); + G.print(); + std::cout.rdbuf(coutbuf); + std::string expectedOutput = "node0: #0 (source), (unprocessed)\n" + "node1: #1, (unprocessed)\n" + "node2: #2, (unprocessed)\n" + "node3: #3, (unprocessed)\n" + "node4: #4, (unprocessed)\n" + "node5: #5, (unprocessed)\n" + "node6: #6, (unprocessed)\n" + "0 -> 1\n" + "0 -> 3\n" + "1 -> 2\n" + "2 -> 3\n" + "3 -> 4\n" + "4 -> 0\n" + "4 -> 5\n" + "5 -> 6\n"; + EXPECT_EQ(ss.str(), expectedOutput); +} diff --git a/unittests/Misc/main.cpp b/unittests/Misc/main.cpp new file mode 100644 index 000000000..b936444aa --- /dev/null +++ b/unittests/Misc/main.cpp @@ -0,0 +1,6 @@ +#include + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}