diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index efe369500..4b43c5798 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -38,9 +38,6 @@ namespace clad { class CladPlugin; clang::FunctionDecl* ProcessDiffRequest(CladPlugin& P, DiffRequest& request); - // FIXME: This function should be removed and the entire plans array - // should be somehow made accessible to all the visitors. - void AddRequestToSchedule(CladPlugin& P, const DiffRequest& request); } // namespace plugin } // namespace clad @@ -87,6 +84,7 @@ namespace clad { plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; const DerivedFnCollector& m_DFC; + clad::DynamicGraph& m_DiffRequestGraph; std::unique_ptr m_NodeCloner; clang::NamespaceDecl* m_BuiltinDerivativesNSD; /// A reference to the model to use for error estimation (if any). @@ -137,7 +135,8 @@ namespace clad { public: DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC); + const DerivedFnCollector& DFC, + clad::DynamicGraph& DRG); ~DerivativeBuilder(); /// Reset the model use for error estimation (if any). /// \param[in] estModel The error estimation model, can be either @@ -172,6 +171,16 @@ namespace clad { /// /// \returns The derived function if found, nullptr otherwise. clang::FunctionDecl* FindDerivedFunction(const DiffRequest& request); + /// Add edge from current request to the given request in the DiffRequest + /// graph. + /// + /// \param[in] request The request to add the edge to. + void AddEdgeToGraph(const DiffRequest& request); + /// Add edge between two requests in the DiffRequest graph. + /// + /// \param[in] from The source request. + /// \param[in] to The destination request. + void AddEdgeToGraph(const DiffRequest& from, const DiffRequest& to); }; } // end namespace clad diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index 919d22c80..20676fd31 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -15,6 +15,14 @@ enum class DiffMode { reverse_mode_forward_pass, error_estimation }; + +/// Returns true if the given mode is a pullback/pushforward mode. +inline bool IsPullbackOrPushforwardMode(DiffMode mode) { + return mode == DiffMode::experimental_pushforward || + mode == DiffMode::experimental_pullback || + mode == DiffMode::experimental_vector_pushforward || + mode == DiffMode::reverse_mode_forward_pass; +} } #endif diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 187810364..028a9ff0d 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -1,10 +1,11 @@ #ifndef CLAD_DIFF_PLANNER_H #define CLAD_DIFF_PLANNER_H -#include "clad/Differentiator/DiffMode.h" -#include "clad/Differentiator/ParseDiffArgsTypes.h" #include "clang/AST/RecursiveASTVisitor.h" #include "llvm/ADT/SmallSet.h" +#include "clad/Differentiator/DiffMode.h" +#include "clad/Differentiator/DynamicGraph.h" +#include "clad/Differentiator/ParseDiffArgsTypes.h" namespace clang { class ASTContext; @@ -90,9 +91,47 @@ struct DiffRequest { /// 3) If no argument is provided, a default argument is used. The /// function will be differentiated w.r.t. to its every parameter. void UpdateDiffParamsInfo(clang::Sema& semaRef); + + /// Define the == operator for DiffRequest. + bool operator==(const DiffRequest& other) const { + // either function match or previous declaration match + return (Function == other.Function || + Function->getPreviousDecl() == other.Function || + Function == other.Function->getPreviousDecl()) && + BaseFunctionName == other.BaseFunctionName && + CurrentDerivativeOrder == other.CurrentDerivativeOrder && + RequestedDerivativeOrder == other.RequestedDerivativeOrder && + CallContext == other.CallContext && Args == other.Args && + Mode == other.Mode && EnableTBRAnalysis == other.EnableTBRAnalysis && + DVI == other.DVI && use_enzyme == other.use_enzyme && + DeclarationOnly == other.DeclarationOnly; + } + + // String operator for printing the node. + operator std::string() const { + std::string res = BaseFunctionName + "__order_" + + std::to_string(CurrentDerivativeOrder) + "__mode_"; + switch (Mode) { + case DiffMode::forward: + res += "forward"; + break; + case DiffMode::reverse: + res += "reverse"; + break; + case DiffMode::vector_forward_mode: + res += "vector_forward_mode"; + break; + case DiffMode::experimental_pushforward: + res += "pushforward"; + break; + case DiffMode::experimental_pullback: + res += "pullback"; + break; + } + return res; + } }; - using DiffSchedule = llvm::SmallVector; using DiffInterval = std::vector; struct RequestOptions { @@ -106,9 +145,9 @@ struct DiffRequest { /// DiffInterval& m_Interval; - /// The diff step-by-step plan for differentiation. + /// Graph to store the dependencies between different requests. /// - DiffSchedule& m_DiffPlans; + clad::DynamicGraph& m_DiffRequestGraph; /// If set it means that we need to find the called functions and /// add them for implicit diff. @@ -120,7 +159,8 @@ struct DiffRequest { public: DiffCollector(clang::DeclGroupRef DGR, DiffInterval& Interval, - DiffSchedule& plans, clang::Sema& S, RequestOptions& opts); + clad::DynamicGraph& requestGraph, clang::Sema& S, + RequestOptions& opts); bool VisitCallExpr(clang::CallExpr* E); private: @@ -128,4 +168,15 @@ struct DiffRequest { }; } +// Define the hash function for DiffRequest. +template <> struct std::hash { + std::size_t operator()(const clad::DiffRequest& DR) const { + // Use the function pointer as the hash of the DiffRequest, it + // is sufficient to break a reasonable number of collisions. + if (DR.Function->getPreviousDecl()) + return std::hash{}(DR.Function->getPreviousDecl()); + return std::hash{}(DR.Function); + } +}; + #endif diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h new file mode 100644 index 000000000..c06248960 --- /dev/null +++ b/include/clad/Differentiator/DynamicGraph.h @@ -0,0 +1,166 @@ +#ifndef CLAD_DYNAMICGRAPH_H +#define CLAD_DYNAMICGRAPH_H + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace clad { +template class DynamicGraph { +private: + // Storing nodes in the graph. The index of the node in the vector is used as + // a unique identifier for the node in the adjacency list. + std::vector nodes; + + // Store the nodes in the graph as an unordered map from the node to a boolean + // indicating whether the node is processed or not. The second element in the + // pair is the id of the node in the nodes vector. + std::unordered_map> nodeMap; + + // Store the adjacency list for the graph. The adjacency list is a map from + // a node to the set of nodes that it has an edge to. We use integers inside + // the set to avoid copying the nodes. + std::unordered_map> adjList; + + // Store the reverse adjacency list for the graph. The reverse adjacency list + // is a map from a node to the set of nodes that have an edge to it. We use + // integers inside the set to avoid copying the nodes. + std::unordered_map> revAdjList; + + // Set of source nodes in the graph. + std::set sources; + + // Store the id of the node being processed right now. + int currentId = -1; // -1 means no node is being processed. + + // Maintain a queue of nodes to be processed next. + std::queue toProcessQueue; + +public: + DynamicGraph() = default; + + // Add an edge from src to dest + void addEdge(const T& src, const T& dest) { + std::pair srcInfo = addNode(src); + std::pair destInfo = addNode(dest); + size_t srcId = srcInfo.second; + size_t destId = destInfo.second; + adjList[srcId].insert(destId); + revAdjList[destId].insert(srcId); + } + + // Add a node to the graph + std::pair addNode(const T& node, bool isSource = false) { + if (nodeMap.find(node) == nodeMap.end()) { + size_t id = nodes.size(); + nodes.push_back(node); + nodeMap[node] = {false, id}; // node is not processed yet. + adjList[id] = {}; + revAdjList[id] = {}; + if (isSource) { + sources.insert(id); + toProcessQueue.push(id); + } + } + return nodeMap[node]; + } + + // Adds the edge from the current node to the destination node. + void addEdgeToCurrentNode(const T& dest) { + if (currentId == -1) + return; + addEdge(nodes[currentId], dest); + } + + // Set the current node to the node with the given id. + void setCurrentProcessingNode(const T& node) { + if (nodeMap.find(node) != nodeMap.end()) + currentId = nodeMap[node].second; + } + + // Mark the current node as processed. + void markCurrentNodeProcessed() { + if (currentId != -1) { + nodeMap[nodes[currentId]].first = true; + for (size_t destId : adjList[currentId]) + if (!nodeMap[nodes[destId]].first) + toProcessQueue.push(destId); + } + currentId = -1; + } + + // Get the nodes in the graph. + std::vector getNodes() { return nodes; } + + // Check if two nodes are connected in the graph. + bool isConnected(const T& src, const T& dest) { + if (nodeMap.find(src) == nodeMap.end() || + nodeMap.find(dest) == nodeMap.end()) + return false; + size_t srcId = nodeMap[src].second; + size_t destId = nodeMap[dest].second; + return adjList[srcId].find(destId) != adjList[srcId].end(); + } + + // Print the graph in a human-readable format. + void print() { + // First print the nodes with their insertion order. + for (const T& node : nodes) { + std::pair nodeInfo = nodeMap[node]; + std::cout << (std::string)node << ": #" << nodeInfo.second; + if (sources.find(nodeInfo.second) != sources.end()) + std::cout << " (source)"; + if (nodeInfo.first) + std::cout << ", (done)\n"; + else + std::cout << ", (unprocessed)\n"; + } + // Then print the edges. + for (int i = 0; i < nodes.size(); i++) + for (size_t dest : adjList[i]) + std::cout << i << " -> " << dest << "\n"; + } + + // Topological sort of the directed graph. If the graph is not a DAG, the + // result will be a partial order. Use a recursive dfs heler function to + // implement the topological sort. If a->b, then a will come before b in the + // topological sort. In reverseOrder mode, the result will be in reverse + // topological order, i.e a->b, then b will come before a in the result. + std::vector topologicalSort(bool reverseOrder = false) { + std::vector res; + std::unordered_set visited; + + std::function dfs = [&](size_t node) -> void { + visited.insert(node); + for (size_t dest : adjList[node]) + if (visited.find(dest) == visited.end()) + dfs(dest); + res.push_back(nodes[node]); + }; + for (size_t source : sources) + if (visited.find(source) == visited.end()) + dfs(source); + + if (reverseOrder) + return res; + std::reverse(res.begin(), res.end()); + return res; + } + + // Get the next to process node from the queue of nodes to be processed. + T getNextToProcessNode() { + if (toProcessQueue.empty()) + return T(); + size_t nextId = toProcessQueue.front(); + toProcessQueue.pop(); + return nodes[nextId]; + } +}; +} // end namespace clad + +#endif // CLAD_DYNAMICGRAPH_H \ No newline at end of file diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index acad4b3a1..01e4602c0 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -1168,8 +1168,8 @@ StmtDiff BaseForwardModeVisitor::VisitCallExpr(const CallExpr* CE) { // into the queue. pushforwardFnRequest.DeclarationOnly = false; pushforwardFnRequest.DerivedFDPrototype = pushforwardFD; - plugin::AddRequestToSchedule(m_CladPlugin, pushforwardFnRequest); } + m_Builder.AddEdgeToGraph(std::move(pushforwardFnRequest)); if (pushforwardFD) { if (baseDiff.getExpr()) { diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 3a6c57ec0..35a252c60 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -37,8 +37,10 @@ using namespace clang; namespace clad { DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, - const DerivedFnCollector& DFC) + const DerivedFnCollector& DFC, + clad::DynamicGraph& G) : m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_DFC(DFC), + m_DiffRequestGraph(G), m_NodeCloner(new utils::StmtClone(m_Sema, m_Context)), m_BuiltinDerivativesNSD(nullptr), m_NumericalDiffNSD(nullptr) {} @@ -292,15 +294,25 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { assert(FD && "Must not be null."); // If FD is only a declaration, try to find its definition. if (!FD->getDefinition()) { - if (request.VerboseDiags) - diag(DiagnosticsEngine::Error, - request.CallContext ? request.CallContext->getBeginLoc() : noLoc, - "attempted differentiation of function '%0', which does not have a " - "definition", { FD->getNameAsString() }); - return {}; + // If only declaration is requested, allow this for non + // pullback/pushforward modes. For ex, this is required for Hessian - + // where we have forward mode followed by reverse mode, but we only need + // the declaration of the forward mode initially. + if (!request.DeclarationOnly || + IsPullbackOrPushforwardMode(request.Mode)) { + if (request.VerboseDiags) + diag(DiagnosticsEngine::Error, + request.CallContext ? request.CallContext->getBeginLoc() : noLoc, + "attempted differentiation of function '%0', which does not " + "have a " + "definition", + {FD->getNameAsString()}); + return {}; + } } - FD = FD->getDefinition(); + if (!request.DeclarationOnly) + FD = FD->getDefinition(); // check if the function is non-differentiable. if (clad::utils::hasNonDifferentiableAttribute(FD)) { @@ -388,4 +400,13 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) { return DFI.DerivedFn(); return nullptr; } + + void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& request) { + m_DiffRequestGraph.addEdgeToCurrentNode(request); + } + + void DerivativeBuilder::AddEdgeToGraph(const DiffRequest& from, + const DiffRequest& to) { + m_DiffRequestGraph.addEdge(from, to); + } }// end namespace clad diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index e4d76d0d4..56f7ef28b 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -232,10 +232,10 @@ namespace clad { } DiffCollector::DiffCollector(DeclGroupRef DGR, DiffInterval& Interval, - DiffSchedule& plans, clang::Sema& S, - RequestOptions& opts) - : m_Interval(Interval), m_DiffPlans(plans), m_TopMostFD(nullptr), - m_Sema(S), m_Options(opts) { + clad::DynamicGraph& requestGraph, + clang::Sema& S, RequestOptions& opts) + : m_Interval(Interval), m_DiffRequestGraph(requestGraph), + m_TopMostFD(nullptr), m_Sema(S), m_Options(opts) { if (Interval.empty()) return; @@ -300,7 +300,8 @@ namespace clad { auto& C = semaRef.getASTContext(); const Expr* diffArgs = Args; const FunctionDecl* FD = Function; - FD = FD->getDefinition(); + if (!DeclarationOnly) + FD = FD->getDefinition(); if (!diffArgs || !FD) { return; } @@ -684,7 +685,7 @@ namespace clad { llvm::SaveAndRestore saveTopMost = m_TopMostFD; m_TopMostFD = FD; TraverseDecl(derivedFD); - m_DiffPlans.push_back(std::move(request)); + m_DiffRequestGraph.addNode(request, true /*isSource*/); } /*else if (m_TopMostFD) { // If another function is called inside differentiated function, diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index 7ddc3f730..e07cfcdbe 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -49,8 +49,9 @@ namespace clad { /// Derives the function w.r.t both forward and reverse mode and returns the /// FunctionDecl obtained from reverse mode differentiation - static FunctionDecl* DeriveUsingForwardAndReverseMode(Sema& SemaRef, - clad::plugin::CladPlugin& CP, DiffRequest IndependentArgRequest, + static FunctionDecl* DeriveUsingForwardAndReverseMode( + Sema& SemaRef, clad::plugin::CladPlugin& CP, + clad::DerivativeBuilder& Builder, DiffRequest IndependentArgRequest, const Expr* ForwardModeArgs, const Expr* ReverseModeArgs) { // Derives function once in forward mode w.r.t to ForwardModeArgs IndependentArgRequest.Args = ForwardModeArgs; @@ -60,24 +61,29 @@ namespace clad { // FIXME: Find a way to do this without accessing plugin namespace functions FunctionDecl* firstDerivative = plugin::ProcessDiffRequest(CP, IndependentArgRequest); + Builder.AddEdgeToGraph(std::move(IndependentArgRequest)); // Further derives function w.r.t to ReverseModeArgs - IndependentArgRequest.Mode = DiffMode::reverse; - IndependentArgRequest.Function = firstDerivative; - IndependentArgRequest.Args = ReverseModeArgs; - IndependentArgRequest.BaseFunctionName = firstDerivative->getNameAsString(); - IndependentArgRequest.UpdateDiffParamsInfo(SemaRef); + DiffRequest ReverseModeRequest{}; + ReverseModeRequest.Mode = DiffMode::reverse; + ReverseModeRequest.Function = firstDerivative; + ReverseModeRequest.Args = ReverseModeArgs; + ReverseModeRequest.BaseFunctionName = firstDerivative->getNameAsString(); + ReverseModeRequest.UpdateDiffParamsInfo(SemaRef); - // Derive declaration of the the forward mode derivative. - IndependentArgRequest.DeclarationOnly = true; FunctionDecl* secondDerivative = - plugin::ProcessDiffRequest(CP, IndependentArgRequest); - - // Add the request to derive the definition of the forward mode derivative - // to the schedule. - IndependentArgRequest.DeclarationOnly = false; - IndependentArgRequest.DerivedFDPrototype = secondDerivative; - plugin::AddRequestToSchedule(CP, IndependentArgRequest); + Builder.FindDerivedFunction(ReverseModeRequest); + if (!secondDerivative) { + // Derive declaration of the the reverse mode derivative. + ReverseModeRequest.DeclarationOnly = true; + secondDerivative = plugin::ProcessDiffRequest(CP, ReverseModeRequest); + + // Add the request to derive the definition of the reverse mode derivative + // to the schedule. + ReverseModeRequest.DeclarationOnly = false; + ReverseModeRequest.DerivedFDPrototype = secondDerivative; + } + Builder.AddEdgeToGraph(std::move(ReverseModeRequest)); return secondDerivative; } @@ -177,9 +183,9 @@ namespace clad { PVD->getNameAsString() + "[" + std::to_string(i) + "]"; auto ForwardModeIASL = CreateStringLiteral(m_Context, independentArgString); - auto DFD = - DeriveUsingForwardAndReverseMode(m_Sema, m_CladPlugin, request, - ForwardModeIASL, request.Args); + auto DFD = DeriveUsingForwardAndReverseMode( + m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, + request.Args); secondDerivativeColumns.push_back(DFD); } @@ -190,9 +196,9 @@ namespace clad { // then in reverse mode w.r.t to all requested args auto ForwardModeIASL = CreateStringLiteral(m_Context, PVD->getNameAsString()); - auto DFD = DeriveUsingForwardAndReverseMode(m_Sema, m_CladPlugin, - request, ForwardModeIASL, - request.Args); + auto DFD = DeriveUsingForwardAndReverseMode( + m_Sema, m_CladPlugin, m_Builder, request, ForwardModeIASL, + request.Args); secondDerivativeColumns.push_back(DFD); } } diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index d1c69c64e..4377d8704 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1784,6 +1784,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActBeforeDifferentiatingCallExpr( pullbackCallArgs, PreCallStmts, dfdx()); + // Overloaded derivative was not found, request the CladPlugin to // derive the called function. DiffRequest pullbackRequest{}; @@ -1812,7 +1813,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // function. pullbackRequest.DeclarationOnly = false; pullbackRequest.DerivedFDPrototype = pullbackFD; - plugin::AddRequestToSchedule(m_CladPlugin, pullbackRequest); } else { // FIXME: Error estimation currently uses singleton objects - // m_ErrorEstHandler and m_EstModel, which is cleared after each @@ -1822,6 +1822,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); } } + m_Builder.AddEdgeToGraph(std::move(pullbackRequest)); // Clad failed to derive it. // FIXME: Add support for reference arguments to the numerical diff. If @@ -1923,8 +1924,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // function. calleeFnForwPassReq.DeclarationOnly = false; calleeFnForwPassReq.DerivedFDPrototype = calleeFnForwPassFD; - plugin::AddRequestToSchedule(m_CladPlugin, calleeFnForwPassReq); } + m_Builder.AddEdgeToGraph(calleeFnForwPassReq); assert(calleeFnForwPassFD && "Clad failed to generate callee function forward pass function"); diff --git a/test/Misc/DynamicGraph.C b/test/Misc/DynamicGraph.C new file mode 100644 index 000000000..a11ee9607 --- /dev/null +++ b/test/Misc/DynamicGraph.C @@ -0,0 +1,70 @@ +// RUN: %cladclang %s -I%S/../../include -oGraph.out 2>&1 +// RUN: ./Graph.out | FileCheck -check-prefix=CHECK-EXEC %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/DynamicGraph.h" +#include +#include + +// Custom type for representing nodes in the graph. +struct Node { + std::string name; + int id; + + Node(std::string name, int id) : name(name), id(id) {} + + bool operator==(const Node& other) const { + return name == other.name && id == other.id; + } + + // string operator for printing the node. + operator std::string() const { + return name + std::to_string(id); + } +}; + +// Specialize std::hash for the Node type. +template<> +struct std::hash { + std::size_t operator()(const Node& n) const { + return std::hash()(n.name) ^ std::hash()(n.id); + } +}; + +int main () { + clad::DynamicGraph G; + for (int i = 0; i < 6; i++) { + Node n("node", i); + if (i == 0) { + G.addNode(n, true/*isSource*/); + } + Node m("node", i + 1); + G.addEdge(n, m); + } + std::vector nodes = G.getNodes(); + std::cout << "Nodes in the graph: " << nodes.size() << "\n"; + // CHECK-EXEC: Nodes in the graph: 7 + + // edge from node 0 to node 3 and node 4 to node 0. + G.addEdge(nodes[0], nodes[3]); + G.addEdge(nodes[4], nodes[0]); + + G.print(); + // CHECK-EXEC: node0: #0 (source), (unprocessed) + // CHECK-EXEC-NEXT: node1: #1, (unprocessed) + // CHECK-EXEC-NEXT: node2: #2, (unprocessed) + // CHECK-EXEC-NEXT: node3: #3, (unprocessed) + // CHECK-EXEC-NEXT: node4: #4, (unprocessed) + // CHECK-EXEC-NEXT: node5: #5, (unprocessed) + // CHECK-EXEC-NEXT: node6: #6, (unprocessed) + // CHECK-EXEC-NEXT: 0 -> 1 + // CHECK-EXEC-NEXT: 0 -> 3 + // CHECK-EXEC-NEXT: 1 -> 2 + // CHECK-EXEC-NEXT: 2 -> 3 + // CHECK-EXEC-NEXT: 3 -> 4 + // CHECK-EXEC-NEXT: 4 -> 0 + // CHECK-EXEC-NEXT: 4 -> 5 + // CHECK-EXEC-NEXT: 5 -> 6 + return 0; +} + diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index 3f2ba506a..d13ce528c 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -122,11 +122,13 @@ namespace clad { Sema& S = m_CI.getSema(); if (!m_DerivativeBuilder) - m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this, m_DFC)); + m_DerivativeBuilder.reset( + new DerivativeBuilder(S, *this, m_DFC, m_DiffRequestGraph)); RequestOptions opts{}; SetRequestOptions(opts); - DiffCollector collector(DGR, CladEnabledRange, m_DiffSchedule, S, opts); + DiffCollector collector(DGR, CladEnabledRange, m_DiffRequestGraph, S, + opts); } FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { @@ -397,13 +399,12 @@ namespace clad { Sema::GlobalEagerInstantiationScope GlobalInstantiations(S, Enabled); Sema::LocalEagerInstantiationScope LocalInstantiations(S); - // Use index based loop to avoid iterator invalidation as - // ProcessDiffRequest might add more requests to m_DiffSchedule. - for (size_t i = 0; i < m_DiffSchedule.size(); ++i) { - // make a copy of the request to avoid invalidating the reference - // when ProcessDiffRequest adds more requests to m_DiffSchedule. - DiffRequest request = m_DiffSchedule[i]; + DiffRequest request = m_DiffRequestGraph.getNextToProcessNode(); + while (request.Function != nullptr) { + m_DiffRequestGraph.setCurrentProcessingNode(request); ProcessDiffRequest(request); + m_DiffRequestGraph.markCurrentNodeProcessed(); + request = m_DiffRequestGraph.getNextToProcessNode(); } // Put the TUScope in a consistent state after clad is done. S.TUScope = nullptr; diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index c69016681..7b570afd8 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -102,7 +102,7 @@ class CladTimerGroup { bool m_HasRuntime = false; CladTimerGroup m_CTG; DerivedFnCollector m_DFC; - DiffSchedule m_DiffSchedule; + DynamicGraph m_DiffRequestGraph; enum class CallKind { HandleCXXStaticMemberVarInstantiation, HandleTopLevelDecl, @@ -237,10 +237,6 @@ class CladTimerGroup { m_Multiplexer->ForgetSema(); } - void AddRequestToSchedule(const DiffRequest& request) { - m_DiffSchedule.push_back(request); - } - // FIXME: We should hide ProcessDiffRequest when we implement proper // handling of the differentiation plans. clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); @@ -267,10 +263,6 @@ class CladTimerGroup { return P.ProcessDiffRequest(request); } - void AddRequestToSchedule(CladPlugin& P, const DiffRequest& request) { - P.AddRequestToSchedule(request); - } - template class Action : public clang::PluginASTAction { private: