Skip to content

Commit

Permalink
Improve DynamicGraph comments and unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed May 3, 2024
1 parent b2ed85b commit 1ee7b35
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 126 deletions.
1 change: 1 addition & 0 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "ArrayRef.h"
#include "BuiltinDerivatives.h"
#include "CladConfig.h"
#include "DynamicGraph.h"
#include "FunctionTraits.h"
#include "Matrix.h"
#include "NumericalDiff.h"
Expand Down
88 changes: 33 additions & 55 deletions include/clad/Differentiator/DynamicGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,35 @@
namespace clad {
template <typename T> class DynamicGraph {
private:
// Storing nodes in the graph. The index of the node in the vector is used as
// a unique identifier for the node in the adjacency list.
/// Storing nodes in the graph. The index of the node in the vector is used as
/// a unique identifier for the node in the adjacency list.
std::vector<T> m_nodes;

// Store the nodes in the graph as an unordered map from the node to a boolean
// indicating whether the node is processed or not. The second element in the
// pair is the id of the node in the nodes vector.
/// Store the nodes in the graph as an unordered map from the node to a
/// boolean indicating whether the node is processed or not. The second
/// element in the pair is the id of the node in the nodes vector.
std::unordered_map<T, std::pair<bool, size_t>> m_nodeMap;

// Store the adjacency list for the graph. The adjacency list is a map from
// a node to the set of nodes that it has an edge to. We use integers inside
// the set to avoid copying the nodes.
/// Store the adjacency list for the graph. The adjacency list is a map from
/// a node to the set of nodes that it has an edge to. We use integers inside
/// the set to avoid copying the nodes.
std::unordered_map<size_t, std::set<size_t>> m_adjList;

// Set of source nodes in the graph.
/// Set of source nodes in the graph.
std::set<size_t> m_sources;

// Store the id of the node being processed right now.
/// Store the id of the node being processed right now.
int m_currentId = -1; // -1 means no node is being processed.

// Maintain a queue of nodes to be processed next.
/// Maintain a queue of nodes to be processed next.
std::queue<size_t> m_toProcessQueue;

public:
DynamicGraph() = default;

// Add an edge from src to dest
/// Add an edge from the source node to the destination node.
/// \param src
/// \param dest
void addEdge(const T& src, const T& dest) {
std::pair<bool, size_t> srcInfo = addNode(src);
std::pair<bool, size_t> destInfo = addNode(dest);
Expand All @@ -48,7 +50,13 @@ template <typename T> class DynamicGraph {
m_adjList[srcId].insert(destId);
}

// Add a node to the graph
/// Add a node to the graph. If the node is already present, return the
/// id of the node in the graph. If the node is a source node, add it to the
/// queue of nodes to be processed.
/// \param node
/// \param isSource
/// \returns A pair of a boolean indicating whether the node is already
/// processed and the id of the node in the graph.
std::pair<bool, size_t> addNode(const T& node, bool isSource = false) {
if (m_nodeMap.find(node) == m_nodeMap.end()) {
size_t id = m_nodes.size();
Expand All @@ -63,19 +71,23 @@ template <typename T> class DynamicGraph {
return m_nodeMap[node];
}

// Adds the edge from the current node to the destination node.
/// Add an edge from the current node being processed to the
/// destination node.
/// \param dest
void addEdgeToCurrentNode(const T& dest) {
if (m_currentId != -1)
addEdge(m_nodes[m_currentId], dest);
}

// Set the current node to the node with the given id.
/// Set the current node being processed.
/// \param node
void setCurrentProcessingNode(const T& node) {
if (m_nodeMap.find(node) != m_nodeMap.end())
m_currentId = m_nodeMap[node].second;
}

// Mark the current node as processed.
/// Mark the current node being processed as processed and add the
/// destination nodes to the queue of nodes to be processed.
void markCurrentNodeProcessed() {
if (m_currentId != -1) {
m_nodeMap[m_nodes[m_currentId]].first = true;
Expand All @@ -86,20 +98,10 @@ template <typename T> class DynamicGraph {
m_currentId = -1;
}

// Get the nodes in the graph.
/// Get the nodes in the graph.
std::vector<T> getNodes() { return m_nodes; }

// Check if two nodes are connected in the graph.
bool isConnected(const T& src, const T& dest) {
if (m_nodeMap.find(src) == m_nodeMap.end() ||
m_nodeMap.find(dest) == m_nodeMap.end())
return false;
size_t srcId = m_nodeMap[src].second;
size_t destId = m_nodeMap[dest].second;
return m_adjList[srcId].find(destId) != m_adjList[srcId].end();
}

// Print the graph in a human-readable format.
/// Print the nodes and edges in the graph.
void print() {
// First print the nodes with their insertion order.
for (const T& node : m_nodes) {
Expand All @@ -118,33 +120,9 @@ template <typename T> class DynamicGraph {
std::cout << i << " -> " << dest << "\n";
}

// Topological sort of the directed graph. If the graph is not a DAG, the
// result will be a partial order. Use a recursive dfs heler function to
// implement the topological sort. If a->b, then a will come before b in the
// topological sort. In reverseOrder mode, the result will be in reverse
// topological order, i.e a->b, then b will come before a in the result.
std::vector<T> topologicalSort(bool reverseOrder = false) {
std::vector<T> res;
std::unordered_set<size_t> visited;

std::function<void(size_t)> dfs = [&](size_t node) -> void {
visited.insert(node);
for (size_t dest : m_adjList[node])
if (visited.find(dest) == visited.end())
dfs(dest);
res.push_back(m_nodes[node]);
};
for (size_t source : m_sources)
if (visited.find(source) == visited.end())
dfs(source);

if (reverseOrder)
return res;
std::reverse(res.begin(), res.end());
return res;
}

// Get the next to process node from the queue of nodes to be processed.
/// Get the next node to be processed from the queue of nodes to be
/// processed.
/// \returns The next node to be processed.
T getNextToProcessNode() {
if (m_toProcessQueue.empty())
return T();
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ namespace clad {
llvm::SaveAndRestore<const FunctionDecl*> saveTopMost = m_TopMostFD;
m_TopMostFD = FD;
TraverseDecl(derivedFD);
m_DiffRequestGraph.addNode(request, true /*isSource*/);
m_DiffRequestGraph.addNode(request, /*isSource=*/true);
}
/*else if (m_TopMostFD) {
// If another function is called inside differentiated function,
Expand Down
70 changes: 0 additions & 70 deletions test/Misc/DynamicGraph.C

This file was deleted.

2 changes: 2 additions & 0 deletions unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@ if (Kokkos_FOUND)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
add_subdirectory(Kokkos)
endif(Kokkos_FOUND)

add_subdirectory(Misc)
4 changes: 4 additions & 0 deletions unittests/Misc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
add_clad_unittest(MiscTests
main.cpp
DynamicGraph.cpp
)
67 changes: 67 additions & 0 deletions unittests/Misc/DynamicGraph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "clad/Differentiator/Differentiator.h"

#include <iostream>
#include <string>

#include "gtest/gtest.h"

struct Node {
std::string name;
int id;

Node(std::string name, int id) : name(name), id(id) {}

bool operator==(const Node& other) const {
return name == other.name && id == other.id;
}

// String operator for printing the node.
operator std::string() const { return name + std::to_string(id); }
};

// Specialize std::hash for the Node type.
template <> struct std::hash<Node> {
std::size_t operator()(const Node& n) const {
return std::hash<std::string>()(n.name) ^ std::hash<int>()(n.id);
}
};

TEST(DynamicGraphTest, Printing) {
clad::DynamicGraph<Node> G;
for (int i = 0; i < 6; i++) {
Node n("node", i);
if (i == 0)
G.addNode(n, /*isSource=*/true);
Node m("node", i + 1);
G.addEdge(n, m);
}
std::vector<Node> nodes = G.getNodes();
EXPECT_EQ(nodes.size(), 7);

// Edge from node 0 to node 3 and node 4 to node 0.
G.addEdge(nodes[0], nodes[3]);
G.addEdge(nodes[4], nodes[0]);

// Check the printed output.
std::stringstream ss;
std::streambuf* coutbuf = std::cout.rdbuf();
std::cout.rdbuf(ss.rdbuf());
G.print();
std::cout.rdbuf(coutbuf);
std::string expectedOutput = "node0: #0 (source), (unprocessed)\n"
"node1: #1, (unprocessed)\n"
"node2: #2, (unprocessed)\n"
"node3: #3, (unprocessed)\n"
"node4: #4, (unprocessed)\n"
"node5: #5, (unprocessed)\n"
"node6: #6, (unprocessed)\n"
"0 -> 1\n"
"0 -> 3\n"
"1 -> 2\n"
"2 -> 3\n"
"3 -> 4\n"
"4 -> 0\n"
"4 -> 5\n"
"5 -> 6\n";
EXPECT_EQ(ss.str(), expectedOutput);
}
6 changes: 6 additions & 0 deletions unittests/Misc/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <gtest/gtest.h>

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

0 comments on commit 1ee7b35

Please sign in to comment.