Skip to content

Commit

Permalink
feat: Partial tuple unpack
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Jul 19, 2024
1 parent 4bd81c3 commit a30897e
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions tket2/src/passes/tuple_unpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use core::panic;

use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr};
use hugr::ops::{OpTrait, OpType};
use hugr::ops::{MakeTuple, OpTrait, OpType};
use hugr::types::Type;
use hugr::{HugrView, Node};
use itertools::Itertools;
Expand Down Expand Up @@ -72,41 +72,49 @@ fn make_rewrite<T: HugrView>(circ: &Circuit<T>, cmd: Command<T>) -> Option<Circu
}

// Remove all unpack operations, but only remove the pack operation if all neighbours are unpacks.
match links.len() == unpack_nodes.len() {
true => Some(remove_pack_unpack(
circ,
&tuple_types,
tuple_node,
unpack_nodes,
)),
false => {
// TODO: Add a rewrite to remove some of the unpack operations.
None
}
}
let num_other_outputs = links.len() - unpack_nodes.len();
Some(remove_pack_unpack(
circ,
&tuple_types,
tuple_node,
unpack_nodes,
num_other_outputs,
))
}

/// Returns a rewrite to remove a tuple pack operation that's only followed by unpack operations.
/// Returns a rewrite to remove a tuple pack operation that's followed by unpack operations,
/// and `other_tuple_links` other operations.
fn remove_pack_unpack<T: HugrView>(
circ: &Circuit<T>,
tuple_types: &[Type],
pack_node: Node,
unpack_nodes: Vec<Node>,
num_other_outputs: usize,
) -> CircuitRewrite {
let num_outputs = tuple_types.len() * unpack_nodes.len();
let num_unpack_outputs = tuple_types.len() * unpack_nodes.len();

let mut nodes = unpack_nodes;
nodes.push(pack_node);
let subcirc = Subcircuit::try_from_nodes(nodes, circ).unwrap();

let replacement = DFGBuilder::new(subcirc.signature(circ)).unwrap();
let wires = replacement
.input_wires()
.cycle()
.take(num_outputs)
.collect_vec();
let mut replacement = DFGBuilder::new(subcirc.signature(circ)).unwrap();
let mut outputs = Vec::with_capacity(num_unpack_outputs + num_other_outputs);

// If needed, re-add the tuple pack node and connect its output to the tuple outputs.
if num_other_outputs > 0 {
let op = MakeTuple::new(tuple_types.to_vec().into());
let [tuple] = replacement
.add_dataflow_op(op, replacement.input_wires())
.unwrap()
.outputs_arr();
outputs.extend(std::iter::repeat(tuple).take(num_other_outputs))
}

// Wire the inputs directly to the unpack outputs
outputs.extend(replacement.input_wires().cycle().take(num_unpack_outputs));

let replacement = replacement
.finish_prelude_hugr_with_outputs(wires)
.finish_prelude_hugr_with_outputs(outputs)
.unwrap_or_else(|e| {
panic!("Failed to create replacement for removing tuple pack/unpack operations. {e}")
})
Expand Down Expand Up @@ -205,8 +213,6 @@ mod test {
#[rstest]
#[case::simple(simple_pack_unpack(), 1, 0)]
#[case::multi(multi_unpack(), 1, 0)]
// TODO: Partial unpack is not currently supported.
#[ignore = "Unimplemented."]
#[case::partial(partial_unpack(), 1, 1)]
fn test_pack_unpack(
#[case] mut circ: Circuit,
Expand Down

0 comments on commit a30897e

Please sign in to comment.