Skip to content

Commit

Permalink
Add transitive_closure_dag function
Browse files Browse the repository at this point in the history
This commit adds a new function transitive_closure_dag() which is an
optimized method for computing the transitive closure for DAGs. In
support of this a new function descendants_at_distance() for finding the
nodes a fixed distance from a given source to both rustworkx and
rustworkx-core.

Related to: Qiskit#704
  • Loading branch information
mtreinish committed Oct 20, 2022
1 parent c073a21 commit dc16084
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 1 deletion.
4 changes: 4 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Traversal
rustworkx.visit.BFSVisitor
rustworkx.visit.DijkstraVisitor
rustworkx.TopologicalSorter
rustworkx.descendants_at_distance

.. _dag-algorithms:

Expand All @@ -94,6 +95,7 @@ DAG Algorithms
rustworkx.dag_weighted_longest_path_length
rustworkx.is_directed_acyclic_graph
rustworkx.layers
rustworkx.transitive_closure_dag

.. _tree:

Expand Down Expand Up @@ -325,6 +327,7 @@ the functions from the explicitly typed based on the data type.
rustworkx.digraph_bfs_search
rustworkx.digraph_dijkstra_search
rustworkx.digraph_node_link_json
rustworkx.digraph_descendants_at_distance

.. _api-functions-pygraph:

Expand Down Expand Up @@ -379,6 +382,7 @@ typed API based on the data type.
rustworkx.graph_bfs_search
rustworkx.graph_dijkstra_search
rustworkx.graph_node_link_json
rustworkx.graph_descendants_at_distance

Exceptions
==========
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
features:
- |
Added a new function ``descendants_at_distance`` to the rustworkx-core
crate under the ``traversal`` module
- |
Added a new function, :func:`~.transitive_closure_dag`, which provides
an optimize method for computing the transitive closure of an input
DAG.
- |
Added a new function :func:`~.descendants_at_distance` which provides
a method to find the nodes at a fixed distance from a source in
a graph object.
44 changes: 44 additions & 0 deletions rustworkx-core/src/traversal/descendants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use hashbrown::HashSet;
use petgraph::visit::{IntoNeighborsDirected, NodeCount, Visitable};

/// Returns all nodes at a fixed `distance` from `source` in `G`.
/// Args:
/// `graph`:
/// `source`:
/// `distance`:
pub fn descendants_at_distance<G>(graph: G, source: G::NodeId, distance: usize) -> Vec<G::NodeId>
where
G: Visitable + IntoNeighborsDirected + NodeCount,
G::NodeId: std::cmp::Eq + std::hash::Hash,
{
let mut current_layer: Vec<G::NodeId> = vec![source];
let mut layers: usize = 0;
let mut visited: HashSet<G::NodeId> = HashSet::with_capacity(graph.node_count());
visited.insert(source);
while !current_layer.is_empty() && layers < distance {
let mut next_layer: Vec<G::NodeId> = Vec::new();
for node in current_layer {
for child in graph.neighbors_directed(node, petgraph::Outgoing) {
if !visited.contains(&child) {
visited.insert(child);
next_layer.push(child);
}
}
}
current_layer = next_layer;
layers += 1;
}
current_layer
}
2 changes: 2 additions & 0 deletions rustworkx-core/src/traversal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
//! Module for graph traversal algorithms.

mod bfs_visit;
mod descendants;
mod dfs_edges;
mod dfs_visit;
mod dijkstra_visit;

