diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index d51280bb1..0372e3148 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -134,7 +134,9 @@ namespace clad { } public: - DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P); + /// Graph for the differentiation requests. + clad::Graph& m_DiffRequestGraph; + DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, clad::Graph& DRG); ~DerivativeBuilder(); /// Reset the model use for error estimation (if any). /// \param[in] estModel The error estimation model, can be either diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index 187810364..5a7d97253 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -2,6 +2,7 @@ #define CLAD_DIFF_PLANNER_H #include "clad/Differentiator/DiffMode.h" +#include "clad/Differentiator/Graph.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" #include "clang/AST/RecursiveASTVisitor.h" #include "llvm/ADT/SmallSet.h" @@ -90,6 +91,49 @@ struct DiffRequest { /// 3) If no argument is provided, a default argument is used. The /// function will be differentiated w.r.t. to its every parameter. void UpdateDiffParamsInfo(clang::Sema& semaRef); + + /// Define the == operator for DiffRequest. + bool operator==(const DiffRequest& other) const { + // either function match or previous declaration match + return (Function == other.Function || + Function->getPreviousDecl() == other.Function || + Function == other.Function->getPreviousDecl()) && + BaseFunctionName == other.BaseFunctionName && + CurrentDerivativeOrder == other.CurrentDerivativeOrder && + RequestedDerivativeOrder == other.RequestedDerivativeOrder && + CallContext == other.CallContext && + Args == other.Args && + Mode == other.Mode && + CallUpdateRequired == other.CallUpdateRequired && + EnableTBRAnalysis == other.EnableTBRAnalysis && + DVI == other.DVI && + use_enzyme == other.use_enzyme && + DerivedFDPrototype == other.DerivedFDPrototype && + DeclarationOnly == other.DeclarationOnly; + } + + // String operator for printing the node. + operator std::string() const { + std::string res = BaseFunctionName + "__order_" + std::to_string(CurrentDerivativeOrder) + "__mode_"; + switch (Mode) { + case DiffMode::forward: + res += "forward"; + break; + case DiffMode::reverse: + res += "reverse"; + break; + case DiffMode::vector_forward_mode: + res += "vector_forward_mode"; + break; + case DiffMode::experimental_pushforward: + res += "pushforward"; + break; + case DiffMode::experimental_pullback: + res += "pullback"; + break; + } + return res; + } }; using DiffSchedule = llvm::SmallVector; @@ -110,17 +154,25 @@ struct DiffRequest { /// DiffSchedule& m_DiffPlans; + /// Graph to store the dependencies between different requests. + /// + clad::Graph& m_DiffRequestGraph; + /// If set it means that we need to find the called functions and /// add them for implicit diff. /// const clang::FunctionDecl* m_TopMostFD = nullptr; + + /// The parent request for the current request. + DiffRequest m_ParentRequest; + clang::Sema& m_Sema; RequestOptions& m_Options; public: DiffCollector(clang::DeclGroupRef DGR, DiffInterval& Interval, - DiffSchedule& plans, clang::Sema& S, RequestOptions& opts); + DiffSchedule& plans, clad::Graph& requestGraph, clang::Sema& S, RequestOptions& opts); bool VisitCallExpr(clang::CallExpr* E); private: @@ -128,4 +180,16 @@ struct DiffRequest { }; } +// Define the hash function for DiffRequest. +template<> +struct std::hash { + std::size_t operator()(const clad::DiffRequest& DR) const { + // Use the function pointer as the hash of the DiffRequest, it + // is sufficient to break a reasonable number of collisions. + if (DR.Function->getPreviousDecl()) + return std::hash{}(DR.Function->getPreviousDecl()); + return std::hash{}(DR.Function); + } +}; + #endif diff --git a/include/clad/Differentiator/Graph.h b/include/clad/Differentiator/Graph.h new file mode 100644 index 000000000..e4f165258 --- /dev/null +++ b/include/clad/Differentiator/Graph.h @@ -0,0 +1,178 @@ +#ifndef CLAD_GRAPH_H +#define CLAD_GRAPH_H + +#include +#include +#include +#include +#include +#include +#include + +namespace clad{ +template +class Graph { +private: + + // Storing nodes in the graph. The index of the node in the vector is used as + // a unique identifier for the node in the adjacency list. + std::vector nodes; + + // Store the nodes in the graph as an unordered map from the node to a boolean + // indicating whether the node is still present in the graph, along with an + // integer indicating the insertion order of the node. + std::unordered_map> nodeMap; + + // Store the adjacency list for the graph. The adjacency list is a map from + // a node to the set of nodes that it has an edge to. We use integers inside + // the set to avoid copying the nodes. + std::unordered_map> adjList; + + // Store the reverse adjacency list for the graph. The reverse adjacency list + // is a map from a node to the set of nodes that have an edge to it. We use + // integers inside the set to avoid copying the nodes. + std::unordered_map> revAdjList; + + // Set of source nodes in the graph. + std::set sources; + +public: + Graph() = default; + + // Add an edge from src to dest + void addEdge(const T& src, const T& dest) { + addNode(src); + addNode(dest); + size_t srcId = nodeMap[src].second; + size_t destId = nodeMap[dest].second; + adjList[srcId].insert(destId); + revAdjList[destId].insert(srcId); + } + + // Add a node to the graph + void addNode(const T& node, bool isSource = false) { + if (nodeMap.find(node) == nodeMap.end()) { + size_t id = nodes.size(); + nodes.push_back(node); + nodeMap[node] = {true, id}; + adjList[id] = {}; + revAdjList[id] = {}; + if (isSource) + sources.insert(id); + } else if (nodeMap[node].first == false) { + nodeMap[node].first = true; + } + } + + // Remove a node from the graph. This will also remove all edges to and from + // the node. + void removeNode(const T& node) { + if (nodeMap.find(node) != nodeMap.end()) { + size_t id = nodeMap[node].second; + nodeMap[node].first = false; + for (size_t destId : adjList[id]) { + revAdjList[destId].erase(id); + } + adjList[id].clear(); + for (size_t srcId : revAdjList[id]) { + adjList[srcId].erase(id); + } + revAdjList[id].clear(); + } + } + + // Get nodeMap in the graph in the order they were inserted. + std::vector getNodes() { + std::vector res; + // iterate over the nodes vector and add the nodes that are still present in + // the graph. + for (const T& node : nodes) { + if (nodeMap[node].first) + res.push_back(node); + } + return res; + } + + // Check if two nodes are connected in the graph. + bool isConnected(const T& src, const T& dest) { + if (nodeMap.find(src) == nodeMap.end() || nodeMap.find(dest) == nodeMap.end()) + return false; + size_t srcId = nodeMap[src].second; + size_t destId = nodeMap[dest].second; + return adjList[srcId].find(destId) != adjList[srcId].end(); + } + + // Print the graph in a human-readable format. + void print() { + // First print the nodes with their insertion order. + for (const T& node : nodes) { + std::pair nodeInfo = nodeMap[node]; + if (nodeInfo.first) { + std::cout << (std::string)node << ": #" << nodeInfo.second; + if (sources.find(nodeInfo.second) != sources.end()) + std::cout << " (source)"; + std::cout << "\n"; + } + } + // Then print the edges. + for (int i = 0; i < nodes.size(); i++) { + if (!nodeMap[nodes[i]].first) + continue; + for (size_t dest : adjList[i]) + std::cout << i << " -> " << dest << "\n"; + } + } + + // Remove non reachable nodes from the sources. + void removeNonReachable() { + std::unordered_set visited; + std::vector stack; + for (size_t source : sources) { + stack.push_back(source); + visited.insert(source); + } + while (!stack.empty()) { + size_t node = stack.back(); + stack.pop_back(); + for (size_t dest : adjList[node]) { + if (visited.find(dest) == visited.end()) { + stack.push_back(dest); + visited.insert(dest); + } + } + } + for (auto it = nodeMap.begin(); it != nodeMap.end(); ++it) { + if (it->second.first && visited.find(it->second.second) == visited.end()) + removeNode(it->first); + } + } + + // Topological sort of the directed graph. If the graph is not a DAG, the + // result will be a partial order. Use a recursive dfs heler function to + // implement the topological sort. If a->b, then a will come before b in the + // topological sort. In reverseOrder mode, the result will be in reverse + // topological order, i.e a->b, then b will come before a in the result. + std::vector topologicalSort(bool reverseOrder = false) { + std::vector res; + std::unordered_set visited; + + std::function dfs = [&](size_t node) -> void { + visited.insert(node); + for (size_t dest : adjList[node]) + if (visited.find(dest) == visited.end()) + dfs(dest); + res.push_back(nodes[node]); + }; + for (size_t source : sources) + if (visited.find(source) == visited.end()) + dfs(source); + + if (reverseOrder) + return res; + std::reverse(res.begin(), res.end()); + return res; + } +}; +} // end namespace clad + +#endif // CLAD_GRAPH_H \ No newline at end of file diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index 96db90302..0ea563299 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -36,10 +36,10 @@ using namespace clang; namespace clad { - DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P) + DerivativeBuilder::DerivativeBuilder(clang::Sema& S, plugin::CladPlugin& P, clad::Graph& G) : m_Sema(S), m_CladPlugin(P), m_Context(S.getASTContext()), m_NodeCloner(new utils::StmtClone(m_Sema, m_Context)), - m_BuiltinDerivativesNSD(nullptr), m_NumericalDiffNSD(nullptr) {} + m_BuiltinDerivativesNSD(nullptr), m_NumericalDiffNSD(nullptr), m_DiffRequestGraph(G) {} DerivativeBuilder::~DerivativeBuilder() {} diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index e4d76d0d4..a210232f1 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -232,9 +232,9 @@ namespace clad { } DiffCollector::DiffCollector(DeclGroupRef DGR, DiffInterval& Interval, - DiffSchedule& plans, clang::Sema& S, + DiffSchedule& plans, clad::Graph& requestGraph, clang::Sema& S, RequestOptions& opts) - : m_Interval(Interval), m_DiffPlans(plans), m_TopMostFD(nullptr), + : m_Interval(Interval), m_DiffPlans(plans), m_DiffRequestGraph(requestGraph), m_TopMostFD(nullptr), m_Sema(S), m_Options(opts) { if (Interval.empty()) @@ -564,7 +564,7 @@ namespace clad { // In that case we should ask the enclosing ast nodes for a source // location and check if it is within range. SourceLocation endLoc = E->getEndLoc(); - if (endLoc.isInvalid() || !isInInterval(endLoc)) + if (endLoc.isInvalid()/* || !isInInterval(endLoc)*/) return true; FunctionDecl* FD = E->getDirectCallee(); @@ -671,6 +671,10 @@ namespace clad { request.VerboseDiags = true; request.Args = E->getArg(1); auto derivedFD = cast(DRE->getDecl()); + if (derivedFD->isDefined()) { + llvm :: outs () << "Function is defined, " << derivedFD->getNameAsString() << "\n"; + derivedFD = derivedFD->getDefinition(); + } request.Function = derivedFD; request.BaseFunctionName = utils::ComputeEffectiveFnName(request.Function); @@ -682,14 +686,88 @@ namespace clad { assert(!m_TopMostFD && "nested clad::differentiate/gradient are not yet supported"); llvm::SaveAndRestore saveTopMost = m_TopMostFD; + llvm::SaveAndRestore saveRequest = m_ParentRequest; m_TopMostFD = FD; + m_ParentRequest = request; + m_DiffRequestGraph.addNode(request, true /*isSource*/); TraverseDecl(derivedFD); m_DiffPlans.push_back(std::move(request)); } - /*else if (m_TopMostFD) { - // If another function is called inside differentiated function, - // this will be handled by Forward/ReverseModeVisitor::Derive. - }*/ + else if (m_TopMostFD) { + // Check if the function call is marked as non-differentiable. + if (clad::utils::hasNonDifferentiableAttribute(E)) + return true; + + // If the function has no args and is not a member function call then we + // assume that it is not related to independent variables and does not + // contribute to gradient. + if ((FD->getNumParams() == 0U) && !isa(E) && + !isa(E)) + return true; + + // Check if function call has all args evaluatable at compile time. + if (!isa(E) && !isa(E)) { + bool allArgsAreConstantLiterals = true; + for (const Expr* arg : E->arguments()) { + // if it's of type MaterializeTemporaryExpr, then check its + // subexpression. + if (const auto* MTE = dyn_cast(arg)) + arg = clad_compat::GetSubExpr(MTE)->IgnoreImpCasts(); + if (!arg->isEvaluatable(m_Sema.getASTContext())) { + allArgsAreConstantLiterals = false; + break; + } + } + if (allArgsAreConstantLiterals) + return true; + } + + // Check if the function is a memory allocation or deallocation function. + // These either have pushforwards or special handling. + if (utils::IsMemoryFunction(FD) || utils::IsMemoryDeallocationFunction(FD)) + return true; + + // Means another function is called inside the differentiated function. + // We need to check if this function is a candidate for differentiation. + // If it is, we need to add a dependency between the parent request and + // this request. + DiffRequest request{}; + switch (m_ParentRequest.Mode) { + case DiffMode::forward: + case DiffMode::experimental_pushforward: + request.Mode = DiffMode::experimental_pushforward; + request.Function = FD; + request.BaseFunctionName = clad::utils::ComputeEffectiveFnName(FD); + request.VerboseDiags = false; + break; + case DiffMode::reverse: + case DiffMode::experimental_pullback: + request.Mode = DiffMode::experimental_pullback; + request.Function = FD; + request.BaseFunctionName = clad::utils::ComputeEffectiveFnName(FD); + request.VerboseDiags = false; + request.EnableTBRAnalysis = m_ParentRequest.EnableTBRAnalysis; + break; + case DiffMode::vector_forward_mode: + case DiffMode::experimental_vector_pushforward: + request.Mode = DiffMode::experimental_vector_pushforward; + request.Function = FD; + request.BaseFunctionName = clad::utils::ComputeEffectiveFnName(FD); + request.VerboseDiags = false; + break; + } + if (!m_DiffRequestGraph.isConnected(m_ParentRequest, request)) { + m_DiffRequestGraph.addEdge(m_ParentRequest, request); + llvm::SaveAndRestore saveRequest = m_ParentRequest; + m_ParentRequest = request; + auto derivedFD = cast(FD); + if (derivedFD->isDefined()) { + llvm :: outs () << "Function is defined, " << derivedFD->getNameAsString() << "\n"; + derivedFD = derivedFD->getDefinition(); + } + TraverseDecl(derivedFD); + } + } return true; // return false to abort visiting. } } // end namespace diff --git a/test/Misc/Graph.C b/test/Misc/Graph.C new file mode 100644 index 000000000..521fb209c --- /dev/null +++ b/test/Misc/Graph.C @@ -0,0 +1,68 @@ +// RUN: %cladclang %s -I%S/../../include -oGraph.out 2>&1 +// RUN: ./Graph.out | FileCheck -check-prefix=CHECK-EXEC %s +// CHECK-NOT: {{.*error|warning|note:.*}} + +#include "clad/Differentiator/Graph.h" +#include +#include + +// Custom type for representing nodes in the graph. +struct Node { + std::string name; + int id; + + Node(std::string name, int id) : name(name), id(id) {} + + bool operator==(const Node& other) const { + return name == other.name && id == other.id; + } + + // string operator for printing the node. + operator std::string() const { + return name + std::to_string(id); + } +}; + +// Specialize std::hash for the Node type. +template<> +struct std::hash { + std::size_t operator()(const Node& n) const { + return std::hash()(n.name) ^ std::hash()(n.id); + } +}; + +int main () { + clad::Graph G; + for (int i = 0; i < 6; i++) { + Node n("node", i); + if (i == 0) { + G.addNode(n, true/*isSource*/); + } + Node m("node", i + 1); + G.addEdge(n, m); + } + std::vector nodes = G.getNodes(); + std::cout << "Nodes in the graph: " << nodes.size() << "\n"; + // CHECK-EXEC: Nodes in the graph: 7 + + // edge from node 0 to node 3 and node 4 to node 0. + G.addEdge(nodes[0], nodes[3]); + G.addEdge(nodes[4], nodes[0]); + std::vector nodes2 = G.getNodes(); + std::cout << "Nodes in the graph: " << nodes2.size() << "\n"; + // CHECK-EXEC: Nodes in the graph: 7 + + // remove node 4 + G.removeNode(nodes[4]); + G.removeNonReachable(); // removes node 5 and 6 + G.print(); + // CHECK-EXEC: node0: #0 (source) + // CHECK-EXEC-NEXT: node1: #1 + // CHECK-EXEC-NEXT: node2: #2 + // CHECK-EXEC-NEXT: node3: #3 + // CHECK-EXEC-NEXT: 0 -> 1 + // CHECK-EXEC-NEXT: 0 -> 3 + // CHECK-EXEC-NEXT: 1 -> 2 + // CHECK-EXEC-NEXT: 2 -> 3 +} + diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index f4c4c1c32..5b0bf35a5 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -122,11 +122,11 @@ namespace clad { Sema& S = m_CI.getSema(); if (!m_DerivativeBuilder) - m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this)); + m_DerivativeBuilder.reset(new DerivativeBuilder(S, *this, m_DiffRequestGraph)); RequestOptions opts{}; SetRequestOptions(opts); - DiffCollector collector(DGR, CladEnabledRange, m_DiffSchedule, S, opts); + DiffCollector collector(DGR, CladEnabledRange, m_DiffSchedule, m_DiffRequestGraph, S, opts); } FunctionDecl* CladPlugin::ProcessDiffRequest(DiffRequest& request) { @@ -291,6 +291,13 @@ namespace clad { return nullptr; } + void CladPlugin::ProcessHandleTopLevelDeclCalls() { + for (auto DelayedCall : m_DelayedCalls) { + if (DelayedCall.m_Kind == CallKind::HandleTopLevelDecl) + HandleTopLevelDeclForClad(DelayedCall.m_DGR); + } + } + void CladPlugin::SendToMultiplexer() { for (auto DelayedCall : m_DelayedCalls) { DeclGroupRef& D = DelayedCall.m_DGR; @@ -390,6 +397,9 @@ namespace clad { } void CladPlugin::HandleTranslationUnit(ASTContext& C) { + // Collect the requested derivatives graph. + ProcessHandleTopLevelDeclCalls(); + Sema& S = m_CI.getSema(); // Restore the TUScope that became a 0 in Sema::ActOnEndOfTranslationUnit. S.TUScope = m_StoredTUScope; @@ -405,6 +415,11 @@ namespace clad { DiffRequest request = m_DiffSchedule[i]; ProcessDiffRequest(request); } + m_DiffRequestGraph.print(); + std::vector nodes = m_DiffRequestGraph.topologicalSort(true/*reverseOrder*/); + for (auto& request : nodes) { + llvm :: outs () << (std::string)(request) << "\n"; + } // Put the TUScope in a consistent state after clad is done. S.TUScope = nullptr; // Force emission of the produced pending template instantiations. diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index b7f6075c2..c032d093b 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -128,6 +128,7 @@ namespace clad { CladTimerGroup m_CTG; DerivedFnCollector m_DFC; DiffSchedule m_DiffSchedule; + Graph m_DiffRequestGraph; enum class CallKind { HandleCXXStaticMemberVarInstantiation, HandleTopLevelDecl, @@ -192,7 +193,6 @@ namespace clad { AppendDelayed({CallKind::HandleCXXStaticMemberVarInstantiation, D}); } bool HandleTopLevelDecl(clang::DeclGroupRef D) override { - HandleTopLevelDeclForClad(D); AppendDelayed({CallKind::HandleTopLevelDecl, D}); return true; // happyness, continue parsing } @@ -266,6 +266,8 @@ namespace clad { m_DiffSchedule.push_back(request); } + clad::Graph& GetDiffRequestGraph() { return m_DiffRequestGraph; } + // FIXME: We should hide ProcessDiffRequest when we implement proper // handling of the differentiation plans. clang::FunctionDecl* ProcessDiffRequest(DiffRequest& request); @@ -275,6 +277,7 @@ namespace clad { assert(!m_HasMultiplexerProcessedDelayedCalls); m_DelayedCalls.push_back(DCI); } + void ProcessHandleTopLevelDeclCalls(); void SendToMultiplexer(); bool CheckBuiltins(); void SetRequestOptions(RequestOptions& opts) const;