From ec5dd2269d5b28a0f3399ad00549643d35502c0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 3 Jun 2024 17:01:17 +0100 Subject: [PATCH] refactor!: Replace Circuit trait with a struct (#370) Replaces the `trait Circuit : HugrView` with a struct containing a `T: HugrView` and a node id indicating the container node for the circuit in the hugr. There are a few design going into this definition: - #### Struct with checked parent node type. The previous `Circuit` trait asked the user to only use it for dataflow-based hugrs. This of course lead to functions assuming things about the hugr structure and panicking with random errors when passed an otherwise valid hugr. Bar checking the root node at each method call, this can only be solved with an opaque struct that checks its preconditions at construction time. The code here will throw a user-friendly error when trying to use incompatible hugrs: ``` Node(0) cannot be used as a circuit parent. A Module is not a dataflow container. ``` - #### Generic `T` defaulting to `Hugr` Most uses in this library manipulate circuits as owned structures (e.g. passing circuits around in badger). However, sometimes we don't own the hugr so we cannot create a `Circuit` from it. With the generics definition we can use `Circuit<&Hugr>` or `Circuit<&mut Hugr>` to respectively view or modify non-owned structures. - #### Parent node pointer Circuits represented as hugrs are not necessarily a flat dataflow structure. For example, a guppy-defined circuit may include calls to other methods from the module. By including a pointer to the entry point of the circuit we can keep all this structure, and rewrite it latter as necessary. This is also useful for non-owned hugrs; we can have a `Circuit<&Hugr>` pointing to a region in the hugr without needing to modify anything else. Closes #112. BREAKING CHANGE: Replaced the `Circuit` trait with a wrapper struct. --------- Co-authored-by: Seyon Sivarajah --- Cargo.lock | 54 +++- Cargo.toml | 3 +- badger-optimiser/src/main.rs | 1 - tket2-py/src/circuit/convert.rs | 12 +- tket2-py/src/circuit/tk2circuit.rs | 40 +-- tket2-py/src/optimiser.rs | 11 +- tket2-py/src/passes.rs | 12 +- tket2-py/src/passes/chunks.rs | 13 +- tket2-py/src/pattern.rs | 27 +- tket2-py/src/pattern/portmatching.rs | 11 +- tket2-py/src/rewrite.rs | 10 +- tket2/Cargo.toml | 1 + tket2/src/circuit.rs | 296 ++++++++++++++++---- tket2/src/circuit/command.rs | 78 +++--- tket2/src/circuit/hash.rs | 30 +- tket2/src/circuit/units.rs | 29 +- tket2/src/json.rs | 37 ++- tket2/src/json/encoder.rs | 25 +- tket2/src/json/tests.rs | 17 +- tket2/src/lib.rs | 4 +- tket2/src/ops.rs | 9 +- tket2/src/optimiser/badger.rs | 67 +++-- tket2/src/optimiser/badger/eq_circ_class.rs | 13 +- tket2/src/optimiser/badger/hugr_pchannel.rs | 14 +- tket2/src/optimiser/badger/hugr_pqueue.rs | 36 +-- tket2/src/optimiser/badger/log.rs | 2 +- tket2/src/optimiser/badger/qtz_circuit.rs | 25 +- tket2/src/passes/chunks.rs | 89 +++--- tket2/src/passes/commutation.rs | 89 +++--- tket2/src/portmatching.rs | 33 ++- tket2/src/portmatching/matcher.rs | 61 ++-- tket2/src/portmatching/pattern.rs | 58 ++-- tket2/src/rewrite.rs | 47 ++-- tket2/src/rewrite/ecc_rewriter.rs | 56 ++-- tket2/src/rewrite/strategy.rs | 45 +-- tket2/src/rewrite/trace.rs | 25 +- tket2/src/utils.rs | 18 +- tket2/tests/badger_termination.rs | 5 +- 38 files changed, 846 insertions(+), 557 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e4a13327..7e9bab20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -325,6 +325,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "cool_asserts" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee9f254e53f61e2688d3677fa2cbe4e9b950afd56f48819c98817417cf6b28ec" +dependencies = [ + "indent_write", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -707,9 +716,19 @@ dependencies = [ [[package]] name = "hugr" -version = "0.4.0" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ebf436d9d4d0239fcb511a1f7d745548bbdc6316baf50bb2e944f294542ea24" +dependencies = [ + "hugr-core", + "hugr-passes", +] + +[[package]] +name = "hugr-core" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8e15eaffd64f1cceac13429e5ceaf20017691e647cc56b8b7c53d73ad9f8714" +checksum = "1f2b8cfdccedf45a563e526a4a2c7443025ece0e2518bca1c3f227318607de44" dependencies = [ "bitvec", "cgmath", @@ -719,7 +738,7 @@ dependencies = [ "downcast-rs", "enum_dispatch", "html-escape", - "itertools 0.12.1", + "itertools 0.13.0", "lazy_static", "num-rational", "paste", @@ -736,6 +755,19 @@ dependencies = [ "typetag", ] +[[package]] +name = "hugr-passes" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c1ae428e41786736cc92cbc3ae11ad39f178591d36a9ac3d1c8990e1a9bf76" +dependencies = [ + "hugr-core", + "itertools 0.13.0", + "lazy_static", + "paste", + "thiserror", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -769,6 +801,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indent_write" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cfe9645a18782869361d9c8732246be7b410ad4e919d3609ebabdac00ba12c3" + [[package]] name = "indexmap" version = "2.2.6" @@ -817,15 +855,6 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -1742,6 +1771,7 @@ dependencies = [ "bytemuck", "cgmath", "chrono", + "cool_asserts", "criterion", "crossbeam-channel", "csv", diff --git a/Cargo.toml b/Cargo.toml index 6c67a2b9..57ed1b73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ missing_docs = "warn" [workspace.dependencies] tket2 = { path = "./tket2" } -hugr = "0.4.0" +hugr = "0.5.0" portgraph = "0.12" pyo3 = "0.21.2" itertools = "0.13.0" @@ -60,3 +60,4 @@ tracing-subscriber = "0.3.17" typetag = "0.2.8" urlencoding = "2.1.2" webbrowser = "1.0.0" +cool_asserts = "2.0.3" diff --git a/badger-optimiser/src/main.rs b/badger-optimiser/src/main.rs index b87ecaef..08aa4d8e 100644 --- a/badger-optimiser/src/main.rs +++ b/badger-optimiser/src/main.rs @@ -15,7 +15,6 @@ use tket2::json::{load_tk1_json_file, save_tk1_json_file}; use tket2::optimiser::badger::log::BadgerLogger; use tket2::optimiser::badger::BadgerOptions; use tket2::optimiser::{BadgerOptimiser, DefaultBadgerOptimiser}; -use tket2::rewrite::trace::RewriteTracer; #[cfg(feature = "peak_alloc")] #[global_allocator] diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index 42c29813..c8f70d63 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -44,8 +44,10 @@ impl CircuitType { /// Converts a `Hugr` into the format indicated by the flag. pub fn convert(self, py: Python, hugr: Hugr) -> PyResult> { match self { - CircuitType::Tket1 => SerialCircuit::encode(&hugr).convert_pyerrs()?.to_tket1(py), - CircuitType::Tket2 => Ok(Bound::new(py, Tk2Circuit { hugr })?.into_any()), + CircuitType::Tket1 => SerialCircuit::encode(&hugr.into()) + .convert_pyerrs()? + .to_tket1(py), + CircuitType::Tket2 => Ok(Bound::new(py, Tk2Circuit { circ: hugr.into() })?.into_any()), } } } @@ -58,16 +60,16 @@ where E: ConvertPyErr, F: FnOnce(Hugr, CircuitType) -> Result, { - let (hugr, typ) = match Tk2Circuit::extract_bound(circ) { + let (circ, typ) = match Tk2Circuit::extract_bound(circ) { // hugr circuit - Ok(t2circ) => (t2circ.hugr, CircuitType::Tket2), + Ok(t2circ) => (t2circ.circ, CircuitType::Tket2), // tket1 circuit Err(_) => ( SerialCircuit::from_tket1(circ)?.decode().convert_pyerrs()?, CircuitType::Tket1, ), }; - (f)(hugr, typ).map_err(|e| e.convert_pyerrs()) + (f)(circ.into_hugr(), typ).map_err(|e| e.convert_pyerrs()) } /// Apply a function expecting a hugr on a python circuit. diff --git a/tket2-py/src/circuit/tk2circuit.rs b/tket2-py/src/circuit/tk2circuit.rs index a31f4a86..024a9c0f 100644 --- a/tket2-py/src/circuit/tk2circuit.rs +++ b/tket2-py/src/circuit/tk2circuit.rs @@ -55,7 +55,7 @@ use super::{cost, with_hugr, PyCircuitCost, PyCustom, PyHugrType, PyNode, PyWire #[derive(Clone, Debug, PartialEq, From)] pub struct Tk2Circuit { /// Rust representation of the circuit. - pub hugr: Hugr, + pub circ: Circuit, } #[pymethods] @@ -67,40 +67,40 @@ impl Tk2Circuit { #[new] pub fn new(circ: &Bound) -> PyResult { Ok(Self { - hugr: with_hugr(circ, |hugr, _| hugr)?, + circ: with_hugr(circ, |hugr, _| hugr)?.into(), }) } /// Convert the [`Tk2Circuit`] to a tket1 circuit. pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult> { - SerialCircuit::encode(&self.hugr) + SerialCircuit::encode(&self.circ) .convert_pyerrs()? .to_tket1(py) } /// Apply a rewrite on the circuit. pub fn apply_rewrite(&mut self, rw: PyCircuitRewrite) { - rw.rewrite.apply(&mut self.hugr).expect("Apply error."); + rw.rewrite.apply(&mut self.circ).expect("Apply error."); } /// Encode the circuit as a HUGR json string. // // TODO: Bind a messagepack encoder/decoder too. pub fn to_hugr_json(&self) -> PyResult { - Ok(serde_json::to_string(&self.hugr).unwrap()) + Ok(serde_json::to_string(self.circ.hugr()).unwrap()) } /// Decode a HUGR json string to a circuit. #[staticmethod] pub fn from_hugr_json(json: &str) -> PyResult { - let hugr = serde_json::from_str(json) + let hugr: Hugr = serde_json::from_str(json) .map_err(|e| PyErr::new::(format!("Invalid encoded HUGR: {e}")))?; - Ok(Tk2Circuit { hugr }) + Ok(Tk2Circuit { circ: hugr.into() }) } /// Encode the circuit as a tket1 json string. pub fn to_tket1_json(&self) -> PyResult { - Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr).convert_pyerrs()?).unwrap()) + Ok(serde_json::to_string(&SerialCircuit::encode(&self.circ).convert_pyerrs()?).unwrap()) } /// Decode a tket1 json string to a circuit. @@ -109,7 +109,7 @@ impl Tk2Circuit { let tk1: SerialCircuit = serde_json::from_str(json) .map_err(|e| PyErr::new::(format!("Invalid encoded HUGR: {e}")))?; Ok(Tk2Circuit { - hugr: tk1.decode().convert_pyerrs()?, + circ: tk1.decode().convert_pyerrs()?, }) } @@ -134,13 +134,13 @@ impl Tk2Circuit { cost: cost.to_object(py), }) }; - let circ_cost = self.hugr.circuit_cost(cost_fn)?; + let circ_cost = self.circ.circuit_cost(cost_fn)?; Ok(circ_cost.cost.into_bound(py)) } /// Returns a hash of the circuit. pub fn hash(&self) -> u64 { - self.hugr.circuit_hash().unwrap() + self.circ.circuit_hash().unwrap() } /// Hash the circuit @@ -160,7 +160,8 @@ impl Tk2Circuit { fn node_op(&self, node: PyNode) -> PyResult { let custom: CustomOp = self - .hugr + .circ + .hugr() .get_optype(node.node) .clone() .try_into() @@ -174,25 +175,27 @@ impl Tk2Circuit { } fn node_inputs(&self, node: PyNode) -> Vec { - self.hugr + self.circ + .hugr() .all_linked_outputs(node.node) .map(|(n, p)| Wire::new(n, p).into()) .collect() } fn node_outputs(&self, node: PyNode) -> Vec { - self.hugr + self.circ + .hugr() .node_outputs(node.node) .map(|p| Wire::new(node.node, p).into()) .collect() } fn input_node(&self) -> PyNode { - self.hugr.input().into() + self.circ.input_node().into() } fn output_node(&self) -> PyNode { - self.hugr.output().into() + self.circ.output_node().into() } } impl Tk2Circuit { @@ -236,11 +239,12 @@ impl Dfg { fn finish(&mut self, outputs: Vec) -> PyResult { Ok(Tk2Circuit { - hugr: self + circ: self .builder .clone() .finish_hugr_with_outputs(outputs.into_iter().map_into(), ®ISTRY) - .convert_pyerrs()?, + .convert_pyerrs()? + .into(), }) } } diff --git a/tket2-py/src/optimiser.rs b/tket2-py/src/optimiser.rs index 49c5abc6..03166d0f 100644 --- a/tket2-py/src/optimiser.rs +++ b/tket2-py/src/optimiser.rs @@ -3,10 +3,10 @@ use std::io::BufWriter; use std::{fs, num::NonZeroUsize, path::PathBuf}; -use hugr::Hugr; use pyo3::prelude::*; use tket2::optimiser::badger::BadgerOptions; use tket2::optimiser::{BadgerLogger, DefaultBadgerOptimiser}; +use tket2::Circuit; use crate::circuit::update_hugr; @@ -96,7 +96,10 @@ impl PyBadgerOptimiser { split_circuit: split_circ.unwrap_or(false), queue_size: queue_size.unwrap_or(100), }; - update_hugr(circ, |circ, _| self.optimise(circ, log_progress, options)) + update_hugr(circ, |circ, _| { + self.optimise(circ.into(), log_progress, options) + .into_hugr() + }) } } @@ -104,10 +107,10 @@ impl PyBadgerOptimiser { /// The Python optimise method, but on Hugrs. pub(super) fn optimise( &self, - circ: Hugr, + circ: Circuit, log_progress: Option, options: BadgerOptions, - ) -> Hugr { + ) -> Circuit { let badger_logger = log_progress .map(|file_name| { let log_file = fs::File::create(file_name).unwrap(); diff --git a/tket2-py/src/passes.rs b/tket2-py/src/passes.rs index d05d37f0..b9488573 100644 --- a/tket2-py/src/passes.rs +++ b/tket2-py/src/passes.rs @@ -39,9 +39,10 @@ create_py_exception!( #[pyfunction] fn greedy_depth_reduce<'py>(circ: &Bound<'py, PyAny>) -> PyResult<(Bound<'py, PyAny>, u32)> { let py = circ.py(); - try_with_hugr(circ, |mut h, typ| { - let n_moves = apply_greedy_commutation(&mut h).convert_pyerrs()?; - let circ = typ.convert(py, h)?; + try_with_hugr(circ, |h, typ| { + let mut circ: Circuit = h.into(); + let n_moves = apply_greedy_commutation(&mut circ).convert_pyerrs()?; + let circ = typ.convert(py, circ.into_hugr())?; PyResult::Ok((circ, n_moves)) }) } @@ -117,7 +118,8 @@ fn badger_optimise<'py>( _ => unreachable!(), }; // Optimise - try_update_hugr(circ, |mut circ, _| { + try_update_hugr(circ, |hugr, _| { + let mut circ: Circuit = hugr.into(); let n_cx = circ .commands() .filter(|c| op_matches(c.optype(), Tk2Op::CX)) @@ -142,6 +144,6 @@ fn badger_optimise<'py>( }; circ = optimiser.optimise(circ, log_file, options); } - PyResult::Ok(circ) + PyResult::Ok(circ.into_hugr()) }) } diff --git a/tket2-py/src/passes/chunks.rs b/tket2-py/src/passes/chunks.rs index ce81fa16..ad973d7a 100644 --- a/tket2-py/src/passes/chunks.rs +++ b/tket2-py/src/passes/chunks.rs @@ -15,7 +15,7 @@ use crate::utils::ConvertPyErr; pub fn chunks(c: &Bound, max_chunk_size: usize) -> PyResult { with_hugr(c, |hugr, typ| { // TODO: Detect if the circuit is in tket1 format or Tk2Circuit. - let chunks = CircuitChunks::split(&hugr, max_chunk_size); + let chunks = CircuitChunks::split(&hugr.into(), max_chunk_size); (chunks, typ).into() }) } @@ -39,27 +39,28 @@ pub struct PyCircuitChunks { impl PyCircuitChunks { /// Reassemble the chunks into a circuit. fn reassemble<'py>(&self, py: Python<'py>) -> PyResult> { - let hugr = self.clone().chunks.reassemble().convert_pyerrs()?; - self.original_type.convert(py, hugr) + let circ = self.clone().chunks.reassemble().convert_pyerrs()?; + self.original_type.convert(py, circ.into_hugr()) } /// Returns clones of the split circuits. fn circuits<'py>(&self, py: Python<'py>) -> PyResult>> { self.chunks .iter() - .map(|hugr| self.original_type.convert(py, hugr.clone())) + .map(|circ| self.original_type.convert(py, circ.clone().into_hugr())) .collect() } /// Replaces a chunk's circuit with an updated version. fn update_circuit(&mut self, index: usize, new_circ: &Bound) -> PyResult<()> { try_with_hugr(new_circ, |hugr, _| { - if hugr.circuit_signature() != self.chunks[index].circuit_signature() { + let circ: Circuit = hugr.into(); + if circ.circuit_signature() != self.chunks[index].circuit_signature() { return Err(PyAttributeError::new_err( "The new circuit has a different signature.", )); } - self.chunks[index] = hugr; + self.chunks[index] = circ; Ok(()) }) } diff --git a/tket2-py/src/pattern.rs b/tket2-py/src/pattern.rs index 61d433b1..86081506 100644 --- a/tket2-py/src/pattern.rs +++ b/tket2-py/src/pattern.rs @@ -6,9 +6,10 @@ use crate::circuit::Tk2Circuit; use crate::rewrite::PyCircuitRewrite; use crate::utils::{create_py_exception, ConvertPyErr}; -use hugr::Hugr; +use hugr::HugrView; use pyo3::prelude::*; use tket2::portmatching::{CircuitPattern, PatternMatch, PatternMatcher}; +use tket2::Circuit; /// The module definition pub fn module(py: Python<'_>) -> PyResult> { @@ -47,7 +48,7 @@ create_py_exception!( #[derive(Clone)] #[pyclass] /// A rewrite rule defined by a left hand side and right hand side of an equation. -pub struct Rule(pub [Hugr; 2]); +pub struct Rule(pub [Circuit; 2]); #[pymethods] impl Rule { @@ -56,13 +57,13 @@ impl Rule { let l = Tk2Circuit::new(l)?; let r = Tk2Circuit::new(r)?; - Ok(Rule([l.hugr, r.hugr])) + Ok(Rule([l.circ, r.circ])) } } #[pyclass] struct RuleMatcher { matcher: PatternMatcher, - rights: Vec, + rights: Vec, } #[pymethods] @@ -79,25 +80,29 @@ impl RuleMatcher { } pub fn find_match(&self, target: &Tk2Circuit) -> PyResult> { - let h = &target.hugr; - if let Some(pmatch) = self.matcher.find_matches_iter(h).next() { - Ok(Some(self.match_to_rewrite(pmatch, h)?)) + let circ = &target.circ; + if let Some(pmatch) = self.matcher.find_matches_iter(circ).next() { + Ok(Some(self.match_to_rewrite(pmatch, circ)?)) } else { Ok(None) } } pub fn find_matches(&self, target: &Tk2Circuit) -> PyResult> { - let h = &target.hugr; + let circ = &target.circ; self.matcher - .find_matches_iter(h) - .map(|m| self.match_to_rewrite(m, h)) + .find_matches_iter(circ) + .map(|m| self.match_to_rewrite(m, circ)) .collect() } } impl RuleMatcher { - fn match_to_rewrite(&self, pmatch: PatternMatch, target: &Hugr) -> PyResult { + fn match_to_rewrite( + &self, + pmatch: PatternMatch, + target: &Circuit, + ) -> PyResult { let r = self.rights.get(pmatch.pattern_id().0).unwrap().clone(); let rw = pmatch.to_rewrite(target, r).convert_pyerrs()?; Ok(rw.into()) diff --git a/tket2-py/src/pattern/portmatching.rs b/tket2-py/src/pattern/portmatching.rs index 4ed90b2e..49695fe4 100644 --- a/tket2-py/src/pattern/portmatching.rs +++ b/tket2-py/src/pattern/portmatching.rs @@ -30,7 +30,9 @@ impl PyCircuitPattern { /// Construct a pattern from a TKET1 circuit #[new] pub fn from_circuit(circ: &Bound) -> PyResult { - let pattern = try_with_hugr(circ, |circ, _| CircuitPattern::try_from_circuit(&circ))?; + let pattern = try_with_hugr(circ, |circ, _| { + CircuitPattern::try_from_circuit(&circ.into()) + })?; Ok(pattern.into()) } @@ -81,7 +83,10 @@ impl PyPatternMatcher { /// Find one convex match in a circuit. pub fn find_match(&self, circ: &Bound) -> PyResult> { with_hugr(circ, |circ, _| { - self.matcher.find_matches_iter(&circ).next().map(Into::into) + self.matcher + .find_matches_iter(&circ.into()) + .next() + .map(Into::into) }) } @@ -89,7 +94,7 @@ impl PyPatternMatcher { pub fn find_matches(&self, circ: &Bound) -> PyResult> { with_hugr(circ, |circ, _| { self.matcher - .find_matches(&circ) + .find_matches(&circ.into()) .into_iter() .map_into() .collect() diff --git a/tket2-py/src/rewrite.rs b/tket2-py/src/rewrite.rs index 5c01038b..18c34b41 100644 --- a/tket2-py/src/rewrite.rs +++ b/tket2-py/src/rewrite.rs @@ -43,7 +43,7 @@ impl PyCircuitRewrite { /// The replacement subcircuit. pub fn replacement(&self) -> Tk2Circuit { - self.rewrite.replacement().clone().into() + self.rewrite.replacement().to_owned().into() } #[new] @@ -55,8 +55,8 @@ impl PyCircuitRewrite { Ok(Self { rewrite: CircuitRewrite::try_new( &source_position.0, - &source_circ.hugr, - replacement.hugr, + &source_circ.circ, + replacement.circ, ) .map_err(|e| PyErr::new::(e.to_string()))?, }) @@ -80,7 +80,7 @@ impl PySubcircuit { fn from_nodes(nodes: Vec, circ: &Tk2Circuit) -> PyResult { let nodes: Vec<_> = nodes.into_iter().map_into().collect(); Ok(Self( - Subcircuit::try_from_nodes(nodes, &circ.hugr) + Subcircuit::try_from_nodes(nodes, &circ.circ) .map_err(|e| PyErr::new::(e.to_string()))?, )) } @@ -107,7 +107,7 @@ impl PyECCRewriter { /// Returns a list of circuit rewrites that can be applied to the given Tk2Circuit. pub fn get_rewrites(&self, circ: &Tk2Circuit) -> Vec { self.0 - .get_rewrites(&circ.hugr) + .get_rewrites(&circ.circ) .into_iter() .map_into() .collect() diff --git a/tket2/Cargo.toml b/tket2/Cargo.toml index 0e81385a..f5d6e4e9 100644 --- a/tket2/Cargo.toml +++ b/tket2/Cargo.toml @@ -68,6 +68,7 @@ rstest = { workspace = true } criterion = { workspace = true, features = ["html_reports"] } webbrowser = { workspace = true } urlencoding = { workspace = true } +cool_asserts = { workspace = true } [[bench]] name = "bench_main" diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index 348876d4..d6d90576 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -15,9 +15,9 @@ use derive_more::From; use hugr::hugr::hugrmut::HugrMut; use hugr::hugr::NodeType; use hugr::ops::dataflow::IOTrait; -use hugr::ops::{Input, Output, DFG}; +use hugr::ops::{Input, NamedOp, OpParent, OpTag, OpTrait, Output, DFG}; use hugr::types::PolyFuncType; -use hugr::PortIndex; +use hugr::{Hugr, PortIndex}; use hugr::{HugrView, OutgoingPort}; use itertools::Itertools; use thiserror::Error; @@ -28,54 +28,138 @@ pub use hugr::{Node, Port, Wire}; use self::units::{filter, LinearUnit, Units}; -/// An object behaving like a quantum circuit. -// -// TODO: More methods: -// - other_{in,out}puts (for non-linear i/o + const inputs)? -// - Vertical slice iterator -// - Depth -pub trait Circuit: HugrView { +/// A quantum circuit, represented as a function in a HUGR. +#[derive(Debug, Clone, PartialEq)] +pub struct Circuit { + /// The HUGR containing the circuit. + hugr: T, + /// The parent node of the circuit. + /// + /// This is checked at runtime to ensure that the node is a DFG node. + parent: Node, +} + +impl Default for Circuit { + fn default() -> Self { + let hugr = T::default(); + let parent = hugr.root(); + Self { hugr, parent } + } +} + +impl Circuit { + /// Create a new circuit from a HUGR and a node. + /// + /// # Errors + /// + /// Returns an error if the parent node is not a DFG node in the HUGR. + pub fn try_new(hugr: T, parent: Node) -> Result { + check_hugr(&hugr, parent)?; + Ok(Self { hugr, parent }) + } + + /// Create a new circuit from a HUGR and a node. + /// + /// See [`Circuit::try_new`] for a version that returns an error. + /// + /// # Panics + /// + /// Panics if the parent node is not a DFG node in the HUGR. + pub fn new(hugr: T, parent: Node) -> Self { + Self::try_new(hugr, parent).unwrap_or_else(|e| panic!("{}", e)) + } + + /// Returns the node containing the circuit definition. + pub fn parent(&self) -> Node { + self.parent + } + + /// Get a reference to the HUGR containing the circuit. + pub fn hugr(&self) -> &T { + &self.hugr + } + + /// Unwrap the HUGR containing the circuit. + pub fn into_hugr(self) -> T { + self.hugr + } + + /// Get a mutable reference to the HUGR containing the circuit. + /// + /// Mutation of the hugr MUST NOT invalidate the parent node, + /// by changing the node's type to a non-DFG node or by removing it. + pub fn hugr_mut(&mut self) -> &mut T { + &mut self.hugr + } + + /// Ensures the circuit contains an owned HUGR. + pub fn to_owned(&self) -> Circuit { + let hugr = self.hugr.base_hugr().clone(); + Circuit { + hugr, + parent: self.parent, + } + } + /// Return the name of the circuit #[inline] - fn name(&self) -> Option<&str> { - self.get_metadata(self.root(), "name")?.as_str() + pub fn name(&self) -> Option<&str> { + self.hugr.get_metadata(self.parent(), "name")?.as_str() } /// Returns the function type of the circuit. /// /// Equivalent to [`HugrView::get_function_type`]. #[inline] - fn circuit_signature(&self) -> PolyFuncType { - self.get_function_type() - .expect("Circuit has no function type") + pub fn circuit_signature(&self) -> PolyFuncType { + let op = self.hugr.get_optype(self.parent); + match op { + OpType::FuncDecl(decl) => decl.signature.clone(), + OpType::FuncDefn(defn) => defn.signature.clone(), + _ => op + .inner_function_type() + .expect("Circuit parent should have a function type") + .into(), + } } /// Returns the input node to the circuit. #[inline] - fn input(&self) -> Node { - self.get_io(self.root()).expect("Circuit has no input node")[0] + pub fn input_node(&self) -> Node { + self.hugr + .get_io(self.parent) + .expect("Circuit has no input node")[0] } /// Returns the output node to the circuit. #[inline] - fn output(&self) -> Node { - self.get_io(self.root()) + pub fn output_node(&self) -> Node { + self.hugr + .get_io(self.parent) .expect("Circuit has no output node")[1] } + /// Returns the input and output nodes of the circuit. + #[inline] + pub fn io_nodes(&self) -> [Node; 2] { + self.hugr + .get_io(self.parent) + .expect("Circuit has no I/O nodes") + } + /// The number of quantum gates in the circuit. #[inline] - fn num_gates(&self) -> usize + pub fn num_gates(&self) -> usize where Self: Sized, { - // TODO: Implement discern quantum gates in the commands iterator. - self.children(self.root()).count() - 2 + // TODO: Discern quantum gates in the commands iterator. + self.hugr().children(self.parent).count() - 2 } /// Count the number of qubits in the circuit. #[inline] - fn qubit_count(&self) -> usize + pub fn qubit_count(&self) -> usize where Self: Sized, { @@ -84,7 +168,7 @@ pub trait Circuit: HugrView { /// Get the input units of the circuit and their types. #[inline] - fn units(&self) -> Units + pub fn units(&self) -> Units where Self: Sized, { @@ -93,7 +177,7 @@ pub trait Circuit: HugrView { /// Get the linear input units of the circuit and their types. #[inline] - fn linear_units(&self) -> impl Iterator + '_ + pub fn linear_units(&self) -> impl Iterator + '_ where Self: Sized, { @@ -102,7 +186,7 @@ pub trait Circuit: HugrView { /// Get the non-linear input units of the circuit and their types. #[inline] - fn nonlinear_units(&self) -> impl Iterator + '_ + pub fn nonlinear_units(&self) -> impl Iterator + '_ where Self: Sized, { @@ -111,7 +195,7 @@ pub trait Circuit: HugrView { /// Returns the units corresponding to qubits inputs to the circuit. #[inline] - fn qubits(&self) -> impl Iterator + '_ + pub fn qubits(&self) -> impl Iterator + '_ where Self: Sized, { @@ -122,7 +206,7 @@ pub trait Circuit: HugrView { /// /// Ignores the Input and Output nodes. #[inline] - fn commands(&self) -> CommandIterator<'_, Self> + pub fn commands(&self) -> CommandIterator<'_, T> where Self: Sized, { @@ -132,7 +216,7 @@ pub trait Circuit: HugrView { /// Compute the cost of the circuit based on a per-operation cost function. #[inline] - fn circuit_cost(&self, op_cost: F) -> C + pub fn circuit_cost(&self, op_cost: F) -> C where Self: Sized, C: Sum, @@ -144,15 +228,62 @@ pub trait Circuit: HugrView { /// Compute the cost of a group of nodes in a circuit based on a /// per-operation cost function. #[inline] - fn nodes_cost(&self, nodes: impl IntoIterator, op_cost: F) -> C + pub fn nodes_cost(&self, nodes: impl IntoIterator, op_cost: F) -> C where C: Sum, F: Fn(&OpType) -> C, { - nodes.into_iter().map(|n| op_cost(self.get_optype(n))).sum() + nodes + .into_iter() + .map(|n| op_cost(self.hugr.get_optype(n))) + .sum() + } + + /// Return the graphviz representation of the underlying graph and hierarchy side by side. + /// + /// For a simpler representation, use the [`Circuit::mermaid_string`] format instead. + pub fn dot_string(&self) -> String { + // TODO: This will print the whole HUGR without identifying the circuit container. + // Should we add some extra formatting for that? + self.hugr.dot_string() + } + + /// Return the mermaid representation of the underlying hierarchical graph. + /// + /// The hierarchy is represented using subgraphs. Edges are labelled with + /// their source and target ports. + /// + /// For a more detailed representation, use the [`Circuit::dot_string`] + /// format instead. + pub fn mermaid_string(&self) -> String { + // TODO: See comment in `dot_string`. + self.hugr.mermaid_string() + } +} + +impl From for Circuit { + fn from(hugr: T) -> Self { + let parent = hugr.root(); + Self::new(hugr, parent) } } +/// Checks if the passed hugr is a valid circuit, +/// and return [`CircuitError`] if not. +fn check_hugr(hugr: &impl HugrView, parent: Node) -> Result<(), CircuitError> { + if !hugr.contains_node(parent) { + return Err(CircuitError::MissingParentNode { parent }); + } + let optype = hugr.get_optype(parent); + if !OpTag::DataflowParent.is_superset(optype.tag()) { + return Err(CircuitError::NonDFGParent { + parent, + optype: optype.clone(), + }); + } + Ok(()) +} + /// Remove an empty wire in a dataflow HUGR. /// /// The wire to be removed is identified by the index of the outgoing port @@ -167,15 +298,18 @@ pub trait Circuit: HugrView { /// occurs. #[allow(dead_code)] pub(crate) fn remove_empty_wire( - circ: &mut impl HugrMut, + circ: &mut Circuit, input_port: usize, ) -> Result<(), CircuitMutError> { - let [inp, out] = circ.get_io(circ.root()).expect("no IO nodes found at root"); - if input_port >= circ.num_outputs(inp) { + let parent = circ.parent(); + let hugr = circ.hugr_mut(); + + let [inp, out] = hugr.get_io(parent).expect("no IO nodes found at parent"); + if input_port >= hugr.num_outputs(inp) { return Err(CircuitMutError::InvalidPortOffset(input_port)); } let input_port = OutgoingPort::from(input_port); - let link = circ + let link = hugr .linked_inputs(inp, input_port) .at_most_one() .map_err(|_| CircuitMutError::DeleteNonEmptyWire(input_port.index()))?; @@ -183,25 +317,52 @@ pub(crate) fn remove_empty_wire( return Err(CircuitMutError::DeleteNonEmptyWire(input_port.index())); } if link.is_some() { - circ.disconnect(inp, input_port); + hugr.disconnect(inp, input_port); } // Shift ports at input - shift_ports(circ, inp, input_port, circ.num_outputs(inp))?; + shift_ports(hugr, inp, input_port, hugr.num_outputs(inp))?; // Shift ports at output if let Some((out, output_port)) = link { - shift_ports(circ, out, output_port, circ.num_inputs(out))?; + shift_ports(hugr, out, output_port, hugr.num_inputs(out))?; } - // Update input node, output node (if necessary) and root signatures. - update_signature(circ, input_port.index(), link.map(|(_, p)| p.index())); + // Update input node, output node (if necessary) and parent signatures. + update_signature( + hugr, + parent, + input_port.index(), + link.map(|(_, p)| p.index()), + ); // Resize ports at input/output node - circ.set_num_ports(inp, 0, circ.num_outputs(inp) - 1); + hugr.set_num_ports(inp, 0, hugr.num_outputs(inp) - 1); if let Some((out, _)) = link { - circ.set_num_ports(out, circ.num_inputs(out) - 1, 0); + hugr.set_num_ports(out, hugr.num_inputs(out) - 1, 0); } Ok(()) } +/// Errors that can occur when mutating a circuit. +#[derive(Debug, Clone, Error, PartialEq)] +pub enum CircuitError { + /// The parent node for the circuit does not exist in the HUGR. + #[error("{parent} cannot define a circuit as it is not present in the HUGR.")] + MissingParentNode { + /// The node that was used as the parent. + parent: Node, + }, + /// The parent node for the circuit is not a DFG node. + #[error( + "{parent} cannot be used as a circuit parent. A {} is not a dataflow container.", + optype.name() + )] + NonDFGParent { + /// The node that was used as the parent. + parent: Node, + /// The parent optype. + optype: OpType, + }, +} + /// Errors that can occur when mutating a circuit. #[derive(Debug, Clone, Error, PartialEq, Eq, From)] pub enum CircuitMutError { @@ -252,15 +413,18 @@ fn shift_ports( // Update the signature of circ when removing the in_index-th input wire and // the out_index-th output wire. -fn update_signature( - circ: &mut C, +fn update_signature( + hugr: &mut impl HugrMut, + parent: Node, in_index: usize, out_index: Option, ) { - let inp = circ.input(); + let inp = hugr + .get_io(parent) + .expect("no IO nodes found at circuit parent")[0]; // Update input node let inp_types: TypeRow = { - let OpType::Input(Input { types }) = circ.get_optype(inp).clone() else { + let OpType::Input(Input { types }) = hugr.get_optype(inp).clone() else { panic!("invalid circuit") }; let mut types = types.into_owned(); @@ -268,15 +432,15 @@ fn update_signature( types.into() }; let new_inp_op = Input::new(inp_types.clone()); - let inp_exts = circ.get_nodetype(inp).input_extensions().cloned(); - circ.replace_op(inp, NodeType::new(new_inp_op, inp_exts)) + let inp_exts = hugr.get_nodetype(inp).input_extensions().cloned(); + hugr.replace_op(inp, NodeType::new(new_inp_op, inp_exts)) .unwrap(); // Update output node if necessary. let out_types = out_index.map(|out_index| { - let out = circ.output(); + let out = hugr.get_io(parent).unwrap()[1]; let out_types: TypeRow = { - let OpType::Output(Output { types }) = circ.get_optype(out).clone() else { + let OpType::Output(Output { types }) = hugr.get_optype(out).clone() else { panic!("invalid circuit") }; let mut types = types.into_owned(); @@ -284,14 +448,14 @@ fn update_signature( types.into() }; let new_out_op = Output::new(out_types.clone()); - let inp_exts = circ.get_nodetype(out).input_extensions().cloned(); - circ.replace_op(out, NodeType::new(new_out_op, inp_exts)) + let inp_exts = hugr.get_nodetype(out).input_extensions().cloned(); + hugr.replace_op(out, NodeType::new(new_out_op, inp_exts)) .unwrap(); out_types }); - // Update root - let OpType::DFG(DFG { mut signature, .. }) = circ.get_optype(circ.root()).clone() else { + // Update parent + let OpType::DFG(DFG { mut signature, .. }) = hugr.get_optype(parent).clone() else { panic!("invalid circuit") }; signature.input = inp_types; @@ -299,26 +463,25 @@ fn update_signature( signature.output = out_types; } let new_dfg_op = DFG { signature }; - let inp_exts = circ.get_nodetype(circ.root()).input_extensions().cloned(); - circ.replace_op(circ.root(), NodeType::new(new_dfg_op, inp_exts)) + let inp_exts = hugr.get_nodetype(parent).input_extensions().cloned(); + hugr.replace_op(parent, NodeType::new(new_dfg_op, inp_exts)) .unwrap(); } -impl Circuit for T where T: HugrView {} - #[cfg(test)] mod tests { + use cool_asserts::assert_matches; + use hugr::types::FunctionType; use hugr::{ builder::{DFGBuilder, DataflowHugr}, extension::{prelude::BOOL_T, PRELUDE_REGISTRY}, - Hugr, }; use super::*; use crate::{json::load_tk1_json_str, utils::build_simple_circuit, Tk2Op}; - fn test_circuit() -> Hugr { + fn test_circuit() -> Circuit { load_tk1_json_str( r#"{ "phase": "0", "bits": [["c", [0]]], @@ -367,10 +530,23 @@ mod tests { ); } + #[test] + fn test_invalid_parent() { + let hugr = Hugr::default(); + + assert_matches!( + Circuit::try_new(hugr.clone(), hugr.root()), + Err(CircuitError::NonDFGParent { .. }), + ); + } + #[test] fn remove_bit() { let h = DFGBuilder::new(FunctionType::new(vec![BOOL_T], vec![])).unwrap(); - let mut circ = h.finish_hugr_with_outputs([], &PRELUDE_REGISTRY).unwrap(); + let mut circ: Circuit = h + .finish_hugr_with_outputs([], &PRELUDE_REGISTRY) + .unwrap() + .into(); assert_eq!(circ.units().count(), 1); assert!(remove_empty_wire(&mut circ, 0).is_ok()); diff --git a/tket2/src/circuit/command.rs b/tket2/src/circuit/command.rs index 84ed2a18..09af12a1 100644 --- a/tket2/src/circuit/command.rs +++ b/tket2/src/circuit/command.rs @@ -8,7 +8,7 @@ use std::iter::FusedIterator; use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; -use hugr::{IncomingPort, OutgoingPort}; +use hugr::{HugrView, IncomingPort, OutgoingPort}; use itertools::Either::{self, Left, Right}; use itertools::{EitherOrBoth, Itertools}; use petgraph::visit as pv; @@ -21,9 +21,9 @@ pub use hugr::types::{EdgeKind, Type, TypeRow}; pub use hugr::{CircuitUnit, Direction, Node, Port, PortIndex, Wire}; /// An operation applied to specific wires. -pub struct Command<'circ, Circ> { +pub struct Command<'circ, T> { /// The circuit. - circ: &'circ Circ, + circ: &'circ Circuit, /// The operation node. node: Node, /// An assignment of linear units to the node's input ports. @@ -32,7 +32,7 @@ pub struct Command<'circ, Circ> { output_linear_units: Vec, } -impl<'circ, Circ: Circuit> Command<'circ, Circ> { +impl<'circ, T: HugrView> Command<'circ, T> { /// Returns the node corresponding to this command. #[inline] pub fn node(&self) -> Node { @@ -42,13 +42,13 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { /// Returns the [`NodeType`] of the command. #[inline] pub fn nodetype(&self) -> &NodeType { - self.circ.get_nodetype(self.node) + self.circ.hugr().get_nodetype(self.node) } /// Returns the [`OpType`] of the command. #[inline] pub fn optype(&self) -> &OpType { - self.circ.get_optype(self.node) + self.circ.hugr().get_optype(self.node) } /// Returns the units of this command in a given direction. @@ -162,7 +162,7 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { } } -impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { +impl<'a, 'circ, T: HugrView> UnitLabeller for &'a Command<'circ, T> { #[inline] fn assign_linear(&self, _: Node, port: Port, _linear_count: usize) -> LinearUnit { let units = match port.direction() { @@ -181,7 +181,7 @@ impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { fn assign_wire(&self, node: Node, port: Port) -> Option { match port.as_directed() { Left(to_port) => { - let (from, from_port) = self.circ.linked_outputs(node, to_port).next()?; + let (from, from_port) = self.circ.hugr().linked_outputs(node, to_port).next()?; Some(Wire::new(from, from_port)) } Right(from_port) => Some(Wire::new(node, from_port)), @@ -189,7 +189,7 @@ impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { } } -impl<'circ, Circ: Circuit> std::fmt::Debug for Command<'circ, Circ> { +impl<'circ, T: HugrView> std::fmt::Debug for Command<'circ, T> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Command") .field("circuit name", &self.circ.name()) @@ -200,7 +200,7 @@ impl<'circ, Circ: Circuit> std::fmt::Debug for Command<'circ, Circ> { } } -impl<'circ, Circ> PartialEq for Command<'circ, Circ> { +impl<'circ, T: HugrView> PartialEq for Command<'circ, T> { fn eq(&self, other: &Self) -> bool { self.node == other.node && self.input_linear_units == other.input_linear_units @@ -208,9 +208,9 @@ impl<'circ, Circ> PartialEq for Command<'circ, Circ> { } } -impl<'circ, Circ> Eq for Command<'circ, Circ> {} +impl<'circ, T: HugrView> Eq for Command<'circ, T> {} -impl<'circ, Circ> Clone for Command<'circ, Circ> { +impl<'circ, T: HugrView> Clone for Command<'circ, T> { fn clone(&self) -> Self { Self { circ: self.circ, @@ -221,7 +221,7 @@ impl<'circ, Circ> Clone for Command<'circ, Circ> { } } -impl<'circ, Circ> std::hash::Hash for Command<'circ, Circ> { +impl<'circ, T: HugrView> std::hash::Hash for Command<'circ, T> { fn hash(&self, state: &mut H) { self.node.hash(state); self.input_linear_units.hash(state); @@ -234,9 +234,9 @@ type NodeWalker = pv::Topo>; /// An iterator over the commands of a circuit. #[derive(Clone)] -pub struct CommandIterator<'circ, Circ> { +pub struct CommandIterator<'circ, T> { /// The circuit. - circ: &'circ Circ, + circ: &'circ Circuit, /// Toposorted nodes. nodes: NodeWalker, /// Last wire for each [`LinearUnit`] in the circuit. @@ -263,28 +263,25 @@ pub struct CommandIterator<'circ, Circ> { delayed_node: Option, } -impl<'circ, Circ> CommandIterator<'circ, Circ> -where - Circ: Circuit, -{ +impl<'circ, T: HugrView> CommandIterator<'circ, T> { /// Create a new iterator over the commands of a circuit. - pub(super) fn new(circ: &'circ Circ) -> Self { + pub(super) fn new(circ: &'circ Circuit) -> Self { // Initialize the map assigning linear units to the input's linear // ports. // // TODO: `with_wires` combinator for `Units`? let wire_unit = circ .linear_units() - .map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit.index())) + .map(|(linear_unit, port, _)| (Wire::new(circ.input_node(), port), linear_unit.index())) .collect(); - let nodes = pv::Topo::new(&circ.as_petgraph()); + let nodes = pv::Topo::new(&circ.hugr().as_petgraph()); Self { circ, nodes, wire_unit, // Ignore the input and output nodes, and the root. - remaining: circ.node_count() - 3, + remaining: circ.hugr().node_count() - 3, delayed_consts: HashSet::new(), delayed_consumers: HashMap::new(), delayed_node: None, @@ -299,13 +296,13 @@ where let node = self .delayed_node .take() - .or_else(|| self.nodes.next(&self.circ.as_petgraph()))?; + .or_else(|| self.nodes.next(&self.circ.hugr().as_petgraph()))?; // If this node is a constant or load const node, delay it. - let tag = self.circ.get_optype(node).tag(); + let tag = self.circ.hugr().get_optype(node).tag(); if tag == OpTag::Const || tag == OpTag::LoadConst { self.delayed_consts.insert(node); - for consumer in self.circ.output_neighbours(node) { + for consumer in self.circ.hugr().output_neighbours(node) { *self.delayed_consumers.entry(consumer).or_default() += 1; } return self.next_node(); @@ -316,7 +313,7 @@ where true => { let delayed = self.next_delayed_node(node); self.delayed_consts.remove(&delayed); - for consumer in self.circ.output_neighbours(delayed) { + for consumer in self.circ.hugr().output_neighbours(delayed) { let Entry::Occupied(mut entry) = self.delayed_consumers.entry(consumer) else { panic!("Delayed node consumer was not in delayed_consumers. Delayed node: {delayed:?}, consumer: {consumer:?}."); }; @@ -336,6 +333,7 @@ where fn next_delayed_node(&mut self, consumer: Node) -> Node { let Some(delayed_pred) = self .circ + .hugr() .input_neighbours(consumer) .find(|k| self.delayed_consts.contains(k)) else { @@ -359,12 +357,12 @@ where /// mutable borrow here. fn process_node(&mut self, node: Node) -> Option<(Vec, Vec)> { // The root node is ignored. - if node == self.circ.root() { + if node == self.circ.parent() { return None; } // Inputs and outputs are also ignored. // The input wire ids are already set in the `wire_unit` map during initialization. - let tag = self.circ.get_optype(node).tag(); + let tag = self.circ.hugr().get_optype(node).tag(); if tag == OpTag::Input || tag == OpTag::Output { return None; } @@ -390,7 +388,7 @@ where // Returns the linear id of the terminated unit. let mut terminate_input = |port: IncomingPort, wire_unit: &mut HashMap| -> Option { - let linear_id = self.circ.single_linked_output(node, port).and_then( + let linear_id = self.circ.hugr().single_linked_output(node, port).and_then( |(wire_node, wire_port)| wire_unit.remove(&Wire::new(wire_node, wire_port)), )?; input_linear_units.push(LinearUnit::new(linear_id)); @@ -425,11 +423,8 @@ where } } -impl<'circ, Circ> Iterator for CommandIterator<'circ, Circ> -where - Circ: Circuit, -{ - type Item = Command<'circ, Circ>; +impl<'circ, T: HugrView> Iterator for CommandIterator<'circ, T> { + type Item = Command<'circ, T>; #[inline] fn next(&mut self) -> Option { @@ -454,9 +449,9 @@ where } } -impl<'circ, Circ> FusedIterator for CommandIterator<'circ, Circ> where Circ: Circuit {} +impl<'circ, T: HugrView> FusedIterator for CommandIterator<'circ, T> {} -impl<'circ, Circ: Circuit> std::fmt::Debug for CommandIterator<'circ, Circ> { +impl<'circ, T: HugrView> std::fmt::Debug for CommandIterator<'circ, T> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("CommandIterator") .field("circuit name", &self.circ.name()) @@ -558,9 +553,10 @@ mod test { .add_dataflow_op(Tk2Op::RzF64, [q_in, loaded_const]) .unwrap(); - let circ = h + let circ: Circuit = h .finish_hugr_with_outputs(rz.outputs(), &FLOAT_OPS_REGISTRY) - .unwrap(); + .unwrap() + .into(); assert_eq!(CommandIterator::new(&circ).count(), 3); let mut commands = CommandIterator::new(&circ); @@ -631,7 +627,7 @@ mod test { let free = h.add_dataflow_op(Tk2Op::QFree, [q_in])?; - let circ = h.finish_hugr_with_outputs([q_new], ®ISTRY)?; + let circ: Circuit = h.finish_hugr_with_outputs([q_new], ®ISTRY)?.into(); let mut cmds = circ.commands(); @@ -678,7 +674,7 @@ mod test { let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), vec![]))?; let [q_in] = h.input_wires_arr(); h.add_dataflow_op(Tk2Op::QFree, [q_in])?; - let circ = h.finish_hugr_with_outputs([], ®ISTRY)?; + let circ: Circuit = h.finish_hugr_with_outputs([], ®ISTRY)?.into(); let cmd1 = circ.commands().next().unwrap(); let cmd2 = circ.commands().next().unwrap(); diff --git a/tket2/src/circuit/hash.rs b/tket2/src/circuit/hash.rs index 478f97d9..84806874 100644 --- a/tket2/src/circuit/hash.rs +++ b/tket2/src/circuit/hash.rs @@ -12,7 +12,7 @@ use thiserror::Error; use super::Circuit; /// Circuit hashing utilities. -pub trait CircuitHash<'circ>: HugrView { +pub trait CircuitHash { /// Compute hash of a circuit. /// /// We compute a hash for each command from its operation and the hash of @@ -23,14 +23,26 @@ pub trait CircuitHash<'circ>: HugrView { /// /// Adapted from Quartz (Apache 2.0) /// - fn circuit_hash(&'circ self) -> Result; + fn circuit_hash(&self) -> Result; } -impl<'circ, T> CircuitHash<'circ> for T +impl CircuitHash for Circuit { + fn circuit_hash(&self) -> Result { + let hugr = self.hugr(); + let container: SiblingGraph = SiblingGraph::try_new(hugr, self.parent()).unwrap(); + container.circuit_hash() + } +} + +impl CircuitHash for T where T: HugrView, { - fn circuit_hash(&'circ self) -> Result { + fn circuit_hash(&self) -> Result { + let Some([_, output_node]) = self.get_io(self.root()) else { + return Err(HashError::NotADfg); + }; + let mut node_hashes = HashState::default(); for node in pg::Topo::new(&self.as_petgraph()) @@ -45,7 +57,7 @@ where // If the output node has no hash, the topological sort failed due to a cycle. node_hashes - .node_hash(self.output()) + .node_hash(output_node) .ok_or(HashError::CyclicCircuit) } } @@ -132,11 +144,13 @@ pub enum HashError { /// The circuit contains a cycle. #[error("The circuit contains a cycle.")] CyclicCircuit, + /// The hashed hugr is not a DFG. + #[error("Tried to hash a non-dfg hugr.")] + NotADfg, } #[cfg(test)] mod test { - use hugr::Hugr; use tket_json_rs::circuit_json; use crate::json::TKETDecode; @@ -185,7 +199,7 @@ mod test { fn hash_constants() { let c_str = r#"{"bits": [], "commands": [{"args": [["q", [0]]], "op": {"params": ["0.5"], "type": "Rz"}}], "created_qubits": [], "discarded_qubits": [], "implicit_permutation": [[["q", [0]], ["q", [0]]]], "phase": "0.0", "qubits": [["q", [0]]]}"#; let ser: circuit_json::SerialCircuit = serde_json::from_str(c_str).unwrap(); - let circ: Hugr = ser.decode().unwrap(); + let circ: Circuit = ser.decode().unwrap(); circ.circuit_hash().unwrap(); } @@ -197,7 +211,7 @@ mod test { let mut all_hashes = Vec::with_capacity(2); for c_str in [c_str1, c_str2] { let ser: circuit_json::SerialCircuit = serde_json::from_str(c_str).unwrap(); - let circ: Hugr = ser.decode().unwrap(); + let circ: Circuit = ser.decode().unwrap(); all_hashes.push(circ.circuit_hash().unwrap()); } assert_ne!(all_hashes[0], all_hashes[1]); diff --git a/tket2/src/circuit/units.rs b/tket2/src/circuit/units.rs index 17383aa7..173186f5 100644 --- a/tket2/src/circuit/units.rs +++ b/tket2/src/circuit/units.rs @@ -18,7 +18,7 @@ use std::iter::FusedIterator; use std::marker::PhantomData; use hugr::types::{EdgeKind, Type, TypeRow}; -use hugr::{CircuitUnit, IncomingPort, OutgoingPort}; +use hugr::{CircuitUnit, HugrView, IncomingPort, OutgoingPort}; use hugr::{Direction, Node, Port, Wire}; use crate::utils::type_is_linear; @@ -83,8 +83,8 @@ impl Units { /// This iterator will yield all units originating from the circuit's input /// node. #[inline] - pub(super) fn new_circ_input(circuit: &impl Circuit) -> Self { - Self::new_outgoing(circuit, circuit.input(), DefaultUnitLabeller) + pub(super) fn new_circ_input(circuit: &Circuit) -> Self { + Self::new_outgoing(circuit, circuit.input_node(), DefaultUnitLabeller) } } @@ -94,7 +94,11 @@ where { /// Create a new iterator over the units originating from node. #[inline] - pub(super) fn new_outgoing(circuit: &impl Circuit, node: Node, unit_labeller: UL) -> Self { + pub(super) fn new_outgoing( + circuit: &Circuit, + node: Node, + unit_labeller: UL, + ) -> Self { Self::new_with_dir(circuit, node, Direction::Outgoing, unit_labeller) } } @@ -105,7 +109,11 @@ where { /// Create a new iterator over the units terminating on the node. #[inline] - pub(super) fn new_incoming(circuit: &impl Circuit, node: Node, unit_labeller: UL) -> Self { + pub(super) fn new_incoming( + circuit: &Circuit, + node: Node, + unit_labeller: UL, + ) -> Self { Self::new_with_dir(circuit, node, Direction::Incoming, unit_labeller) } } @@ -117,8 +125,8 @@ where { /// Create a new iterator over the units of a node. #[inline] - fn new_with_dir( - circuit: &impl Circuit, + fn new_with_dir( + circuit: &Circuit, node: Node, direction: Direction, unit_labeller: UL, @@ -142,9 +150,10 @@ where // We should revisit it once this is reworked on the HUGR side. // // TODO: EdgeKind::Function is not currently supported. - fn init_types(circuit: &impl Circuit, node: Node, direction: Direction) -> TypeRow { - let optype = circuit.get_optype(node); - let sig = circuit.signature(node).unwrap_or_default(); + fn init_types(circuit: &Circuit, node: Node, direction: Direction) -> TypeRow { + let hugr = circuit.hugr(); + let optype = hugr.get_optype(node); + let sig = hugr.signature(node).unwrap_or_default(); let mut types = match direction { Direction::Outgoing => sig.output, Direction::Incoming => sig.input, diff --git a/tket2/src/json.rs b/tket2/src/json.rs index aab0b33b..34c71089 100644 --- a/tket2/src/json.rs +++ b/tket2/src/json.rs @@ -14,7 +14,6 @@ use std::{fs, io}; use hugr::ops::{OpType, Value}; use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; -use hugr::Hugr; use stringreader::StringReader; use thiserror::Error; @@ -37,23 +36,25 @@ const METADATA_Q_REGISTERS: &str = "TKET1_JSON.qubit_registers"; /// Explicit names for the input bit registers. const METADATA_B_REGISTERS: &str = "TKET1_JSON.bit_registers"; -/// A JSON-serialized circuit that can be converted to a [`Hugr`]. +/// A serialized representation of a [`Circuit`]. +/// +/// Implemented by [`SerialCircuit`], the JSON format used by tket1's `pytket` library. pub trait TKETDecode: Sized { /// The error type for decoding. type DecodeError; /// The error type for decoding. type EncodeError; - /// Convert the serialized circuit to a [`Hugr`]. - fn decode(self) -> Result; - /// Convert a [`Hugr`] to a new serialized circuit. - fn encode(circuit: &impl Circuit) -> Result; + /// Convert the serialized circuit to a circuit. + fn decode(self) -> Result; + /// Convert a circuit to a new serialized circuit. + fn encode(circuit: &Circuit) -> Result; } impl TKETDecode for SerialCircuit { type DecodeError = OpConvertError; type EncodeError = OpConvertError; - fn decode(self) -> Result { + fn decode(self) -> Result { let mut decoder = JsonDecoder::new(&self); if !self.phase.is_empty() { @@ -65,10 +66,10 @@ impl TKETDecode for SerialCircuit { for com in self.commands { decoder.add_command(com); } - Ok(decoder.finish()) + Ok(decoder.finish().into()) } - fn encode(circ: &impl Circuit) -> Result { + fn encode(circ: &Circuit) -> Result { let mut encoder = JsonEncoder::new(circ); let f64_inputs = circ.units().filter_map(|(wire, _, t)| match (wire, t) { (CircuitUnit::Wire(wire), t) if t == FLOAT64_TYPE => Some(wire), @@ -102,43 +103,41 @@ pub enum OpConvertError { } /// Load a TKET1 circuit from a JSON file. -pub fn load_tk1_json_file(path: impl AsRef) -> Result { +pub fn load_tk1_json_file(path: impl AsRef) -> Result { let file = fs::File::open(path)?; let reader = io::BufReader::new(file); load_tk1_json_reader(reader) } /// Load a TKET1 circuit from a JSON reader. -pub fn load_tk1_json_reader(json: impl io::Read) -> Result { +pub fn load_tk1_json_reader(json: impl io::Read) -> Result { let ser: SerialCircuit = serde_json::from_reader(json)?; - Ok(ser.decode()?) + let circ: Circuit = ser.decode()?; + Ok(circ) } /// Load a TKET1 circuit from a JSON string. -pub fn load_tk1_json_str(json: &str) -> Result { +pub fn load_tk1_json_str(json: &str) -> Result { let reader = StringReader::new(json); load_tk1_json_reader(reader) } /// Save a circuit to file in TK1 JSON format. -pub fn save_tk1_json_file( - circ: &impl Circuit, - path: impl AsRef, -) -> Result<(), TK1ConvertError> { +pub fn save_tk1_json_file(circ: &Circuit, path: impl AsRef) -> Result<(), TK1ConvertError> { let file = fs::File::create(path)?; let writer = io::BufWriter::new(file); save_tk1_json_writer(circ, writer) } /// Save a circuit in TK1 JSON format to a writer. -pub fn save_tk1_json_writer(circ: &impl Circuit, w: impl io::Write) -> Result<(), TK1ConvertError> { +pub fn save_tk1_json_writer(circ: &Circuit, w: impl io::Write) -> Result<(), TK1ConvertError> { let serial_circ = SerialCircuit::encode(circ)?; serde_json::to_writer(w, &serial_circ)?; Ok(()) } /// Save a circuit in TK1 JSON format to a String. -pub fn save_tk1_json_str(circ: &impl Circuit) -> Result { +pub fn save_tk1_json_str(circ: &Circuit) -> Result { let mut buf = io::BufWriter::new(Vec::new()); save_tk1_json_writer(circ, &mut buf)?; let bytes = buf.into_inner().unwrap(); diff --git a/tket2/src/json/encoder.rs b/tket2/src/json/encoder.rs index 589d871c..0e65efbd 100644 --- a/tket2/src/json/encoder.rs +++ b/tket2/src/json/encoder.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use hugr::extension::prelude::QB_T; use hugr::ops::{NamedOp, OpType}; use hugr::std_extensions::arithmetic::float_types::ConstF64; -use hugr::Wire; +use hugr::{HugrView, Wire}; use itertools::{Either, Itertools}; use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit}; @@ -48,8 +48,9 @@ pub(super) struct JsonEncoder { impl JsonEncoder { /// Create a new [`JsonEncoder`] from a [`Circuit`]. - pub fn new(circ: &impl Circuit) -> Self { + pub fn new(circ: &Circuit) -> Self { let name = circ.name().map(str::to_string); + let hugr = circ.hugr(); let mut qubit_registers = vec![]; let mut bit_registers = vec![]; @@ -58,17 +59,17 @@ impl JsonEncoder { // Recover other parameters stored in the metadata // TODO: Check for invalid encoded metadata - let root = circ.root(); - if let Some(p) = circ.get_metadata(root, METADATA_PHASE) { + let root = circ.parent(); + if let Some(p) = hugr.get_metadata(root, METADATA_PHASE) { phase = p.as_str().unwrap().to_string(); } - if let Some(perm) = circ.get_metadata(root, METADATA_IMPLICIT_PERM) { + if let Some(perm) = hugr.get_metadata(root, METADATA_IMPLICIT_PERM) { implicit_permutation = serde_json::from_value(perm.clone()).unwrap(); } - if let Some(q_regs) = circ.get_metadata(root, METADATA_Q_REGISTERS) { + if let Some(q_regs) = hugr.get_metadata(root, METADATA_Q_REGISTERS) { qubit_registers = serde_json::from_value(q_regs.clone()).unwrap(); } - if let Some(b_regs) = circ.get_metadata(root, METADATA_B_REGISTERS) { + if let Some(b_regs) = hugr.get_metadata(root, METADATA_B_REGISTERS) { bit_registers = serde_json::from_value(b_regs.clone()).unwrap(); } @@ -109,9 +110,9 @@ impl JsonEncoder { } /// Add a circuit command to the serialization. - pub fn add_command( + pub fn add_command( &mut self, - command: Command<'_, C>, + command: Command<'_, T>, optype: &OpType, ) -> Result<(), OpConvertError> { // Register any output of the command that can be used as a TKET1 parameter. @@ -169,7 +170,11 @@ impl JsonEncoder { /// Record any output of the command that can be used as a TKET1 parameter. /// Returns whether parameters were recorded. /// Associates the output wires with the parameter expression. - fn record_parameters(&mut self, command: &Command<'_, C>, optype: &OpType) -> bool { + fn record_parameters( + &mut self, + command: &Command<'_, T>, + optype: &OpType, + ) -> bool { // Only consider commands where all inputs are parameters. let inputs = command .inputs() diff --git a/tket2/src/json/tests.rs b/tket2/src/json/tests.rs index f08ef58e..fde3e641 100644 --- a/tket2/src/json/tests.rs +++ b/tket2/src/json/tests.rs @@ -7,7 +7,6 @@ use hugr::extension::prelude::QB_T; use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use hugr::types::FunctionType; -use hugr::Hugr; use rstest::{fixture, rstest}; use tket_json_rs::circuit_json::{self, SerialCircuit}; use tket_json_rs::optype; @@ -64,7 +63,7 @@ fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num let ser: circuit_json::SerialCircuit = serde_json::from_str(circ_s).unwrap(); assert_eq!(ser.commands.len(), num_commands); - let circ: Hugr = ser.clone().decode().unwrap(); + let circ: Circuit = ser.clone().decode().unwrap(); assert_eq!(circ.qubit_count(), num_qubits); @@ -78,13 +77,13 @@ fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num fn json_file_roundtrip(#[case] circ: impl AsRef) { let reader = BufReader::new(std::fs::File::open(circ).unwrap()); let ser: circuit_json::SerialCircuit = serde_json::from_reader(reader).unwrap(); - let circ: Hugr = ser.clone().decode().unwrap(); + let circ: Circuit = ser.clone().decode().unwrap(); let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap(); compare_serial_circs(&ser, &reser); } #[fixture] -fn circ_add_angles_symbolic() -> Hugr { +fn circ_add_angles_symbolic() -> Circuit { let input_t = vec![QB_T, FLOAT64_TYPE, FLOAT64_TYPE]; let output_t = vec![QB_T]; let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); @@ -99,11 +98,11 @@ fn circ_add_angles_symbolic() -> Hugr { let res = h.add_dataflow_op(Tk2Op::RxF64, [qb, f12]).unwrap(); let qb = res.outputs().next().unwrap(); - h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap().into() } #[fixture] -fn circ_add_angles_constants() -> Hugr { +fn circ_add_angles_constants() -> Circuit { let qb_row = vec![QB_T]; let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row)).unwrap(); @@ -120,13 +119,13 @@ fn circ_add_angles_constants() -> Hugr { .add_dataflow_op(Tk2Op::RxF64, [qb, point5]) .unwrap() .outputs(); - h.finish_hugr_with_outputs(qbs, ®ISTRY).unwrap() + h.finish_hugr_with_outputs(qbs, ®ISTRY).unwrap().into() } #[rstest] #[case::symbolic(circ_add_angles_symbolic(), "f0 + f1")] #[case::constants(circ_add_angles_constants(), "0.2 + 0.3")] -fn test_add_angle_serialise(#[case] circ_add_angles: Hugr, #[case] param_str: &str) { +fn test_add_angle_serialise(#[case] circ_add_angles: Circuit, #[case] param_str: &str) { let ser: SerialCircuit = SerialCircuit::encode(&circ_add_angles).unwrap(); assert_eq!(ser.commands.len(), 1); assert_eq!(ser.commands[0].op.op_type, optype::OpType::Rx); @@ -135,7 +134,7 @@ fn test_add_angle_serialise(#[case] circ_add_angles: Hugr, #[case] param_str: &s // Note: this is not a proper roundtrip as the symbols f0 and f1 are not // converted back to circuit inputs. This would require parsing symbolic // expressions. - let deser: Hugr = ser.clone().decode().unwrap(); + let deser: Circuit = ser.clone().decode().unwrap(); let reser = SerialCircuit::encode(&deser).unwrap(); compare_serial_circs(&ser, &reser); } diff --git a/tket2/src/lib.rs b/tket2/src/lib.rs index 750e43d8..756b52d8 100644 --- a/tket2/src/lib.rs +++ b/tket2/src/lib.rs @@ -20,11 +20,11 @@ //! #![cfg_attr(not(miri), doc = "```")] // this doctest reads from the filesystem, so it fails with miri #![cfg_attr(miri, doc = "```ignore")] -//! use tket2::{Circuit, Hugr}; +//! use tket2::Circuit; //! use hugr::HugrView; //! //! // Load a tket1 circuit. -//! let mut circ: Hugr = tket2::json::load_tk1_json_file("../test_files/barenco_tof_5.json").unwrap(); +//! let mut circ: Circuit = tket2::json::load_tk1_json_file("../test_files/barenco_tof_5.json").unwrap(); //! //! assert_eq!(circ.qubit_count(), 9); //! assert_eq!(circ.num_gates(), 170); diff --git a/tket2/src/ops.rs b/tket2/src/ops.rs index a6e143be..f941c517 100644 --- a/tket2/src/ops.rs +++ b/tket2/src/ops.rs @@ -257,15 +257,16 @@ pub(crate) mod test { use std::sync::Arc; use hugr::extension::simple_op::MakeOpDef; + use hugr::extension::OpDef; use hugr::ops::NamedOp; use hugr::CircuitUnit; - use hugr::{extension::OpDef, Hugr}; use rstest::{fixture, rstest}; use strum::IntoEnumIterator; use super::Tk2Op; + use crate::circuit::Circuit; use crate::extension::{TKET2_EXTENSION as EXTENSION, TKET2_EXTENSION_ID as EXTENSION_ID}; - use crate::{circuit::Circuit, utils::build_simple_circuit}; + use crate::utils::build_simple_circuit; fn get_opdef(op: impl NamedOp) -> Option<&'static Arc> { EXTENSION.get_op(&op.name()) } @@ -279,7 +280,7 @@ pub(crate) mod test { } #[fixture] - pub(crate) fn t2_bell_circuit() -> Hugr { + pub(crate) fn t2_bell_circuit() -> Circuit { let h = build_simple_circuit(2, |circ| { circ.append(Tk2Op::H, [0])?; circ.append(Tk2Op::CX, [0, 1])?; @@ -290,7 +291,7 @@ pub(crate) mod test { } #[rstest] - fn check_t2_bell(t2_bell_circuit: Hugr) { + fn check_t2_bell(t2_bell_circuit: Circuit) { assert_eq!(t2_bell_circuit.commands().count(), 2); } diff --git a/tket2/src/optimiser/badger.rs b/tket2/src/optimiser/badger.rs index 5487095a..9aa6bafa 100644 --- a/tket2/src/optimiser/badger.rs +++ b/tket2/src/optimiser/badger.rs @@ -22,14 +22,13 @@ use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; use fxhash::FxHashSet; use hugr::hugr::HugrError; +use hugr::HugrView; pub use log::BadgerLogger; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; use std::{mem, thread}; -use hugr::Hugr; - use crate::circuit::cost::CircuitCost; use crate::circuit::CircuitHash; use crate::optimiser::badger::hugr_pchannel::{HugrPriorityChannel, PriorityChannelLog}; @@ -37,7 +36,6 @@ use crate::optimiser::badger::hugr_pqueue::{Entry, HugrPQ}; use crate::optimiser::badger::worker::BadgerWorker; use crate::passes::CircuitChunks; use crate::rewrite::strategy::RewriteStrategy; -use crate::rewrite::trace::RewriteTracer; use crate::rewrite::Rewriter; use crate::Circuit; @@ -113,7 +111,7 @@ impl BadgerOptimiser { Self { rewriter, strategy } } - fn cost(&self, circ: &Hugr) -> S::Cost + fn cost(&self, circ: &Circuit) -> S::Cost where S: RewriteStrategy, { @@ -130,7 +128,7 @@ where /// Run the Badger optimiser on a circuit. /// /// A timeout (in seconds) can be provided. - pub fn optimise(&self, circ: &Hugr, options: BadgerOptions) -> Hugr { + pub fn optimise(&self, circ: &Circuit, options: BadgerOptions) -> Circuit { self.optimise_with_log(circ, Default::default(), options) } @@ -139,10 +137,10 @@ where /// A timeout (in seconds) can be provided. pub fn optimise_with_log( &self, - circ: &Hugr, + circ: &Circuit, log_config: BadgerLogger, options: BadgerOptions, - ) -> Hugr { + ) -> Circuit { if options.split_circuit && options.n_threads.get() > 1 { return self.split_run(circ, log_config, options).unwrap(); } @@ -153,12 +151,18 @@ where } #[tracing::instrument(target = "badger::metrics", skip(self, circ, logger))] - fn badger(&self, circ: &Hugr, mut logger: BadgerLogger, opt: BadgerOptions) -> Hugr { + fn badger( + &self, + circ: &Circuit, + mut logger: BadgerLogger, + opt: BadgerOptions, + ) -> Circuit { let start_time = Instant::now(); let mut last_best_time = Instant::now(); + let circ = circ.to_owned(); let mut best_circ = circ.clone(); - let mut best_circ_cost = self.cost(circ); + let mut best_circ_cost = self.cost(&circ); let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); logger.log_best(&best_circ_cost, num_rewrites); @@ -170,12 +174,12 @@ where // The priority queue of circuits to be processed (this should not get big) let cost_fn = { let strategy = self.strategy.clone(); - move |circ: &'_ Hugr| strategy.circuit_cost(circ) + move |circ: &'_ Circuit| strategy.circuit_cost(circ) }; - let cost = (cost_fn)(circ); + let cost = (cost_fn)(&circ); let mut pq = HugrPQ::new(cost_fn, opt.queue_size); - pq.push_unchecked(circ.clone(), hash, cost); + pq.push_unchecked(circ.to_owned(), hash, cost); let mut circ_cnt = 0; let mut timeout_flag = false; @@ -250,16 +254,17 @@ where #[tracing::instrument(target = "badger::metrics", skip(self, circ, logger))] fn badger_multithreaded( &self, - circ: &Hugr, + circ: &Circuit, mut logger: BadgerLogger, opt: BadgerOptions, - ) -> Hugr { + ) -> Circuit { let n_threads: usize = opt.n_threads.get(); + let circ = circ.to_owned(); // multi-consumer priority channel for queuing circuits to be processed by the workers let cost_fn = { let strategy = self.strategy.clone(); - move |circ: &'_ Hugr| strategy.circuit_cost(circ) + move |circ: &'_ Circuit| strategy.circuit_cost(circ) }; let (pq, rx_log) = HugrPriorityChannel::init(cost_fn.clone(), opt.queue_size); @@ -271,7 +276,7 @@ where pq.send(vec![Work { cost: best_circ_cost.clone(), hash: initial_circ_hash, - circ: circ.clone(), + circ, }]) .unwrap(); @@ -381,18 +386,19 @@ where #[tracing::instrument(target = "badger::metrics", skip(self, circ, logger))] fn split_run( &self, - circ: &Hugr, + circ: &Circuit, mut logger: BadgerLogger, opt: BadgerOptions, - ) -> Result { - let circ_cost = self.cost(circ); + ) -> Result { + let circ = circ.to_owned(); + let circ_cost = self.cost(&circ); let max_chunk_cost = circ_cost.clone().div_cost(opt.n_threads); logger.log(format!( "Splitting circuit with cost {:?} into chunks of at most {max_chunk_cost:?}.", circ_cost.clone() )); let mut chunks = - CircuitChunks::split_with_cost(circ, max_chunk_cost, |op| self.strategy.op_cost(op)); + CircuitChunks::split_with_cost(&circ, max_chunk_cost, |op| self.strategy.op_cost(op)); let num_rewrites = circ.rewrite_trace().map(|rs| rs.len()); logger.log_best(circ_cost.clone(), num_rewrites); @@ -495,7 +501,6 @@ mod tests { extension::prelude::QB_T, std_extensions::arithmetic::float_types::FLOAT64_TYPE, types::FunctionType, - Hugr, }; use rstest::{fixture, rstest}; @@ -506,14 +511,14 @@ mod tests { use super::{BadgerOptimiser, DefaultBadgerOptimiser}; /// Simplified description of the circuit's commands. - fn gates(circ: &Hugr) -> Vec { + fn gates(circ: &Circuit) -> Vec { circ.commands() .map(|cmd| cmd.optype().try_into().unwrap()) .collect() } #[fixture] - fn rz_rz() -> Hugr { + fn rz_rz() -> Circuit { let input_t = vec![QB_T, FLOAT64_TYPE, FLOAT64_TYPE]; let output_t = vec![QB_T]; let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); @@ -528,7 +533,7 @@ mod tests { let res = h.add_dataflow_op(Tk2Op::RzF64, [qb, f2]).unwrap(); let qb = res.outputs().next().unwrap(); - h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap().into() } /// This hugr corresponds to the qasm circuit: @@ -549,9 +554,9 @@ mod tests { /// ``` const NON_COMPOSABLE: &str = r#"{"phase":"0.0","commands":[{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[4]],["q",[1]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[1]],["q",[2]]]},{"op":{"type":"U3","params":["0.5","0","0.5"],"signature":["Q"]},"args":[["q",[1]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[3]],["q",[4]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[4]],["q",[0]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[0]],["q",[2]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[0]],["q",[2]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[3]],["q",[1]]]}],"qubits":[["q",[0]],["q",[1]],["q",[2]],["q",[3]],["q",[4]]],"bits":[],"implicit_permutation":[[["q",[0]],["q",[0]]],[["q",[1]],["q",[1]]],[["q",[2]],["q",[2]]],[["q",[3]],["q",[3]]],[["q",[4]],["q",[4]]]]}"#; - /// A Hugr that would trigger non-composable rewrites, if we applied them blindly from nam_6_3 matches. + /// A circuit that would trigger non-composable rewrites, if we applied them blindly from nam_6_3 matches. #[fixture] - fn non_composable_rw_hugr() -> Hugr { + fn non_composable_rw_hugr() -> Circuit { load_tk1_json_str(NON_COMPOSABLE).unwrap() } @@ -571,7 +576,7 @@ mod tests { } #[rstest] - fn rz_rz_cancellation(rz_rz: Hugr, badger_opt: DefaultBadgerOptimiser) { + fn rz_rz_cancellation(rz_rz: Circuit, badger_opt: DefaultBadgerOptimiser) { let opt_rz = badger_opt.optimise( &rz_rz, BadgerOptions { @@ -584,7 +589,7 @@ mod tests { } #[rstest] - fn rz_rz_cancellation_parallel(rz_rz: Hugr, badger_opt: DefaultBadgerOptimiser) { + fn rz_rz_cancellation_parallel(rz_rz: Circuit, badger_opt: DefaultBadgerOptimiser) { let mut opt_rz = badger_opt.optimise( &rz_rz, BadgerOptions { @@ -594,13 +599,13 @@ mod tests { ..Default::default() }, ); - opt_rz.update_validate(®ISTRY).unwrap(); + opt_rz.hugr_mut().update_validate(®ISTRY).unwrap(); } #[rstest] #[ignore = "Loading the ECC set is really slow (~5 seconds)"] fn non_composable_rewrites( - non_composable_rw_hugr: Hugr, + non_composable_rw_hugr: Circuit, badger_opt_full: DefaultBadgerOptimiser, ) { let mut opt = badger_opt_full.optimise( @@ -612,7 +617,7 @@ mod tests { }, ); // No rewrites applied. - opt.update_validate(®ISTRY).unwrap(); + opt.hugr_mut().update_validate(®ISTRY).unwrap(); } #[test] diff --git a/tket2/src/optimiser/badger/eq_circ_class.rs b/tket2/src/optimiser/badger/eq_circ_class.rs index f45c2846..c2c7f118 100644 --- a/tket2/src/optimiser/badger/eq_circ_class.rs +++ b/tket2/src/optimiser/badger/eq_circ_class.rs @@ -26,10 +26,10 @@ pub struct EqCircClass { impl EqCircClass { /// Create a new equivalence class with a representative circuit. - pub fn new(rep_circ: Hugr, other_circs: Vec) -> Self { + pub fn new(rep_circ: Circuit, other_circs: impl IntoIterator) -> Self { Self { - rep_circ, - other_circs, + rep_circ: rep_circ.into_hugr(), + other_circs: other_circs.into_iter().map(|c| c.into_hugr()).collect(), } } @@ -64,8 +64,11 @@ impl EqCircClass { /// Create an equivalence class from a set of circuits. /// /// The smallest circuit is chosen as the representative. - pub fn from_circuits(circs: impl Into>) -> Result { - let mut circs: Vec<_> = circs.into(); + pub fn from_circuits( + circs: impl IntoIterator, + ) -> Result { + let mut circs: Vec = circs.into_iter().collect(); + if circs.is_empty() { return Err(EqCircClassError::NoRepresentative); }; diff --git a/tket2/src/optimiser/badger/hugr_pchannel.rs b/tket2/src/optimiser/badger/hugr_pchannel.rs index e50d8980..b69d7cdc 100644 --- a/tket2/src/optimiser/badger/hugr_pchannel.rs +++ b/tket2/src/optimiser/badger/hugr_pchannel.rs @@ -6,19 +6,19 @@ use std::time::Instant; use crossbeam_channel::{select, Receiver, RecvError, SendError, Sender}; use fxhash::FxHashSet; -use hugr::Hugr; use crate::circuit::cost::CircuitCost; +use crate::Circuit; use super::hugr_pqueue::{Entry, HugrPQ}; /// A unit of work for a worker, consisting of a circuit to process, along its /// hash and cost. -pub type Work

= Entry; +pub type Work

= Entry; -/// A priority channel for HUGRs. +/// A priority channel for circuits. /// -/// Queues hugrs using a cost function `C` that produces priority values `P`. +/// Queues circuits using a cost function `C` that produces priority values `P`. /// /// Uses a thread internally to orchestrate the queueing. #[derive(Debug, Clone)] @@ -50,7 +50,7 @@ pub struct HugrPriorityChannel { /// Logging information from the priority channel. #[derive(Debug, Clone)] pub enum PriorityChannelLog

{ - NewBestCircuit(Hugr, P), + NewBestCircuit(Circuit, P), CircuitCount { processed_count: usize, seen_count: usize, @@ -106,12 +106,12 @@ impl PriorityChannelCommunication

{ impl HugrPriorityChannel where - C: Fn(&Hugr) -> P + Send + Sync + 'static, + C: Fn(&Circuit) -> P + Send + Sync + 'static, P: CircuitCost + Send + Sync + 'static, { /// Initialize the queueing system. /// - /// Start the Hugr priority queue in a new thread. + /// Start the circuit priority queue in a new thread. /// /// Get back a [`PriorityChannelCommunication`] for adding and removing circuits to/from the queue, /// and a channel receiver to receive logging information. diff --git a/tket2/src/optimiser/badger/hugr_pqueue.rs b/tket2/src/optimiser/badger/hugr_pqueue.rs index 16569ef6..b0429237 100644 --- a/tket2/src/optimiser/badger/hugr_pqueue.rs +++ b/tket2/src/optimiser/badger/hugr_pqueue.rs @@ -1,9 +1,9 @@ use delegate::delegate; use fxhash::FxHashMap; -use hugr::Hugr; use priority_queue::DoublePriorityQueue; use crate::circuit::CircuitHash; +use crate::Circuit; /// A min-priority queue for Hugrs. /// @@ -12,7 +12,7 @@ use crate::circuit::CircuitHash; #[derive(Debug, Clone, Default)] pub struct HugrPQ { queue: DoublePriorityQueue, - hash_lookup: FxHashMap, + hash_lookup: FxHashMap, cost_fn: C, max_size: usize, } @@ -34,9 +34,9 @@ impl HugrPQ { } } - /// Reference to the minimal Hugr in the queue. + /// Reference to the minimal circuit in the queue. #[allow(unused)] - pub fn peek(&self) -> Option> { + pub fn peek(&self) -> Option> { let (hash, cost) = self.queue.peek_min()?; let circ = self.hash_lookup.get(hash)?; Some(Entry { @@ -46,20 +46,20 @@ impl HugrPQ { }) } - /// Push a Hugr into the queue. + /// Push a circuit into the queue. /// /// If the queue is full, the element with the highest cost will be dropped. #[allow(unused)] - pub fn push(&mut self, hugr: Hugr) + pub fn push(&mut self, circ: Circuit) where - C: Fn(&Hugr) -> P, + C: Fn(&Circuit) -> P, { - let hash = hugr.circuit_hash().unwrap(); - let cost = (self.cost_fn)(&hugr); - self.push_unchecked(hugr, hash, cost); + let hash = circ.circuit_hash().unwrap(); + let cost = (self.cost_fn)(&circ); + self.push_unchecked(circ, hash, cost); } - /// Push a Hugr into the queue with a precomputed hash and cost. + /// Push a circuit into the queue with a precomputed hash and cost. /// /// This is useful to avoid recomputing the hash and cost function in /// [`HugrPQ::push`] when they are already known. @@ -67,9 +67,9 @@ impl HugrPQ { /// This does not check that the hash is valid. /// /// If the queue is full, the most last will be dropped. - pub fn push_unchecked(&mut self, hugr: Hugr, hash: u64, cost: P) + pub fn push_unchecked(&mut self, circ: Circuit, hash: u64, cost: P) where - C: Fn(&Hugr) -> P, + C: Fn(&Circuit) -> P, { if !self.check_accepted(&cost) { return; @@ -78,18 +78,18 @@ impl HugrPQ { self.pop_max(); } self.queue.push(hash, cost); - self.hash_lookup.insert(hash, hugr); + self.hash_lookup.insert(hash, circ); } - /// Pop the minimal Hugr from the queue. - pub fn pop(&mut self) -> Option> { + /// Pop the minimal circuit from the queue. + pub fn pop(&mut self) -> Option> { let (hash, cost) = self.queue.pop_min()?; let circ = self.hash_lookup.remove(&hash)?; Some(Entry { circ, cost, hash }) } - /// Pop the maximal Hugr from the queue. - pub fn pop_max(&mut self) -> Option> { + /// Pop the maximal circuit from the queue. + pub fn pop_max(&mut self) -> Option> { let (hash, cost) = self.queue.pop_max()?; let circ = self.hash_lookup.remove(&hash)?; Some(Entry { circ, cost, hash }) diff --git a/tket2/src/optimiser/badger/log.rs b/tket2/src/optimiser/badger/log.rs index e0fb5eea..e7201d66 100644 --- a/tket2/src/optimiser/badger/log.rs +++ b/tket2/src/optimiser/badger/log.rs @@ -40,7 +40,7 @@ impl<'w> BadgerLogger<'w> { /// /// [`log`]: pub fn new(best_progress_csv_writer: impl io::Write + 'w) -> Self { - let boxed_candidates_writer: Box = Box::new(best_progress_csv_writer); + let boxed_candidates_writer: Box = Box::new(best_progress_csv_writer); Self { circ_candidates_csv: Some(csv::Writer::from_writer(boxed_candidates_writer)), ..Default::default() diff --git a/tket2/src/optimiser/badger/qtz_circuit.rs b/tket2/src/optimiser/badger/qtz_circuit.rs index 81a13922..f9701c76 100644 --- a/tket2/src/optimiser/badger/qtz_circuit.rs +++ b/tket2/src/optimiser/badger/qtz_circuit.rs @@ -7,12 +7,11 @@ use hugr::extension::prelude::QB_T; use hugr::ops::OpType as Op; use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::types::{FunctionType, Type}; -use hugr::CircuitUnit; -use hugr::Hugr as Circuit; +use hugr::{CircuitUnit, Hugr}; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use crate::Tk2Op; +use crate::{Circuit, Tk2Op}; #[derive(Debug, Serialize, Deserialize)] struct RepCircOp { @@ -60,7 +59,7 @@ fn map_op(opstr: &str) -> Op { } // TODO change to TryFrom -impl From for Circuit { +impl From for Circuit { fn from(RepCircData { circ: rc, meta }: RepCircData) -> Self { let qb_types: Vec = vec![QB_T; meta.n_qb]; let param_types: Vec = vec![FLOAT64_TYPE; meta.n_input_param]; @@ -107,10 +106,13 @@ impl From for Circuit { builder .finish_hugr_with_outputs(circ_outputs, &crate::extension::REGISTRY) .unwrap() + .into() } } -pub(super) fn load_ecc_set(path: impl AsRef) -> io::Result>> { +pub(super) fn load_ecc_set( + path: impl AsRef, +) -> io::Result>>> { let jsons = std::fs::read_to_string(path)?; let (_, ecc_map): (Vec<()>, HashMap>) = serde_json::from_str(&jsons).unwrap(); @@ -127,12 +129,9 @@ pub(super) fn load_ecc_set(path: impl AsRef) -> io::Result HashMap { + + fn load_representative_set(path: &str) -> HashMap> { let jsons = std::fs::read_to_string(path).unwrap(); // read_rep_json(&jsons).unwrap(); let st: Vec = serde_json::from_str(&jsons).unwrap(); @@ -144,8 +143,7 @@ mod tests { #[test] #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri fn test_read_rep() { - let rep_map: HashMap = - load_representative_set("../test_files/h_rz_cxrepresentative_set.json"); + let rep_map = load_representative_set("../test_files/h_rz_cxrepresentative_set.json"); for c in rep_map.values().take(1) { println!("{}", c.dot_string()); @@ -155,8 +153,7 @@ mod tests { #[test] #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri fn test_read_complete() { - let _ecc: HashMap> = - load_ecc_set("../test_files/h_rz_cxcomplete_ECC_set.json").unwrap(); + let _ecc = load_ecc_set("../test_files/h_rz_cxcomplete_ECC_set.json").unwrap(); // ecc.values() // .flatten() diff --git a/tket2/src/passes/chunks.rs b/tket2/src/passes/chunks.rs index 41312c05..8f728853 100644 --- a/tket2/src/passes/chunks.rs +++ b/tket2/src/passes/chunks.rs @@ -15,7 +15,7 @@ use hugr::hugr::{HugrError, NodeMetadataMap}; use hugr::ops::handle::DataflowParentID; use hugr::ops::OpType; use hugr::types::FunctionType; -use hugr::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; +use hugr::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; use itertools::Itertools; use portgraph::algorithms::ConvexChecker; @@ -36,7 +36,7 @@ pub struct ChunkConnection(Wire); #[derive(Debug, Clone)] pub struct Chunk { /// The extracted circuit. - pub circ: Hugr, + pub circ: Circuit, /// The original wires connected to the input. inputs: Vec, /// The original wires connected to the output. @@ -47,18 +47,18 @@ impl Chunk { /// Extract a chunk from a circuit. /// /// The chunk is extracted from the input wires to the output wires. - pub(self) fn extract( - circ: &H, + pub(self) fn extract( + circ: &Circuit, nodes: impl IntoIterator, checker: &impl ConvexChecker, ) -> Self { let subgraph = SiblingSubgraph::try_from_nodes_with_checker( nodes.into_iter().collect_vec(), - circ, + circ.hugr(), checker, ) .expect("Failed to define the chunk subgraph"); - let extracted = subgraph.extract_subgraph(circ, "Chunk"); + let extracted = subgraph.extract_subgraph(circ.hugr(), "Chunk").into(); // Transform the subgraph's input/output sets into wires that can be // matched between different chunks. // @@ -70,6 +70,7 @@ impl Chunk { .map(|wires| { let (inp_node, inp_port) = wires[0]; let (out_node, out_port) = circ + .hugr() .linked_outputs(inp_node, inp_port) .exactly_one() .ok() @@ -91,18 +92,20 @@ impl Chunk { /// Insert the chunk back into a circuit. pub(self) fn insert(&self, circ: &mut impl HugrMut, root: Node) -> ChunkInsertResult { - if self.circ.children(self.circ.root()).nth(2).is_none() { + let chunk = self.circ.hugr(); + let chunk_root = chunk.root(); + if chunk.children(self.circ.parent()).nth(2).is_none() { // The chunk is empty. We don't need to insert anything. return self.empty_chunk_insert_result(); } - let [chunk_inp, chunk_out] = self.circ.get_io(self.circ.root()).unwrap(); + let [chunk_inp, chunk_out] = chunk.get_io(chunk_root).unwrap(); let chunk_sg: SiblingGraph<'_, DataflowParentID> = - SiblingGraph::try_new(&self.circ, self.circ.root()).unwrap(); + SiblingGraph::try_new(&chunk, chunk_root).unwrap(); // Insert the chunk circuit into the original circuit. let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&chunk_sg) .unwrap_or_else(|e| panic!("The chunk circuit is no longer a dataflow graph: {e}")); - let node_map = circ.insert_subgraph(root, &self.circ, &subgraph); + let node_map = circ.insert_subgraph(root, &chunk, &subgraph); let mut input_map = HashMap::with_capacity(self.inputs.len()); let mut output_map = HashMap::with_capacity(self.outputs.len()); @@ -111,11 +114,8 @@ impl Chunk { // // Connections to an inserted node are translated into a [`ConnectionTarget::InsertedNode`]. // Connections from the input directly into the output become a [`ConnectionTarget::TransitiveConnection`]. - for (&connection, chunk_inp_port) in - self.inputs.iter().zip(self.circ.node_outputs(chunk_inp)) - { - let connection_targets: Vec = self - .circ + for (&connection, chunk_inp_port) in self.inputs.iter().zip(chunk.node_outputs(chunk_inp)) { + let connection_targets: Vec = chunk .linked_inputs(chunk_inp, chunk_inp_port) .map(|(node, port)| { if node == chunk_out { @@ -131,9 +131,8 @@ impl Chunk { input_map.insert(connection, connection_targets); } - for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) { - let (node, port) = self - .circ + for (&wire, chunk_out_port) in self.outputs.iter().zip(chunk.node_inputs(chunk_out)) { + let (node, port) = chunk .linked_outputs(chunk_out, chunk_out_port) .exactly_one() .ok() @@ -159,15 +158,13 @@ impl Chunk { /// /// TODO: Support empty Subgraphs in Hugr? fn empty_chunk_insert_result(&self) -> ChunkInsertResult { - let [chunk_inp, chunk_out] = self.circ.get_io(self.circ.root()).unwrap(); + let hugr = self.circ.hugr(); + let [chunk_inp, chunk_out] = self.circ.io_nodes(); let mut input_map = HashMap::with_capacity(self.inputs.len()); let mut output_map = HashMap::with_capacity(self.outputs.len()); - for (&connection, chunk_inp_port) in - self.inputs.iter().zip(self.circ.node_outputs(chunk_inp)) - { - let connection_targets: Vec = self - .circ + for (&connection, chunk_inp_port) in self.inputs.iter().zip(hugr.node_outputs(chunk_inp)) { + let connection_targets: Vec = hugr .linked_ports(chunk_inp, chunk_inp_port) .map(|(node, port)| { assert_eq!(node, chunk_out); @@ -178,9 +175,8 @@ impl Chunk { input_map.insert(connection, connection_targets); } - for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) { - let (node, port) = self - .circ + for (&wire, chunk_out_port) in self.outputs.iter().zip(hugr.node_inputs(chunk_out)) { + let (node, port) = hugr .linked_ports(chunk_out, chunk_out_port) .exactly_one() .ok() @@ -251,38 +247,39 @@ impl CircuitChunks { /// Split a circuit into chunks. /// /// The circuit is split into chunks of at most `max_size` gates. - pub fn split(circ: &impl Circuit, max_size: usize) -> Self { + pub fn split(circ: &Circuit, max_size: usize) -> Self { Self::split_with_cost(circ, max_size.saturating_sub(1), |_| 1) } /// Split a circuit into chunks. /// /// The circuit is split into chunks of at most `max_cost`, using the provided cost function. - pub fn split_with_cost( - circ: &H, + pub fn split_with_cost( + circ: &Circuit, max_cost: C, op_cost: impl Fn(&OpType) -> C, ) -> Self { - let root_meta = circ.get_node_metadata(circ.root()).cloned(); + let hugr = circ.hugr(); + let root_meta = hugr.get_node_metadata(circ.parent()).cloned(); let signature = circ.circuit_signature().body().clone(); - let [circ_input, circ_output] = circ.get_io(circ.root()).unwrap(); - let input_connections = circ + let [circ_input, circ_output] = circ.io_nodes(); + let input_connections = hugr .node_outputs(circ_input) .map(|port| Wire::new(circ_input, port).into()) .collect(); - let output_connections = circ + let output_connections = hugr .node_inputs(circ_output) - .flat_map(|p| circ.linked_outputs(circ_output, p)) + .flat_map(|p| hugr.linked_outputs(circ_output, p)) .map(|(n, p)| Wire::new(n, p).into()) .collect(); let mut chunks = Vec::new(); - let convex_checker = TopoConvexChecker::new(circ); + let convex_checker = TopoConvexChecker::new(circ.hugr()); let mut running_cost = C::default(); let mut current_group = 0; for (_, commands) in &circ.commands().map(|cmd| cmd.node()).chunk_by(|&node| { - let new_cost = running_cost.clone() + op_cost(circ.get_optype(node)); + let new_cost = running_cost.clone() + op_cost(hugr.get_optype(node)); if new_cost.sub_cost(&max_cost).as_isize() > 0 { running_cost = C::default(); current_group += 1; @@ -303,7 +300,7 @@ impl CircuitChunks { } /// Reassemble the chunks into a circuit. - pub fn reassemble(self) -> Result { + pub fn reassemble(self) -> Result { let name = self .root_meta .as_ref() @@ -423,16 +420,16 @@ impl CircuitChunks { reassembled.overwrite_node_metadata(root, self.root_meta); - Ok(reassembled) + Ok(reassembled.into()) } /// Returns a list of references to the split circuits. - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.chunks.iter().map(|chunk| &chunk.circ) } /// Returns a list of references to the split circuits. - pub fn iter_mut(&mut self) -> impl Iterator { + pub fn iter_mut(&mut self) -> impl Iterator { self.chunks.iter_mut().map(|chunk| &mut chunk.circ) } @@ -448,7 +445,7 @@ impl CircuitChunks { } impl Index for CircuitChunks { - type Output = Hugr; + type Output = Circuit; fn index(&self, index: usize) -> &Self::Output { &self.chunks[index].circ @@ -490,7 +487,7 @@ mod test { let mut reassembled = chunks.reassemble().unwrap(); - reassembled.update_validate(®ISTRY).unwrap(); + reassembled.hugr_mut().update_validate(®ISTRY).unwrap(); assert_eq!(circ.circuit_hash(), reassembled.circuit_hash()); } @@ -520,14 +517,14 @@ mod test { let mut reassembled = chunks.reassemble().unwrap(); - reassembled.update_validate(®ISTRY).unwrap(); + reassembled.hugr_mut().update_validate(®ISTRY).unwrap(); assert_eq!(reassembled.commands().count(), 1); let h = reassembled.commands().next().unwrap().node(); - let [inp, out] = reassembled.get_io(reassembled.root()).unwrap(); + let [inp, out] = reassembled.io_nodes(); assert_eq!( - &reassembled.output_neighbours(inp).collect_vec(), + &reassembled.hugr().output_neighbours(inp).collect_vec(), &[h, out, out] ); } diff --git a/tket2/src/passes/commutation.rs b/tket2/src/passes/commutation.rs index 4424cb84..564a6b6a 100644 --- a/tket2/src/passes/commutation.rs +++ b/tket2/src/passes/commutation.rs @@ -1,12 +1,13 @@ use std::{collections::HashMap, rc::Rc}; use hugr::hugr::{hugrmut::HugrMut, HugrError, Rewrite}; -use hugr::{CircuitUnit, Direction, Hugr, HugrView, Node, Port, PortIndex}; +use hugr::{CircuitUnit, Direction, HugrView, Node, Port, PortIndex}; use itertools::Itertools; use portgraph::PortOffset; +use crate::Circuit; use crate::{ - circuit::{command::Command, Circuit}, + circuit::command::Command, ops::{Pauli, Tk2Op}, }; @@ -26,11 +27,8 @@ struct ComCommand { inputs: Vec, } -impl<'c, Circ> From> for ComCommand -where - Circ: HugrView, -{ - fn from(com: Command<'c, Circ>) -> Self { +impl<'c, T: HugrView> From> for ComCommand { + fn from(com: Command<'c, T>) -> Self { ComCommand { node: com.node(), inputs: com.inputs().map(|(c, _, _)| c).collect(), @@ -69,13 +67,16 @@ fn add_to_slice(slice: &mut Slice, com: Rc) { } } -fn load_slices(circ: &impl Circuit) -> SliceVec { +fn load_slices(circ: &Circuit) -> SliceVec { let mut slices = vec![]; let n_qbs = circ.qubit_count(); let mut qubit_free_slice = vec![0; n_qbs]; - for command in circ.commands().filter(|c| is_slice_op(circ, c.node())) { + for command in circ + .commands() + .filter(|c| is_slice_op(circ.hugr(), c.node())) + { let command: ComCommand = command.into(); let free_slice = command .qubits() @@ -106,7 +107,7 @@ fn is_slice_op(h: &impl HugrView, node: Node) -> bool { /// Starting from starting_index, work back along slices to check for the /// earliest slice that can accommodate this command, if any. fn available_slice( - circ: &impl HugrView, + circ: &Circuit, slice_vec: &[Slice], starting_index: usize, command: &Rc, @@ -142,7 +143,7 @@ fn available_slice( fn commutes_at_slice( command: &Rc, slice: &Slice, - circ: &impl HugrView, + circ: &Circuit, ) -> Option>> { // map from qubit to node it is connected to immediately after the free slice. let mut prev_nodes: HashMap> = HashMap::new(); @@ -155,12 +156,22 @@ fn commutes_at_slice( let port = command.port_of_qb(q, Direction::Incoming)?; - let op: Tk2Op = circ.get_optype(command.node()).clone().try_into().ok()?; + let op: Tk2Op = circ + .hugr() + .get_optype(command.node()) + .clone() + .try_into() + .ok()?; // TODO: if not tk2op, might still have serialized commutation data we // can use. let pauli = commutation_on_port(&op.qubit_commutation(), port)?; - let other_op: Tk2Op = circ.get_optype(other_com.node()).clone().try_into().ok()?; + let other_op: Tk2Op = circ + .hugr() + .get_optype(other_com.node()) + .clone() + .try_into() + .ok()?; let other_pauli = commutation_on_port( &other_op.qubit_commutation(), other_com.port_of_qb(q, Direction::Outgoing)?, @@ -287,7 +298,7 @@ impl Rewrite for PullForward { } /// Pass which greedily commutes operations forwards in order to reduce depth. -pub fn apply_greedy_commutation(circ: &mut Hugr) -> Result { +pub fn apply_greedy_commutation(circ: &mut Circuit) -> Result { let mut count = 0; let mut slice_vec = load_slices(circ); @@ -301,7 +312,7 @@ pub fn apply_greedy_commutation(circ: &mut Hugr) -> Result Result Hugr { + fn example_cx() -> Circuit { build_simple_circuit(4, |circ| { circ.append(Tk2Op::CX, [0, 2])?; circ.append(Tk2Op::CX, [1, 2])?; @@ -353,7 +362,7 @@ mod test { #[fixture] // example circuit from original task with lower depth - fn example_cx_better() -> Hugr { + fn example_cx_better() -> Circuit { build_simple_circuit(4, |circ| { circ.append(Tk2Op::CX, [0, 2])?; circ.append(Tk2Op::CX, [1, 3])?; @@ -365,7 +374,7 @@ mod test { #[fixture] // can't commute anything here - fn cant_commute() -> Hugr { + fn cant_commute() -> Circuit { build_simple_circuit(4, |circ| { circ.append(Tk2Op::Z, [1])?; circ.append(Tk2Op::CX, [0, 1])?; @@ -376,7 +385,7 @@ mod test { } #[fixture] - fn big_example() -> Hugr { + fn big_example() -> Circuit { build_simple_circuit(4, |circ| { circ.append(Tk2Op::CX, [0, 3])?; circ.append(Tk2Op::CX, [1, 2])?; @@ -395,7 +404,7 @@ mod test { #[fixture] // commute a single qubit gate - fn single_qb_commute() -> Hugr { + fn single_qb_commute() -> Circuit { build_simple_circuit(3, |circ| { circ.append(Tk2Op::H, [1])?; circ.append(Tk2Op::CX, [0, 1])?; @@ -407,7 +416,7 @@ mod test { #[fixture] // commute 2 single qubit gates - fn single_qb_commute_2() -> Hugr { + fn single_qb_commute_2() -> Circuit { build_simple_circuit(4, |circ| { circ.append(Tk2Op::CX, [1, 2])?; circ.append(Tk2Op::CX, [1, 0])?; @@ -421,7 +430,7 @@ mod test { #[fixture] // A commutation forward exists but depth doesn't change - fn commutes_but_same_depth() -> Hugr { + fn commutes_but_same_depth() -> Circuit { build_simple_circuit(3, |circ| { circ.append(Tk2Op::H, [1])?; circ.append(Tk2Op::CX, [0, 1])?; @@ -434,7 +443,7 @@ mod test { #[fixture] // Gate being commuted has a non-linear input - fn non_linear_inputs() -> Hugr { + fn non_linear_inputs() -> Circuit { let build = || { let mut dfg = DFGBuilder::new(FunctionType::new( type_row![QB_T, QB_T, FLOAT64_TYPE], @@ -451,12 +460,12 @@ mod test { let qbs = circ.finish(); dfg.finish_hugr_with_outputs(qbs, ®ISTRY) }; - build().unwrap() + build().unwrap().into() } #[fixture] // Gates being commuted have non-linear outputs - fn non_linear_outputs() -> Hugr { + fn non_linear_outputs() -> Circuit { let build = || { let mut dfg = DFGBuilder::new(FunctionType::new( type_row![QB_T, QB_T], @@ -474,11 +483,11 @@ mod test { outs.extend(measured); dfg.finish_hugr_with_outputs(outs, ®ISTRY) }; - build().unwrap() + build().unwrap().into() } // bug https://github.com/CQCL/tket2/issues/253 - fn cx_commute_bug() -> Hugr { + fn cx_commute_bug() -> Circuit { build_simple_circuit(3, |circ| { circ.append(Tk2Op::H, [2])?; circ.append(Tk2Op::CX, [2, 1])?; @@ -508,7 +517,7 @@ mod test { } #[rstest] - fn test_load_slices_cx(example_cx: Hugr) { + fn test_load_slices_cx(example_cx: Circuit) { let circ = example_cx; let commands: Vec = circ.commands().map_into().collect(); let slices = load_slices(&circ); @@ -518,7 +527,7 @@ mod test { } #[rstest] - fn test_load_slices_cx_better(example_cx_better: Hugr) { + fn test_load_slices_cx_better(example_cx_better: Circuit) { let circ = example_cx_better; let commands: Vec = circ.commands().map_into().collect(); @@ -529,7 +538,7 @@ mod test { } #[rstest] - fn test_load_slices_bell(t2_bell_circuit: Hugr) { + fn test_load_slices_bell(t2_bell_circuit: Circuit) { let circ = t2_bell_circuit; let commands: Vec = circ.commands().map_into().collect(); @@ -540,7 +549,7 @@ mod test { } #[rstest] - fn test_available_slice(example_cx: Hugr) { + fn test_available_slice(example_cx: Circuit) { let circ = example_cx; let slices = load_slices(&circ); let (found, prev_nodes) = @@ -556,7 +565,7 @@ mod test { } #[rstest] - fn big_test(big_example: Hugr) { + fn big_test(big_example: Circuit) { let circ = big_example; let slices = load_slices(&circ); assert_eq!(slices.len(), 6); @@ -578,7 +587,7 @@ mod test { } /// Calculate depth by placing commands in slices. - fn depth(h: &Hugr) -> usize { + fn depth(h: &Circuit) -> usize { load_slices(h).len() } #[rstest] @@ -594,14 +603,14 @@ mod test { #[case(non_linear_outputs(), true, 1)] #[case(cx_commute_bug(), true, 1)] fn commutation_example( - #[case] mut case: Hugr, + #[case] mut case: Circuit, #[case] should_reduce: bool, #[case] expected_moves: u32, ) { - let node_count = case.node_count(); + let node_count = case.hugr().node_count(); let depth_before = depth(&case); let move_count = apply_greedy_commutation(&mut case).unwrap(); - case.update_validate(®ISTRY).unwrap(); + case.hugr_mut().update_validate(®ISTRY).unwrap(); assert_eq!( move_count, expected_moves, @@ -619,7 +628,7 @@ mod test { } assert_eq!( - case.node_count(), + case.hugr().node_count(), node_count, "depth optimisation should not change the number of nodes." ) diff --git a/tket2/src/portmatching.rs b/tket2/src/portmatching.rs index 8e74695e..b1b57d4e 100644 --- a/tket2/src/portmatching.rs +++ b/tket2/src/portmatching.rs @@ -18,7 +18,7 @@ //! let mut dfg = DFGBuilder::new(FunctionType::new(vec![], vec![QB_T]))?; //! let alloc = dfg.add_dataflow_op(Tk2Op::QAlloc, [])?; //! dfg.finish_hugr_with_outputs(alloc.outputs(), &tket2::extension::REGISTRY) -//! }?; +//! }?.into(); //! let pattern = CircuitPattern::try_from_circuit(&circuit_pattern)?; //! //! // Define a circuit that contains a qubit allocation. @@ -39,7 +39,7 @@ //! .append(Tk2Op::CX, [1, 0])?; //! let outputs = circuit.finish(); //! -//! let circuit = dfg.finish_hugr_with_outputs(outputs, &tket2::extension::REGISTRY)?; +//! let circuit = dfg.finish_hugr_with_outputs(outputs, &tket2::extension::REGISTRY)?.into(); //! (circuit, alloc.node()) //! }; //! @@ -56,7 +56,7 @@ pub mod matcher; pub mod pattern; -use hugr::OutgoingPort; +use hugr::{HugrView, OutgoingPort}; use itertools::Itertools; pub use matcher::{PatternMatch, PatternMatcher}; pub use pattern::CircuitPattern; @@ -111,13 +111,10 @@ enum InvalidEdgeProperty { } impl PEdge { - fn try_from_port( - node: Node, - port: Port, - circ: &impl Circuit, - ) -> Result { + fn try_from_port(node: Node, port: Port, circ: &Circuit) -> Result { + let hugr = circ.hugr(); let src = port; - let (dst_node, dst) = circ + let (dst_node, dst) = hugr .linked_ports(node, src) .exactly_one() .map_err(|mut e| { @@ -127,10 +124,10 @@ impl PEdge { InvalidEdgeProperty::NoLinkedEdge(src) } })?; - if circ.get_optype(dst_node).tag() == OpTag::Input { + if hugr.get_optype(dst_node).tag() == OpTag::Input { return Ok(Self::InputEdge { src }); } - let port_type = circ + let port_type = hugr .signature(node) .unwrap() .port_type(src) @@ -202,29 +199,30 @@ impl From for NodeID { #[cfg(test)] mod tests { - use crate::Tk2Op; + use crate::{Circuit, Tk2Op}; use hugr::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, extension::{prelude::QB_T, PRELUDE_REGISTRY}, types::FunctionType, - Hugr, }; use rstest::{fixture, rstest}; use super::{CircuitPattern, PatternMatcher}; #[fixture] - fn lhs() -> Hugr { + fn lhs() -> Circuit { let mut h = DFGBuilder::new(FunctionType::new(vec![], vec![QB_T])).unwrap(); let res = h.add_dataflow_op(Tk2Op::QAlloc, []).unwrap(); let q = res.out_wire(0); - h.finish_hugr_with_outputs([q], &PRELUDE_REGISTRY).unwrap() + h.finish_hugr_with_outputs([q], &PRELUDE_REGISTRY) + .unwrap() + .into() } #[fixture] - pub fn circ() -> Hugr { + pub fn circ() -> Circuit { let mut h = DFGBuilder::new(FunctionType::new(vec![QB_T], vec![QB_T])).unwrap(); let mut inps = h.input_wires(); let q_in = inps.next().unwrap(); @@ -238,10 +236,11 @@ mod tests { h.finish_hugr_with_outputs([q_out], &PRELUDE_REGISTRY) .unwrap() + .into() } #[rstest] - fn simple_match(circ: Hugr, lhs: Hugr) { + fn simple_match(circ: Circuit, lhs: Circuit) { let p = CircuitPattern::try_from_circuit(&lhs).unwrap(); let m = PatternMatcher::from_patterns(vec![p]); diff --git a/tket2/src/portmatching/matcher.rs b/tket2/src/portmatching/matcher.rs index c5423614..f949e5ea 100644 --- a/tket2/src/portmatching/matcher.rs +++ b/tket2/src/portmatching/matcher.rs @@ -13,7 +13,7 @@ use hugr::hugr::views::sibling_subgraph::{ }; use hugr::hugr::views::SiblingSubgraph; use hugr::ops::{NamedOp, OpType}; -use hugr::{Hugr, IncomingPort, Node, OutgoingPort, Port, PortIndex}; +use hugr::{HugrView, IncomingPort, Node, OutgoingPort, Port, PortIndex}; use itertools::Itertools; use portgraph::algorithms::ConvexChecker; use portmatching::{ @@ -115,10 +115,10 @@ impl PatternMatch { pub fn try_from_root_match( root: Node, pattern: PatternID, - circ: &impl Circuit, + circ: &Circuit, matcher: &PatternMatcher, ) -> Result { - let checker = TopoConvexChecker::new(circ); + let checker = TopoConvexChecker::new(circ.hugr()); Self::try_from_root_match_with_checker(root, pattern, circ, matcher, &checker) } @@ -128,10 +128,10 @@ impl PatternMatch { /// checker object to speed up convexity checking. /// /// See [`PatternMatch::try_from_root_match`] for more details. - pub fn try_from_root_match_with_checker( + pub fn try_from_root_match_with_checker( root: Node, pattern: PatternID, - circ: &C, + circ: &Circuit, matcher: &PatternMatcher, checker: &impl ConvexChecker, ) -> Result { @@ -172,11 +172,11 @@ impl PatternMatch { pub fn try_from_io( root: Node, pattern: PatternID, - circ: &impl Circuit, + circ: &Circuit, inputs: Vec>, outputs: Vec<(Node, OutgoingPort)>, ) -> Result { - let checker = TopoConvexChecker::new(circ); + let checker = TopoConvexChecker::new(circ.hugr()); Self::try_from_io_with_checker(root, pattern, circ, inputs, outputs, &checker) } @@ -188,15 +188,16 @@ impl PatternMatch { /// /// This checks at construction time that the match is convex. This will /// have runtime linear in the size of the circuit. - pub fn try_from_io_with_checker( + pub fn try_from_io_with_checker( root: Node, pattern: PatternID, - circ: &C, + circ: &Circuit, inputs: Vec>, outputs: Vec<(Node, OutgoingPort)>, checker: &impl ConvexChecker, ) -> Result { - let subgraph = SiblingSubgraph::try_new_with_checker(inputs, outputs, circ, checker)?; + let subgraph = + SiblingSubgraph::try_new_with_checker(inputs, outputs, circ.hugr(), checker)?; Ok(Self { position: subgraph.into(), pattern, @@ -207,8 +208,8 @@ impl PatternMatch { /// Construct a rewrite to replace `self` with `repl`. pub fn to_rewrite( &self, - source: &Hugr, - target: Hugr, + source: &Circuit, + target: Circuit, ) -> Result { CircuitRewrite::try_new(&self.position, source, target) } @@ -263,25 +264,25 @@ impl PatternMatcher { } /// Find all convex pattern matches in a circuit. - pub fn find_matches_iter<'a, 'c: 'a, C: Circuit + Clone>( + pub fn find_matches_iter<'a, 'c: 'a>( &'a self, - circuit: &'c C, + circuit: &'c Circuit, ) -> impl Iterator + 'a { - let checker = TopoConvexChecker::new(circuit); + let checker = TopoConvexChecker::new(circuit.hugr()); circuit .commands() .flat_map(move |cmd| self.find_rooted_matches(circuit, cmd.node(), &checker)) } /// Find all convex pattern matches in a circuit.and collect in to a vector - pub fn find_matches(&self, circuit: &C) -> Vec { + pub fn find_matches(&self, circuit: &Circuit) -> Vec { self.find_matches_iter(circuit).collect() } /// Find all convex pattern matches in a circuit rooted at a given node. - fn find_rooted_matches( + fn find_rooted_matches( &self, - circ: &C, + circ: &Circuit, root: Node, checker: &impl ConvexChecker, ) -> Vec { @@ -429,23 +430,24 @@ fn compatible_offsets(e1: &PEdge, e2: &PEdge) -> bool { /// Returns a predicate checking that an edge at `src` satisfies `prop` in `circ`. pub(super) fn validate_circuit_edge( - circ: &impl Circuit, + circ: &Circuit, ) -> impl for<'a> Fn(NodeID, &'a PEdge) -> Option + '_ { move |src, &prop| { let NodeID::HugrNode(src) = src else { return None; }; + let hugr = circ.hugr(); match prop { PEdge::InternalEdge { src: src_port, dst: dst_port, .. } => { - let (next_node, next_port) = circ.linked_ports(src, src_port).exactly_one().ok()?; + let (next_node, next_port) = hugr.linked_ports(src, src_port).exactly_one().ok()?; (dst_port == next_port).then_some(NodeID::HugrNode(next_node)) } PEdge::InputEdge { src: src_port } => { - let (next_node, next_port) = circ.linked_ports(src, src_port).exactly_one().ok()?; + let (next_node, next_port) = hugr.linked_ports(src, src_port).exactly_one().ok()?; Some(NodeID::CopyNode(next_node, next_port)) } } @@ -454,13 +456,13 @@ pub(super) fn validate_circuit_edge( /// Returns a predicate checking that `node` satisfies `prop` in `circ`. pub(crate) fn validate_circuit_node( - circ: &impl Circuit, + circ: &Circuit, ) -> impl for<'a> Fn(NodeID, &PNode) -> bool + '_ { move |node, prop| { let NodeID::HugrNode(node) = node else { return false; }; - &MatchOp::from(circ.get_optype(node).clone()) == prop + &MatchOp::from(circ.hugr().get_optype(node).clone()) == prop } } @@ -479,16 +481,15 @@ fn handle_match_error(match_res: Result, root: Node) #[cfg(test)] mod tests { - use hugr::Hugr; use itertools::Itertools; use rstest::{fixture, rstest}; use crate::utils::build_simple_circuit; - use crate::Tk2Op; + use crate::{Circuit, Tk2Op}; use super::{CircuitPattern, PatternMatcher}; - fn h_cx() -> Hugr { + fn h_cx() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::H, [0]).unwrap(); @@ -497,7 +498,7 @@ mod tests { .unwrap() } - fn cx_xc() -> Hugr { + fn cx_xc() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::CX, [1, 0]).unwrap(); @@ -507,7 +508,7 @@ mod tests { } #[fixture] - fn cx_cx_3() -> Hugr { + fn cx_cx_3() -> Circuit { build_simple_circuit(3, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::CX, [2, 1]).unwrap(); @@ -517,7 +518,7 @@ mod tests { } #[fixture] - fn cx_cx() -> Hugr { + fn cx_cx() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::CX, [0, 1]).unwrap(); @@ -558,7 +559,7 @@ mod tests { } #[rstest] - fn cx_cx_replace_to_id(cx_cx: Hugr, cx_cx_3: Hugr) { + fn cx_cx_replace_to_id(cx_cx: Circuit, cx_cx_3: Circuit) { let p = CircuitPattern::try_from_circuit(&cx_cx_3).unwrap(); let m = PatternMatcher::from_patterns(vec![p]); diff --git a/tket2/src/portmatching/pattern.rs b/tket2/src/portmatching/pattern.rs index 4020a6bf..4a460ebd 100644 --- a/tket2/src/portmatching/pattern.rs +++ b/tket2/src/portmatching/pattern.rs @@ -1,6 +1,6 @@ //! Circuit Patterns for pattern matching -use hugr::IncomingPort; +use hugr::{HugrView, IncomingPort}; use hugr::{Node, Port}; use itertools::Itertools; use portmatching::{patterns::NoRootFound, HashMap, Pattern, SinglePatternMatcher}; @@ -30,7 +30,8 @@ impl CircuitPattern { } /// Construct a pattern from a circuit. - pub fn try_from_circuit(circuit: &impl Circuit) -> Result { + pub fn try_from_circuit(circuit: &Circuit) -> Result { + let hugr = circuit.hugr(); if circuit.num_gates() == 0 { return Err(InvalidPattern::EmptyCircuit); } @@ -42,10 +43,9 @@ impl CircuitPattern { let in_offset: IncomingPort = in_offset.into(); let edge_prop = PEdge::try_from_port(cmd.node(), in_offset.into(), circuit) .expect("Invalid HUGR"); - let (prev_node, prev_port) = circuit + let (prev_node, prev_port) = hugr .linked_outputs(cmd.node(), in_offset) .exactly_one() - .ok() .expect("invalid HUGR"); let prev_node = match edge_prop { PEdge::InternalEdge { .. } => NodeID::HugrNode(prev_node), @@ -58,18 +58,16 @@ impl CircuitPattern { if !pattern.is_valid() { return Err(InvalidPattern::NotConnected); } - let (inp, out) = (circuit.input(), circuit.output()); - let inp_ports = circuit.signature(inp).unwrap().output_ports(); - let out_ports = circuit.signature(out).unwrap().input_ports(); + let [inp, out] = circuit.io_nodes(); + let inp_ports = hugr.signature(inp).unwrap().output_ports(); + let out_ports = hugr.signature(out).unwrap().input_ports(); let inputs = inp_ports - .map(|p| circuit.linked_ports(inp, p).collect()) + .map(|p| hugr.linked_ports(inp, p).collect()) .collect_vec(); let outputs = out_ports .map(|p| { - circuit - .linked_ports(out, p) + hugr.linked_ports(out, p) .exactly_one() - .ok() .expect("invalid circuit") }) .collect_vec(); @@ -87,7 +85,11 @@ impl CircuitPattern { } /// Compute the map from pattern nodes to circuit nodes in `circ`. - pub fn get_match_map(&self, root: Node, circ: &impl Circuit) -> Option> { + pub fn get_match_map( + &self, + root: Node, + circ: &Circuit, + ) -> Option> { let single_matcher = SinglePatternMatcher::from_pattern(self.pattern.clone()); single_matcher .get_match_map( @@ -143,7 +145,6 @@ mod tests { use hugr::ops::OpType; use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::types::FunctionType; - use hugr::Hugr; use crate::extension::REGISTRY; use crate::utils::build_simple_circuit; @@ -151,7 +152,7 @@ mod tests { use super::*; - fn h_cx() -> Hugr { + fn h_cx() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1])?; circ.append(Tk2Op::H, [0])?; @@ -161,7 +162,7 @@ mod tests { } /// A circuit with two rotation gates in sequence, sharing a param - fn circ_with_copy() -> Hugr { + fn circ_with_copy() -> Circuit { let input_t = vec![QB_T, FLOAT64_TYPE]; let output_t = vec![QB_T]; let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); @@ -175,11 +176,11 @@ mod tests { let res = h.add_dataflow_op(Tk2Op::RxF64, [qb, f]).unwrap(); let qb = res.outputs().next().unwrap(); - h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap().into() } /// A circuit with two rotation gates in parallel, sharing a param - fn circ_with_copy_disconnected() -> Hugr { + fn circ_with_copy_disconnected() -> Circuit { let input_t = vec![QB_T, QB_T, FLOAT64_TYPE]; let output_t = vec![QB_T, QB_T]; let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); @@ -194,14 +195,16 @@ mod tests { let res = h.add_dataflow_op(Tk2Op::RxF64, [qb2, f]).unwrap(); let qb2 = res.outputs().next().unwrap(); - h.finish_hugr_with_outputs([qb1, qb2], ®ISTRY).unwrap() + h.finish_hugr_with_outputs([qb1, qb2], ®ISTRY) + .unwrap() + .into() } #[test] fn construct_pattern() { - let hugr = h_cx(); + let circ = h_cx(); - let p = CircuitPattern::try_from_circuit(&hugr).unwrap(); + let p = CircuitPattern::try_from_circuit(&circ).unwrap(); let edges: HashSet<_> = p .pattern @@ -210,9 +213,9 @@ mod tests { .iter() .map(|e| (e.source.unwrap(), e.target.unwrap())) .collect(); - let inp = hugr.input(); - let cx_gate = NodeID::HugrNode(get_nodes_by_tk2op(&hugr, Tk2Op::CX)[0]); - let h_gate = NodeID::HugrNode(get_nodes_by_tk2op(&hugr, Tk2Op::H)[0]); + let inp = circ.input_node(); + let cx_gate = NodeID::HugrNode(get_nodes_by_tk2op(&circ, Tk2Op::CX)[0]); + let h_gate = NodeID::HugrNode(get_nodes_by_tk2op(&circ, Tk2Op::H)[0]); assert_eq!( edges, [ @@ -252,10 +255,11 @@ mod tests { ); } - fn get_nodes_by_tk2op(circ: &impl Circuit, t2_op: Tk2Op) -> Vec { + fn get_nodes_by_tk2op(circ: &Circuit, t2_op: Tk2Op) -> Vec { let t2_op: OpType = t2_op.into(); - circ.nodes() - .filter(|n| circ.get_optype(*n) == &t2_op) + circ.hugr() + .nodes() + .filter(|n| circ.hugr().get_optype(*n) == &t2_op) .collect() } @@ -265,7 +269,7 @@ mod tests { let pattern = CircuitPattern::try_from_circuit(&circ).unwrap(); let edges = pattern.pattern.edges().unwrap(); let rx_ns = get_nodes_by_tk2op(&circ, Tk2Op::RxF64); - let inp = circ.input(); + let inp = circ.input_node(); for rx_n in rx_ns { assert!(edges.iter().any(|e| { e.reverse().is_none() diff --git a/tket2/src/rewrite.rs b/tket2/src/rewrite.rs index 45dfafec..85eeceee 100644 --- a/tket2/src/rewrite.rs +++ b/tket2/src/rewrite.rs @@ -10,17 +10,16 @@ use bytemuck::TransparentWrapper; pub use ecc_rewriter::ECCRewriter; use derive_more::{From, Into}; +use hugr::hugr::hugrmut::HugrMut; use hugr::hugr::views::sibling_subgraph::{InvalidReplacement, InvalidSubgraph}; -use hugr::Node; use hugr::{ - hugr::{hugrmut::HugrMut, views::SiblingSubgraph, Rewrite, SimpleReplacementError}, - Hugr, SimpleReplacement, + hugr::{views::SiblingSubgraph, Rewrite, SimpleReplacementError}, + SimpleReplacement, }; +use hugr::{Hugr, HugrView, Node}; use crate::circuit::Circuit; -use self::trace::RewriteTracer; - /// A subcircuit of a circuit. #[derive(Debug, Clone, From, Into)] #[repr(transparent)] @@ -34,9 +33,9 @@ impl Subcircuit { /// Create a new subcircuit induced from a set of nodes. pub fn try_from_nodes( nodes: impl Into>, - hugr: &Hugr, + circ: &Circuit, ) -> Result { - let subgraph = SiblingSubgraph::try_from_nodes(nodes, hugr)?; + let subgraph = SiblingSubgraph::try_from_nodes(nodes, circ.hugr())?; Ok(Self { subgraph }) } @@ -53,12 +52,13 @@ impl Subcircuit { /// Create a rewrite rule to replace the subcircuit. pub fn create_rewrite( &self, - source: &Hugr, - target: Hugr, + source: &Circuit, + target: Circuit, ) -> Result { - Ok(CircuitRewrite( - self.subgraph.create_simple_replacement(source, target)?, - )) + Ok(CircuitRewrite(self.subgraph.create_simple_replacement( + source.hugr(), + target.into_hugr(), + )?)) } } @@ -70,12 +70,12 @@ impl CircuitRewrite { /// Create a new rewrite rule. pub fn try_new( source_position: &Subcircuit, - source: &Hugr, - target: Hugr, + source: &Circuit, + target: Circuit, ) -> Result { source_position .subgraph - .create_simple_replacement(source, target) + .create_simple_replacement(source.hugr(), target.into_hugr()) .map(Self) } @@ -95,8 +95,8 @@ impl CircuitRewrite { } /// The replacement subcircuit. - pub fn replacement(&self) -> &Hugr { - self.0.replacement() + pub fn replacement(&self) -> Circuit<&Hugr> { + self.0.replacement().into() } /// Returns a set of nodes referenced by the rewrite. Modifying any these @@ -111,20 +111,23 @@ impl CircuitRewrite { /// Apply the rewrite rule to a circuit. #[inline] - pub fn apply(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> { + pub fn apply(self, circ: &mut Circuit) -> Result<(), SimpleReplacementError> { circ.add_rewrite_trace(&self); - self.0.apply(circ) + self.0.apply(circ.hugr_mut()) } /// Apply the rewrite rule to a circuit, without registering it in the rewrite trace. #[inline] - pub fn apply_notrace(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> { - self.0.apply(circ) + pub fn apply_notrace( + self, + circ: &mut Circuit, + ) -> Result<(), SimpleReplacementError> { + self.0.apply(circ.hugr_mut()) } } /// Generate rewrite rules for circuits. pub trait Rewriter { /// Get the rewrite rules for a circuit. - fn get_rewrites(&self, circ: &C) -> Vec; + fn get_rewrites(&self, circ: &Circuit) -> Vec; } diff --git a/tket2/src/rewrite/ecc_rewriter.rs b/tket2/src/rewrite/ecc_rewriter.rs index 954717b9..1c6be502 100644 --- a/tket2/src/rewrite/ecc_rewriter.rs +++ b/tket2/src/rewrite/ecc_rewriter.rs @@ -13,7 +13,7 @@ //! of the Quartz repository. use derive_more::{From, Into}; -use hugr::PortIndex; +use hugr::{Hugr, HugrView, PortIndex}; use itertools::Itertools; use portmatching::PatternID; use std::{ @@ -24,8 +24,6 @@ use std::{ }; use thiserror::Error; -use hugr::Hugr; - use crate::{ circuit::{remove_empty_wire, Circuit}, optimiser::badger::{load_eccs_json_file, EqCircClass}, @@ -76,7 +74,7 @@ impl ECCRewriter { /// Equivalence classes are represented as [`EqCircClass`]s, lists of /// HUGRs where one of the elements is chosen as the representative. pub fn from_eccs(eccs: impl Into>) -> Self { - let eccs = eccs.into(); + let eccs: Vec = eccs.into(); let rewrite_rules = get_rewrite_rules(&eccs); let patterns = get_patterns(&eccs); let targets = into_targets(eccs); @@ -90,7 +88,7 @@ impl ECCRewriter { let targets = r .into_iter() .filter(|&id| { - let circ = &targets[id.0]; + let circ = (&targets[id.0]).into(); let target_empty_wires: HashSet<_> = empty_wires(&circ).into_iter().collect(); pattern_empty_wires @@ -111,10 +109,10 @@ impl ECCRewriter { } /// Get all targets of rewrite rules given a source pattern. - fn get_targets(&self, pattern: PatternID) -> impl Iterator { + fn get_targets(&self, pattern: PatternID) -> impl Iterator> { self.rewrite_rules[pattern.0] .iter() - .map(|id| &self.targets[id.0]) + .map(|id| (&self.targets[id.0]).into()) } /// Serialise a rewriter to an IO stream. @@ -167,19 +165,18 @@ impl ECCRewriter { } impl Rewriter for ECCRewriter { - fn get_rewrites(&self, circ: &C) -> Vec { + fn get_rewrites(&self, circ: &Circuit) -> Vec { let matches = self.matcher.find_matches(circ); matches .into_iter() .flat_map(|m| { let pattern_id = m.pattern_id(); self.get_targets(pattern_id).map(move |repl| { - let mut repl = repl.clone(); + let mut repl = repl.to_owned(); for &empty_qb in self.empty_wires[pattern_id.0].iter().rev() { remove_empty_wire(&mut repl, empty_qb).unwrap(); } - m.to_rewrite(circ.base_hugr(), repl) - .expect("invalid replacement") + m.to_rewrite(circ, repl).expect("invalid replacement") }) }) .collect() @@ -231,9 +228,9 @@ fn get_patterns(rep_sets: &[EqCircClass]) -> Vec Vec Vec { - let input = circ.input(); - let input_sig = circ.signature(input).unwrap(); - circ.node_outputs(input) +fn empty_wires(circ: &Circuit) -> Vec { + let hugr = circ.hugr(); + let input = circ.input_node(); + let input_sig = hugr.signature(input).unwrap(); + hugr.node_outputs(input) // Only consider dataflow edges .filter(|&p| input_sig.out_port_type(p).is_some()) // Only consider ports linked to at most one other port - .filter_map(|p| Some((p, circ.linked_ports(input, p).at_most_one().ok()?))) + .filter_map(|p| Some((p, hugr.linked_ports(input, p).at_most_one().ok()?))) // Ports are either connected to output or nothing .filter_map(|(from, to)| { if let Some((n, _)) = to { // Wires connected to output - (n == circ.output()).then_some(from.index()) + (n == circ.output_node()).then_some(from.index()) } else { // Wires connected to nothing Some(from.index()) @@ -272,11 +270,11 @@ mod tests { use super::*; - fn empty() -> Hugr { + fn empty() -> Circuit { build_simple_circuit(2, |_| Ok(())).unwrap() } - fn h_h() -> Hugr { + fn h_h() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::H, [0]).unwrap(); circ.append(Tk2Op::H, [0]).unwrap(); @@ -286,7 +284,7 @@ mod tests { .unwrap() } - fn cx_cx() -> Hugr { + fn cx_cx() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::CX, [0, 1]).unwrap(); @@ -295,7 +293,7 @@ mod tests { .unwrap() } - fn cx_x() -> Hugr { + fn cx_x() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::X, [1]).unwrap(); @@ -304,7 +302,7 @@ mod tests { .unwrap() } - fn x_cx() -> Hugr { + fn x_cx() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::X, [1]).unwrap(); circ.append(Tk2Op::CX, [0, 1]).unwrap(); @@ -328,7 +326,13 @@ mod tests { vec![TargetID(3)], ] ); - assert_eq!(rewriter.get_targets(PatternID(1)).collect_vec(), [&h_h()]); + assert_eq!( + rewriter + .get_targets(PatternID(1)) + .map(|c| c.to_owned()) + .collect_vec(), + [h_h()] + ); } #[test] diff --git a/tket2/src/rewrite/strategy.rs b/tket2/src/rewrite/strategy.rs index e6576942..3e44b638 100644 --- a/tket2/src/rewrite/strategy.rs +++ b/tket2/src/rewrite/strategy.rs @@ -25,13 +25,13 @@ use std::{collections::HashSet, fmt::Debug}; use derive_more::From; use hugr::ops::OpType; -use hugr::Hugr; +use hugr::HugrView; use itertools::Itertools; use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, LexicographicCost}; use crate::Circuit; -use super::trace::{RewriteTrace, RewriteTracer}; +use super::trace::RewriteTrace; use super::CircuitRewrite; /// Rewriting strategies for circuit optimisation. @@ -51,7 +51,7 @@ pub trait RewriteStrategy { fn apply_rewrites( &self, rewrites: impl IntoIterator, - circ: &Hugr, + circ: &Circuit, ) -> impl Iterator>; /// The cost of a single operation for this strategy's cost function. @@ -59,13 +59,13 @@ pub trait RewriteStrategy { /// The cost of a circuit using this strategy's cost function. #[inline] - fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { + fn circuit_cost(&self, circ: &Circuit) -> Self::Cost { circ.circuit_cost(|op| self.op_cost(op)) } /// Returns the cost of a rewrite's matched subcircuit before replacing it. #[inline] - fn pre_rewrite_cost(&self, rw: &CircuitRewrite, circ: &Hugr) -> Self::Cost { + fn pre_rewrite_cost(&self, rw: &CircuitRewrite, circ: &Circuit) -> Self::Cost { circ.nodes_cost(rw.subcircuit().nodes().iter().copied(), |op| { self.op_cost(op) }) @@ -81,15 +81,18 @@ pub trait RewriteStrategy { #[derive(Debug, Clone)] pub struct RewriteResult { /// The rewritten circuit. - pub circ: Hugr, + pub circ: Circuit, /// The cost delta of the rewrite. pub cost_delta: C::CostDelta, } -impl From<(Hugr, C::CostDelta)> for RewriteResult { +impl From<(Circuit, C::CostDelta)> for RewriteResult { #[inline] - fn from((circ, cost_delta): (Hugr, C::CostDelta)) -> Self { - Self { circ, cost_delta } + fn from((circ, cost_delta): (Circuit, C::CostDelta)) -> Self { + Self { + circ: circ.to_owned(), + cost_delta, + } } } @@ -113,7 +116,7 @@ impl RewriteStrategy for GreedyRewriteStrategy { fn apply_rewrites( &self, rewrites: impl IntoIterator, - circ: &Hugr, + circ: &Circuit, ) -> impl Iterator> { let rewrites = rewrites .into_iter() @@ -140,7 +143,7 @@ impl RewriteStrategy for GreedyRewriteStrategy { iter::once((circ, cost_delta).into()) } - fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { + fn circuit_cost(&self, circ: &Circuit) -> Self::Cost { circ.num_gates() } @@ -184,7 +187,7 @@ impl RewriteStrategy for ExhaustiveGreedyStrategy { fn apply_rewrites( &self, rewrites: impl IntoIterator, - circ: &Hugr, + circ: &Circuit, ) -> impl Iterator> { // Check only the rewrites that reduce the size of the circuit. let rewrites = rewrites @@ -262,7 +265,7 @@ impl RewriteStrategy for ExhaustiveThresholdStrategy { fn apply_rewrites( &self, rewrites: impl IntoIterator, - circ: &Hugr, + circ: &Circuit, ) -> impl Iterator> { rewrites.into_iter().filter_map(|rw| { let pattern_cost = self.pre_rewrite_cost(&rw, circ); @@ -429,7 +432,7 @@ impl GammaStrategyCost usize> { #[cfg(test)] mod tests { use super::*; - use hugr::{Hugr, Node}; + use hugr::Node; use itertools::Itertools; use crate::rewrite::trace::REWRITE_TRACING_ENABLED; @@ -440,7 +443,7 @@ mod tests { Tk2Op, }; - fn n_cx(n_gates: usize) -> Hugr { + fn n_cx(n_gates: usize) -> Circuit { let qbs = [0, 1]; build_simple_circuit(2, |circ| { for _ in 0..n_gates { @@ -452,15 +455,15 @@ mod tests { } /// Rewrite cx_nodes -> empty - fn rw_to_empty(hugr: &Hugr, cx_nodes: impl Into>) -> CircuitRewrite { - let subcirc = Subcircuit::try_from_nodes(cx_nodes, hugr).unwrap(); - subcirc.create_rewrite(hugr, n_cx(0)).unwrap() + fn rw_to_empty(circ: &Circuit, cx_nodes: impl Into>) -> CircuitRewrite { + let subcirc = Subcircuit::try_from_nodes(cx_nodes, circ).unwrap(); + subcirc.create_rewrite(circ, n_cx(0)).unwrap() } /// Rewrite cx_nodes -> 10x CX - fn rw_to_full(hugr: &Hugr, cx_nodes: impl Into>) -> CircuitRewrite { - let subcirc = Subcircuit::try_from_nodes(cx_nodes, hugr).unwrap(); - subcirc.create_rewrite(hugr, n_cx(10)).unwrap() + fn rw_to_full(circ: &Circuit, cx_nodes: impl Into>) -> CircuitRewrite { + let subcirc = Subcircuit::try_from_nodes(cx_nodes, circ).unwrap(); + subcirc.create_rewrite(circ, n_cx(10)).unwrap() } #[test] diff --git a/tket2/src/rewrite/trace.rs b/tket2/src/rewrite/trace.rs index f2c17494..d469135f 100644 --- a/tket2/src/rewrite/trace.rs +++ b/tket2/src/rewrite/trace.rs @@ -17,13 +17,13 @@ pub const METADATA_REWRITES: &str = "TKET2.rewrites"; /// Enable it by setting the `rewrite-tracing` feature. /// /// Note that circuits must be explicitly enabled for rewrite tracing by calling -/// [`RewriteTracer::enable_rewrite_tracing`]. +/// [`Circuit::enable_rewrite_tracing`]. pub const REWRITE_TRACING_ENABLED: bool = cfg!(feature = "rewrite-tracing"); /// The trace of a rewrite applied to a circuit. /// /// Traces are only enabled if the `rewrite-tracing` feature is enabled and -/// [`RewriteTracer::enable_rewrite_tracing`] is called on the circuit. +/// [`Circuit::enable_rewrite_tracing`] is called on the circuit. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct RewriteTrace { @@ -67,18 +67,19 @@ impl From for serde_json::Value { } } -/// Extension trait for circuits that can trace rewrites applied to them. +/// Implementation for rewrite tracing in circuits. /// /// This is only tracked if the `rewrite-tracing` feature is enabled and /// `enable_rewrite_tracing` is called on the circuit before. -pub trait RewriteTracer: Circuit + HugrMut + Sized { +impl Circuit { /// Enable rewrite tracing for the circuit. #[inline] - fn enable_rewrite_tracing(&mut self) { + pub fn enable_rewrite_tracing(&mut self) { if !REWRITE_TRACING_ENABLED { return; } - let meta = self.get_metadata_mut(self.root(), METADATA_REWRITES); + let root = self.parent(); + let meta = self.hugr_mut().get_metadata_mut(root, METADATA_REWRITES); if *meta == NodeMetadata::Null { *meta = NodeMetadata::Array(vec![]); } @@ -88,12 +89,14 @@ pub trait RewriteTracer: Circuit + HugrMut + Sized { /// /// Returns `true` if the rewrite was successfully registered, or `false` if it was ignored. #[inline] - fn add_rewrite_trace(&mut self, rewrite: impl Into) -> bool { + pub fn add_rewrite_trace(&mut self, rewrite: impl Into) -> bool { if !REWRITE_TRACING_ENABLED { return false; } + let root = self.parent(); match self - .get_metadata_mut(self.root(), METADATA_REWRITES) + .hugr_mut() + .get_metadata_mut(root, METADATA_REWRITES) .as_array_mut() { Some(meta) => { @@ -112,14 +115,12 @@ pub trait RewriteTracer: Circuit + HugrMut + Sized { // // TODO return an `impl Iterator` once rust 1.75 lands. #[inline] - fn rewrite_trace(&self) -> Option> { + pub fn rewrite_trace(&self) -> Option> { if !REWRITE_TRACING_ENABLED { return None; } - let meta = self.get_metadata(self.root(), METADATA_REWRITES)?; + let meta = self.hugr().get_metadata(self.parent(), METADATA_REWRITES)?; let rewrites = meta.as_array()?; Some(rewrites.iter().map_into().collect_vec()) } } - -impl RewriteTracer for T {} diff --git a/tket2/src/utils.rs b/tket2/src/utils.rs index 1c46692d..ec4daa87 100644 --- a/tket2/src/utils.rs +++ b/tket2/src/utils.rs @@ -9,6 +9,8 @@ use hugr::{ Hugr, }; +use crate::circuit::Circuit; + pub(crate) fn type_is_linear(typ: &Type) -> bool { !TypeBound::Copyable.contains(typ.least_upper_bound()) } @@ -18,7 +20,7 @@ pub(crate) fn type_is_linear(typ: &Type) -> bool { pub(crate) fn build_simple_circuit( num_qubits: usize, f: impl FnOnce(&mut CircuitBuilder>) -> Result<(), BuildError>, -) -> Result { +) -> Result { let qb_row = vec![QB_T; num_qubits]; let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row))?; @@ -29,14 +31,16 @@ pub(crate) fn build_simple_circuit( f(&mut circ)?; let qbs = circ.finish(); - h.finish_hugr_with_outputs(qbs, &PRELUDE_REGISTRY) + let hugr = h.finish_hugr_with_outputs(qbs, &PRELUDE_REGISTRY)?; + Ok(hugr.into()) } // Test only utils #[allow(dead_code)] +#[allow(unused_imports)] #[cfg(test)] pub(crate) mod test { - #[allow(unused_imports)] + use crate::Circuit; use hugr::HugrView; /// Open a browser page to render a dot string graph. @@ -51,6 +55,14 @@ pub(crate) mod test { webbrowser::open(&base).unwrap(); } + /// Open a browser page to render a Circuit's dot string graph. + /// + /// Only for use in local testing. Will fail to compile on CI. + #[cfg(not(ci_run))] + pub(crate) fn viz_circ(circ: &Circuit) { + viz_dotstr(circ.dot_string()); + } + /// Open a browser page to render a HugrView's dot string graph. /// /// Only for use in local testing. Will fail to compile on CI. diff --git a/tket2/tests/badger_termination.rs b/tket2/tests/badger_termination.rs index f899ca9d..b79d46f1 100644 --- a/tket2/tests/badger_termination.rs +++ b/tket2/tests/badger_termination.rs @@ -1,6 +1,5 @@ #![cfg(feature = "portmatching")] -use hugr::Hugr; use rstest::{fixture, rstest}; use tket2::optimiser::badger::BadgerOptions; use tket2::{ @@ -29,7 +28,7 @@ fn nam_4_2() -> DefaultBadgerOptimiser { ///q_2: ┤ X ├───────────────────────────────────────────┤ X ├───────────── /// └───┘ └───┘ #[fixture] -fn simple_circ() -> Hugr { +fn simple_circ() -> Circuit { // The TK1 json of the circuit let json = r#"{ "bits": [], @@ -57,7 +56,7 @@ fn simple_circ() -> Hugr { #[rstest] //#[ignore = "Takes 200ms"] -fn badger_termination(simple_circ: Hugr, nam_4_2: DefaultBadgerOptimiser) { +fn badger_termination(simple_circ: Circuit, nam_4_2: DefaultBadgerOptimiser) { let opt_circ = nam_4_2.optimise( &simple_circ, BadgerOptions {