From dc16084ef9a9f93616e54280a719d9cf46a23f70 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Thu, 20 Oct 2022 12:50:54 -0400 Subject: [PATCH] Add transitive_closure_dag function This commit adds a new function transitive_closure_dag() which is an optimized method for computing the transitive closure for DAGs. In support of this a new function descendants_at_distance() for finding the nodes a fixed distance from a given source to both rustworkx and rustworkx-core. Related to: #704 --- docs/source/api.rst | 4 ++ ...ansitive-closure-dag-3fb45113d552f007.yaml | 13 ++++ rustworkx-core/src/traversal/descendants.rs | 44 +++++++++++++ rustworkx-core/src/traversal/mod.rs | 2 + rustworkx/__init__.py | 34 ++++++++++ src/dag_algo/mod.rs | 49 ++++++++++++++ src/lib.rs | 3 + src/traversal/mod.rs | 66 ++++++++++++++++++- .../digraph/test_transitive_closure.py | 33 ++++++++++ 9 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml create mode 100644 rustworkx-core/src/traversal/descendants.rs create mode 100644 tests/rustworkx_tests/digraph/test_transitive_closure.py diff --git a/docs/source/api.rst b/docs/source/api.rst index 8368188c2..3c6853736 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -79,6 +79,7 @@ Traversal rustworkx.visit.BFSVisitor rustworkx.visit.DijkstraVisitor rustworkx.TopologicalSorter + rustworkx.descendants_at_distance .. _dag-algorithms: @@ -94,6 +95,7 @@ DAG Algorithms rustworkx.dag_weighted_longest_path_length rustworkx.is_directed_acyclic_graph rustworkx.layers + rustworkx.transitive_closure_dag .. _tree: @@ -325,6 +327,7 @@ the functions from the explicitly typed based on the data type. rustworkx.digraph_bfs_search rustworkx.digraph_dijkstra_search rustworkx.digraph_node_link_json + rustworkx.digraph_descendants_at_distance .. _api-functions-pygraph: @@ -379,6 +382,7 @@ typed API based on the data type. rustworkx.graph_bfs_search rustworkx.graph_dijkstra_search rustworkx.graph_node_link_json + rustworkx.graph_descendants_at_distance Exceptions ========== diff --git a/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml b/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml new file mode 100644 index 000000000..cc5099212 --- /dev/null +++ b/releasenotes/notes/add-transitive-closure-dag-3fb45113d552f007.yaml @@ -0,0 +1,13 @@ +--- +features: + - | + Added a new function ``descendants_at_distance`` to the rustworkx-core + crate under the ``traversal`` module + - | + Added a new function, :func:`~.transitive_closure_dag`, which provides + an optimize method for computing the transitive closure of an input + DAG. + - | + Added a new function :func:`~.descendants_at_distance` which provides + a method to find the nodes at a fixed distance from a source in + a graph object. diff --git a/rustworkx-core/src/traversal/descendants.rs b/rustworkx-core/src/traversal/descendants.rs new file mode 100644 index 000000000..67bf6bf60 --- /dev/null +++ b/rustworkx-core/src/traversal/descendants.rs @@ -0,0 +1,44 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use hashbrown::HashSet; +use petgraph::visit::{IntoNeighborsDirected, NodeCount, Visitable}; + +/// Returns all nodes at a fixed `distance` from `source` in `G`. +/// Args: +/// `graph`: +/// `source`: +/// `distance`: +pub fn descendants_at_distance(graph: G, source: G::NodeId, distance: usize) -> Vec +where + G: Visitable + IntoNeighborsDirected + NodeCount, + G::NodeId: std::cmp::Eq + std::hash::Hash, +{ + let mut current_layer: Vec = vec![source]; + let mut layers: usize = 0; + let mut visited: HashSet = HashSet::with_capacity(graph.node_count()); + visited.insert(source); + while !current_layer.is_empty() && layers < distance { + let mut next_layer: Vec = Vec::new(); + for node in current_layer { + for child in graph.neighbors_directed(node, petgraph::Outgoing) { + if !visited.contains(&child) { + visited.insert(child); + next_layer.push(child); + } + } + } + current_layer = next_layer; + layers += 1; + } + current_layer +} diff --git a/rustworkx-core/src/traversal/mod.rs b/rustworkx-core/src/traversal/mod.rs index 2a6254481..b531075da 100644 --- a/rustworkx-core/src/traversal/mod.rs +++ b/rustworkx-core/src/traversal/mod.rs @@ -13,11 +13,13 @@ //! Module for graph traversal algorithms. mod bfs_visit; +mod descendants; mod dfs_edges; mod dfs_visit; mod dijkstra_visit; pub use bfs_visit::{breadth_first_search, BfsEvent}; +pub use descendants::descendants_at_distance; pub use dfs_edges::dfs_edges; pub use dfs_visit::{depth_first_search, DfsEvent}; pub use dijkstra_visit::{dijkstra_search, DijkstraEvent}; diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 6d82d41f2..313d2831c 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -2382,3 +2382,37 @@ def _graph_node_link_json(graph, path=None, graph_attrs=None, node_attrs=None, e return graph_node_link_json( graph, path=path, graph_attrs=graph_attrs, node_attrs=node_attrs, edge_attrs=edge_attrs ) + + +@functools.singledispatch +def descendants_at_distance(graph, source, distance): + """Returns all nodes at a fixed distance from ``source`` in ``graph`` + + :param graph: The graph to find the descendants in + :param int source: The node index to find the descendants from + :param int distance: The distance from ``source`` + + :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``. + :rtype: NodeIndices + + For example:: + + import rustworkx as rx + + graph = rx.generators.path_graph(5) + res = rx.descendants_at_distance(graph, 2, 2) + print(res) + + will return: ``[0, 4]`` + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@descendants_at_distance.register(PyDiGraph) +def _digraph_descendants_at_distance(graph, source, distance): + return digraph_descendants_at_distance(graph, source, distance) + + +@descendants_at_distance.register(PyGraph) +def _graph_descendants_at_distance(graph, source, distance): + return graph_descendants_at_distance(graph, source, distance) diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index dbfcf4702..bf6e0f2a3 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -29,6 +29,8 @@ use petgraph::graph::NodeIndex; use petgraph::prelude::*; use petgraph::visit::NodeCount; +use rustworkx_core::traversal::descendants_at_distance; + /// Find the longest path in a DAG /// /// :param PyDiGraph graph: The graph to find the longest path on. The input @@ -634,3 +636,50 @@ pub fn collect_bicolor_runs( Ok(block_list) } + +/// Return the transitive closure of a directed acyclic graph +/// +/// The transitive closure of :math:`G = (V, E)` is a graph :math:`G+ = (V, E+)` +/// such that for all pairs of :math:`v, w` in :math:`V` there is an edge +/// :math:`(v, w) in :math:`E+` if and only if there is a non-null path from +/// :math:`v` to :math:`w` in :math:`G`. +/// +/// :param PyDiGraph graph: The input DAG to compute the transitive closure of +/// :param list topological_order: An optional topological order for ``graph`` +/// which represents the order the graph will be traversed in computing +/// the transitive closure. If one is not provided (or it is explicitly +/// set to ``None``) a topological order will be computed by this function. +/// +/// :returns: The transitive closure of ``graph`` +/// :rtype: PyDiGraph +/// +/// :raises DAGHasCycle: If the input ``graph`` is not acyclic +#[pyfunction] +#[pyo3(text_signature = "(graph, / topological_order=None)")] +pub fn transitive_closure_dag( + py: Python, + graph: &digraph::PyDiGraph, + topological_order: Option>, +) -> PyResult { + let node_order: Vec = match topological_order { + Some(topo_order) => topo_order.into_iter().map(NodeIndex::new).collect(), + None => match algo::toposort(&graph.graph, None) { + Ok(nodes) => nodes, + Err(_err) => return Err(DAGHasCycle::new_err("Topological Sort encountered a cycle")), + }, + }; + let mut out_graph = graph.graph.clone(); + for node in node_order.into_iter().rev() { + for descendant in descendants_at_distance(&out_graph, node, 2) { + out_graph.add_edge(node, descendant, py.None()); + } + } + Ok(digraph::PyDiGraph { + graph: out_graph, + cycle_state: algo::DfsSpace::default(), + check_cycle: false, + node_removed: false, + multigraph: true, + attrs: py.None(), + }) +} diff --git a/src/lib.rs b/src/lib.rs index a9ad4f4dd..ad449de46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -477,6 +477,9 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(read_graphml))?; m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?; + m.add_wrapped(wrap_pyfunction!(transitive_closure_dag))?; + m.add_wrapped(wrap_pyfunction!(graph_descendants_at_distance))?; + m.add_wrapped(wrap_pyfunction!(digraph_descendants_at_distance))?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/traversal/mod.rs b/src/traversal/mod.rs index 8f27cad86..0709322bb 100644 --- a/src/traversal/mod.rs +++ b/src/traversal/mod.rs @@ -19,7 +19,7 @@ use dfs_visit::{dfs_handler, PyDfsVisitor}; use dijkstra_visit::{dijkstra_handler, PyDijkstraVisitor}; use rustworkx_core::traversal::{ - breadth_first_search, depth_first_search, dfs_edges, dijkstra_search, + breadth_first_search, depth_first_search, descendants_at_distance, dfs_edges, dijkstra_search, }; use super::{digraph, graph, iterators, CostFn}; @@ -707,3 +707,67 @@ pub fn graph_dijkstra_search( Ok(()) } + +/// Returns all nodes at a fixed distance from ``source`` in ``graph`` +/// +/// :param PyGraph graph: The graph to find the descendants in +/// :param int source: The node index to find the descendants from +/// :param int distance: The distance from ``source`` +/// +/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``. +/// :rtype: NodeIndices +/// For example:: +/// +/// import rustworkx as rx +/// +/// graph = rx.generators.path_graph(5) +/// res = rx.descendants_at_distance(graph, 2, 2) +/// print(res) +/// +/// will return: ``[0, 4]`` +#[pyfunction] +pub fn graph_descendants_at_distance( + graph: graph::PyGraph, + source: usize, + distance: usize, +) -> iterators::NodeIndices { + let source = NodeIndex::new(source); + iterators::NodeIndices { + nodes: descendants_at_distance(&graph.graph, source, distance) + .into_iter() + .map(|x| x.index()) + .collect(), + } +} + +/// Returns all nodes at a fixed distance from ``source`` in ``graph`` +/// +/// :param PyDiGraph graph: The graph to find the descendants in +/// :param int source: The node index to find the descendants from +/// :param int distance: The distance from ``source`` +/// +/// :returns: The node indices of the nodes ``distance`` from ``source`` in ``graph``. +/// :rtype: NodeIndices +/// For example:: +/// +/// import rustworkx as rx +/// +/// graph = rx.generators.directed_path_graph(5) +/// res = rx.descendants_at_distance(graph, 2, 2) +/// print(res) +/// +/// will return: ``[4]`` +#[pyfunction] +pub fn digraph_descendants_at_distance( + graph: digraph::PyDiGraph, + source: usize, + distance: usize, +) -> iterators::NodeIndices { + let source = NodeIndex::new(source); + iterators::NodeIndices { + nodes: descendants_at_distance(&graph.graph, source, distance) + .into_iter() + .map(|x| x.index()) + .collect(), + } +} diff --git a/tests/rustworkx_tests/digraph/test_transitive_closure.py b/tests/rustworkx_tests/digraph/test_transitive_closure.py new file mode 100644 index 000000000..cf707bf55 --- /dev/null +++ b/tests/rustworkx_tests/digraph/test_transitive_closure.py @@ -0,0 +1,33 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import rustworkx as rx + + +class TestTransitivity(unittest.TestCase): + def test_path_graph(self): + graph = rx.generators.directed_path_graph(4) + transitive_closure = rx.transitive_closure_dag(graph) + expected_edge_list = [(0, 1), (1, 2), (2, 3), (1, 3), (0, 3), (0, 2)] + self.assertEqual(transitive_closure.edge_list(), expected_edge_list) + + def test_invalid_type(self): + with self.assertRaises(TypeError): + rx.transitive_closure_dag(rx.PyGraph()) + + def test_cycle_error(self): + graph = rx.PyDiGraph() + graph.extend_from_edge_list([(0, 1), (1, 0)]) + with self.assertRaises(rx.DAGHasCycle): + rx.transitive_closure_dag(graph)