diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 65211f24c..4a089c095 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -11,6 +11,7 @@ #include "ArrayRef.h" #include "BuiltinDerivatives.h" #include "CladConfig.h" +#include "DynamicGraph.h" #include "FunctionTraits.h" #include "Matrix.h" #include "NumericalDiff.h" diff --git a/include/clad/Differentiator/DynamicGraph.h b/include/clad/Differentiator/DynamicGraph.h index 86db4e6de..a21ebfabc 100644 --- a/include/clad/Differentiator/DynamicGraph.h +++ b/include/clad/Differentiator/DynamicGraph.h @@ -13,33 +13,35 @@ namespace clad { template class DynamicGraph { private: - // Storing nodes in the graph. The index of the node in the vector is used as - // a unique identifier for the node in the adjacency list. + /// Storing nodes in the graph. The index of the node in the vector is used as + /// a unique identifier for the node in the adjacency list. std::vector m_nodes; - // Store the nodes in the graph as an unordered map from the node to a boolean - // indicating whether the node is processed or not. The second element in the - // pair is the id of the node in the nodes vector. + /// Store the nodes in the graph as an unordered map from the node to a + /// boolean indicating whether the node is processed or not. The second + /// element in the pair is the id of the node in the nodes vector. std::unordered_map> m_nodeMap; - // Store the adjacency list for the graph. The adjacency list is a map from - // a node to the set of nodes that it has an edge to. We use integers inside - // the set to avoid copying the nodes. + /// Store the adjacency list for the graph. The adjacency list is a map from + /// a node to the set of nodes that it has an edge to. We use integers inside + /// the set to avoid copying the nodes. std::unordered_map> m_adjList; - // Set of source nodes in the graph. + /// Set of source nodes in the graph. std::set m_sources; - // Store the id of the node being processed right now. + /// Store the id of the node being processed right now. int m_currentId = -1; // -1 means no node is being processed. - // Maintain a queue of nodes to be processed next. + /// Maintain a queue of nodes to be processed next. std::queue m_toProcessQueue; public: DynamicGraph() = default; - // Add an edge from src to dest + /// @brief Add an edge from the source node to the destination node. + /// @param src + /// @param dest void addEdge(const T& src, const T& dest) { std::pair srcInfo = addNode(src); std::pair destInfo = addNode(dest); @@ -48,7 +50,13 @@ template class DynamicGraph { m_adjList[srcId].insert(destId); } - // Add a node to the graph + /// @brief Add a node to the graph. If the node is already present, return the + /// id of the node in the graph. If the node is a source node, add it to the + /// queue of nodes to be processed. + /// @param node + /// @param isSource + /// @return A pair of a boolean indicating whether the node is already + /// processed and the id of the node in the graph. std::pair addNode(const T& node, bool isSource = false) { if (m_nodeMap.find(node) == m_nodeMap.end()) { size_t id = m_nodes.size(); @@ -63,19 +71,23 @@ template class DynamicGraph { return m_nodeMap[node]; } - // Adds the edge from the current node to the destination node. + /// @brief Add an edge from the current node being processed to the + /// destination node. + /// @param dest void addEdgeToCurrentNode(const T& dest) { if (m_currentId != -1) addEdge(m_nodes[m_currentId], dest); } - // Set the current node to the node with the given id. + /// @brief Set the current node being processed. + /// @param node void setCurrentProcessingNode(const T& node) { if (m_nodeMap.find(node) != m_nodeMap.end()) m_currentId = m_nodeMap[node].second; } - // Mark the current node as processed. + /// @brief Mark the current node being processed as processed and add the + /// destination nodes to the queue of nodes to be processed. void markCurrentNodeProcessed() { if (m_currentId != -1) { m_nodeMap[m_nodes[m_currentId]].first = true; @@ -86,20 +98,10 @@ template class DynamicGraph { m_currentId = -1; } - // Get the nodes in the graph. + /// @brief Get the nodes in the graph. std::vector getNodes() { return m_nodes; } - // Check if two nodes are connected in the graph. - bool isConnected(const T& src, const T& dest) { - if (m_nodeMap.find(src) == m_nodeMap.end() || - m_nodeMap.find(dest) == m_nodeMap.end()) - return false; - size_t srcId = m_nodeMap[src].second; - size_t destId = m_nodeMap[dest].second; - return m_adjList[srcId].find(destId) != m_adjList[srcId].end(); - } - - // Print the graph in a human-readable format. + /// @brief Print the nodes and edges in the graph. void print() { // First print the nodes with their insertion order. for (const T& node : m_nodes) { @@ -118,33 +120,9 @@ template class DynamicGraph { std::cout << i << " -> " << dest << "\n"; } - // Topological sort of the directed graph. If the graph is not a DAG, the - // result will be a partial order. Use a recursive dfs heler function to - // implement the topological sort. If a->b, then a will come before b in the - // topological sort. In reverseOrder mode, the result will be in reverse - // topological order, i.e a->b, then b will come before a in the result. - std::vector topologicalSort(bool reverseOrder = false) { - std::vector res; - std::unordered_set visited; - - std::function dfs = [&](size_t node) -> void { - visited.insert(node); - for (size_t dest : m_adjList[node]) - if (visited.find(dest) == visited.end()) - dfs(dest); - res.push_back(m_nodes[node]); - }; - for (size_t source : m_sources) - if (visited.find(source) == visited.end()) - dfs(source); - - if (reverseOrder) - return res; - std::reverse(res.begin(), res.end()); - return res; - } - - // Get the next to process node from the queue of nodes to be processed. + /// @brief Get the next node to be processed from the queue of nodes to be + /// processed. + /// @return The next node to be processed. T getNextToProcessNode() { if (m_toProcessQueue.empty()) return T(); diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 5b7c2e623..07a7993c1 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -685,7 +685,7 @@ namespace clad { llvm::SaveAndRestore saveTopMost = m_TopMostFD; m_TopMostFD = FD; TraverseDecl(derivedFD); - m_DiffRequestGraph.addNode(request, true /*isSource*/); + m_DiffRequestGraph.addNode(request, /*isSource=*/true); } /*else if (m_TopMostFD) { // If another function is called inside differentiated function, diff --git a/test/Misc/DynamicGraph.C b/test/Misc/DynamicGraph.C deleted file mode 100644 index a11ee9607..000000000 --- a/test/Misc/DynamicGraph.C +++ /dev/null @@ -1,70 +0,0 @@ -// RUN: %cladclang %s -I%S/../../include -oGraph.out 2>&1 -// RUN: ./Graph.out | FileCheck -check-prefix=CHECK-EXEC %s -// CHECK-NOT: {{.*error|warning|note:.*}} - -#include "clad/Differentiator/DynamicGraph.h" -#include -#include - -// Custom type for representing nodes in the graph. -struct Node { - std::string name; - int id; - - Node(std::string name, int id) : name(name), id(id) {} - - bool operator==(const Node& other) const { - return name == other.name && id == other.id; - } - - // string operator for printing the node. - operator std::string() const { - return name + std::to_string(id); - } -}; - -// Specialize std::hash for the Node type. -template<> -struct std::hash { - std::size_t operator()(const Node& n) const { - return std::hash()(n.name) ^ std::hash()(n.id); - } -}; - -int main () { - clad::DynamicGraph G; - for (int i = 0; i < 6; i++) { - Node n("node", i); - if (i == 0) { - G.addNode(n, true/*isSource*/); - } - Node m("node", i + 1); - G.addEdge(n, m); - } - std::vector nodes = G.getNodes(); - std::cout << "Nodes in the graph: " << nodes.size() << "\n"; - // CHECK-EXEC: Nodes in the graph: 7 - - // edge from node 0 to node 3 and node 4 to node 0. - G.addEdge(nodes[0], nodes[3]); - G.addEdge(nodes[4], nodes[0]); - - G.print(); - // CHECK-EXEC: node0: #0 (source), (unprocessed) - // CHECK-EXEC-NEXT: node1: #1, (unprocessed) - // CHECK-EXEC-NEXT: node2: #2, (unprocessed) - // CHECK-EXEC-NEXT: node3: #3, (unprocessed) - // CHECK-EXEC-NEXT: node4: #4, (unprocessed) - // CHECK-EXEC-NEXT: node5: #5, (unprocessed) - // CHECK-EXEC-NEXT: node6: #6, (unprocessed) - // CHECK-EXEC-NEXT: 0 -> 1 - // CHECK-EXEC-NEXT: 0 -> 3 - // CHECK-EXEC-NEXT: 1 -> 2 - // CHECK-EXEC-NEXT: 2 -> 3 - // CHECK-EXEC-NEXT: 3 -> 4 - // CHECK-EXEC-NEXT: 4 -> 0 - // CHECK-EXEC-NEXT: 4 -> 5 - // CHECK-EXEC-NEXT: 5 -> 6 - return 0; -} - diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index 6100f7346..a8c5a946f 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -40,3 +40,5 @@ if (Kokkos_FOUND) set(CMAKE_CXX_STANDARD_REQUIRED TRUE) add_subdirectory(Kokkos) endif(Kokkos_FOUND) + +add_subdirectory(Misc) diff --git a/unittests/Misc/CMakeLists.txt b/unittests/Misc/CMakeLists.txt new file mode 100644 index 000000000..dc13b2417 --- /dev/null +++ b/unittests/Misc/CMakeLists.txt @@ -0,0 +1,4 @@ +add_clad_unittest(MiscTests + main.cpp + DynamicGraph.cpp +) diff --git a/unittests/Misc/DynamicGraph.cpp b/unittests/Misc/DynamicGraph.cpp new file mode 100644 index 000000000..6954a6698 --- /dev/null +++ b/unittests/Misc/DynamicGraph.cpp @@ -0,0 +1,67 @@ +#include "clad/Differentiator/Differentiator.h" + +#include +#include + +#include "gtest/gtest.h" + +struct Node { + std::string name; + int id; + + Node(std::string name, int id) : name(name), id(id) {} + + bool operator==(const Node& other) const { + return name == other.name && id == other.id; + } + + // String operator for printing the node. + operator std::string() const { return name + std::to_string(id); } +}; + +// Specialize std::hash for the Node type. +template <> struct std::hash { + std::size_t operator()(const Node& n) const { + return std::hash()(n.name) ^ std::hash()(n.id); + } +}; + +TEST(DynamicGraphTest, Printing) { + clad::DynamicGraph G; + for (int i = 0; i < 6; i++) { + Node n("node", i); + if (i == 0) + G.addNode(n, /*isSource=*/true); + Node m("node", i + 1); + G.addEdge(n, m); + } + std::vector nodes = G.getNodes(); + EXPECT_EQ(nodes.size(), 7); + + // Edge from node 0 to node 3 and node 4 to node 0. + G.addEdge(nodes[0], nodes[3]); + G.addEdge(nodes[4], nodes[0]); + + // Check the printed output. + std::stringstream ss; + std::streambuf* coutbuf = std::cout.rdbuf(); + std::cout.rdbuf(ss.rdbuf()); + G.print(); + std::cout.rdbuf(coutbuf); + std::string expectedOutput = "node0: #0 (source), (unprocessed)\n" + "node1: #1, (unprocessed)\n" + "node2: #2, (unprocessed)\n" + "node3: #3, (unprocessed)\n" + "node4: #4, (unprocessed)\n" + "node5: #5, (unprocessed)\n" + "node6: #6, (unprocessed)\n" + "0 -> 1\n" + "0 -> 3\n" + "1 -> 2\n" + "2 -> 3\n" + "3 -> 4\n" + "4 -> 0\n" + "4 -> 5\n" + "5 -> 6\n"; + EXPECT_EQ(ss.str(), expectedOutput); +} diff --git a/unittests/Misc/main.cpp b/unittests/Misc/main.cpp new file mode 100644 index 000000000..b936444aa --- /dev/null +++ b/unittests/Misc/main.cpp @@ -0,0 +1,6 @@ +#include + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}