Skip to content

Commit

Permalink
Move collect_runs() to rustworkx-core (Qiskit#1210)
Browse files Browse the repository at this point in the history
* Initial implementaiton and testing

* initial implementation with filter function support, no error handling

* Added error handling

* Making collect_runs in rustworkx-core iteratable

* Implemented iteratable collect_runs all the way through

* More idiomatic rust

* Minor language fix

* Added documentation and fixed clippy warnings

* Running cargo fmt

* Addressing review comments

* Updating documentation

* Formatting

* Adding missing use statements to the doc test

* Rust newbie misses... more formatting

---------

Co-authored-by: Matthew Treinish <[email protected]>
  • Loading branch information
eliarbel and mtreinish authored Jun 20, 2024
1 parent ace1f83 commit c2eed52
Show file tree
Hide file tree
Showing 2 changed files with 297 additions and 40 deletions.
281 changes: 279 additions & 2 deletions rustworkx-core/src/dag_algo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::cmp::{Eq, Ordering};
use std::collections::BinaryHeap;
use std::hash::Hash;

use hashbrown::HashMap;
use hashbrown::{HashMap, HashSet};

use petgraph::algo;
use petgraph::data::DataMap;
Expand Down Expand Up @@ -323,7 +323,7 @@ where
///
/// # Arguments:
///
/// * `dag`: The DAG to find bicolor runs in
/// * `graph`: The DAG to find bicolor runs in
/// * `filter_fn`: The filter function to use for matching nodes. It takes
/// in one argument, the node data payload/weight object, and will return a
/// boolean whether the node matches the conditions or not.
Expand Down Expand Up @@ -475,6 +475,161 @@ where
Ok(Some(block_list))
}

/// Collect runs that match a filter function
///
/// A run is a path of nodes where there is only a single successor and all
/// nodes in the path match the given condition. Each node in the graph can
/// appear in only a single run.
///
/// # Arguments:
///
/// * `graph`: The DAG to collect runs from
/// * `include_node_fn`: A filter function used for matching nodes. It takes
/// in one argument, the node data payload/weight object, and returns a
/// boolean whether the node matches the conditions or not.
/// If it returns ``false``, the node will be skipped, cutting the run it's part of.
///
/// # Returns:
///
/// * An Iterator object for extracting the runs one by one. Each run is of type `Result<Vec<G::NodeId>>`.
/// * `None` if a cycle is found in the graph.
///
/// # Example
///
/// ```rust
/// use petgraph::graph::DiGraph;
/// use rustworkx_core::dag_algo::collect_runs;
///
/// let mut graph: DiGraph<i32, ()> = DiGraph::new();
/// let n1 = graph.add_node(-1);
/// let n2 = graph.add_node(2);
/// let n3 = graph.add_node(3);
/// graph.add_edge(n1, n2, ());
/// graph.add_edge(n1, n3, ());
///
/// let positive_payload = |n| -> Result<bool, ()> {Ok(*graph.node_weight(n).expect("i32") > 0)};
/// let mut runs = collect_runs(&graph, positive_payload).expect("Some");
///
/// assert_eq!(runs.next(), Some(Ok(vec![n3])));
/// assert_eq!(runs.next(), Some(Ok(vec![n2])));
/// assert_eq!(runs.next(), None);
/// ```
///
pub fn collect_runs<G, F, E>(
graph: G,
include_node_fn: F,
) -> Option<impl Iterator<Item = Result<Vec<G::NodeId>, E>>>
where
G: GraphProp<EdgeType = Directed>
+ IntoNeighborsDirected
+ IntoNodeIdentifiers
+ Visitable
+ NodeCount,
F: Fn(G::NodeId) -> Result<bool, E>,
<G as GraphBase>::NodeId: Hash + Eq,
{
let mut nodes = match algo::toposort(graph, None) {
Ok(nodes) => nodes,
Err(_) => return None,
};

nodes.reverse(); // reversing so that pop() in Runs::next obeys the topological order

let runs = Runs {
graph,
seen: HashSet::with_capacity(nodes.len()),
sorted_nodes: nodes,
include_node_fn,
};

Some(runs)
}

/// Auxiliary struct to make the output of [`collect_runs`] iteratable
///
/// If the filtering function passed to [`collect_runs`] returns an error, it is propagated
/// through `next` as `Err`. In this case the run in which the error occurred will be skipped
/// but the iterator can be used further until consumed.
///
struct Runs<G, F, E>
where
G: GraphProp<EdgeType = Directed>
+ IntoNeighborsDirected
+ IntoNodeIdentifiers
+ Visitable
+ NodeCount,
F: Fn(G::NodeId) -> Result<bool, E>,
{
graph: G,
sorted_nodes: Vec<G::NodeId>, // topologically-sorted nodes
seen: HashSet<G::NodeId>,
include_node_fn: F, // filtering function of the nodes
}

