Skip to content

Commit

Permalink
fix!: force_order failing on Const nodes, add arg to rank. (#1300)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: the `rank` argument of `force_order` takes an
additional argument.
  • Loading branch information
doug-q authored Jul 16, 2024
1 parent 056531c commit 36e71e0
Showing 1 changed file with 53 additions and 23 deletions.
76 changes: 53 additions & 23 deletions hugr-passes/src/force_order.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
//! Provides [force_order], a tool for fixing the order of nodes in a Hugr.
use std::{cmp::Reverse, collections::BinaryHeap};
use std::{cmp::Reverse, collections::BinaryHeap, iter};

use hugr_core::{
hugr::{
hugrmut::HugrMut,
views::{DescendantsGraph, HierarchyView, SiblingGraph},
HugrError,
},
ops::{OpTag, OpTrait},
ops::{NamedOp, OpTag, OpTrait},
types::EdgeKind,
Direction, HugrView as _, Node,
HugrView as _, Node,
};
use itertools::Itertools as _;
use petgraph::{
Expand All @@ -36,45 +36,58 @@ use petgraph::{
/// there is no path from `n2` to `n1` (otherwise this would invalidate `hugr`).
/// Nodes of equal rank will be ordered arbitrarily, although that arbitrary
/// order is deterministic.
pub fn force_order(
hugr: &mut impl HugrMut,
pub fn force_order<H: HugrMut>(
hugr: &mut H,
root: Node,
rank: impl Fn(Node) -> i64,
rank: impl Fn(&H, Node) -> i64,
) -> Result<(), HugrError> {
force_order_by_key(hugr, root, rank)
}

/// As [force_order], but allows a generic [Ord] choice for the result of the
/// `rank` function.
pub fn force_order_by_key<K: Ord>(
hugr: &mut impl HugrMut,
pub fn force_order_by_key<H: HugrMut, K: Ord>(
hugr: &mut H,
root: Node,
rank: impl Fn(Node) -> K,
rank: impl Fn(&H, Node) -> K,
) -> Result<(), HugrError> {
let dataflow_parents = DescendantsGraph::<Node>::try_new(hugr, root)?
.nodes()
.filter(|n| hugr.get_optype(*n).tag() <= OpTag::DataflowParent)
.collect_vec();
for dp in dataflow_parents {
// we filter out the input and output nodes from the topological sort
let [i, o] = hugr.get_io(dp).unwrap();
let rank = |n| rank(hugr, n);
let sg = SiblingGraph::<Node>::try_new(hugr, dp)?;
let petgraph = NodeFiltered::from_fn(sg.as_petgraph(), |x| x != dp);
let petgraph = NodeFiltered::from_fn(sg.as_petgraph(), |x| x != dp && x != i && x != o);
let ordered_nodes = ForceOrder::new(&petgraph, &rank)
.iter(&petgraph)
.filter(|&x| hugr.get_optype(x).tag() <= OpTag::DataflowChild)
.filter(|&x| {
let expected_edge = Some(EdgeKind::StateOrder);
let optype = hugr.get_optype(x);
if optype.other_input() == expected_edge || optype.other_output() == expected_edge {
assert_eq!(
optype.other_input(),
optype.other_output(),
"Optype does not have both input and output order edge: {}",
optype.name()
);
true
} else {
false
}
})
.collect_vec();

for (&n1, &n2) in ordered_nodes.iter().tuple_windows() {
// we iterate over the topologically sorted nodes, prepending the input
// node and suffixing the output node.
for (&n1, &n2) in iter::once(&i)
.chain(ordered_nodes.iter())
.chain(iter::once(&o))
.tuple_windows()
{
let (n1_ot, n2_ot) = (hugr.get_optype(n1), hugr.get_optype(n2));
assert_eq!(
Some(EdgeKind::StateOrder),
n1_ot.other_port_kind(Direction::Outgoing),
"Node {n1} does not support state order edges"
);
assert_eq!(
Some(EdgeKind::StateOrder),
n2_ot.other_port_kind(Direction::Incoming),
"Node {n2} does not support state order edges"
);
if !hugr.output_neighbours(n1).contains(&n2) {
hugr.connect(
n1,
Expand Down Expand Up @@ -192,10 +205,13 @@ mod test {

use super::*;
use hugr_core::builder::{endo_ft, BuildHandle, Dataflow, DataflowHugr};
use hugr_core::extension::EMPTY_REG;
use hugr_core::ops::handle::{DataflowOpID, NodeHandle};

use hugr_core::ops::Value;
use hugr_core::std_extensions::arithmetic::int_ops::{self, IntOpDef};
use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
use hugr_core::types::{FunctionType, Type};
use hugr_core::{builder::DFGBuilder, hugr::Hugr};
use hugr_core::{HugrView, Wire};

Expand Down Expand Up @@ -257,7 +273,7 @@ mod test {
type RankMap = HashMap<Node, i64>;

fn force_order_test_impl(hugr: &mut Hugr, rank_map: RankMap) -> Vec<Node> {
force_order(hugr, hugr.root(), |n| *rank_map.get(&n).unwrap_or(&0)).unwrap();
force_order(hugr, hugr.root(), |_, n| *rank_map.get(&n).unwrap_or(&0)).unwrap();

let topo_sorted = Topo::new(&hugr.as_petgraph())
.iter(&hugr.as_petgraph())
Expand Down Expand Up @@ -303,4 +319,18 @@ mod test {
let topo_sort = force_order_test_impl(&mut hugr, rank_map);
assert_eq!(vec![v0, v1, v2, v3], topo_sort);
}

#[test]
fn test_force_order_const() {
let mut hugr = {
let mut builder =
DFGBuilder::new(FunctionType::new(Type::EMPTY_TYPEROW, Type::UNIT)).unwrap();
let unit = builder.add_load_value(Value::unary_unit_sum());
builder
.finish_hugr_with_outputs([unit], &EMPTY_REG)
.unwrap()
};
let root = hugr.root();
force_order(&mut hugr, root, |_, _| 0).unwrap();
}
}

0 comments on commit 36e71e0

Please sign in to comment.