From 65043ebdd39319884e9485c053a1c743ba1cfee8 Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Mon, 19 Jun 2023 17:09:33 +0900 Subject: [PATCH 01/13] Implement deduplicate edge --- rustworkx-core/src/steiner_tree.rs | 55 ++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 rustworkx-core/src/steiner_tree.rs diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs new file mode 100644 index 000000000..dced2e879 --- /dev/null +++ b/rustworkx-core/src/steiner_tree.rs @@ -0,0 +1,55 @@ +use hashbrown::HashMap; +use petgraph::stable_graph::{NodeIndex, StableGraph}; +use petgraph::visit::{EdgeRef, GraphBase, IntoEdgeReferences}; +use petgraph::Directed; +use std::cmp::Ordering; + +fn deduplicate_edges( + out_graph: &mut StableGraph<(), W, Directed>, + weight_fn: &mut F, +) -> Result<(), E> +where + W: Clone, + F: FnMut(&W) -> Result, +{ + //if out_graph.multigraph { + if true { + // Find all edges between nodes + let mut duplicate_map: HashMap< + [NodeIndex; 2], + Vec<( as GraphBase>::EdgeId, W)>, + > = HashMap::new(); + for edge in out_graph.edge_references() { + if duplicate_map.contains_key(&[edge.source(), edge.target()]) { + duplicate_map + .get_mut(&[edge.source(), edge.target()]) + .unwrap() + .push((edge.id(), edge.weight().clone())); + } else if duplicate_map.contains_key(&[edge.target(), edge.source()]) { + duplicate_map + .get_mut(&[edge.target(), edge.source()]) + .unwrap() + .push((edge.id(), edge.weight().clone())); + } else { + duplicate_map.insert( + [edge.source(), edge.target()], + vec![(edge.id(), edge.weight().clone())], + ); + } + } + // For a node pair with > 1 edge find minimum edge and remove others + for edges_raw in duplicate_map.values().filter(|x| x.len() > 1) { + let mut edges: Vec<( as GraphBase>::EdgeId, f64)> = + Vec::with_capacity(edges_raw.len()); + for edge in edges_raw { + let w = weight_fn(&edge.1)?; + edges.push((edge.0, w)); + } + edges.sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(Ordering::Less)); + edges[1..].iter().for_each(|x| { + out_graph.remove_edge(x.0); + }); + } + } + Ok(()) +} From a718ee0b5be37ca3b637b0835c4ac0a53ed85cec Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Mon, 19 Jun 2023 17:11:51 +0900 Subject: [PATCH 02/13] Move MetricClousureEdge --- rustworkx-core/src/steiner_tree.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs index dced2e879..f08c65798 100644 --- a/rustworkx-core/src/steiner_tree.rs +++ b/rustworkx-core/src/steiner_tree.rs @@ -4,6 +4,14 @@ use petgraph::visit::{EdgeRef, GraphBase, IntoEdgeReferences}; use petgraph::Directed; use std::cmp::Ordering; +struct MetricClosureEdge { + source: usize, + target: usize, + distance: f64, + path: Vec, +} + + fn deduplicate_edges( out_graph: &mut StableGraph<(), W, Directed>, weight_fn: &mut F, From 2ba4b1d0a69e2dc8dd066c43d861615fe353a755 Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Thu, 29 Jun 2023 17:53:46 +0900 Subject: [PATCH 03/13] Add definition of _metric_clousure_edges --- rustworkx-core/src/steiner_tree.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs index f08c65798..d5de8f8ff 100644 --- a/rustworkx-core/src/steiner_tree.rs +++ b/rustworkx-core/src/steiner_tree.rs @@ -11,6 +11,19 @@ struct MetricClosureEdge { path: Vec, } +fn _metric_closure_edges( + graph: &StableGraph<(), W, Directed>, + weight_fn: &mut F, +) -> Result, E> { + let node_count = graph.node_count(); + if node_count == 0 { + return Ok(Vec::new()); + } + // TODO implemented + panic!("not implemented"); +} + + fn deduplicate_edges( out_graph: &mut StableGraph<(), W, Directed>, From 1ee9fc6f8126e9b7fc486807e299a299a0c78b32 Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Thu, 29 Jun 2023 18:35:51 +0900 Subject: [PATCH 04/13] Add definition of metric_clousure --- rustworkx-core/src/steiner_tree.rs | 33 ++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs index d5de8f8ff..1ad113513 100644 --- a/rustworkx-core/src/steiner_tree.rs +++ b/rustworkx-core/src/steiner_tree.rs @@ -10,6 +10,39 @@ struct MetricClosureEdge { distance: f64, path: Vec, } +/// Return the metric closure of a graph +/// +/// The metric closure of a graph is the complete graph in which each edge is +/// weighted by the shortest path distance between the nodes in the graph. +/// +/// :param PyGraph graph: The input graph to find the metric closure for +/// :param weight_fn: A callable object that will be passed an edge's +/// weight/data payload and expected to return a ``float``. For example, +/// you can use ``weight_fn=float`` to cast every weight as a float +/// +/// :return: A metric closure graph from the input graph +/// :rtype: PyGraph +/// :raises ValueError: when an edge weight with NaN or negative value +/// is provided. +pub fn metric_closure( + graph: &StableGraph<(), W, Directed>, + weight_fn: &mut F, +) -> Result, E> +where + W: Clone, +{ + let mut out_graph: StableGraph<(), W, Directed> = graph.clone(); + out_graph.clear_edges(); + // let edges = _metric_closure_edges(graph, weight_fn)?; + //for edge in edges { + // out_graph.add_edge( + // NodeIndex::new(edge.source), + // NodeIndex::new(edge.target), + // edge.distance, + //); + //} + Ok(out_graph) +} fn _metric_closure_edges( graph: &StableGraph<(), W, Directed>, From df973f3b53fd7ba696325e88ac77a4d11824168c Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Thu, 29 Jun 2023 18:52:54 +0900 Subject: [PATCH 05/13] Add definition of fast_metric_edges --- rustworkx-core/src/steiner_tree.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs index 1ad113513..29b570bdb 100644 --- a/rustworkx-core/src/steiner_tree.rs +++ b/rustworkx-core/src/steiner_tree.rs @@ -1,15 +1,19 @@ use hashbrown::HashMap; use petgraph::stable_graph::{NodeIndex, StableGraph}; +use petgraph::visit::NodeIndexable; use petgraph::visit::{EdgeRef, GraphBase, IntoEdgeReferences}; use petgraph::Directed; use std::cmp::Ordering; +use crate::petgraph::unionfind::UnionFind; + struct MetricClosureEdge { source: usize, target: usize, distance: f64, path: Vec, } + /// Return the metric closure of a graph /// /// The metric closure of a graph is the complete graph in which each edge is @@ -56,6 +60,25 @@ fn _metric_closure_edges( panic!("not implemented"); } +/// Computes the shortest path between all pairs `(s, t)` of the given `terminal_nodes` +/// *provided* that: +/// - there is an edge `(u, v)` in the graph and path pass through this edge. +/// - node `s` is the closest node to `u` among all `terminal_nodes` +/// - node `t` is the closest node to `v` among all `terminal_nodes` +/// and wraps the result inside a `MetricClosureEdge` +/// +/// For example, if all vertices are terminals, it returns the original edges of the graph. +fn fast_metric_edges( + graph: &mut StableGraph<(), W, Directed>, + terminal_nodes: Vec, + weight_fn: &mut F, +) -> Result, E> +where + W: Clone, + F: FnMut(&W) -> Result, +{ + Ok(Vec::new()) +} fn deduplicate_edges( From 86fdc60a6dbb07f39a443f1df1dfea848f199a70 Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Thu, 29 Jun 2023 18:53:03 +0900 Subject: [PATCH 06/13] Add definition of steiner_tree --- rustworkx-core/src/steiner_tree.rs | 58 ++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs index 29b570bdb..73857d5b2 100644 --- a/rustworkx-core/src/steiner_tree.rs +++ b/rustworkx-core/src/steiner_tree.rs @@ -80,6 +80,64 @@ where Ok(Vec::new()) } +/// Return an approximation to the minimum Steiner tree of a graph. +/// +/// The minimum tree of ``graph`` with regard to a set of ``terminal_nodes`` +/// is a tree within ``graph`` that spans those nodes and has a minimum size +/// (measured as the sum of edge weights) amoung all such trees. +/// +/// The minimum steiner tree can be approximated by computing the minimum +/// spanning tree of the subgraph of the metric closure of ``graph`` induced +/// by the terminal nodes, where the metric closure of ``graph`` is the +/// complete graph in which each edge is weighted by the shortest path distance +/// between nodes in ``graph``. +/// +/// This algorithm [1]_ produces a tree whose weight is within a +/// :math:`(2 - (2 / t))` factor of the weight of the optimal Steiner tree +/// where :math:`t` is the number of terminal nodes. The algorithm implemented +/// here is due to [2]_ . It avoids computing all pairs shortest paths but rather +/// reduces the problem to a single source shortest path and a minimum spanning tree +/// problem. +/// +/// :param PyGraph graph: The graph to compute the minimum Steiner tree for +/// :param list terminal_nodes: The list of node indices for which the Steiner +/// tree is to be computed for. +/// :param weight_fn: A callable object that will be passed an edge's +/// weight/data payload and expected to return a ``float``. For example, +/// you can use ``weight_fn=float`` to cast every weight as a float. +/// +/// :returns: An approximation to the minimal steiner tree of ``graph`` induced +/// by ``terminal_nodes``. +/// :rtype: PyGraph +/// :raises ValueError: when an edge weight with NaN or negative value +/// is provided. +/// +/// .. [1] Kou, Markowsky & Berman, +/// "A fast algorithm for Steiner trees" +/// Acta Informatica 15, 141–145 (1981). +/// https://link.springer.com/article/10.1007/BF00288961 +/// .. [2] Kurt Mehlhorn, +/// "A faster approximation algorithm for the Steiner problem in graphs" +/// https://doi.org/10.1016/0020-0190(88)90066-X +pub fn steiner_tree( + graph: &mut StableGraph<(), W, Directed>, + terminal_nodes: Vec, + weight_fn: &mut F, + //) -> Result, E> +) -> Result<(), E> +where + W: Clone, + F: FnMut(&W) -> Result, +{ + let mut edge_list = fast_metric_edges(graph, terminal_nodes, &mut weight_fn)?; + let mut subgraphs = UnionFind::::new(graph.node_bound()); + edge_list.par_sort_unstable_by(|a, b| { + let weight_a = (a.distance, a.source, a.target); + let weight_b = (b.distance, b.source, b.target); + weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) + }); + Ok(()) +} fn deduplicate_edges( out_graph: &mut StableGraph<(), W, Directed>, From ad114be12a41cd5004763a461404c0a124f7ab13 Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Thu, 29 Jun 2023 19:07:12 +0900 Subject: [PATCH 07/13] Add mod steiner_tree to lib.rs --- rustworkx-core/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index a432d891a..e5d38eb58 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -92,6 +92,8 @@ mod min_scored; pub mod token_swapper; pub mod utils; +pub mod steiner_tree; + // re-export petgraph so there is a consistent version available to users and // then only need to require rustworkx-core in their dependencies pub use petgraph; From 5790bc3ddad084d1b0cf9e632ba7b552ac6091da Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Fri, 30 Jun 2023 16:16:35 +0900 Subject: [PATCH 08/13] Add necessary crate --- rustworkx-core/src/steiner_tree.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs index 73857d5b2..d9edf5db8 100644 --- a/rustworkx-core/src/steiner_tree.rs +++ b/rustworkx-core/src/steiner_tree.rs @@ -3,6 +3,7 @@ use petgraph::stable_graph::{NodeIndex, StableGraph}; use petgraph::visit::NodeIndexable; use petgraph::visit::{EdgeRef, GraphBase, IntoEdgeReferences}; use petgraph::Directed; +use rayon::prelude::ParallelSliceMut; use std::cmp::Ordering; use crate::petgraph::unionfind::UnionFind; From 162a9fa064a65fea8f1846c6477c99282777ae47 Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Fri, 30 Jun 2023 16:17:13 +0900 Subject: [PATCH 09/13] correct type --- rustworkx-core/src/steiner_tree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs index d9edf5db8..e5bf35f37 100644 --- a/rustworkx-core/src/steiner_tree.rs +++ b/rustworkx-core/src/steiner_tree.rs @@ -130,7 +130,7 @@ where W: Clone, F: FnMut(&W) -> Result, { - let mut edge_list = fast_metric_edges(graph, terminal_nodes, &mut weight_fn)?; + let mut edge_list = fast_metric_edges(graph, terminal_nodes, weight_fn)?; let mut subgraphs = UnionFind::::new(graph.node_bound()); edge_list.par_sort_unstable_by(|a, b| { let weight_a = (a.distance, a.source, a.target); From f1e5457a9985d9055c568b1c2632de3773c48332 Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Fri, 14 Jul 2023 13:25:52 +0900 Subject: [PATCH 10/13] change MetricClousureEdge to generic --- rustworkx-core/src/steiner_tree.rs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs index e5bf35f37..faf507251 100644 --- a/rustworkx-core/src/steiner_tree.rs +++ b/rustworkx-core/src/steiner_tree.rs @@ -8,10 +8,10 @@ use std::cmp::Ordering; use crate::petgraph::unionfind::UnionFind; -struct MetricClosureEdge { +struct MetricClosureEdge { source: usize, target: usize, - distance: f64, + distance: W, path: Vec, } @@ -38,21 +38,21 @@ where { let mut out_graph: StableGraph<(), W, Directed> = graph.clone(); out_graph.clear_edges(); - // let edges = _metric_closure_edges(graph, weight_fn)?; - //for edge in edges { - // out_graph.add_edge( - // NodeIndex::new(edge.source), - // NodeIndex::new(edge.target), - // edge.distance, - //); - //} + let edges = _metric_closure_edges(graph, weight_fn)?; + for edge in edges { + out_graph.add_edge( + NodeIndex::new(edge.source), + NodeIndex::new(edge.target), + edge.distance, + ); + } Ok(out_graph) } fn _metric_closure_edges( graph: &StableGraph<(), W, Directed>, weight_fn: &mut F, -) -> Result, E> { +) -> Result>, E> { let node_count = graph.node_count(); if node_count == 0 { return Ok(Vec::new()); @@ -73,7 +73,7 @@ fn fast_metric_edges( graph: &mut StableGraph<(), W, Directed>, terminal_nodes: Vec, weight_fn: &mut F, -) -> Result, E> +) -> Result>, E> where W: Clone, F: FnMut(&W) -> Result, @@ -137,7 +137,7 @@ where let weight_b = (b.distance, b.source, b.target); weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) }); - Ok(()) + Ok(out) } fn deduplicate_edges( From 54d2d0ea5b9bfd8b81212fe46eda85bb18aea1ae Mon Sep 17 00:00:00 2001 From: Ryuhei Yoshida Date: Thu, 27 Jul 2023 17:19:16 +0900 Subject: [PATCH 11/13] Update steiner_tree --- rustworkx-core/src/steiner_tree.rs | 158 +++++++++++++++++++++++++++-- 1 file changed, 150 insertions(+), 8 deletions(-) diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs index faf507251..3ad1af55d 100644 --- a/rustworkx-core/src/steiner_tree.rs +++ b/rustworkx-core/src/steiner_tree.rs @@ -1,14 +1,18 @@ -use hashbrown::HashMap; -use petgraph::stable_graph::{NodeIndex, StableGraph}; +use hashbrown::{HashMap, HashSet}; +use num_traits::Float; +use petgraph::stable_graph::{EdgeReference, NodeIndex, StableGraph}; use petgraph::visit::NodeIndexable; use petgraph::visit::{EdgeRef, GraphBase, IntoEdgeReferences}; use petgraph::Directed; use rayon::prelude::ParallelSliceMut; use std::cmp::Ordering; +use crate::dictmap::{DictMap, InitWithHasher}; use crate::petgraph::unionfind::UnionFind; +use crate::shortest_path::dijkstra; +use crate::utils::pairwise; -struct MetricClosureEdge { +pub struct MetricClosureEdge { source: usize, target: usize, distance: W, @@ -57,6 +61,7 @@ fn _metric_closure_edges( if node_count == 0 { return Ok(Vec::new()); } + // TODO implemented panic!("not implemented"); } @@ -75,9 +80,82 @@ fn fast_metric_edges( weight_fn: &mut F, ) -> Result>, E> where - W: Clone, + W: Clone + + std::ops::Add + + std::default::Default + + std::marker::Copy + + std::cmp::PartialOrd + + std::fmt::Debug, F: FnMut(&W) -> Result, { + // temporarily add a ``dummy`` node, connect it with + // all the terminal nodes and find all the shortest paths + // starting from ``dummy`` node. + let dummy = graph.add_node(()); + for node in terminal_nodes { + graph.add_edge(dummy, NodeIndex::new(node), None); + } + let cost_fn = |edge: EdgeReference<'_, W>| -> Result { + if edge.source() != dummy && edge.target() != dummy { + let weight: f64 = weight_fn(edge.weight())?; + is_valid_weight(weight) + } else { + Ok(W::zero()) + } + }; + let mut paths = DictMap::with_capacity(graph.node_count()); + let mut distance: DictMap = + dijkstra(&*graph, dummy, None, cost_fn, Some(&mut paths))?; + paths.remove(&dummy); + distance.remove(&dummy); + graph.remove_node(dummy); + + // ``partition[u]`` holds the terminal node closest to node ``u``. + let mut partition: Vec = vec![std::usize::MAX; graph.node_bound()]; + for (u, path) in paths.iter() { + let u = u.index(); + partition[u] = path[1].index(); + } + + let mut out_edges: Vec> = Vec::with_capacity(graph.edge_count()); + for edge in graph.edge_references() { + let source = edge.source(); + let target = edge.target(); + // assert that ``source`` is reachable from a terminal node. + if distance.contains_key(&source) { + let weight: W = distance[&source] + cost_fn(edge)? + distance[&target]; + let mut path: Vec = paths[&source].iter().skip(1).map(|x| x.index()).collect(); + path.append( + &mut paths[&target] + .iter() + .skip(1) + .rev() + .map(|x| x.index()) + .collect(), + ); + + let source = source.index(); + let target = target.index(); + + let mut source = partition[source]; + let mut target = partition[target]; + + match source.cmp(&target) { + Ordering::Equal => continue, + Ordering::Greater => std::mem::swap(&mut source, &mut target), + _ => {} + } + + out_edges.push(MetricClosureEdge { + source, + target, + distance: weight, + path, + }); + } + } + + //TODO Ok(Vec::new()) } @@ -124,11 +202,16 @@ pub fn steiner_tree( graph: &mut StableGraph<(), W, Directed>, terminal_nodes: Vec, weight_fn: &mut F, - //) -> Result, E> -) -> Result<(), E> +) -> Result, E> where - W: Clone, + W: Copy + + Clone + + PartialOrd + + std::fmt::Debug + + std::default::Default + + std::ops::Add, F: FnMut(&W) -> Result, + MetricClosureEdge: Send, { let mut edge_list = fast_metric_edges(graph, terminal_nodes, weight_fn)?; let mut subgraphs = UnionFind::::new(graph.node_bound()); @@ -137,7 +220,51 @@ where let weight_b = (b.distance, b.source, b.target); weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) }); - Ok(out) + let mut mst_edges: Vec> = Vec::new(); + for float_edge_pair in edge_list { + let u = float_edge_pair.source; + let v = float_edge_pair.target; + if subgraphs.union(u, v) { + mst_edges.push(float_edge_pair); + } + } + //TODO implement error + // assert that the terminal nodes are connected. + //if !terminal_nodes.is_empty() && mst_edges.len() != terminal_nodes.len() - 1 { + //return Err(PyValueError::new_err( "The terminal nodes in the input graph must belong to the same connected component. The steiner tree is not defined for a graph with unconnected terminal nodes",)); + //} + // Generate the output graph from the MST + let out_edge_list: Vec<[usize; 2]> = mst_edges + .into_iter() + .flat_map(|edge| pairwise(edge.path)) + .filter_map(|x| x.0.map(|a| [a, x.1])) + .collect(); + let out_edges: HashSet<(usize, usize)> = out_edge_list.iter().map(|x| (x[0], x[1])).collect(); + let mut out_graph = graph.clone(); + let out_nodes: HashSet = out_edge_list + .iter() + .flat_map(|x| x.iter()) + .copied() + .map(NodeIndex::new) + .collect(); + for node in graph + .node_indices() + .filter(|node| !out_nodes.contains(node)) + { + out_graph.remove_node(node); + // out_graph.node_removed = true; + } + for edge in graph.edge_references().filter(|edge| { + let source = edge.source().index(); + let target = edge.target().index(); + !out_edges.contains(&(source, target)) && !out_edges.contains(&(target, source)) + }) { + out_graph.remove_edge(edge.id()); + } + // Deduplicate potential duplicate edges + deduplicate_edges(&mut out_graph, weight_fn)?; + + Ok(out_graph) } fn deduplicate_edges( @@ -189,3 +316,18 @@ where } Ok(()) } + +#[inline] +fn is_valid_weight(val: W) -> Result { + if val.is_sign_negative() { + return Err(E); + //return Err(E "Negative weights not supported."); + } + + if val.is_nan() { + return Err(E); + //return Err(E "NaN weights not supported."); + } + + Ok(val) +} From 9ccdca1f03cff3e2cc5ea30f962f662935c1a1dc Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Mon, 19 Feb 2024 14:05:01 -0500 Subject: [PATCH 12/13] Finish rustworkx-core port This commit finishes the rustworkx-core port of the steiner tree and metric closure functions. In particular this was especially tricky because the petgraph traits make it exceedingly difficult to have a generic function that takes in a graph for analysis and modifies it as most of the traits required for visiting/iteration that are used for analysis are only defined on borrowed graphs, and the limited traits for modifying graphs are defined on owned graph types. This causes a conflict where you can't easily express that a generic type G created in a function from a user input is both mutated using a trait and analyzed as there is a type mismatch between G and &G. After spending far too long to fail to find a pattern to express this, I opted to just use a discrete type for the return and leave the actual graph mutation up to the rustworkx-core user because we're lacking the ability to cleanly express what is needed via petgraph. --- rustworkx-core/src/steiner_tree.rs | 623 ++++++++++++++++++++--------- src/steiner_tree.rs | 283 +++---------- tests/graph/test_steiner_tree.py | 2 +- 3 files changed, 494 insertions(+), 414 deletions(-) diff --git a/rustworkx-core/src/steiner_tree.rs b/rustworkx-core/src/steiner_tree.rs index 3ad1af55d..bca9e40a6 100644 --- a/rustworkx-core/src/steiner_tree.rs +++ b/rustworkx-core/src/steiner_tree.rs @@ -1,21 +1,120 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use std::cmp::{Eq, Ordering}; +use std::convert::Infallible; +use std::hash::Hash; + use hashbrown::{HashMap, HashSet}; -use num_traits::Float; -use petgraph::stable_graph::{EdgeReference, NodeIndex, StableGraph}; -use petgraph::visit::NodeIndexable; -use petgraph::visit::{EdgeRef, GraphBase, IntoEdgeReferences}; -use petgraph::Directed; -use rayon::prelude::ParallelSliceMut; -use std::cmp::Ordering; +use rayon::prelude::*; + +use petgraph::stable_graph::{EdgeIndex, NodeIndex, StableGraph}; +use petgraph::unionfind::UnionFind; +use petgraph::visit::{ + EdgeCount, EdgeIndexable, EdgeRef, GraphProp, IntoEdgeReferences, IntoEdges, + IntoNodeIdentifiers, IntoNodeReferences, NodeCount, NodeIndexable, NodeRef, Visitable, +}; +use petgraph::Undirected; -use crate::dictmap::{DictMap, InitWithHasher}; -use crate::petgraph::unionfind::UnionFind; +use crate::dictmap::*; use crate::shortest_path::dijkstra; use crate::utils::pairwise; -pub struct MetricClosureEdge { +type AllPairsDijkstraReturn = HashMap>, DictMap)>; + +fn all_pairs_dijkstra_shortest_paths( + graph: G, + mut weight_fn: F, +) -> Result +where + G: NodeIndexable + + IntoNodeIdentifiers + + EdgeCount + + NodeCount + + EdgeIndexable + + Visitable + + Sync + + IntoEdges, + G::NodeId: Eq + Hash + Send, + G::EdgeId: Eq + Hash + Send, + F: FnMut(G::EdgeRef) -> Result, +{ + if graph.node_count() == 0 { + return Ok(HashMap::new()); + } else if graph.edge_count() == 0 { + return Ok(graph + .node_identifiers() + .map(|x| { + ( + NodeIndexable::to_index(&graph, x), + (DictMap::new(), DictMap::new()), + ) + }) + .collect()); + } + let mut edge_weights: Vec> = vec![None; graph.edge_bound()]; + for edge in graph.edge_references() { + let index = EdgeIndexable::to_index(&graph, edge.id()); + edge_weights[index] = Some(weight_fn(edge)?); + } + let edge_cost = |e: G::EdgeRef| -> Result { + Ok(edge_weights[EdgeIndexable::to_index(&graph, e.id())].unwrap()) + }; + + let node_indices: Vec = graph + .node_identifiers() + .map(|n| NodeIndexable::to_index(&graph, n)) + .collect(); + Ok(node_indices + .into_par_iter() + .map(|x| { + let mut paths: DictMap> = + DictMap::with_capacity(graph.node_count()); + let distances: DictMap = dijkstra( + graph, + NodeIndexable::from_index(&graph, x), + None, + edge_cost, + Some(&mut paths), + ) + .unwrap(); + ( + x, + ( + paths + .into_iter() + .map(|(k, v)| { + ( + NodeIndexable::to_index(&graph, k), + v.into_iter() + .map(|n| NodeIndexable::to_index(&graph, n)) + .collect(), + ) + }) + .collect(), + distances + .into_iter() + .map(|(k, v)| (NodeIndexable::to_index(&graph, k), v)) + .collect(), + ), + ) + }) + .collect()) +} + +struct MetricClosureEdge { source: usize, target: usize, - distance: W, + distance: f64, path: Vec, } @@ -24,46 +123,178 @@ pub struct MetricClosureEdge { /// The metric closure of a graph is the complete graph in which each edge is /// weighted by the shortest path distance between the nodes in the graph. /// -/// :param PyGraph graph: The input graph to find the metric closure for -/// :param weight_fn: A callable object that will be passed an edge's -/// weight/data payload and expected to return a ``float``. For example, -/// you can use ``weight_fn=float`` to cast every weight as a float +/// Arguments: +/// `graph`: The input graph to compute the metric closure for +/// `weight_fn`: A callable weight function that will be passed an edge reference +/// for each edge in the graph and it is expected to return a `Result` +/// which if it doesn't error represents the weight of that edge. +/// `default_weight`: A blind callable that returns a default weight to use for +/// edges added to the output +/// +/// Returns a `StableGraph` with the input graph node ids for node weights and edge weights with a +/// tuple of the numeric weight (found via `weight_fn`) and the path. The output will be `None` +/// if `graph` is disconnected. +/// +/// # Example +/// ```rust +/// use std::convert::Infallible; +/// +/// use rustworkx_core::petgraph::Graph; +/// use rustworkx_core::petgraph::Undirected; +/// use rustworkx_core::petgraph::graph::EdgeReference; +/// use rustworkx_core::petgraph::visit::{IntoEdgeReferences, EdgeRef}; +/// +/// use rustworkx_core::steiner_tree::metric_closure; +/// +/// let input_graph = Graph::<(), u8, Undirected>::from_edges(&[ +/// (0, 1, 10), +/// (1, 2, 10), +/// (2, 3, 10), +/// (3, 4, 10), +/// (4, 5, 10), +/// (1, 6, 1), +/// (6, 4, 1), +/// ]); +/// +/// let weight_fn = |e: EdgeReference| -> Result { +/// Ok(*e.weight() as f64) +/// }; +/// +/// let closure = metric_closure(&input_graph, weight_fn).unwrap().unwrap(); +/// let mut output_edge_list: Vec<(usize, usize, (f64, Vec))> = closure.edge_references().map(|edge| (edge.source().index(), edge.target().index(), edge.weight().clone())).collect(); +/// let mut expected_edges: Vec<(usize, usize, (f64, Vec))> = vec![ +/// (0, 1, (10.0, vec![0, 1])), +/// (0, 2, (20.0, vec![0, 1, 2])), +/// (0, 3, (22.0, vec![0, 1, 6, 4, 3])), +/// (0, 4, (12.0, vec![0, 1, 6, 4])), +/// (0, 5, (22.0, vec![0, 1, 6, 4, 5])), +/// (0, 6, (11.0, vec![0, 1, 6])), +/// (1, 2, (10.0, vec![1, 2])), +/// (1, 3, (12.0, vec![1, 6, 4, 3])), +/// (1, 4, (2.0, vec![1, 6, 4])), +/// (1, 5, (12.0, vec![1, 6, 4, 5])), +/// (1, 6, (1.0, vec![1, 6])), +/// (2, 3, (10.0, vec![2, 3])), +/// (2, 4, (12.0, vec![2, 1, 6, 4])), +/// (2, 5, (22.0, vec![2, 1, 6, 4, 5])), +/// (2, 6, (11.0, vec![2, 1, 6])), +/// (3, 4, (10.0, vec![3, 4])), +/// (3, 5, (20.0, vec![3, 4, 5])), +/// (3, 6, (11.0, vec![3, 4, 6])), +/// (4, 5, (10.0, vec![4, 5])), +/// (4, 6, (1.0, vec![4, 6])), +/// (5, 6, (11.0, vec![5, 4, 6])), +/// ]; +/// output_edge_list.sort_by_key(|x| [x.0, x.1]); +/// expected_edges.sort_by_key(|x| [x.0, x.1]); +/// assert_eq!(output_edge_list, expected_edges); /// -/// :return: A metric closure graph from the input graph -/// :rtype: PyGraph -/// :raises ValueError: when an edge weight with NaN or negative value -/// is provided. -pub fn metric_closure( - graph: &StableGraph<(), W, Directed>, - weight_fn: &mut F, -) -> Result, E> +/// ``` +#[allow(clippy::type_complexity)] +pub fn metric_closure( + graph: G, + weight_fn: F, +) -> Result), Undirected>>, E> where - W: Clone, + G: NodeIndexable + + EdgeIndexable + + Sync + + EdgeCount + + NodeCount + + Visitable + + IntoNodeReferences + + IntoEdges + + Visitable + + GraphProp, + G::NodeId: Eq + Hash + NodeRef + Send, + G::EdgeId: Eq + Hash + Send, + G::NodeWeight: Clone, + F: FnMut(G::EdgeRef) -> Result, { - let mut out_graph: StableGraph<(), W, Directed> = graph.clone(); - out_graph.clear_edges(); - let edges = _metric_closure_edges(graph, weight_fn)?; - for edge in edges { + let mut out_graph: StableGraph), Undirected> = + StableGraph::with_capacity(graph.node_count(), graph.edge_count()); + let node_map: HashMap = graph + .node_references() + .map(|node| { + ( + NodeIndexable::to_index(&graph, node.id()), + out_graph.add_node(node.id()), + ) + }) + .collect(); + let edges = metric_closure_edges(graph, weight_fn)?; + if edges.is_none() { + return Ok(None); + } + for edge in edges.unwrap() { out_graph.add_edge( - NodeIndex::new(edge.source), - NodeIndex::new(edge.target), - edge.distance, + node_map[&edge.source], + node_map[&edge.target], + (edge.distance, edge.path), ); } - Ok(out_graph) + Ok(Some(out_graph)) } -fn _metric_closure_edges( - graph: &StableGraph<(), W, Directed>, - weight_fn: &mut F, -) -> Result>, E> { +fn metric_closure_edges( + graph: G, + weight_fn: F, +) -> Result>, E> +where + G: NodeIndexable + + Sync + + Visitable + + IntoNodeReferences + + IntoEdges + + Visitable + + NodeIndexable + + NodeCount + + EdgeCount + + EdgeIndexable, + G::NodeId: Eq + Hash + Send, + G::EdgeId: Eq + Hash + Send, + F: FnMut(G::EdgeRef) -> Result, +{ let node_count = graph.node_count(); if node_count == 0 { - return Ok(Vec::new()); + return Ok(Some(Vec::new())); } - - // TODO implemented - panic!("not implemented"); + let mut out_vec = Vec::with_capacity(node_count * (node_count - 1) / 2); + let paths = all_pairs_dijkstra_shortest_paths(graph, weight_fn)?; + let mut nodes: HashSet = graph + .node_identifiers() + .map(|x| NodeIndexable::to_index(&graph, x)) + .collect(); + let first_node = graph + .node_identifiers() + .map(|x| NodeIndexable::to_index(&graph, x)) + .next() + .unwrap(); + let path_keys: HashSet = paths[&first_node].0.keys().copied().collect(); + // first_node will always be missing from path_keys so if the difference + // is > 1 with nodes that means there is another node in the graph that + // first_node doesn't have a path to. + if nodes.difference(&path_keys).count() > 1 { + return Ok(None); + } + // Iterate over node indices for a deterministic order + for node in graph + .node_identifiers() + .map(|x| NodeIndexable::to_index(&graph, x)) + { + let path_map = &paths[&node].0; + nodes.remove(&node); + let distance = &paths[&node].1; + for v in &nodes { + out_vec.push(MetricClosureEdge { + source: node, + target: *v, + distance: distance[v], + path: path_map[v].clone(), + }); + } + } + Ok(Some(out_vec)) } /// Computes the shortest path between all pairs `(s, t)` of the given `terminal_nodes` @@ -74,68 +305,104 @@ fn _metric_closure_edges( /// and wraps the result inside a `MetricClosureEdge` /// /// For example, if all vertices are terminals, it returns the original edges of the graph. -fn fast_metric_edges( - graph: &mut StableGraph<(), W, Directed>, - terminal_nodes: Vec, - weight_fn: &mut F, -) -> Result>, E> +fn fast_metric_edges( + in_graph: G, + terminal_nodes: &[G::NodeId], + mut weight_fn: F, +) -> Result, E> where - W: Clone - + std::ops::Add - + std::default::Default - + std::marker::Copy - + std::cmp::PartialOrd - + std::fmt::Debug, - F: FnMut(&W) -> Result, + G: IntoEdges + + NodeIndexable + + EdgeIndexable + + Sync + + EdgeCount + + Visitable + + IntoNodeReferences + + NodeCount, + G::NodeId: Eq + Hash + Send, + G::EdgeId: Eq + Hash + Send, + F: FnMut(G::EdgeRef) -> Result, { + let mut graph: StableGraph<(), (), Undirected> = StableGraph::with_capacity( + in_graph.node_count() + 1, + in_graph.edge_count() + terminal_nodes.len(), + ); + let node_map: HashMap = in_graph + .node_references() + .map(|n| (n.id(), graph.add_node(()))) + .collect(); + let reverse_node_map: HashMap = + node_map.iter().map(|(k, v)| (*v, *k)).collect(); + let edge_map: HashMap = in_graph + .edge_references() + .map(|e| { + ( + graph.add_edge(node_map[&e.source()], node_map[&e.target()], ()), + e, + ) + }) + .collect(); + // temporarily add a ``dummy`` node, connect it with // all the terminal nodes and find all the shortest paths // starting from ``dummy`` node. let dummy = graph.add_node(()); for node in terminal_nodes { - graph.add_edge(dummy, NodeIndex::new(node), None); + graph.add_edge(dummy, node_map[node], ()); } - let cost_fn = |edge: EdgeReference<'_, W>| -> Result { - if edge.source() != dummy && edge.target() != dummy { - let weight: f64 = weight_fn(edge.weight())?; - is_valid_weight(weight) - } else { - Ok(W::zero()) - } - }; + let mut paths = DictMap::with_capacity(graph.node_count()); - let mut distance: DictMap = - dijkstra(&*graph, dummy, None, cost_fn, Some(&mut paths))?; - paths.remove(&dummy); - distance.remove(&dummy); - graph.remove_node(dummy); + + let mut wrapped_weight_fn = + |e: <&StableGraph<(), ()> as IntoEdgeReferences>::EdgeRef| -> Result { + if let Some(edge_ref) = edge_map.get(&e.id()) { + weight_fn(*edge_ref) + } else { + Ok(0.0) + } + }; + + let mut distance: DictMap = dijkstra( + &graph, + dummy, + None, + &mut wrapped_weight_fn, + Some(&mut paths), + )?; + paths.swap_remove(&dummy); + distance.swap_remove(&dummy); // ``partition[u]`` holds the terminal node closest to node ``u``. let mut partition: Vec = vec![std::usize::MAX; graph.node_bound()]; for (u, path) in paths.iter() { - let u = u.index(); - partition[u] = path[1].index(); + let u = NodeIndexable::to_index(&in_graph, reverse_node_map[u]); + partition[u] = NodeIndexable::to_index(&in_graph, reverse_node_map[&path[1]]); } - let mut out_edges: Vec> = Vec::with_capacity(graph.edge_count()); + let mut out_edges: Vec = Vec::with_capacity(graph.edge_count()); + for edge in graph.edge_references() { let source = edge.source(); let target = edge.target(); // assert that ``source`` is reachable from a terminal node. if distance.contains_key(&source) { - let weight: W = distance[&source] + cost_fn(edge)? + distance[&target]; - let mut path: Vec = paths[&source].iter().skip(1).map(|x| x.index()).collect(); + let weight = distance[&source] + wrapped_weight_fn(edge)? + distance[&target]; + let mut path: Vec = paths[&source] + .iter() + .skip(1) + .map(|x| NodeIndexable::to_index(&in_graph, reverse_node_map[x])) + .collect(); path.append( &mut paths[&target] .iter() .skip(1) .rev() - .map(|x| x.index()) + .map(|x| NodeIndexable::to_index(&in_graph, reverse_node_map[x])) .collect(), ); - let source = source.index(); - let target = target.index(); + let source = NodeIndexable::to_index(&in_graph, reverse_node_map[&source]); + let target = NodeIndexable::to_index(&in_graph, reverse_node_map[&target]); let mut source = partition[source]; let mut target = partition[target]; @@ -155,8 +422,23 @@ where } } - //TODO - Ok(Vec::new()) + // if parallel edges, keep the edge with minimum distance. + out_edges.par_sort_unstable_by(|a, b| { + let weight_a = (a.source, a.target, a.distance); + let weight_b = (b.source, b.target, b.distance); + weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) + }); + + out_edges.dedup_by(|edge_a, edge_b| { + edge_a.source == edge_b.source && edge_a.target == edge_b.target + }); + + Ok(out_edges) +} + +pub struct SteinerTreeResult { + pub used_node_indices: HashSet, + pub used_edge_endpoints: HashSet<(usize, usize)>, } /// Return an approximation to the minimum Steiner tree of a graph. @@ -178,18 +460,53 @@ where /// reduces the problem to a single source shortest path and a minimum spanning tree /// problem. /// -/// :param PyGraph graph: The graph to compute the minimum Steiner tree for -/// :param list terminal_nodes: The list of node indices for which the Steiner -/// tree is to be computed for. -/// :param weight_fn: A callable object that will be passed an edge's -/// weight/data payload and expected to return a ``float``. For example, -/// you can use ``weight_fn=float`` to cast every weight as a float. +/// Arguments: +/// `graph`: The input graph to compute the steiner tree of +/// `terminal_nodes`: The terminal nodes of the steiner tree +/// `weight_fn`: A callable weight function that will be passed an edge reference +/// for each edge in the graph and it is expected to return a `Result` +/// which if it doesn't error represents the weight of that edge. /// -/// :returns: An approximation to the minimal steiner tree of ``graph`` induced -/// by ``terminal_nodes``. -/// :rtype: PyGraph -/// :raises ValueError: when an edge weight with NaN or negative value -/// is provided. +/// Returns a custom struct that contains a set of nodes and edges and `None` +/// if the graph is disconnected relative to the terminal nodes. +/// +/// # Example +/// +/// ```rust +/// use std::convert::Infallible; +/// +/// use rustworkx_core::petgraph::Graph; +/// use rustworkx_core::petgraph::graph::NodeIndex; +/// use rustworkx_core::petgraph::Undirected; +/// use rustworkx_core::petgraph::graph::EdgeReference; +/// use rustworkx_core::petgraph::visit::{IntoEdgeReferences, EdgeRef}; +/// +/// use rustworkx_core::steiner_tree::steiner_tree; +/// +/// let input_graph = Graph::<(), u8, Undirected>::from_edges(&[ +/// (0, 1, 10), +/// (1, 2, 10), +/// (2, 3, 10), +/// (3, 4, 10), +/// (4, 5, 10), +/// (1, 6, 1), +/// (6, 4, 1), +/// ]); +/// +/// let weight_fn = |e: EdgeReference| -> Result { +/// Ok(*e.weight() as f64) +/// }; +/// let terminal_nodes = vec![ +/// NodeIndex::new(0), +/// NodeIndex::new(1), +/// NodeIndex::new(2), +/// NodeIndex::new(3), +/// NodeIndex::new(4), +/// NodeIndex::new(5), +/// ]; +/// +/// let tree = steiner_tree(&input_graph, &terminal_nodes, weight_fn).unwrap().unwrap(); +/// ``` /// /// .. [1] Kou, Markowsky & Berman, /// "A fast algorithm for Steiner trees" @@ -198,29 +515,33 @@ where /// .. [2] Kurt Mehlhorn, /// "A faster approximation algorithm for the Steiner problem in graphs" /// https://doi.org/10.1016/0020-0190(88)90066-X -pub fn steiner_tree( - graph: &mut StableGraph<(), W, Directed>, - terminal_nodes: Vec, - weight_fn: &mut F, -) -> Result, E> +pub fn steiner_tree( + graph: G, + terminal_nodes: &[G::NodeId], + weight_fn: F, +) -> Result, E> where - W: Copy - + Clone - + PartialOrd - + std::fmt::Debug - + std::default::Default - + std::ops::Add, - F: FnMut(&W) -> Result, - MetricClosureEdge: Send, + G: IntoEdges + + NodeIndexable + + Sync + + EdgeCount + + IntoNodeReferences + + EdgeIndexable + + Visitable + + NodeCount, + G::NodeId: Eq + Hash + Send, + G::EdgeId: Eq + Hash + Send, + F: FnMut(G::EdgeRef) -> Result, { + let node_bound = graph.node_bound(); let mut edge_list = fast_metric_edges(graph, terminal_nodes, weight_fn)?; - let mut subgraphs = UnionFind::::new(graph.node_bound()); + let mut subgraphs = UnionFind::::new(node_bound); edge_list.par_sort_unstable_by(|a, b| { let weight_a = (a.distance, a.source, a.target); let weight_b = (b.distance, b.source, b.target); weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) }); - let mut mst_edges: Vec> = Vec::new(); + let mut mst_edges: Vec = Vec::new(); for float_edge_pair in edge_list { let u = float_edge_pair.source; let v = float_edge_pair.target; @@ -228,11 +549,10 @@ where mst_edges.push(float_edge_pair); } } - //TODO implement error // assert that the terminal nodes are connected. - //if !terminal_nodes.is_empty() && mst_edges.len() != terminal_nodes.len() - 1 { - //return Err(PyValueError::new_err( "The terminal nodes in the input graph must belong to the same connected component. The steiner tree is not defined for a graph with unconnected terminal nodes",)); - //} + if !terminal_nodes.is_empty() && mst_edges.len() != terminal_nodes.len() - 1 { + return Ok(None); + } // Generate the output graph from the MST let out_edge_list: Vec<[usize; 2]> = mst_edges .into_iter() @@ -240,94 +560,13 @@ where .filter_map(|x| x.0.map(|a| [a, x.1])) .collect(); let out_edges: HashSet<(usize, usize)> = out_edge_list.iter().map(|x| (x[0], x[1])).collect(); - let mut out_graph = graph.clone(); - let out_nodes: HashSet = out_edge_list + let out_nodes: HashSet = out_edge_list .iter() .flat_map(|x| x.iter()) .copied() - .map(NodeIndex::new) .collect(); - for node in graph - .node_indices() - .filter(|node| !out_nodes.contains(node)) - { - out_graph.remove_node(node); - // out_graph.node_removed = true; - } - for edge in graph.edge_references().filter(|edge| { - let source = edge.source().index(); - let target = edge.target().index(); - !out_edges.contains(&(source, target)) && !out_edges.contains(&(target, source)) - }) { - out_graph.remove_edge(edge.id()); - } - // Deduplicate potential duplicate edges - deduplicate_edges(&mut out_graph, weight_fn)?; - - Ok(out_graph) -} - -fn deduplicate_edges( - out_graph: &mut StableGraph<(), W, Directed>, - weight_fn: &mut F, -) -> Result<(), E> -where - W: Clone, - F: FnMut(&W) -> Result, -{ - //if out_graph.multigraph { - if true { - // Find all edges between nodes - let mut duplicate_map: HashMap< - [NodeIndex; 2], - Vec<( as GraphBase>::EdgeId, W)>, - > = HashMap::new(); - for edge in out_graph.edge_references() { - if duplicate_map.contains_key(&[edge.source(), edge.target()]) { - duplicate_map - .get_mut(&[edge.source(), edge.target()]) - .unwrap() - .push((edge.id(), edge.weight().clone())); - } else if duplicate_map.contains_key(&[edge.target(), edge.source()]) { - duplicate_map - .get_mut(&[edge.target(), edge.source()]) - .unwrap() - .push((edge.id(), edge.weight().clone())); - } else { - duplicate_map.insert( - [edge.source(), edge.target()], - vec![(edge.id(), edge.weight().clone())], - ); - } - } - // For a node pair with > 1 edge find minimum edge and remove others - for edges_raw in duplicate_map.values().filter(|x| x.len() > 1) { - let mut edges: Vec<( as GraphBase>::EdgeId, f64)> = - Vec::with_capacity(edges_raw.len()); - for edge in edges_raw { - let w = weight_fn(&edge.1)?; - edges.push((edge.0, w)); - } - edges.sort_unstable_by(|a, b| (a.1).partial_cmp(&b.1).unwrap_or(Ordering::Less)); - edges[1..].iter().for_each(|x| { - out_graph.remove_edge(x.0); - }); - } - } - Ok(()) -} - -#[inline] -fn is_valid_weight(val: W) -> Result { - if val.is_sign_negative() { - return Err(E); - //return Err(E "Negative weights not supported."); - } - - if val.is_nan() { - return Err(E); - //return Err(E "NaN weights not supported."); - } - - Ok(val) + Ok(Some(SteinerTreeResult { + used_node_indices: out_nodes, + used_edge_endpoints: out_edges, + })) } diff --git a/src/steiner_tree.rs b/src/steiner_tree.rs index c8df908da..735395411 100644 --- a/src/steiner_tree.rs +++ b/src/steiner_tree.rs @@ -12,7 +12,7 @@ use std::cmp::Ordering; -use hashbrown::{HashMap, HashSet}; +use hashbrown::HashMap; use rayon::prelude::*; use pyo3::exceptions::PyValueError; @@ -20,23 +20,12 @@ use pyo3::prelude::*; use pyo3::Python; use petgraph::stable_graph::{EdgeIndex, EdgeReference, NodeIndex}; -use petgraph::unionfind::UnionFind; -use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeIndexable}; +use petgraph::visit::{EdgeRef, IntoEdgeReferences}; use crate::graph; -use crate::is_valid_weight; -use crate::shortest_path::all_pairs_dijkstra::all_pairs_dijkstra_shortest_paths; -use rustworkx_core::dictmap::*; -use rustworkx_core::shortest_path::dijkstra; -use rustworkx_core::utils::pairwise; - -struct MetricClosureEdge { - source: usize, - target: usize, - distance: f64, - path: Vec, -} +use rustworkx_core::steiner_tree::metric_closure as core_metric_closure; +use rustworkx_core::steiner_tree::steiner_tree as core_steiner_tree; /// Return the metric closure of a graph /// @@ -59,165 +48,30 @@ pub fn metric_closure( graph: &graph::PyGraph, weight_fn: PyObject, ) -> PyResult { - let mut out_graph = graph.clone(); - out_graph.graph.clear_edges(); - let edges = _metric_closure_edges(py, graph, weight_fn)?; - for edge in edges { - out_graph.graph.add_edge( - NodeIndex::new(edge.source), - NodeIndex::new(edge.target), - (edge.distance, edge.path).to_object(py), - ); - } - Ok(out_graph) -} - -fn _metric_closure_edges( - py: Python, - graph: &graph::PyGraph, - weight_fn: PyObject, -) -> PyResult> { - let node_count = graph.graph.node_count(); - if node_count == 0 { - return Ok(Vec::new()); - } - let mut out_vec = Vec::with_capacity(node_count * (node_count - 1) / 2); - let mut distances = HashMap::with_capacity(graph.graph.node_count()); - let paths = - all_pairs_dijkstra_shortest_paths(py, &graph.graph, weight_fn, Some(&mut distances))?.paths; - let mut nodes: HashSet = graph.graph.node_indices().map(|x| x.index()).collect(); - let first_node = graph - .graph - .node_indices() - .map(|x| x.index()) - .next() - .unwrap(); - let path_keys: HashSet = paths[&first_node].paths.keys().copied().collect(); - // first_node will always be missing from path_keys so if the difference - // is > 1 with nodes that means there is another node in the graph that - // first_node doesn't have a path to. - if nodes.difference(&path_keys).count() > 1 { - return Err(PyValueError::new_err( - "The input graph must be a connected graph. The metric closure is \ - not defined for a graph with unconnected nodes", - )); - } - // Iterate over node indices for a deterministic order - for node in graph.graph.node_indices().map(|x| x.index()) { - let path_map = &paths[&node].paths; - nodes.remove(&node); - let distance = &distances[&node]; - for v in &nodes { - let v_index = NodeIndex::new(*v); - out_vec.push(MetricClosureEdge { - source: node, - target: *v, - distance: distance[&v_index], - path: path_map[v].clone(), - }); - } - } - Ok(out_vec) -} - -/// Computes the shortest path between all pairs `(s, t)` of the given `terminal_nodes` -/// *provided* that: -/// - there is an edge `(u, v)` in the graph and path pass through this edge. -/// - node `s` is the closest node to `u` among all `terminal_nodes` -/// - node `t` is the closest node to `v` among all `terminal_nodes` -/// and wraps the result inside a `MetricClosureEdge` -/// -/// For example, if all vertices are terminals, it returns the original edges of the graph. -fn fast_metric_edges( - py: Python, - graph: &mut graph::PyGraph, - terminal_nodes: &[usize], - weight_fn: &PyObject, -) -> PyResult> { - // temporarily add a ``dummy`` node, connect it with - // all the terminal nodes and find all the shortest paths - // starting from ``dummy`` node. - let dummy = graph.graph.add_node(py.None()); - for node in terminal_nodes { - graph - .graph - .add_edge(dummy, NodeIndex::new(*node), py.None()); - } - - let cost_fn = |edge: EdgeReference<'_, PyObject>| -> PyResult { - if edge.source() != dummy && edge.target() != dummy { - let weight: f64 = weight_fn.call1(py, (edge.weight(),))?.extract(py)?; - is_valid_weight(weight) - } else { - Ok(0.0) - } + let callable = |e: EdgeReference| -> PyResult { + let data = e.weight(); + let raw = weight_fn.call1(py, (data,))?; + raw.extract(py) }; - - let mut paths = DictMap::with_capacity(graph.graph.node_count()); - let mut distance: DictMap = - dijkstra(&graph.graph, dummy, None, cost_fn, Some(&mut paths))?; - paths.swap_remove(&dummy); - distance.swap_remove(&dummy); - graph.graph.remove_node(dummy); - - // ``partition[u]`` holds the terminal node closest to node ``u``. - let mut partition: Vec = vec![std::usize::MAX; graph.graph.node_bound()]; - for (u, path) in paths.iter() { - let u = u.index(); - partition[u] = path[1].index(); - } - - let mut out_edges: Vec = Vec::with_capacity(graph.graph.edge_count()); - - for edge in graph.graph.edge_references() { - let source = edge.source(); - let target = edge.target(); - // assert that ``source`` is reachable from a terminal node. - if distance.contains_key(&source) { - let weight = distance[&source] + cost_fn(edge)? + distance[&target]; - let mut path: Vec = paths[&source].iter().skip(1).map(|x| x.index()).collect(); - path.append( - &mut paths[&target] - .iter() - .skip(1) - .rev() - .map(|x| x.index()) - .collect(), + if let Some(result_graph) = core_metric_closure(&graph.graph, callable)? { + let mut out_graph = graph.clone(); + out_graph.graph.clear_edges(); + for edge in result_graph.edge_indices() { + let (source, target) = result_graph.edge_endpoints(edge).unwrap(); + let weight = result_graph.edge_weight(edge).unwrap(); + out_graph.graph.add_edge( + *result_graph.node_weight(source).unwrap(), + *result_graph.node_weight(target).unwrap(), + weight.to_object(py), ); - - let source = source.index(); - let target = target.index(); - - let mut source = partition[source]; - let mut target = partition[target]; - - match source.cmp(&target) { - Ordering::Equal => continue, - Ordering::Greater => std::mem::swap(&mut source, &mut target), - _ => {} - } - - out_edges.push(MetricClosureEdge { - source, - target, - distance: weight, - path, - }); } + Ok(out_graph) + } else { + Err(PyValueError::new_err( + "The input graph must be a connected graph. The metric closure is \ + not defined for a graph with unconnected nodes", + )) } - - // if parallel edges, keep the edge with minimum distance. - out_edges.par_sort_unstable_by(|a, b| { - let weight_a = (a.source, a.target, a.distance); - let weight_b = (b.source, b.target, b.distance); - weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) - }); - - out_edges.dedup_by(|edge_a, edge_b| { - edge_a.source == edge_b.source && edge_a.target == edge_b.target - }); - - Ok(out_edges) } /// Return an approximation to the minimum Steiner tree of a graph. @@ -267,60 +121,47 @@ pub fn steiner_tree( terminal_nodes: Vec, weight_fn: PyObject, ) -> PyResult { - let mut edge_list = fast_metric_edges(py, graph, &terminal_nodes, &weight_fn)?; - let mut subgraphs = UnionFind::::new(graph.graph.node_bound()); - edge_list.par_sort_unstable_by(|a, b| { - let weight_a = (a.distance, a.source, a.target); - let weight_b = (b.distance, b.source, b.target); - weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) - }); - let mut mst_edges: Vec = Vec::new(); - for float_edge_pair in edge_list { - let u = float_edge_pair.source; - let v = float_edge_pair.target; - if subgraphs.union(u, v) { - mst_edges.push(float_edge_pair); + let callable = |e: EdgeReference| -> PyResult { + let data = e.weight(); + let raw = weight_fn.call1(py, (data,))?; + raw.extract(py) + }; + let result = core_steiner_tree( + &graph.graph, + &terminal_nodes + .into_iter() + .map(NodeIndex::new) + .collect::>(), + callable, + )?; + if let Some(result) = result { + let mut out_graph = graph.clone(); + for node in graph + .graph + .node_indices() + .filter(|node| !result.used_node_indices.contains(&node.index())) + { + out_graph.graph.remove_node(node); } - } - // assert that the terminal nodes are connected. - if !terminal_nodes.is_empty() && mst_edges.len() != terminal_nodes.len() - 1 { - return Err(PyValueError::new_err( + for edge in graph.graph.edge_references().filter(|edge| { + let source = edge.source().index(); + let target = edge.target().index(); + !result.used_edge_endpoints.contains(&(source, target)) + && !result.used_edge_endpoints.contains(&(target, source)) + }) { + out_graph.graph.remove_edge(edge.id()); + } + deduplicate_edges(py, &mut out_graph, &weight_fn)?; + if out_graph.graph.node_count() != graph.graph.node_count() { + out_graph.node_removed = true; + } + Ok(out_graph) + } else { + Err(PyValueError::new_err( "The terminal nodes in the input graph must belong to the same connected component. \ - The steiner tree is not defined for a graph with unconnected terminal nodes", - )); - } - // Generate the output graph from the MST - let out_edge_list: Vec<[usize; 2]> = mst_edges - .into_iter() - .flat_map(|edge| pairwise(edge.path)) - .filter_map(|x| x.0.map(|a| [a, x.1])) - .collect(); - let out_edges: HashSet<(usize, usize)> = out_edge_list.iter().map(|x| (x[0], x[1])).collect(); - let mut out_graph = graph.clone(); - let out_nodes: HashSet = out_edge_list - .iter() - .flat_map(|x| x.iter()) - .copied() - .map(NodeIndex::new) - .collect(); - for node in graph - .graph - .node_indices() - .filter(|node| !out_nodes.contains(node)) - { - out_graph.graph.remove_node(node); - out_graph.node_removed = true; - } - for edge in graph.graph.edge_references().filter(|edge| { - let source = edge.source().index(); - let target = edge.target().index(); - !out_edges.contains(&(source, target)) && !out_edges.contains(&(target, source)) - }) { - out_graph.graph.remove_edge(edge.id()); + The steiner tree is not defined for a graph with unconnected terminal nodes", + )) } - // Deduplicate potential duplicate edges - deduplicate_edges(py, &mut out_graph, &weight_fn)?; - Ok(out_graph) } fn deduplicate_edges( diff --git a/tests/graph/test_steiner_tree.py b/tests/graph/test_steiner_tree.py index 74ea5de82..0d144c0b4 100644 --- a/tests/graph/test_steiner_tree.py +++ b/tests/graph/test_steiner_tree.py @@ -151,7 +151,7 @@ def test_steiner_graph_multigraph(self): def test_not_connected_steiner_tree(self): self.graph.add_node(None) with self.assertRaises(ValueError): - rustworkx.steiner_tree(self.graph, [1, 2, 8], weight_fn=float) + rustworkx.steiner_tree(self.graph, [1, 2, 0], weight_fn=float) def test_steiner_tree_empty_graph(self): graph = rustworkx.PyGraph() From 17d04d14b85b30cf7c8437db8c13298a6df56959 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Sun, 3 Mar 2024 15:41:50 -0500 Subject: [PATCH 13/13] Add back checking on valid weights and terminal nodes --- src/steiner_tree.rs | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/steiner_tree.rs b/src/steiner_tree.rs index 735395411..57819b992 100644 --- a/src/steiner_tree.rs +++ b/src/steiner_tree.rs @@ -22,7 +22,7 @@ use pyo3::Python; use petgraph::stable_graph::{EdgeIndex, EdgeReference, NodeIndex}; use petgraph::visit::{EdgeRef, IntoEdgeReferences}; -use crate::graph; +use crate::{graph, is_valid_weight}; use rustworkx_core::steiner_tree::metric_closure as core_metric_closure; use rustworkx_core::steiner_tree::steiner_tree as core_steiner_tree; @@ -51,7 +51,8 @@ pub fn metric_closure( let callable = |e: EdgeReference| -> PyResult { let data = e.weight(); let raw = weight_fn.call1(py, (data,))?; - raw.extract(py) + let weight = raw.extract(py)?; + is_valid_weight(weight) }; if let Some(result_graph) = core_metric_closure(&graph.graph, callable)? { let mut out_graph = graph.clone(); @@ -126,14 +127,18 @@ pub fn steiner_tree( let raw = weight_fn.call1(py, (data,))?; raw.extract(py) }; - let result = core_steiner_tree( - &graph.graph, - &terminal_nodes - .into_iter() - .map(NodeIndex::new) - .collect::>(), - callable, - )?; + let mut terminal_n: Vec = Vec::with_capacity(terminal_nodes.len()); + for n in &terminal_nodes { + let index = NodeIndex::new(*n); + if graph.graph.node_weight(index).is_none() { + return Err(PyValueError::new_err(format!( + "Provided terminal node index {} is not present in graph", + n + ))); + } + terminal_n.push(index); + } + let result = core_steiner_tree(&graph.graph, &terminal_n, callable)?; if let Some(result) = result { let mut out_graph = graph.clone(); for node in graph