impl<G, F, E> Iterator for Runs<G, F, E>
where
G: GraphProp<EdgeType = Directed>
+ IntoNeighborsDirected
+ IntoNodeIdentifiers
+ Visitable
+ NodeCount,
F: Fn(G::NodeId) -> Result<bool, E>,
<G as GraphBase>::NodeId: Hash + Eq,
{
// This is a run, wrapped in Result for catching filter function errors
type Item = Result<Vec<G::NodeId>, E>;

fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.sorted_nodes.pop() {
if self.seen.contains(&node) {
continue;
}
self.seen.insert(node);

match (self.include_node_fn)(node) {
Ok(false) => continue,
Err(e) => return Some(Err(e)),
_ => (),
}

let mut run: Vec<G::NodeId> = vec![node];
loop {
let mut successors: Vec<G::NodeId> = self
.graph
.neighbors_directed(*run.last().unwrap(), petgraph::Direction::Outgoing)
.collect();
successors.dedup();

if successors.len() != 1 || self.seen.contains(&successors[0]) {
break;
}

self.seen.insert(successors[0]);

match (self.include_node_fn)(successors[0]) {
Ok(false) => continue,
Err(e) => return Some(Err(e)),
_ => (),
}

run.push(successors[0]);
}

if !run.is_empty() {
return Some(Ok(run));
}
}

None
}

fn size_hint(&self) -> (usize, Option<usize>) {
// Lower bound is 0 in case all remaining nodes are filtered out
// Upper bound is the remaining unprocessed nodes (some of which may be seen already), potentially all resulting with singleton runs
(0, Some(self.sorted_nodes.len()))
}
}

// Tests for longest_path
#[cfg(test)]
mod test_longest_path {
Expand Down Expand Up @@ -972,3 +1127,125 @@ mod test_collect_bicolor_runs {
assert_eq!(result, Some(expected))
}
}

#[cfg(test)]
mod test_collect_runs {
use super::collect_runs;
use petgraph::{graph::DiGraph, visit::GraphBase};

type BareDiGraph = DiGraph<(), ()>;
type RunResult = Result<Vec<<BareDiGraph as GraphBase>::NodeId>, ()>;

#[test]
fn test_empty_graph() {
let graph: BareDiGraph = DiGraph::new();

let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");

let run = runs.next();
assert!(run == None);

let runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");

let runs: Vec<RunResult> = runs.collect();

assert_eq!(runs, Vec::<RunResult>::new());
}

#[test]
fn test_simple_run_w_filter() {
let mut graph: BareDiGraph = DiGraph::new();
let n1 = graph.add_node(());
let n2 = graph.add_node(());
let n3 = graph.add_node(());
graph.add_edge(n1, n2, ());
graph.add_edge(n2, n3, ());

let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");

let the_run = runs.next().expect("Some").expect("3 nodes");
assert_eq!(the_run.len(), 3);
assert_eq!(runs.next(), None);

assert_eq!(the_run, vec![n1, n2, n3]);

// Now with some filters
let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(false) }).expect("Some");

assert_eq!(runs.next(), None);

let mut runs = collect_runs(&graph, |n| -> Result<bool, ()> { Ok(n != n2) }).expect("Some");

assert_eq!(runs.next(), Some(Ok(vec![n1])));
assert_eq!(runs.next(), Some(Ok(vec![n3])));
}

#[test]
fn test_multiple_runs_w_filter() {
let mut graph: BareDiGraph = DiGraph::new();
let n1 = graph.add_node(());
let n2 = graph.add_node(());
let n3 = graph.add_node(());
let n4 = graph.add_node(());
let n5 = graph.add_node(());
let n6 = graph.add_node(());
let n7 = graph.add_node(());

graph.add_edge(n1, n2, ());
graph.add_edge(n2, n3, ());
graph.add_edge(n3, n7, ());
graph.add_edge(n4, n3, ());
graph.add_edge(n4, n7, ());
graph.add_edge(n5, n4, ());
graph.add_edge(n6, n5, ());

let runs: Vec<RunResult> = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) })
.expect("Some")
.collect();

assert_eq!(runs, vec![Ok(vec![n6, n5, n4]), Ok(vec![n1, n2, n3, n7])]);

// And now with some filter
let runs: Vec<RunResult> =
collect_runs(&graph, |n| -> Result<bool, ()> { Ok(n != n4 && n != n2) })
.expect("Some")
.collect();

