Skip to content

Commit

Permalink
Generate all connected subgraphs of size k (#1104)
Browse files Browse the repository at this point in the history
* Create test_all_connected_subgraphs.py

* interface ok

* up init

* Update mod.rs

* up

* lint fmt

* clippy

* stub...

* updated reno and shortened tests

* update to while let

---------

Co-authored-by: Ivan Carvalho <[email protected]>
  • Loading branch information
sbrandhsn and IvanIsCoding authored Mar 3, 2024
1 parent efb296e commit 6e37baf
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 0 deletions.
10 changes: 10 additions & 0 deletions releasenotes/notes/connected-subgraphs-b6fd8d5e37276240.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
features:
- |
Added a function :func:`~rustworkx.connected_subgraphs` to determine all connected subgraphs of size `k` in
polynomial delay for undirected graphs. This improves upon the brute-force method by two orders of magnitude for
sparse graphs such as heavy-hex, enabling addressing larger graphs and for a larger `k`. The introduced method is
based on "Enumerating Connected Induced Subgraphs: Improved Delay and Experimental Comparison" by Christian
Komusiewicz and Frank Sommer. In particular, the procedure `Simple` is implemented. Possible runtime improvement can
be gained by parallelization over each recursion or by following the discussion in Lemma 4 of above work and thus
implementing intermediate sets `X` and `P` more efficiently.
1 change: 1 addition & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ from .rustworkx import graph_katz_centrality as graph_katz_centrality
from .rustworkx import graph_greedy_color as graph_greedy_color
from .rustworkx import graph_greedy_edge_color as graph_greedy_edge_color
from .rustworkx import graph_is_bipartite as graph_is_bipartite
from .rustworkx import connected_subgraphs as connected_subgraphs
from .rustworkx import digraph_is_bipartite as digraph_is_bipartite
from .rustworkx import graph_two_color as graph_two_color
from .rustworkx import digraph_two_color as digraph_two_color
Expand Down
1 change: 1 addition & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def stoer_wagner_min_cut(
def simple_cycles(graph: PyDiGraph, /) -> Iterator[NodeIndices]: ...
def graph_isolates(graph: PyGraph) -> NodeIndices: ...
def digraph_isolates(graph: PyDiGraph) -> NodeIndices: ...
def connected_subgraphs(graph: PyGraph, k: int, /) -> list[list[int]]: ...

# DAG Algorithms

Expand Down
23 changes: 23 additions & 0 deletions src/connectivity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

mod all_pairs_all_simple_paths;
mod johnson_simple_cycles;
mod subgraphs;

use super::{
digraph, get_edge_iter_with_weights, graph, score, weight_callable, InvalidNode, NullGraph,
Expand All @@ -39,6 +40,7 @@ use crate::iterators::{
};
use crate::{EdgeType, StablePyGraph};

use crate::graph::PyGraph;
use rustworkx_core::coloring::two_color;
use rustworkx_core::connectivity;

Expand Down Expand Up @@ -659,6 +661,27 @@ pub fn digraph_all_pairs_all_simple_paths(
))
}

/// Return all the connected subgraphs (as a list of node indices) with exactly k nodes
///
///
/// :param PyGraph graph: The graph to find all connected subgraphs in
/// :param int k: The maximum number of nodes in a returned connected subgraph.
///
/// :returns: A list of connected subgraphs with k nodes, represented by their node indices
///
/// :raises ValueError: If ``k`` is larger than the number of nodes in ``graph``
#[pyfunction]
#[pyo3(text_signature = "(graph, k, /)")]
pub fn connected_subgraphs(graph: &PyGraph, k: usize) -> PyResult<Vec<Vec<usize>>> {
if k > graph.node_count() {
return Err(PyValueError::new_err(
"Value for k must be < node count in input graph",
));
}

Ok(subgraphs::k_connected_subgraphs(&graph.graph, k))
}

/// Return all the simple paths between all pairs of nodes in the graph
///
/// This function is multithreaded and will launch a thread pool with threads
Expand Down
83 changes: 83 additions & 0 deletions src/connectivity/subgraphs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use crate::StablePyGraph;
use hashbrown::HashSet;
use petgraph::stable_graph::NodeIndex;
use petgraph::EdgeType;
use std::cmp::max;

// Implemented after ``Simple`` from
// "Enumerating Connected Induced Subgraphs: Improved Delay and Experimental Comparison".
// Christian Komusiewicz and Frank Sommer. Some more runtime and spaceimprovements can be gained by
// implementing the data structure defined for Lemma 4 (essentially a more efficient way of
// tracking set ``x`` and ``p``)
pub fn k_connected_subgraphs<Ty: EdgeType + Sync>(
graph: &StablePyGraph<Ty>,
k: usize,
) -> Vec<Vec<usize>> {
let mut connected_subgraphs = Vec::new();
let mut graph = graph.clone();

while let Some(v) = graph.node_indices().next() {
if graph.node_count() < max(k, 1) {
break;
}

let mut p: HashSet<NodeIndex> = HashSet::new();
p.insert(v);
let mut x: HashSet<NodeIndex> = graph.neighbors(v).collect();
simple_enum(&mut p, &mut x, &graph, &mut connected_subgraphs, k);
graph.remove_node(v);
}
connected_subgraphs
}

fn simple_enum<Ty: EdgeType + Sync>(
p: &mut HashSet<NodeIndex>,
x: &mut HashSet<NodeIndex>,
graph: &StablePyGraph<Ty>,
subgraphs: &mut Vec<Vec<usize>>,
k: usize,
) -> bool {
if p.len() == k {
subgraphs.push(p.iter().map(|n| n.index()).collect::<Vec<usize>>());
return true;
}
let mut is_leaf_node: bool = false;
while let Some(u) = x.iter().next().cloned() {
x.remove(&u);

let nu: HashSet<NodeIndex> = graph.neighbors(u).collect();
let np: HashSet<NodeIndex> = p
.iter()
.flat_map(|n| graph.neighbors(*n))
.collect::<HashSet<NodeIndex>>()
.union(p)
.cloned()
.collect();
//X' = X u N(u)/ N|P|
let mut x_next: HashSet<NodeIndex> = x
.union(&nu.difference(&np).cloned().collect())
.cloned()
.collect();
let mut p_next: HashSet<NodeIndex> = p.clone();
p_next.insert(u);

if simple_enum(&mut p_next, &mut x_next, graph, subgraphs, k) {
is_leaf_node = true;
} else {
return is_leaf_node;
}
}
is_leaf_node
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(chain_decomposition))?;
m.add_wrapped(wrap_pyfunction!(graph_isolates))?;
m.add_wrapped(wrap_pyfunction!(digraph_isolates))?;
m.add_wrapped(wrap_pyfunction!(connected_subgraphs))?;
m.add_wrapped(wrap_pyfunction!(is_planar))?;
m.add_wrapped(wrap_pyfunction!(read_graphml))?;
m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?;
Expand Down
97 changes: 97 additions & 0 deletions tests/graph/test_all_connected_subgraphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import itertools
import unittest
import rustworkx


def bruteforce(g, k):
connected_subgraphs = []
for sg in (
g.subgraph(selected_nodes) for selected_nodes in itertools.combinations(g.node_indices(), k)
):
if rustworkx.is_connected(sg):
connected_subgraphs.append(list(sg.nodes()))
return connected_subgraphs


class TestGraphAllConnectedSubgraphs(unittest.TestCase):
def setUp(self):
super().setUp()
self.edges = [(0, 1), (1, 2), (2, 3), (0, 3), (0, 4), (4, 5), (4, 7), (7, 6), (5, 6)]
self.nodes = list(range(8))
g = rustworkx.PyGraph()
g.add_nodes_from(self.nodes)
g.add_edges_from_no_data(self.edges)
self.expected_subgraphs = {k: list(bruteforce(g, k)) for k in range(1, 9)}

def test_empty_graph(self):
graph = rustworkx.PyGraph()
subgraphs = rustworkx.connected_subgraphs(graph, 0)
expected = []
self.assertConnectedSubgraphsEqual(subgraphs, expected)

def test_empty_graph_2(self):
graph = rustworkx.PyGraph()
graph.add_nodes_from(self.nodes)
graph.add_edges_from_no_data(self.edges)
subgraphs = rustworkx.connected_subgraphs(graph, 0)
expected = []
self.assertConnectedSubgraphsEqual(subgraphs, expected)

def test_size_one_subgraphs(self):
graph = rustworkx.PyGraph()
graph.add_nodes_from(self.nodes)
graph.add_edges_from_no_data(self.edges)
subgraphs = rustworkx.connected_subgraphs(graph, 1)
self.assertConnectedSubgraphsEqual(subgraphs, self.expected_subgraphs[1])

def test_sized_subgraphs(self):
graph = rustworkx.PyGraph()
graph.add_nodes_from(self.nodes)
graph.add_edges_from_no_data(self.edges)
for i in range(2, 9):
with self.subTest(subgraph_size=i):
subgraphs = rustworkx.connected_subgraphs(graph, i)
self.assertConnectedSubgraphsEqual(subgraphs, self.expected_subgraphs[i])

def test_unique_subgraphs(self):
graph = rustworkx.PyGraph()
graph.add_nodes_from(self.nodes)
graph.add_edges_from_no_data(self.edges)
for i in range(2, 9):
with self.subTest(subgraph_size=i):
subgraphs = rustworkx.connected_subgraphs(graph, i)
self.assertEqual(len(subgraphs), len({tuple(sorted(el)) for el in subgraphs}))

def test_disconnected_graph(self):
graph = rustworkx.PyGraph()
graph.add_nodes_from([0, 1, 2, 3, 4])
graph.add_edge(0, 1, None)
graph.add_edge(1, 2, None)
graph.add_edge(0, 2, None)

graph.add_edge(3, 4, None)

self.assertConnectedSubgraphsEqual(
rustworkx.connected_subgraphs(graph, 1), [[n] for n in graph.nodes()]
)
self.assertConnectedSubgraphsEqual(
rustworkx.connected_subgraphs(graph, 2), graph.edge_list()
)
self.assertConnectedSubgraphsEqual(rustworkx.connected_subgraphs(graph, 3), [[0, 1, 2]])
self.assertConnectedSubgraphsEqual(rustworkx.connected_subgraphs(graph, 4), [])

def assertConnectedSubgraphsEqual(self, subgraphs, expected):
self.assertEqual(
{tuple(sorted(el)) for el in subgraphs}, {tuple(sorted(el)) for el in expected}
)

0 comments on commit 6e37baf

Please sign in to comment.