diff --git a/rustworkx-core/src/dag_algo.rs b/rustworkx-core/src/dag_algo.rs index 081802fa7..a4507779e 100644 --- a/rustworkx-core/src/dag_algo.rs +++ b/rustworkx-core/src/dag_algo.rs @@ -14,7 +14,7 @@ use std::cmp::{Eq, Ordering}; use std::collections::BinaryHeap; use std::hash::Hash; -use hashbrown::HashMap; +use hashbrown::{HashMap, HashSet}; use petgraph::algo; use petgraph::data::DataMap; @@ -323,7 +323,7 @@ where /// /// # Arguments: /// -/// * `dag`: The DAG to find bicolor runs in +/// * `graph`: The DAG to find bicolor runs in /// * `filter_fn`: The filter function to use for matching nodes. It takes /// in one argument, the node data payload/weight object, and will return a /// boolean whether the node matches the conditions or not. @@ -475,6 +475,161 @@ where Ok(Some(block_list)) } +/// Collect runs that match a filter function +/// +/// A run is a path of nodes where there is only a single successor and all +/// nodes in the path match the given condition. Each node in the graph can +/// appear in only a single run. +/// +/// # Arguments: +/// +/// * `graph`: The DAG to collect runs from +/// * `include_node_fn`: A filter function used for matching nodes. It takes +/// in one argument, the node data payload/weight object, and returns a +/// boolean whether the node matches the conditions or not. +/// If it returns ``false``, the node will be skipped, cutting the run it's part of. +/// +/// # Returns: +/// +/// * An Iterator object for extracting the runs one by one. Each run is of type `Result>`. +/// * `None` if a cycle is found in the graph. +/// +/// # Example +/// +/// ```rust +/// use petgraph::graph::DiGraph; +/// use rustworkx_core::dag_algo::collect_runs; +/// +/// let mut graph: DiGraph = DiGraph::new(); +/// let n1 = graph.add_node(-1); +/// let n2 = graph.add_node(2); +/// let n3 = graph.add_node(3); +/// graph.add_edge(n1, n2, ()); +/// graph.add_edge(n1, n3, ()); +/// +/// let positive_payload = |n| -> Result {Ok(*graph.node_weight(n).expect("i32") > 0)}; +/// let mut runs = collect_runs(&graph, positive_payload).expect("Some"); +/// +/// assert_eq!(runs.next(), Some(Ok(vec![n3]))); +/// assert_eq!(runs.next(), Some(Ok(vec![n2]))); +/// assert_eq!(runs.next(), None); +/// ``` +/// +pub fn collect_runs( + graph: G, + include_node_fn: F, +) -> Option, E>>> +where + G: GraphProp + + IntoNeighborsDirected + + IntoNodeIdentifiers + + Visitable + + NodeCount, + F: Fn(G::NodeId) -> Result, + ::NodeId: Hash + Eq, +{ + let mut nodes = match algo::toposort(graph, None) { + Ok(nodes) => nodes, + Err(_) => return None, + }; + + nodes.reverse(); // reversing so that pop() in Runs::next obeys the topological order + + let runs = Runs { + graph, + seen: HashSet::with_capacity(nodes.len()), + sorted_nodes: nodes, + include_node_fn, + }; + + Some(runs) +} + +/// Auxiliary struct to make the output of [`collect_runs`] iteratable +/// +/// If the filtering function passed to [`collect_runs`] returns an error, it is propagated +/// through `next` as `Err`. In this case the run in which the error occurred will be skipped +/// but the iterator can be used further until consumed. +/// +struct Runs +where + G: GraphProp + + IntoNeighborsDirected + + IntoNodeIdentifiers + + Visitable + + NodeCount, + F: Fn(G::NodeId) -> Result, +{ + graph: G, + sorted_nodes: Vec, // topologically-sorted nodes + seen: HashSet, + include_node_fn: F, // filtering function of the nodes +} + +impl Iterator for Runs +where + G: GraphProp + + IntoNeighborsDirected + + IntoNodeIdentifiers + + Visitable + + NodeCount, + F: Fn(G::NodeId) -> Result, + ::NodeId: Hash + Eq, +{ + // This is a run, wrapped in Result for catching filter function errors + type Item = Result, E>; + + fn next(&mut self) -> Option { + while let Some(node) = self.sorted_nodes.pop() { + if self.seen.contains(&node) { + continue; + } + self.seen.insert(node); + + match (self.include_node_fn)(node) { + Ok(false) => continue, + Err(e) => return Some(Err(e)), + _ => (), + } + + let mut run: Vec = vec![node]; + loop { + let mut successors: Vec = self + .graph + .neighbors_directed(*run.last().unwrap(), petgraph::Direction::Outgoing) + .collect(); + successors.dedup(); + + if successors.len() != 1 || self.seen.contains(&successors[0]) { + break; + } + + self.seen.insert(successors[0]); + + match (self.include_node_fn)(successors[0]) { + Ok(false) => continue, + Err(e) => return Some(Err(e)), + _ => (), + } + + run.push(successors[0]); + } + + if !run.is_empty() { + return Some(Ok(run)); + } + } + + None + } + + fn size_hint(&self) -> (usize, Option) { + // Lower bound is 0 in case all remaining nodes are filtered out + // Upper bound is the remaining unprocessed nodes (some of which may be seen already), potentially all resulting with singleton runs + (0, Some(self.sorted_nodes.len())) + } +} + // Tests for longest_path #[cfg(test)] mod test_longest_path { @@ -972,3 +1127,125 @@ mod test_collect_bicolor_runs { assert_eq!(result, Some(expected)) } } + +#[cfg(test)] +mod test_collect_runs { + use super::collect_runs; + use petgraph::{graph::DiGraph, visit::GraphBase}; + + type BareDiGraph = DiGraph<(), ()>; + type RunResult = Result::NodeId>, ()>; + + #[test] + fn test_empty_graph() { + let graph: BareDiGraph = DiGraph::new(); + + let mut runs = collect_runs(&graph, |_| -> Result { Ok(true) }).expect("Some"); + + let run = runs.next(); + assert!(run == None); + + let runs = collect_runs(&graph, |_| -> Result { Ok(true) }).expect("Some"); + + let runs: Vec = runs.collect(); + + assert_eq!(runs, Vec::::new()); + } + + #[test] + fn test_simple_run_w_filter() { + let mut graph: BareDiGraph = DiGraph::new(); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + let n3 = graph.add_node(()); + graph.add_edge(n1, n2, ()); + graph.add_edge(n2, n3, ()); + + let mut runs = collect_runs(&graph, |_| -> Result { Ok(true) }).expect("Some"); + + let the_run = runs.next().expect("Some").expect("3 nodes"); + assert_eq!(the_run.len(), 3); + assert_eq!(runs.next(), None); + + assert_eq!(the_run, vec![n1, n2, n3]); + + // Now with some filters + let mut runs = collect_runs(&graph, |_| -> Result { Ok(false) }).expect("Some"); + + assert_eq!(runs.next(), None); + + let mut runs = collect_runs(&graph, |n| -> Result { Ok(n != n2) }).expect("Some"); + + assert_eq!(runs.next(), Some(Ok(vec![n1]))); + assert_eq!(runs.next(), Some(Ok(vec![n3]))); + } + + #[test] + fn test_multiple_runs_w_filter() { + let mut graph: BareDiGraph = DiGraph::new(); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + let n3 = graph.add_node(()); + let n4 = graph.add_node(()); + let n5 = graph.add_node(()); + let n6 = graph.add_node(()); + let n7 = graph.add_node(()); + + graph.add_edge(n1, n2, ()); + graph.add_edge(n2, n3, ()); + graph.add_edge(n3, n7, ()); + graph.add_edge(n4, n3, ()); + graph.add_edge(n4, n7, ()); + graph.add_edge(n5, n4, ()); + graph.add_edge(n6, n5, ()); + + let runs: Vec = collect_runs(&graph, |_| -> Result { Ok(true) }) + .expect("Some") + .collect(); + + assert_eq!(runs, vec![Ok(vec![n6, n5, n4]), Ok(vec![n1, n2, n3, n7])]); + + // And now with some filter + let runs: Vec = + collect_runs(&graph, |n| -> Result { Ok(n != n4 && n != n2) }) + .expect("Some") + .collect(); + + assert_eq!(runs, vec![Ok(vec![n6, n5]), Ok(vec![n1]), Ok(vec![n3, n7])]); + } + + #[test] + fn test_singleton_runs_w_filter() { + let mut graph: BareDiGraph = DiGraph::new(); + let n1 = graph.add_node(()); + let n2 = graph.add_node(()); + let n3 = graph.add_node(()); + + graph.add_edge(n1, n2, ()); + graph.add_edge(n1, n3, ()); + + let mut runs = collect_runs(&graph, |_| -> Result { Ok(true) }).expect("Some"); + + assert_eq!(runs.next().expect("n1"), Ok(vec![n1])); + assert_eq!(runs.next().expect("n3"), Ok(vec![n3])); + assert_eq!(runs.next().expect("n2"), Ok(vec![n2])); + + // And now with some filter + let runs: Vec = collect_runs(&graph, |n| -> Result { Ok(n != n1) }) + .expect("Some") + .collect(); + + assert_eq!(runs, vec![Ok(vec![n3]), Ok(vec![n2])]); + } + + #[test] + fn test_error_propagation() { + let mut graph: BareDiGraph = DiGraph::new(); + graph.add_node(()); + + let mut runs = collect_runs(&graph, |_| -> Result { Err(()) }).expect("Some"); + + assert!(runs.next().expect("Some").is_err()); + assert_eq!(runs.next(), None); + } +} diff --git a/src/dag_algo/mod.rs b/src/dag_algo/mod.rs index 99ccc86c9..0571efd16 100644 --- a/src/dag_algo/mod.rs +++ b/src/dag_algo/mod.rs @@ -19,6 +19,7 @@ use super::iterators::NodeIndices; use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph}; use rustworkx_core::dag_algo::collect_bicolor_runs as core_collect_bicolor_runs; +use rustworkx_core::dag_algo::collect_runs as core_collect_runs; use rustworkx_core::dag_algo::lexicographical_topological_sort as core_lexico_topo_sort; use rustworkx_core::dag_algo::longest_path as core_longest_path; use rustworkx_core::traversal::dfs_edges; @@ -29,10 +30,8 @@ use pyo3::types::PyList; use pyo3::Python; use petgraph::algo; -use petgraph::graph::NodeIndex; use petgraph::prelude::*; use petgraph::stable_graph::EdgeReference; -use petgraph::visit::NodeCount; use num_traits::{Num, Zero}; @@ -561,47 +560,28 @@ pub fn collect_runs( graph: &digraph::PyDiGraph, filter_fn: PyObject, ) -> PyResult>> { - let mut out_list: Vec> = Vec::new(); - let mut seen: HashSet = HashSet::with_capacity(graph.node_count()); - - let filter_node = |node: &PyObject| -> PyResult { - let res = filter_fn.call1(py, (node,))?; - res.extract(py) + let filter_node = |node_id| -> Result { + let py_node = graph.graph.node_weight(node_id); + filter_fn.call1(py, (py_node,))?.extract::(py) }; - let nodes = match algo::toposort(&graph.graph, None) { - Ok(nodes) => nodes, - Err(_err) => return Err(DAGHasCycle::new_err("Sort encountered a cycle")), + let core_runs = match core_collect_runs(&graph.graph, filter_node) { + Some(runs) => runs, + None => return Err(DAGHasCycle::new_err("The DAG contains a cycle")), }; - for node in nodes { - if !filter_node(&graph.graph[node])? || seen.contains(&node) { - continue; - } - seen.insert(node); - let mut group: Vec = vec![graph.graph[node].clone_ref(py)]; - let mut successors: Vec = graph - .graph - .neighbors_directed(node, petgraph::Direction::Outgoing) + + let mut result: Vec> = Vec::with_capacity(core_runs.size_hint().1.unwrap_or(0)); + for run_result in core_runs { + // This is where a filter function error will be returned, otherwise Result is stripped away + let py_run: Vec = run_result? + .iter() + .map(|node| return graph.graph.node_weight(*node).into_py(py)) .collect(); - successors.dedup(); - - while successors.len() == 1 - && filter_node(&graph.graph[successors[0]])? - && !seen.contains(&successors[0]) - { - group.push(graph.graph[successors[0]].clone_ref(py)); - seen.insert(successors[0]); - successors = graph - .graph - .neighbors_directed(successors[0], petgraph::Direction::Outgoing) - .collect(); - successors.dedup(); - } - if !group.is_empty() { - out_list.push(group); - } + + result.push(py_run) } - Ok(out_list) + + Ok(result) } /// Collect runs that match a filter function given edge colors.