pub use bfs_visit::{breadth_first_search, BfsEvent};
pub use descendants::descendants_at_distance;
pub use dfs_edges::dfs_edges;
pub use dfs_visit::{depth_first_search, DfsEvent};
pub use dijkstra_visit::{dijkstra_search, DijkstraEvent};
Expand Down
34 changes: 34 additions & 0 deletions rustworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,3 +2382,37 @@ def _graph_node_link_json(graph, path=None, graph_attrs=None, node_attrs=None, e
return graph_node_link_json(
graph, path=path, graph_attrs=graph_attrs, node_attrs=node_attrs, edge_attrs=edge_attrs
)


@functools.singledispatch
def descendants_at_distance(graph, source, distance):
"""Returns all nodes at a fixed distance from ``source`` in ``graph``
:param graph: The graph to find the descendants in
:param int source: The node index to find the descendants from
:param int distance: The distance from ``source``
:returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``.
:rtype: NodeIndices
For example::
import rustworkx as rx
graph = rx.generators.path_graph(5)
res = rx.descendants_at_distance(graph, 2, 2)
print(res)
will return: ``[0, 4]``
"""
raise TypeError("Invalid Input Type %s for graph" % type(graph))


@descendants_at_distance.register(PyDiGraph)
def _digraph_descendants_at_distance(graph, source, distance):
return digraph_descendants_at_distance(graph, source, distance)


@descendants_at_distance.register(PyGraph)
def _graph_descendants_at_distance(graph, source, distance):
return graph_descendants_at_distance(graph, source, distance)
49 changes: 49 additions & 0 deletions src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use petgraph::graph::NodeIndex;
use petgraph::prelude::*;
use petgraph::visit::NodeCount;

use rustworkx_core::traversal::descendants_at_distance;

/// Find the longest path in a DAG
///
/// :param PyDiGraph graph: The graph to find the longest path on. The input
Expand Down Expand Up @@ -634,3 +636,50 @@ pub fn collect_bicolor_runs(

Ok(block_list)
}

/// Return the transitive closure of a directed acyclic graph
///
/// The transitive closure of :math:`G = (V, E)` is a graph :math:`G+ = (V, E+)`
/// such that for all pairs of :math:`v, w` in :math:`V` there is an edge
/// :math:`(v, w) in :math:`E+` if and only if there is a non-null path from
/// :math:`v` to :math:`w` in :math:`G`.
///
/// :param PyDiGraph graph: The input DAG to compute the transitive closure of
/// :param list topological_order: An optional topological order for ``graph``
/// which represents the order the graph will be traversed in computing
/// the transitive closure. If one is not provided (or it is explicitly
/// set to ``None``) a topological order will be computed by this function.
///
/// :returns: The transitive closure of ``graph``
/// :rtype: PyDiGraph
///
/// :raises DAGHasCycle: If the input ``graph`` is not acyclic
#[pyfunction]
#[pyo3(text_signature = "(graph, / topological_order=None)")]
pub fn transitive_closure_dag(
py: Python,
graph: &digraph::PyDiGraph,
topological_order: Option<Vec<usize>>,
) -> PyResult<digraph::PyDiGraph> {
let node_order: Vec<NodeIndex> = match topological_order {
Some(topo_order) => topo_order.into_iter().map(NodeIndex::new).collect(),
None => match algo::toposort(&graph.graph, None) {
Ok(nodes) => nodes,
Err(_err) => return Err(DAGHasCycle::new_err("Topological Sort encountered a cycle")),
},
};
let mut out_graph = graph.graph.clone();
for node in node_order.into_iter().rev() {
for descendant in descendants_at_distance(&out_graph, node, 2) {
out_graph.add_edge(node, descendant, py.None());
}
}
Ok(digraph::PyDiGraph {
graph: out_graph,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph: true,
attrs: py.None(),
})
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(read_graphml))?;
m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(transitive_closure_dag))?;
m.add_wrapped(wrap_pyfunction!(graph_descendants_at_distance))?;
m.add_wrapped(wrap_pyfunction!(digraph_descendants_at_distance))?;
m.add_class::<digraph::PyDiGraph>()?;
m.add_class::<graph::PyGraph>()?;
m.add_class::<toposort::TopologicalSorter>()?;
Expand Down
66 changes: 65 additions & 1 deletion src/traversal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use dfs_visit::{dfs_handler, PyDfsVisitor};
use dijkstra_visit::{dijkstra_handler, PyDijkstraVisitor};

use rustworkx_core::traversal::{
breadth_first_search, depth_first_search, dfs_edges, dijkstra_search,
breadth_first_search, depth_first_search, descendants_at_distance, dfs_edges, dijkstra_search,
};

use super::{digraph, graph, iterators, CostFn};
Expand Down Expand Up @@ -707,3 +707,67 @@ pub fn graph_dijkstra_search(

Ok(())
}

/// Returns all nodes at a fixed distance from ``source`` in ``graph``
///
/// :param PyGraph graph: The graph to find the descendants in
/// :param int source: The node index to find the descendants from
/// :param int distance: The distance from ``source``
///
/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``.
/// :rtype: NodeIndices
/// For example::
///
/// import rustworkx as rx
///
/// graph = rx.generators.path_graph(5)
/// res = rx.descendants_at_distance(graph, 2, 2)
/// print(res)
///
/// will return: ``[0, 4]``
#[pyfunction]
pub fn graph_descendants_at_distance(
graph: graph::PyGraph,
source: usize,
distance: usize,
) -> iterators::NodeIndices {
let source = NodeIndex::new(source);
iterators::NodeIndices {
nodes: descendants_at_distance(&graph.graph, source, distance)
.into_iter()
.map(|x| x.index())
.collect(),
}
}

/// Returns all nodes at a fixed distance from ``source`` in ``graph``
///
/// :param PyDiGraph graph: The graph to find the descendants in
/// :param int source: The node index to find the descendants from
/// :param int distance: The distance from ``source``
///
/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``.
/// :rtype: NodeIndices
/// For example::
///
/// import rustworkx as rx
///
/// graph = rx.generators.directed_path_graph(5)
/// res = rx.descendants_at_distance(graph, 2, 2)
/// print(res)
///
/// will return: ``[4]``
#[pyfunction]
pub fn digraph_descendants_at_distance(
graph: digraph::PyDiGraph,
source: usize,
distance: usize,
) -> iterators::NodeIndices {
let source = NodeIndex::new(source);
iterators::NodeIndices {
nodes: descendants_at_distance(&graph.graph, source, distance)
.into_iter()
.map(|x| x.index())
.collect(),
}
}
33 changes: 33 additions & 0 deletions tests/rustworkx_tests/digraph/test_transitive_closure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest

import rustworkx as rx


class TestTransitivity(unittest.TestCase):
def test_path_graph(self):
graph = rx.generators.directed_path_graph(4)
transitive_closure = rx.transitive_closure_dag(graph)
expected_edge_list = [(0, 1), (1, 2), (2, 3), (1, 3), (0, 3), (0, 2)]
self.assertEqual(transitive_closure.edge_list(), expected_edge_list)

def test_invalid_type(self):
with self.assertRaises(TypeError):
rx.transitive_closure_dag(rx.PyGraph())

def test_cycle_error(self):
graph = rx.PyDiGraph()
graph.extend_from_edge_list([(0, 1), (1, 0)])
with self.assertRaises(rx.DAGHasCycle):
rx.transitive_closure_dag(graph)

0 comments on commit dc16084

Please sign in to comment.