assert_eq!(runs, vec![Ok(vec![n6, n5]), Ok(vec![n1]), Ok(vec![n3, n7])]);
}

#[test]
fn test_singleton_runs_w_filter() {
let mut graph: BareDiGraph = DiGraph::new();
let n1 = graph.add_node(());
let n2 = graph.add_node(());
let n3 = graph.add_node(());

graph.add_edge(n1, n2, ());
graph.add_edge(n1, n3, ());

let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Ok(true) }).expect("Some");

assert_eq!(runs.next().expect("n1"), Ok(vec![n1]));
assert_eq!(runs.next().expect("n3"), Ok(vec![n3]));
assert_eq!(runs.next().expect("n2"), Ok(vec![n2]));

// And now with some filter
let runs: Vec<RunResult> = collect_runs(&graph, |n| -> Result<bool, ()> { Ok(n != n1) })
.expect("Some")
.collect();

assert_eq!(runs, vec![Ok(vec![n3]), Ok(vec![n2])]);
}

#[test]
fn test_error_propagation() {
let mut graph: BareDiGraph = DiGraph::new();
graph.add_node(());

let mut runs = collect_runs(&graph, |_| -> Result<bool, ()> { Err(()) }).expect("Some");

assert!(runs.next().expect("Some").is_err());
assert_eq!(runs.next(), None);
}
}
56 changes: 18 additions & 38 deletions src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use super::iterators::NodeIndices;
use crate::{digraph, DAGHasCycle, InvalidNode, StablePyGraph};

use rustworkx_core::dag_algo::collect_bicolor_runs as core_collect_bicolor_runs;
use rustworkx_core::dag_algo::collect_runs as core_collect_runs;
use rustworkx_core::dag_algo::lexicographical_topological_sort as core_lexico_topo_sort;
use rustworkx_core::dag_algo::longest_path as core_longest_path;
use rustworkx_core::traversal::dfs_edges;
Expand All @@ -29,10 +30,8 @@ use pyo3::types::PyList;
use pyo3::Python;

use petgraph::algo;
use petgraph::graph::NodeIndex;
use petgraph::prelude::*;
use petgraph::stable_graph::EdgeReference;
use petgraph::visit::NodeCount;

use num_traits::{Num, Zero};

Expand Down Expand Up @@ -561,47 +560,28 @@ pub fn collect_runs(
graph: &digraph::PyDiGraph,
filter_fn: PyObject,
) -> PyResult<Vec<Vec<PyObject>>> {
let mut out_list: Vec<Vec<PyObject>> = Vec::new();
let mut seen: HashSet<NodeIndex> = HashSet::with_capacity(graph.node_count());

let filter_node = |node: &PyObject| -> PyResult<bool> {
let res = filter_fn.call1(py, (node,))?;
res.extract(py)
let filter_node = |node_id| -> Result<bool, PyErr> {
let py_node = graph.graph.node_weight(node_id);
filter_fn.call1(py, (py_node,))?.extract::<bool>(py)
};

let nodes = match algo::toposort(&graph.graph, None) {
Ok(nodes) => nodes,
Err(_err) => return Err(DAGHasCycle::new_err("Sort encountered a cycle")),
let core_runs = match core_collect_runs(&graph.graph, filter_node) {
Some(runs) => runs,
None => return Err(DAGHasCycle::new_err("The DAG contains a cycle")),
};
for node in nodes {
if !filter_node(&graph.graph[node])? || seen.contains(&node) {
continue;
}
seen.insert(node);
let mut group: Vec<PyObject> = vec![graph.graph[node].clone_ref(py)];
let mut successors: Vec<NodeIndex> = graph
.graph
.neighbors_directed(node, petgraph::Direction::Outgoing)

let mut result: Vec<Vec<PyObject>> = Vec::with_capacity(core_runs.size_hint().1.unwrap_or(0));
for run_result in core_runs {
// This is where a filter function error will be returned, otherwise Result is stripped away
let py_run: Vec<PyObject> = run_result?
.iter()
.map(|node| return graph.graph.node_weight(*node).into_py(py))
.collect();
successors.dedup();

while successors.len() == 1
&& filter_node(&graph.graph[successors[0]])?
&& !seen.contains(&successors[0])
{
group.push(graph.graph[successors[0]].clone_ref(py));
seen.insert(successors[0]);
successors = graph
.graph
.neighbors_directed(successors[0], petgraph::Direction::Outgoing)
.collect();
successors.dedup();
}
if !group.is_empty() {
out_list.push(group);
}

result.push(py_run)
}
Ok(out_list)

Ok(result)
}

/// Collect runs that match a filter function given edge colors.
Expand Down

0 comments on commit c2eed52

Please sign in to comment.