diff --git a/Cargo.lock b/Cargo.lock index e4a13327..7e9bab20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -325,6 +325,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "cool_asserts" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee9f254e53f61e2688d3677fa2cbe4e9b950afd56f48819c98817417cf6b28ec" +dependencies = [ + "indent_write", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -707,9 +716,19 @@ dependencies = [ [[package]] name = "hugr" -version = "0.4.0" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ebf436d9d4d0239fcb511a1f7d745548bbdc6316baf50bb2e944f294542ea24" +dependencies = [ + "hugr-core", + "hugr-passes", +] + +[[package]] +name = "hugr-core" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8e15eaffd64f1cceac13429e5ceaf20017691e647cc56b8b7c53d73ad9f8714" +checksum = "1f2b8cfdccedf45a563e526a4a2c7443025ece0e2518bca1c3f227318607de44" dependencies = [ "bitvec", "cgmath", @@ -719,7 +738,7 @@ dependencies = [ "downcast-rs", "enum_dispatch", "html-escape", - "itertools 0.12.1", + "itertools 0.13.0", "lazy_static", "num-rational", "paste", @@ -736,6 +755,19 @@ dependencies = [ "typetag", ] +[[package]] +name = "hugr-passes" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c1ae428e41786736cc92cbc3ae11ad39f178591d36a9ac3d1c8990e1a9bf76" +dependencies = [ + "hugr-core", + "itertools 0.13.0", + "lazy_static", + "paste", + "thiserror", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -769,6 +801,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indent_write" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cfe9645a18782869361d9c8732246be7b410ad4e919d3609ebabdac00ba12c3" + [[package]] name = "indexmap" version = "2.2.6" @@ -817,15 +855,6 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.13.0" @@ -1742,6 +1771,7 @@ dependencies = [ "bytemuck", "cgmath", "chrono", + "cool_asserts", "criterion", "crossbeam-channel", "csv", diff --git a/Cargo.toml b/Cargo.toml index 6c67a2b9..57ed1b73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ missing_docs = "warn" [workspace.dependencies] tket2 = { path = "./tket2" } -hugr = "0.4.0" +hugr = "0.5.0" portgraph = "0.12" pyo3 = "0.21.2" itertools = "0.13.0" @@ -60,3 +60,4 @@ tracing-subscriber = "0.3.17" typetag = "0.2.8" urlencoding = "2.1.2" webbrowser = "1.0.0" +cool_asserts = "2.0.3" diff --git a/badger-optimiser/src/main.rs b/badger-optimiser/src/main.rs index b87ecaef..08aa4d8e 100644 --- a/badger-optimiser/src/main.rs +++ b/badger-optimiser/src/main.rs @@ -15,7 +15,6 @@ use tket2::json::{load_tk1_json_file, save_tk1_json_file}; use tket2::optimiser::badger::log::BadgerLogger; use tket2::optimiser::badger::BadgerOptions; use tket2::optimiser::{BadgerOptimiser, DefaultBadgerOptimiser}; -use tket2::rewrite::trace::RewriteTracer; #[cfg(feature = "peak_alloc")] #[global_allocator] diff --git a/tket2-py/src/circuit/convert.rs b/tket2-py/src/circuit/convert.rs index 42c29813..c8f70d63 100644 --- a/tket2-py/src/circuit/convert.rs +++ b/tket2-py/src/circuit/convert.rs @@ -44,8 +44,10 @@ impl CircuitType { /// Converts a `Hugr` into the format indicated by the flag. pub fn convert(self, py: Python, hugr: Hugr) -> PyResult> { match self { - CircuitType::Tket1 => SerialCircuit::encode(&hugr).convert_pyerrs()?.to_tket1(py), - CircuitType::Tket2 => Ok(Bound::new(py, Tk2Circuit { hugr })?.into_any()), + CircuitType::Tket1 => SerialCircuit::encode(&hugr.into()) + .convert_pyerrs()? + .to_tket1(py), + CircuitType::Tket2 => Ok(Bound::new(py, Tk2Circuit { circ: hugr.into() })?.into_any()), } } } @@ -58,16 +60,16 @@ where E: ConvertPyErr, F: FnOnce(Hugr, CircuitType) -> Result, { - let (hugr, typ) = match Tk2Circuit::extract_bound(circ) { + let (circ, typ) = match Tk2Circuit::extract_bound(circ) { // hugr circuit - Ok(t2circ) => (t2circ.hugr, CircuitType::Tket2), + Ok(t2circ) => (t2circ.circ, CircuitType::Tket2), // tket1 circuit Err(_) => ( SerialCircuit::from_tket1(circ)?.decode().convert_pyerrs()?, CircuitType::Tket1, ), }; - (f)(hugr, typ).map_err(|e| e.convert_pyerrs()) + (f)(circ.into_hugr(), typ).map_err(|e| e.convert_pyerrs()) } /// Apply a function expecting a hugr on a python circuit. diff --git a/tket2-py/src/circuit/tk2circuit.rs b/tket2-py/src/circuit/tk2circuit.rs index a31f4a86..024a9c0f 100644 --- a/tket2-py/src/circuit/tk2circuit.rs +++ b/tket2-py/src/circuit/tk2circuit.rs @@ -55,7 +55,7 @@ use super::{cost, with_hugr, PyCircuitCost, PyCustom, PyHugrType, PyNode, PyWire #[derive(Clone, Debug, PartialEq, From)] pub struct Tk2Circuit { /// Rust representation of the circuit. - pub hugr: Hugr, + pub circ: Circuit, } #[pymethods] @@ -67,40 +67,40 @@ impl Tk2Circuit { #[new] pub fn new(circ: &Bound) -> PyResult { Ok(Self { - hugr: with_hugr(circ, |hugr, _| hugr)?, + circ: with_hugr(circ, |hugr, _| hugr)?.into(), }) } /// Convert the [`Tk2Circuit`] to a tket1 circuit. pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult> { - SerialCircuit::encode(&self.hugr) + SerialCircuit::encode(&self.circ) .convert_pyerrs()? .to_tket1(py) } /// Apply a rewrite on the circuit. pub fn apply_rewrite(&mut self, rw: PyCircuitRewrite) { - rw.rewrite.apply(&mut self.hugr).expect("Apply error."); + rw.rewrite.apply(&mut self.circ).expect("Apply error."); } /// Encode the circuit as a HUGR json string. // // TODO: Bind a messagepack encoder/decoder too. pub fn to_hugr_json(&self) -> PyResult { - Ok(serde_json::to_string(&self.hugr).unwrap()) + Ok(serde_json::to_string(self.circ.hugr()).unwrap()) } /// Decode a HUGR json string to a circuit. #[staticmethod] pub fn from_hugr_json(json: &str) -> PyResult { - let hugr = serde_json::from_str(json) + let hugr: Hugr = serde_json::from_str(json) .map_err(|e| PyErr::new::(format!("Invalid encoded HUGR: {e}")))?; - Ok(Tk2Circuit { hugr }) + Ok(Tk2Circuit { circ: hugr.into() }) } /// Encode the circuit as a tket1 json string. pub fn to_tket1_json(&self) -> PyResult { - Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr).convert_pyerrs()?).unwrap()) + Ok(serde_json::to_string(&SerialCircuit::encode(&self.circ).convert_pyerrs()?).unwrap()) } /// Decode a tket1 json string to a circuit. @@ -109,7 +109,7 @@ impl Tk2Circuit { let tk1: SerialCircuit = serde_json::from_str(json) .map_err(|e| PyErr::new::(format!("Invalid encoded HUGR: {e}")))?; Ok(Tk2Circuit { - hugr: tk1.decode().convert_pyerrs()?, + circ: tk1.decode().convert_pyerrs()?, }) } @@ -134,13 +134,13 @@ impl Tk2Circuit { cost: cost.to_object(py), }) }; - let circ_cost = self.hugr.circuit_cost(cost_fn)?; + let circ_cost = self.circ.circuit_cost(cost_fn)?; Ok(circ_cost.cost.into_bound(py)) } /// Returns a hash of the circuit. pub fn hash(&self) -> u64 { - self.hugr.circuit_hash().unwrap() + self.circ.circuit_hash().unwrap() } /// Hash the circuit @@ -160,7 +160,8 @@ impl Tk2Circuit { fn node_op(&self, node: PyNode) -> PyResult { let custom: CustomOp = self - .hugr + .circ + .hugr() .get_optype(node.node) .clone() .try_into() @@ -174,25 +175,27 @@ impl Tk2Circuit { } fn node_inputs(&self, node: PyNode) -> Vec { - self.hugr + self.circ + .hugr() .all_linked_outputs(node.node) .map(|(n, p)| Wire::new(n, p).into()) .collect() } fn node_outputs(&self, node: PyNode) -> Vec { - self.hugr + self.circ + .hugr() .node_outputs(node.node) .map(|p| Wire::new(node.node, p).into()) .collect() } fn input_node(&self) -> PyNode { - self.hugr.input().into() + self.circ.input_node().into() } fn output_node(&self) -> PyNode { - self.hugr.output().into() + self.circ.output_node().into() } } impl Tk2Circuit { @@ -236,11 +239,12 @@ impl Dfg { fn finish(&mut self, outputs: Vec) -> PyResult { Ok(Tk2Circuit { - hugr: self + circ: self .builder .clone() .finish_hugr_with_outputs(outputs.into_iter().map_into(), ®ISTRY) - .convert_pyerrs()?, + .convert_pyerrs()? + .into(), }) } } diff --git a/tket2-py/src/optimiser.rs b/tket2-py/src/optimiser.rs index 49c5abc6..03166d0f 100644 --- a/tket2-py/src/optimiser.rs +++ b/tket2-py/src/optimiser.rs @@ -3,10 +3,10 @@ use std::io::BufWriter; use std::{fs, num::NonZeroUsize, path::PathBuf}; -use hugr::Hugr; use pyo3::prelude::*; use tket2::optimiser::badger::BadgerOptions; use tket2::optimiser::{BadgerLogger, DefaultBadgerOptimiser}; +use tket2::Circuit; use crate::circuit::update_hugr; @@ -96,7 +96,10 @@ impl PyBadgerOptimiser { split_circuit: split_circ.unwrap_or(false), queue_size: queue_size.unwrap_or(100), }; - update_hugr(circ, |circ, _| self.optimise(circ, log_progress, options)) + update_hugr(circ, |circ, _| { + self.optimise(circ.into(), log_progress, options) + .into_hugr() + }) } } @@ -104,10 +107,10 @@ impl PyBadgerOptimiser { /// The Python optimise method, but on Hugrs. pub(super) fn optimise( &self, - circ: Hugr, + circ: Circuit, log_progress: Option, options: BadgerOptions, - ) -> Hugr { + ) -> Circuit { let badger_logger = log_progress .map(|file_name| { let log_file = fs::File::create(file_name).unwrap(); diff --git a/tket2-py/src/passes.rs b/tket2-py/src/passes.rs index d05d37f0..b9488573 100644 --- a/tket2-py/src/passes.rs +++ b/tket2-py/src/passes.rs @@ -39,9 +39,10 @@ create_py_exception!( #[pyfunction] fn greedy_depth_reduce<'py>(circ: &Bound<'py, PyAny>) -> PyResult<(Bound<'py, PyAny>, u32)> { let py = circ.py(); - try_with_hugr(circ, |mut h, typ| { - let n_moves = apply_greedy_commutation(&mut h).convert_pyerrs()?; - let circ = typ.convert(py, h)?; + try_with_hugr(circ, |h, typ| { + let mut circ: Circuit = h.into(); + let n_moves = apply_greedy_commutation(&mut circ).convert_pyerrs()?; + let circ = typ.convert(py, circ.into_hugr())?; PyResult::Ok((circ, n_moves)) }) } @@ -117,7 +118,8 @@ fn badger_optimise<'py>( _ => unreachable!(), }; // Optimise - try_update_hugr(circ, |mut circ, _| { + try_update_hugr(circ, |hugr, _| { + let mut circ: Circuit = hugr.into(); let n_cx = circ .commands() .filter(|c| op_matches(c.optype(), Tk2Op::CX)) @@ -142,6 +144,6 @@ fn badger_optimise<'py>( }; circ = optimiser.optimise(circ, log_file, options); } - PyResult::Ok(circ) + PyResult::Ok(circ.into_hugr()) }) } diff --git a/tket2-py/src/passes/chunks.rs b/tket2-py/src/passes/chunks.rs index ce81fa16..ad973d7a 100644 --- a/tket2-py/src/passes/chunks.rs +++ b/tket2-py/src/passes/chunks.rs @@ -15,7 +15,7 @@ use crate::utils::ConvertPyErr; pub fn chunks(c: &Bound, max_chunk_size: usize) -> PyResult { with_hugr(c, |hugr, typ| { // TODO: Detect if the circuit is in tket1 format or Tk2Circuit. - let chunks = CircuitChunks::split(&hugr, max_chunk_size); + let chunks = CircuitChunks::split(&hugr.into(), max_chunk_size); (chunks, typ).into() }) } @@ -39,27 +39,28 @@ pub struct PyCircuitChunks { impl PyCircuitChunks { /// Reassemble the chunks into a circuit. fn reassemble<'py>(&self, py: Python<'py>) -> PyResult> { - let hugr = self.clone().chunks.reassemble().convert_pyerrs()?; - self.original_type.convert(py, hugr) + let circ = self.clone().chunks.reassemble().convert_pyerrs()?; + self.original_type.convert(py, circ.into_hugr()) } /// Returns clones of the split circuits. fn circuits<'py>(&self, py: Python<'py>) -> PyResult>> { self.chunks .iter() - .map(|hugr| self.original_type.convert(py, hugr.clone())) + .map(|circ| self.original_type.convert(py, circ.clone().into_hugr())) .collect() } /// Replaces a chunk's circuit with an updated version. fn update_circuit(&mut self, index: usize, new_circ: &Bound) -> PyResult<()> { try_with_hugr(new_circ, |hugr, _| { - if hugr.circuit_signature() != self.chunks[index].circuit_signature() { + let circ: Circuit = hugr.into(); + if circ.circuit_signature() != self.chunks[index].circuit_signature() { return Err(PyAttributeError::new_err( "The new circuit has a different signature.", )); } - self.chunks[index] = hugr; + self.chunks[index] = circ; Ok(()) }) } diff --git a/tket2-py/src/pattern.rs b/tket2-py/src/pattern.rs index 61d433b1..86081506 100644 --- a/tket2-py/src/pattern.rs +++ b/tket2-py/src/pattern.rs @@ -6,9 +6,10 @@ use crate::circuit::Tk2Circuit; use crate::rewrite::PyCircuitRewrite; use crate::utils::{create_py_exception, ConvertPyErr}; -use hugr::Hugr; +use hugr::HugrView; use pyo3::prelude::*; use tket2::portmatching::{CircuitPattern, PatternMatch, PatternMatcher}; +use tket2::Circuit; /// The module definition pub fn module(py: Python<'_>) -> PyResult> { @@ -47,7 +48,7 @@ create_py_exception!( #[derive(Clone)] #[pyclass] /// A rewrite rule defined by a left hand side and right hand side of an equation. -pub struct Rule(pub [Hugr; 2]); +pub struct Rule(pub [Circuit; 2]); #[pymethods] impl Rule { @@ -56,13 +57,13 @@ impl Rule { let l = Tk2Circuit::new(l)?; let r = Tk2Circuit::new(r)?; - Ok(Rule([l.hugr, r.hugr])) + Ok(Rule([l.circ, r.circ])) } } #[pyclass] struct RuleMatcher { matcher: PatternMatcher, - rights: Vec, + rights: Vec, } #[pymethods] @@ -79,25 +80,29 @@ impl RuleMatcher { } pub fn find_match(&self, target: &Tk2Circuit) -> PyResult> { - let h = &target.hugr; - if let Some(pmatch) = self.matcher.find_matches_iter(h).next() { - Ok(Some(self.match_to_rewrite(pmatch, h)?)) + let circ = &target.circ; + if let Some(pmatch) = self.matcher.find_matches_iter(circ).next() { + Ok(Some(self.match_to_rewrite(pmatch, circ)?)) } else { Ok(None) } } pub fn find_matches(&self, target: &Tk2Circuit) -> PyResult> { - let h = &target.hugr; + let circ = &target.circ; self.matcher - .find_matches_iter(h) - .map(|m| self.match_to_rewrite(m, h)) + .find_matches_iter(circ) + .map(|m| self.match_to_rewrite(m, circ)) .collect() } } impl RuleMatcher { - fn match_to_rewrite(&self, pmatch: PatternMatch, target: &Hugr) -> PyResult { + fn match_to_rewrite( + &self, + pmatch: PatternMatch, + target: &Circuit, + ) -> PyResult { let r = self.rights.get(pmatch.pattern_id().0).unwrap().clone(); let rw = pmatch.to_rewrite(target, r).convert_pyerrs()?; Ok(rw.into()) diff --git a/tket2-py/src/pattern/portmatching.rs b/tket2-py/src/pattern/portmatching.rs index 4ed90b2e..49695fe4 100644 --- a/tket2-py/src/pattern/portmatching.rs +++ b/tket2-py/src/pattern/portmatching.rs @@ -30,7 +30,9 @@ impl PyCircuitPattern { /// Construct a pattern from a TKET1 circuit #[new] pub fn from_circuit(circ: &Bound) -> PyResult { - let pattern = try_with_hugr(circ, |circ, _| CircuitPattern::try_from_circuit(&circ))?; + let pattern = try_with_hugr(circ, |circ, _| { + CircuitPattern::try_from_circuit(&circ.into()) + })?; Ok(pattern.into()) } @@ -81,7 +83,10 @@ impl PyPatternMatcher { /// Find one convex match in a circuit. pub fn find_match(&self, circ: &Bound) -> PyResult> { with_hugr(circ, |circ, _| { - self.matcher.find_matches_iter(&circ).next().map(Into::into) + self.matcher + .find_matches_iter(&circ.into()) + .next() + .map(Into::into) }) } @@ -89,7 +94,7 @@ impl PyPatternMatcher { pub fn find_matches(&self, circ: &Bound) -> PyResult> { with_hugr(circ, |circ, _| { self.matcher - .find_matches(&circ) + .find_matches(&circ.into()) .into_iter() .map_into() .collect() diff --git a/tket2-py/src/rewrite.rs b/tket2-py/src/rewrite.rs index 5c01038b..18c34b41 100644 --- a/tket2-py/src/rewrite.rs +++ b/tket2-py/src/rewrite.rs @@ -43,7 +43,7 @@ impl PyCircuitRewrite { /// The replacement subcircuit. pub fn replacement(&self) -> Tk2Circuit { - self.rewrite.replacement().clone().into() + self.rewrite.replacement().to_owned().into() } #[new] @@ -55,8 +55,8 @@ impl PyCircuitRewrite { Ok(Self { rewrite: CircuitRewrite::try_new( &source_position.0, - &source_circ.hugr, - replacement.hugr, + &source_circ.circ, + replacement.circ, ) .map_err(|e| PyErr::new::(e.to_string()))?, }) @@ -80,7 +80,7 @@ impl PySubcircuit { fn from_nodes(nodes: Vec, circ: &Tk2Circuit) -> PyResult { let nodes: Vec<_> = nodes.into_iter().map_into().collect(); Ok(Self( - Subcircuit::try_from_nodes(nodes, &circ.hugr) + Subcircuit::try_from_nodes(nodes, &circ.circ) .map_err(|e| PyErr::new::(e.to_string()))?, )) } @@ -107,7 +107,7 @@ impl PyECCRewriter { /// Returns a list of circuit rewrites that can be applied to the given Tk2Circuit. pub fn get_rewrites(&self, circ: &Tk2Circuit) -> Vec { self.0 - .get_rewrites(&circ.hugr) + .get_rewrites(&circ.circ) .into_iter() .map_into() .collect() diff --git a/tket2/Cargo.toml b/tket2/Cargo.toml index 0e81385a..f5d6e4e9 100644 --- a/tket2/Cargo.toml +++ b/tket2/Cargo.toml @@ -68,6 +68,7 @@ rstest = { workspace = true } criterion = { workspace = true, features = ["html_reports"] } webbrowser = { workspace = true } urlencoding = { workspace = true } +cool_asserts = { workspace = true } [[bench]] name = "bench_main" diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index 348876d4..d6d90576 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -15,9 +15,9 @@ use derive_more::From; use hugr::hugr::hugrmut::HugrMut; use hugr::hugr::NodeType; use hugr::ops::dataflow::IOTrait; -use hugr::ops::{Input, Output, DFG}; +use hugr::ops::{Input, NamedOp, OpParent, OpTag, OpTrait, Output, DFG}; use hugr::types::PolyFuncType; -use hugr::PortIndex; +use hugr::{Hugr, PortIndex}; use hugr::{HugrView, OutgoingPort}; use itertools::Itertools; use thiserror::Error; @@ -28,54 +28,138 @@ pub use hugr::{Node, Port, Wire}; use self::units::{filter, LinearUnit, Units}; -/// An object behaving like a quantum circuit. -// -// TODO: More methods: -// - other_{in,out}puts (for non-linear i/o + const inputs)? -// - Vertical slice iterator -// - Depth -pub trait Circuit: HugrView { +/// A quantum circuit, represented as a function in a HUGR. +#[derive(Debug, Clone, PartialEq)] +pub struct Circuit { + /// The HUGR containing the circuit. + hugr: T, + /// The parent node of the circuit. + /// + /// This is checked at runtime to ensure that the node is a DFG node. + parent: Node, +} + +impl Default for Circuit { + fn default() -> Self { + let hugr = T::default(); + let parent = hugr.root(); + Self { hugr, parent } + } +} + +impl Circuit { + /// Create a new circuit from a HUGR and a node. + /// + /// # Errors + /// + /// Returns an error if the parent node is not a DFG node in the HUGR. + pub fn try_new(hugr: T, parent: Node) -> Result { + check_hugr(&hugr, parent)?; + Ok(Self { hugr, parent }) + } + + /// Create a new circuit from a HUGR and a node. + /// + /// See [`Circuit::try_new`] for a version that returns an error. + /// + /// # Panics + /// + /// Panics if the parent node is not a DFG node in the HUGR. + pub fn new(hugr: T, parent: Node) -> Self { + Self::try_new(hugr, parent).unwrap_or_else(|e| panic!("{}", e)) + } + + /// Returns the node containing the circuit definition. + pub fn parent(&self) -> Node { + self.parent + } + + /// Get a reference to the HUGR containing the circuit. + pub fn hugr(&self) -> &T { + &self.hugr + } + + /// Unwrap the HUGR containing the circuit. + pub fn into_hugr(self) -> T { + self.hugr + } + + /// Get a mutable reference to the HUGR containing the circuit. + /// + /// Mutation of the hugr MUST NOT invalidate the parent node, + /// by changing the node's type to a non-DFG node or by removing it. + pub fn hugr_mut(&mut self) -> &mut T { + &mut self.hugr + } + + /// Ensures the circuit contains an owned HUGR. + pub fn to_owned(&self) -> Circuit { + let hugr = self.hugr.base_hugr().clone(); + Circuit { + hugr, + parent: self.parent, + } + } + /// Return the name of the circuit #[inline] - fn name(&self) -> Option<&str> { - self.get_metadata(self.root(), "name")?.as_str() + pub fn name(&self) -> Option<&str> { + self.hugr.get_metadata(self.parent(), "name")?.as_str() } /// Returns the function type of the circuit. /// /// Equivalent to [`HugrView::get_function_type`]. #[inline] - fn circuit_signature(&self) -> PolyFuncType { - self.get_function_type() - .expect("Circuit has no function type") + pub fn circuit_signature(&self) -> PolyFuncType { + let op = self.hugr.get_optype(self.parent); + match op { + OpType::FuncDecl(decl) => decl.signature.clone(), + OpType::FuncDefn(defn) => defn.signature.clone(), + _ => op + .inner_function_type() + .expect("Circuit parent should have a function type") + .into(), + } } /// Returns the input node to the circuit. #[inline] - fn input(&self) -> Node { - self.get_io(self.root()).expect("Circuit has no input node")[0] + pub fn input_node(&self) -> Node { + self.hugr + .get_io(self.parent) + .expect("Circuit has no input node")[0] } /// Returns the output node to the circuit. #[inline] - fn output(&self) -> Node { - self.get_io(self.root()) + pub fn output_node(&self) -> Node { + self.hugr + .get_io(self.parent) .expect("Circuit has no output node")[1] } + /// Returns the input and output nodes of the circuit. + #[inline] + pub fn io_nodes(&self) -> [Node; 2] { + self.hugr + .get_io(self.parent) + .expect("Circuit has no I/O nodes") + } + /// The number of quantum gates in the circuit. #[inline] - fn num_gates(&self) -> usize + pub fn num_gates(&self) -> usize where Self: Sized, { - // TODO: Implement discern quantum gates in the commands iterator. - self.children(self.root()).count() - 2 + // TODO: Discern quantum gates in the commands iterator. + self.hugr().children(self.parent).count() - 2 } /// Count the number of qubits in the circuit. #[inline] - fn qubit_count(&self) -> usize + pub fn qubit_count(&self) -> usize where Self: Sized, { @@ -84,7 +168,7 @@ pub trait Circuit: HugrView { /// Get the input units of the circuit and their types. #[inline] - fn units(&self) -> Units + pub fn units(&self) -> Units where Self: Sized, { @@ -93,7 +177,7 @@ pub trait Circuit: HugrView { /// Get the linear input units of the circuit and their types. #[inline] - fn linear_units(&self) -> impl Iterator + '_ + pub fn linear_units(&self) -> impl Iterator + '_ where Self: Sized, { @@ -102,7 +186,7 @@ pub trait Circuit: HugrView { /// Get the non-linear input units of the circuit and their types. #[inline] - fn nonlinear_units(&self) -> impl Iterator + '_ + pub fn nonlinear_units(&self) -> impl Iterator + '_ where Self: Sized, { @@ -111,7 +195,7 @@ pub trait Circuit: HugrView { /// Returns the units corresponding to qubits inputs to the circuit. #[inline] - fn qubits(&self) -> impl Iterator + '_ + pub fn qubits(&self) -> impl Iterator + '_ where Self: Sized, { @@ -122,7 +206,7 @@ pub trait Circuit: HugrView { /// /// Ignores the Input and Output nodes. #[inline] - fn commands(&self) -> CommandIterator<'_, Self> + pub fn commands(&self) -> CommandIterator<'_, T> where Self: Sized, { @@ -132,7 +216,7 @@ pub trait Circuit: HugrView { /// Compute the cost of the circuit based on a per-operation cost function. #[inline] - fn circuit_cost(&self, op_cost: F) -> C + pub fn circuit_cost(&self, op_cost: F) -> C where Self: Sized, C: Sum, @@ -144,15 +228,62 @@ pub trait Circuit: HugrView { /// Compute the cost of a group of nodes in a circuit based on a /// per-operation cost function. #[inline] - fn nodes_cost(&self, nodes: impl IntoIterator, op_cost: F) -> C + pub fn nodes_cost(&self, nodes: impl IntoIterator, op_cost: F) -> C where C: Sum, F: Fn(&OpType) -> C, { - nodes.into_iter().map(|n| op_cost(self.get_optype(n))).sum() + nodes + .into_iter() + .map(|n| op_cost(self.hugr.get_optype(n))) + .sum() + } + + /// Return the graphviz representation of the underlying graph and hierarchy side by side. + /// + /// For a simpler representation, use the [`Circuit::mermaid_string`] format instead. + pub fn dot_string(&self) -> String { + // TODO: This will print the whole HUGR without identifying the circuit container. + // Should we add some extra formatting for that? + self.hugr.dot_string() + } + + /// Return the mermaid representation of the underlying hierarchical graph. + /// + /// The hierarchy is represented using subgraphs. Edges are labelled with + /// their source and target ports. + /// + /// For a more detailed representation, use the [`Circuit::dot_string`] + /// format instead. + pub fn mermaid_string(&self) -> String { + // TODO: See comment in `dot_string`. + self.hugr.mermaid_string() + } +} + +impl From for Circuit { + fn from(hugr: T) -> Self { + let parent = hugr.root(); + Self::new(hugr, parent) } } +/// Checks if the passed hugr is a valid circuit, +/// and return [`CircuitError`] if not. +fn check_hugr(hugr: &impl HugrView, parent: Node) -> Result<(), CircuitError> { + if !hugr.contains_node(parent) { + return Err(CircuitError::MissingParentNode { parent }); + } + let optype = hugr.get_optype(parent); + if !OpTag::DataflowParent.is_superset(optype.tag()) { + return Err(CircuitError::NonDFGParent { + parent, + optype: optype.clone(), + }); + } + Ok(()) +} + /// Remove an empty wire in a dataflow HUGR. /// /// The wire to be removed is identified by the index of the outgoing port @@ -167,15 +298,18 @@ pub trait Circuit: HugrView { /// occurs. #[allow(dead_code)] pub(crate) fn remove_empty_wire( - circ: &mut impl HugrMut, + circ: &mut Circuit, input_port: usize, ) -> Result<(), CircuitMutError> { - let [inp, out] = circ.get_io(circ.root()).expect("no IO nodes found at root"); - if input_port >= circ.num_outputs(inp) { + let parent = circ.parent(); + let hugr = circ.hugr_mut(); + + let [inp, out] = hugr.get_io(parent).expect("no IO nodes found at parent"); + if input_port >= hugr.num_outputs(inp) { return Err(CircuitMutError::InvalidPortOffset(input_port)); } let input_port = OutgoingPort::from(input_port); - let link = circ + let link = hugr .linked_inputs(inp, input_port) .at_most_one() .map_err(|_| CircuitMutError::DeleteNonEmptyWire(input_port.index()))?; @@ -183,25 +317,52 @@ pub(crate) fn remove_empty_wire( return Err(CircuitMutError::DeleteNonEmptyWire(input_port.index())); } if link.is_some() { - circ.disconnect(inp, input_port); + hugr.disconnect(inp, input_port); } // Shift ports at input - shift_ports(circ, inp, input_port, circ.num_outputs(inp))?; + shift_ports(hugr, inp, input_port, hugr.num_outputs(inp))?; // Shift ports at output if let Some((out, output_port)) = link { - shift_ports(circ, out, output_port, circ.num_inputs(out))?; + shift_ports(hugr, out, output_port, hugr.num_inputs(out))?; } - // Update input node, output node (if necessary) and root signatures. - update_signature(circ, input_port.index(), link.map(|(_, p)| p.index())); + // Update input node, output node (if necessary) and parent signatures. + update_signature( + hugr, + parent, + input_port.index(), + link.map(|(_, p)| p.index()), + ); // Resize ports at input/output node - circ.set_num_ports(inp, 0, circ.num_outputs(inp) - 1); + hugr.set_num_ports(inp, 0, hugr.num_outputs(inp) - 1); if let Some((out, _)) = link { - circ.set_num_ports(out, circ.num_inputs(out) - 1, 0); + hugr.set_num_ports(out, hugr.num_inputs(out) - 1, 0); } Ok(()) } +/// Errors that can occur when mutating a circuit. +#[derive(Debug, Clone, Error, PartialEq)] +pub enum CircuitError { + /// The parent node for the circuit does not exist in the HUGR. + #[error("{parent} cannot define a circuit as it is not present in the HUGR.")] + MissingParentNode { + /// The node that was used as the parent. + parent: Node, + }, + /// The parent node for the circuit is not a DFG node. + #[error( + "{parent} cannot be used as a circuit parent. A {} is not a dataflow container.", + optype.name() + )] + NonDFGParent { + /// The node that was used as the parent. + parent: Node, + /// The parent optype. + optype: OpType, + }, +} + /// Errors that can occur when mutating a circuit. #[derive(Debug, Clone, Error, PartialEq, Eq, From)] pub enum CircuitMutError { @@ -252,15 +413,18 @@ fn shift_ports( // Update the signature of circ when removing the in_index-th input wire and // the out_index-th output wire. -fn update_signature( - circ: &mut C, +fn update_signature( + hugr: &mut impl HugrMut, + parent: Node, in_index: usize, out_index: Option, ) { - let inp = circ.input(); + let inp = hugr + .get_io(parent) + .expect("no IO nodes found at circuit parent")[0]; // Update input node let inp_types: TypeRow = { - let OpType::Input(Input { types }) = circ.get_optype(inp).clone() else { + let OpType::Input(Input { types }) = hugr.get_optype(inp).clone() else { panic!("invalid circuit") }; let mut types = types.into_owned(); @@ -268,15 +432,15 @@ fn update_signature( types.into() }; let new_inp_op = Input::new(inp_types.clone()); - let inp_exts = circ.get_nodetype(inp).input_extensions().cloned(); - circ.replace_op(inp, NodeType::new(new_inp_op, inp_exts)) + let inp_exts = hugr.get_nodetype(inp).input_extensions().cloned(); + hugr.replace_op(inp, NodeType::new(new_inp_op, inp_exts)) .unwrap(); // Update output node if necessary. let out_types = out_index.map(|out_index| { - let out = circ.output(); + let out = hugr.get_io(parent).unwrap()[1]; let out_types: TypeRow = { - let OpType::Output(Output { types }) = circ.get_optype(out).clone() else { + let OpType::Output(Output { types }) = hugr.get_optype(out).clone() else { panic!("invalid circuit") }; let mut types = types.into_owned(); @@ -284,14 +448,14 @@ fn update_signature( types.into() }; let new_out_op = Output::new(out_types.clone()); - let inp_exts = circ.get_nodetype(out).input_extensions().cloned(); - circ.replace_op(out, NodeType::new(new_out_op, inp_exts)) + let inp_exts = hugr.get_nodetype(out).input_extensions().cloned(); + hugr.replace_op(out, NodeType::new(new_out_op, inp_exts)) .unwrap(); out_types }); - // Update root - let OpType::DFG(DFG { mut signature, .. }) = circ.get_optype(circ.root()).clone() else { + // Update parent + let OpType::DFG(DFG { mut signature, .. }) = hugr.get_optype(parent).clone() else { panic!("invalid circuit") }; signature.input = inp_types; @@ -299,26 +463,25 @@ fn update_signature( signature.output = out_types; } let new_dfg_op = DFG { signature }; - let inp_exts = circ.get_nodetype(circ.root()).input_extensions().cloned(); - circ.replace_op(circ.root(), NodeType::new(new_dfg_op, inp_exts)) + let inp_exts = hugr.get_nodetype(parent).input_extensions().cloned(); + hugr.replace_op(parent, NodeType::new(new_dfg_op, inp_exts)) .unwrap(); } -impl Circuit for T where T: HugrView {} - #[cfg(test)] mod tests { + use cool_asserts::assert_matches; + use hugr::types::FunctionType; use hugr::{ builder::{DFGBuilder, DataflowHugr}, extension::{prelude::BOOL_T, PRELUDE_REGISTRY}, - Hugr, }; use super::*; use crate::{json::load_tk1_json_str, utils::build_simple_circuit, Tk2Op}; - fn test_circuit() -> Hugr { + fn test_circuit() -> Circuit { load_tk1_json_str( r#"{ "phase": "0", "bits": [["c", [0]]], @@ -367,10 +530,23 @@ mod tests { ); } + #[test] + fn test_invalid_parent() { + let hugr = Hugr::default(); + + assert_matches!( + Circuit::try_new(hugr.clone(), hugr.root()), + Err(CircuitError::NonDFGParent { .. }), + ); + } + #[test] fn remove_bit() { let h = DFGBuilder::new(FunctionType::new(vec![BOOL_T], vec![])).unwrap(); - let mut circ = h.finish_hugr_with_outputs([], &PRELUDE_REGISTRY).unwrap(); + let mut circ: Circuit = h + .finish_hugr_with_outputs([], &PRELUDE_REGISTRY) + .unwrap() + .into(); assert_eq!(circ.units().count(), 1); assert!(remove_empty_wire(&mut circ, 0).is_ok()); diff --git a/tket2/src/circuit/command.rs b/tket2/src/circuit/command.rs index 84ed2a18..09af12a1 100644 --- a/tket2/src/circuit/command.rs +++ b/tket2/src/circuit/command.rs @@ -8,7 +8,7 @@ use std::iter::FusedIterator; use hugr::hugr::NodeType; use hugr::ops::{OpTag, OpTrait}; -use hugr::{IncomingPort, OutgoingPort}; +use hugr::{HugrView, IncomingPort, OutgoingPort}; use itertools::Either::{self, Left, Right}; use itertools::{EitherOrBoth, Itertools}; use petgraph::visit as pv; @@ -21,9 +21,9 @@ pub use hugr::types::{EdgeKind, Type, TypeRow}; pub use hugr::{CircuitUnit, Direction, Node, Port, PortIndex, Wire}; /// An operation applied to specific wires. -pub struct Command<'circ, Circ> { +pub struct Command<'circ, T> { /// The circuit. - circ: &'circ Circ, + circ: &'circ Circuit, /// The operation node. node: Node, /// An assignment of linear units to the node's input ports. @@ -32,7 +32,7 @@ pub struct Command<'circ, Circ> { output_linear_units: Vec, } -impl<'circ, Circ: Circuit> Command<'circ, Circ> { +impl<'circ, T: HugrView> Command<'circ, T> { /// Returns the node corresponding to this command. #[inline] pub fn node(&self) -> Node { @@ -42,13 +42,13 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { /// Returns the [`NodeType`] of the command. #[inline] pub fn nodetype(&self) -> &NodeType { - self.circ.get_nodetype(self.node) + self.circ.hugr().get_nodetype(self.node) } /// Returns the [`OpType`] of the command. #[inline] pub fn optype(&self) -> &OpType { - self.circ.get_optype(self.node) + self.circ.hugr().get_optype(self.node) } /// Returns the units of this command in a given direction. @@ -162,7 +162,7 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> { } } -impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { +impl<'a, 'circ, T: HugrView> UnitLabeller for &'a Command<'circ, T> { #[inline] fn assign_linear(&self, _: Node, port: Port, _linear_count: usize) -> LinearUnit { let units = match port.direction() { @@ -181,7 +181,7 @@ impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { fn assign_wire(&self, node: Node, port: Port) -> Option { match port.as_directed() { Left(to_port) => { - let (from, from_port) = self.circ.linked_outputs(node, to_port).next()?; + let (from, from_port) = self.circ.hugr().linked_outputs(node, to_port).next()?; Some(Wire::new(from, from_port)) } Right(from_port) => Some(Wire::new(node, from_port)), @@ -189,7 +189,7 @@ impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> { } } -impl<'circ, Circ: Circuit> std::fmt::Debug for Command<'circ, Circ> { +impl<'circ, T: HugrView> std::fmt::Debug for Command<'circ, T> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Command") .field("circuit name", &self.circ.name()) @@ -200,7 +200,7 @@ impl<'circ, Circ: Circuit> std::fmt::Debug for Command<'circ, Circ> { } } -impl<'circ, Circ> PartialEq for Command<'circ, Circ> { +impl<'circ, T: HugrView> PartialEq for Command<'circ, T> { fn eq(&self, other: &Self) -> bool { self.node == other.node && self.input_linear_units == other.input_linear_units @@ -208,9 +208,9 @@ impl<'circ, Circ> PartialEq for Command<'circ, Circ> { } } -impl<'circ, Circ> Eq for Command<'circ, Circ> {} +impl<'circ, T: HugrView> Eq for Command<'circ, T> {} -impl<'circ, Circ> Clone for Command<'circ, Circ> { +impl<'circ, T: HugrView> Clone for Command<'circ, T> { fn clone(&self) -> Self { Self { circ: self.circ, @@ -221,7 +221,7 @@ impl<'circ, Circ> Clone for Command<'circ, Circ> { } } -impl<'circ, Circ> std::hash::Hash for Command<'circ, Circ> { +impl<'circ, T: HugrView> std::hash::Hash for Command<'circ, T> { fn hash(&self, state: &mut H) { self.node.hash(state); self.input_linear_units.hash(state); @@ -234,9 +234,9 @@ type NodeWalker = pv::Topo>; /// An iterator over the commands of a circuit. #[derive(Clone)] -pub struct CommandIterator<'circ, Circ> { +pub struct CommandIterator<'circ, T> { /// The circuit. - circ: &'circ Circ, + circ: &'circ Circuit, /// Toposorted nodes. nodes: NodeWalker, /// Last wire for each [`LinearUnit`] in the circuit. @@ -263,28 +263,25 @@ pub struct CommandIterator<'circ, Circ> { delayed_node: Option, } -impl<'circ, Circ> CommandIterator<'circ, Circ> -where - Circ: Circuit, -{ +impl<'circ, T: HugrView> CommandIterator<'circ, T> { /// Create a new iterator over the commands of a circuit. - pub(super) fn new(circ: &'circ Circ) -> Self { + pub(super) fn new(circ: &'circ Circuit) -> Self { // Initialize the map assigning linear units to the input's linear // ports. // // TODO: `with_wires` combinator for `Units`? let wire_unit = circ .linear_units() - .map(|(linear_unit, port, _)| (Wire::new(circ.input(), port), linear_unit.index())) + .map(|(linear_unit, port, _)| (Wire::new(circ.input_node(), port), linear_unit.index())) .collect(); - let nodes = pv::Topo::new(&circ.as_petgraph()); + let nodes = pv::Topo::new(&circ.hugr().as_petgraph()); Self { circ, nodes, wire_unit, // Ignore the input and output nodes, and the root. - remaining: circ.node_count() - 3, + remaining: circ.hugr().node_count() - 3, delayed_consts: HashSet::new(), delayed_consumers: HashMap::new(), delayed_node: None, @@ -299,13 +296,13 @@ where let node = self .delayed_node .take() - .or_else(|| self.nodes.next(&self.circ.as_petgraph()))?; + .or_else(|| self.nodes.next(&self.circ.hugr().as_petgraph()))?; // If this node is a constant or load const node, delay it. - let tag = self.circ.get_optype(node).tag(); + let tag = self.circ.hugr().get_optype(node).tag(); if tag == OpTag::Const || tag == OpTag::LoadConst { self.delayed_consts.insert(node); - for consumer in self.circ.output_neighbours(node) { + for consumer in self.circ.hugr().output_neighbours(node) { *self.delayed_consumers.entry(consumer).or_default() += 1; } return self.next_node(); @@ -316,7 +313,7 @@ where true => { let delayed = self.next_delayed_node(node); self.delayed_consts.remove(&delayed); - for consumer in self.circ.output_neighbours(delayed) { + for consumer in self.circ.hugr().output_neighbours(delayed) { let Entry::Occupied(mut entry) = self.delayed_consumers.entry(consumer) else { panic!("Delayed node consumer was not in delayed_consumers. Delayed node: {delayed:?}, consumer: {consumer:?}."); }; @@ -336,6 +333,7 @@ where fn next_delayed_node(&mut self, consumer: Node) -> Node { let Some(delayed_pred) = self .circ + .hugr() .input_neighbours(consumer) .find(|k| self.delayed_consts.contains(k)) else { @@ -359,12 +357,12 @@ where /// mutable borrow here. fn process_node(&mut self, node: Node) -> Option<(Vec, Vec)> { // The root node is ignored. - if node == self.circ.root() { + if node == self.circ.parent() { return None; } // Inputs and outputs are also ignored. // The input wire ids are already set in the `wire_unit` map during initialization. - let tag = self.circ.get_optype(node).tag(); + let tag = self.circ.hugr().get_optype(node).tag(); if tag == OpTag::Input || tag == OpTag::Output { return None; } @@ -390,7 +388,7 @@ where // Returns the linear id of the terminated unit. let mut terminate_input = |port: IncomingPort, wire_unit: &mut HashMap| -> Option { - let linear_id = self.circ.single_linked_output(node, port).and_then( + let linear_id = self.circ.hugr().single_linked_output(node, port).and_then( |(wire_node, wire_port)| wire_unit.remove(&Wire::new(wire_node, wire_port)), )?; input_linear_units.push(LinearUnit::new(linear_id)); @@ -425,11 +423,8 @@ where } } -impl<'circ, Circ> Iterator for CommandIterator<'circ, Circ> -where - Circ: Circuit, -{ - type Item = Command<'circ, Circ>; +impl<'circ, T: HugrView> Iterator for CommandIterator<'circ, T> { + type Item = Command<'circ, T>; #[inline] fn next(&mut self) -> Option { @@ -454,9 +449,9 @@ where } } -impl<'circ, Circ> FusedIterator for CommandIterator<'circ, Circ> where Circ: Circuit {} +impl<'circ, T: HugrView> FusedIterator for CommandIterator<'circ, T> {} -impl<'circ, Circ: Circuit> std::fmt::Debug for CommandIterator<'circ, Circ> { +impl<'circ, T: HugrView> std::fmt::Debug for CommandIterator<'circ, T> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("CommandIterator") .field("circuit name", &self.circ.name()) @@ -558,9 +553,10 @@ mod test { .add_dataflow_op(Tk2Op::RzF64, [q_in, loaded_const]) .unwrap(); - let circ = h + let circ: Circuit = h .finish_hugr_with_outputs(rz.outputs(), &FLOAT_OPS_REGISTRY) - .unwrap(); + .unwrap() + .into(); assert_eq!(CommandIterator::new(&circ).count(), 3); let mut commands = CommandIterator::new(&circ); @@ -631,7 +627,7 @@ mod test { let free = h.add_dataflow_op(Tk2Op::QFree, [q_in])?; - let circ = h.finish_hugr_with_outputs([q_new], ®ISTRY)?; + let circ: Circuit = h.finish_hugr_with_outputs([q_new], ®ISTRY)?.into(); let mut cmds = circ.commands(); @@ -678,7 +674,7 @@ mod test { let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), vec![]))?; let [q_in] = h.input_wires_arr(); h.add_dataflow_op(Tk2Op::QFree, [q_in])?; - let circ = h.finish_hugr_with_outputs([], ®ISTRY)?; + let circ: Circuit = h.finish_hugr_with_outputs([], ®ISTRY)?.into(); let cmd1 = circ.commands().next().unwrap(); let cmd2 = circ.commands().next().unwrap(); diff --git a/tket2/src/circuit/hash.rs b/tket2/src/circuit/hash.rs index 478f97d9..84806874 100644 --- a/tket2/src/circuit/hash.rs +++ b/tket2/src/circuit/hash.rs @@ -12,7 +12,7 @@ use thiserror::Error; use super::Circuit; /// Circuit hashing utilities. -pub trait CircuitHash<'circ>: HugrView { +pub trait CircuitHash { /// Compute hash of a circuit. /// /// We compute a hash for each command from its operation and the hash of @@ -23,14 +23,26 @@ pub trait CircuitHash<'circ>: HugrView { /// /// Adapted from Quartz (Apache 2.0) /// - fn circuit_hash(&'circ self) -> Result; + fn circuit_hash(&self) -> Result; } -impl<'circ, T> CircuitHash<'circ> for T +impl CircuitHash for Circuit { + fn circuit_hash(&self) -> Result { + let hugr = self.hugr(); + let container: SiblingGraph = SiblingGraph::try_new(hugr, self.parent()).unwrap(); + container.circuit_hash() + } +} + +impl CircuitHash for T where T: HugrView, { - fn circuit_hash(&'circ self) -> Result { + fn circuit_hash(&self) -> Result { + let Some([_, output_node]) = self.get_io(self.root()) else { + return Err(HashError::NotADfg); + }; + let mut node_hashes = HashState::default(); for node in pg::Topo::new(&self.as_petgraph()) @@ -45,7 +57,7 @@ where // If the output node has no hash, the topological sort failed due to a cycle. node_hashes - .node_hash(self.output()) + .node_hash(output_node) .ok_or(HashError::CyclicCircuit) } } @@ -132,11 +144,13 @@ pub enum HashError { /// The circuit contains a cycle. #[error("The circuit contains a cycle.")] CyclicCircuit, + /// The hashed hugr is not a DFG. + #[error("Tried to hash a non-dfg hugr.")] + NotADfg, } #[cfg(test)] mod test { - use hugr::Hugr; use tket_json_rs::circuit_json; use crate::json::TKETDecode; @@ -185,7 +199,7 @@ mod test { fn hash_constants() { let c_str = r#"{"bits": [], "commands": [{"args": [["q", [0]]], "op": {"params": ["0.5"], "type": "Rz"}}], "created_qubits": [], "discarded_qubits": [], "implicit_permutation": [[["q", [0]], ["q", [0]]]], "phase": "0.0", "qubits": [["q", [0]]]}"#; let ser: circuit_json::SerialCircuit = serde_json::from_str(c_str).unwrap(); - let circ: Hugr = ser.decode().unwrap(); + let circ: Circuit = ser.decode().unwrap(); circ.circuit_hash().unwrap(); } @@ -197,7 +211,7 @@ mod test { let mut all_hashes = Vec::with_capacity(2); for c_str in [c_str1, c_str2] { let ser: circuit_json::SerialCircuit = serde_json::from_str(c_str).unwrap(); - let circ: Hugr = ser.decode().unwrap(); + let circ: Circuit = ser.decode().unwrap(); all_hashes.push(circ.circuit_hash().unwrap()); } assert_ne!(all_hashes[0], all_hashes[1]); diff --git a/tket2/src/circuit/units.rs b/tket2/src/circuit/units.rs index 17383aa7..173186f5 100644 --- a/tket2/src/circuit/units.rs +++ b/tket2/src/circuit/units.rs @@ -18,7 +18,7 @@ use std::iter::FusedIterator; use std::marker::PhantomData; use hugr::types::{EdgeKind, Type, TypeRow}; -use hugr::{CircuitUnit, IncomingPort, OutgoingPort}; +use hugr::{CircuitUnit, HugrView, IncomingPort, OutgoingPort}; use hugr::{Direction, Node, Port, Wire}; use crate::utils::type_is_linear; @@ -83,8 +83,8 @@ impl Units { /// This iterator will yield all units originating from the circuit's input /// node. #[inline] - pub(super) fn new_circ_input(circuit: &impl Circuit) -> Self { - Self::new_outgoing(circuit, circuit.input(), DefaultUnitLabeller) + pub(super) fn new_circ_input(circuit: &Circuit) -> Self { + Self::new_outgoing(circuit, circuit.input_node(), DefaultUnitLabeller) } } @@ -94,7 +94,11 @@ where { /// Create a new iterator over the units originating from node. #[inline] - pub(super) fn new_outgoing(circuit: &impl Circuit, node: Node, unit_labeller: UL) -> Self { + pub(super) fn new_outgoing( + circuit: &Circuit, + node: Node, + unit_labeller: UL, + ) -> Self { Self::new_with_dir(circuit, node, Direction::Outgoing, unit_labeller) } } @@ -105,7 +109,11 @@ where { /// Create a new iterator over the units terminating on the node. #[inline] - pub(super) fn new_incoming(circuit: &impl Circuit, node: Node, unit_labeller: UL) -> Self { + pub(super) fn new_incoming( + circuit: &Circuit, + node: Node, + unit_labeller: UL, + ) -> Self { Self::new_with_dir(circuit, node, Direction::Incoming, unit_labeller) } } @@ -117,8 +125,8 @@ where { /// Create a new iterator over the units of a node. #[inline] - fn new_with_dir( - circuit: &impl Circuit, + fn new_with_dir( + circuit: &Circuit, node: Node, direction: Direction, unit_labeller: UL, @@ -142,9 +150,10 @@ where // We should revisit it once this is reworked on the HUGR side. // // TODO: EdgeKind::Function is not currently supported. - fn init_types(circuit: &impl Circuit, node: Node, direction: Direction) -> TypeRow { - let optype = circuit.get_optype(node); - let sig = circuit.signature(node).unwrap_or_default(); + fn init_types(circuit: &Circuit, node: Node, direction: Direction) -> TypeRow { + let hugr = circuit.hugr(); + let optype = hugr.get_optype(node); + let sig = hugr.signature(node).unwrap_or_default(); let mut types = match direction { Direction::Outgoing => sig.output, Direction::Incoming => sig.input, diff --git a/tket2/src/json.rs b/tket2/src/json.rs index aab0b33b..34c71089 100644 --- a/tket2/src/json.rs +++ b/tket2/src/json.rs @@ -14,7 +14,6 @@ use std::{fs, io}; use hugr::ops::{OpType, Value}; use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; -use hugr::Hugr; use stringreader::StringReader; use thiserror::Error; @@ -37,23 +36,25 @@ const METADATA_Q_REGISTERS: &str = "TKET1_JSON.qubit_registers"; /// Explicit names for the input bit registers. const METADATA_B_REGISTERS: &str = "TKET1_JSON.bit_registers"; -/// A JSON-serialized circuit that can be converted to a [`Hugr`]. +/// A serialized representation of a [`Circuit`]. +/// +/// Implemented by [`SerialCircuit`], the JSON format used by tket1's `pytket` library. pub trait TKETDecode: Sized { /// The error type for decoding. type DecodeError; /// The error type for decoding. type EncodeError; - /// Convert the serialized circuit to a [`Hugr`]. - fn decode(self) -> Result; - /// Convert a [`Hugr`] to a new serialized circuit. - fn encode(circuit: &impl Circuit) -> Result; + /// Convert the serialized circuit to a circuit. + fn decode(self) -> Result; + /// Convert a circuit to a new serialized circuit. + fn encode(circuit: &Circuit) -> Result; } impl TKETDecode for SerialCircuit { type DecodeError = OpConvertError; type EncodeError = OpConvertError; - fn decode(self) -> Result { + fn decode(self) -> Result { let mut decoder = JsonDecoder::new(&self); if !self.phase.is_empty() { @@ -65,10 +66,10 @@ impl TKETDecode for SerialCircuit { for com in self.commands { decoder.add_command(com); } - Ok(decoder.finish()) + Ok(decoder.finish().into()) } - fn encode(circ: &impl Circuit) -> Result { + fn encode(circ: &Circuit) -> Result { let mut encoder = JsonEncoder::new(circ); let f64_inputs = circ.units().filter_map(|(wire, _, t)| match (wire, t) { (CircuitUnit::Wire(wire), t) if t == FLOAT64_TYPE => Some(wire), @@ -102,43 +103,41 @@ pub enum OpConvertError { } /// Load a TKET1 circuit from a JSON file. -pub fn load_tk1_json_file(path: impl AsRef) -> Result { +pub fn load_tk1_json_file(path: impl AsRef) -> Result { let file = fs::File::open(path)?; let reader = io::BufReader::new(file); load_tk1_json_reader(reader) } /// Load a TKET1 circuit from a JSON reader. -pub fn load_tk1_json_reader(json: impl io::Read) -> Result { +pub fn load_tk1_json_reader(json: impl io::Read) -> Result { let ser: SerialCircuit = serde_json::from_reader(json)?; - Ok(ser.decode()?) + let circ: Circuit = ser.decode()?; + Ok(circ) } /// Load a TKET1 circuit from a JSON string. -pub fn load_tk1_json_str(json: &str) -> Result { +pub fn load_tk1_json_str(json: &str) -> Result { let reader = StringReader::new(json); load_tk1_json_reader(reader) } /// Save a circuit to file in TK1 JSON format. -pub fn save_tk1_json_file( - circ: &impl Circuit, - path: impl AsRef, -) -> Result<(), TK1ConvertError> { +pub fn save_tk1_json_file(circ: &Circuit, path: impl AsRef) -> Result<(), TK1ConvertError> { let file = fs::File::create(path)?; let writer = io::BufWriter::new(file); save_tk1_json_writer(circ, writer) } /// Save a circuit in TK1 JSON format to a writer. -pub fn save_tk1_json_writer(circ: &impl Circuit, w: impl io::Write) -> Result<(), TK1ConvertError> { +pub fn save_tk1_json_writer(circ: &Circuit, w: impl io::Write) -> Result<(), TK1ConvertError> { let serial_circ = SerialCircuit::encode(circ)?; serde_json::to_writer(w, &serial_circ)?; Ok(()) } /// Save a circuit in TK1 JSON format to a String. -pub fn save_tk1_json_str(circ: &impl Circuit) -> Result { +pub fn save_tk1_json_str(circ: &Circuit) -> Result { let mut buf = io::BufWriter::new(Vec::new()); save_tk1_json_writer(circ, &mut buf)?; let bytes = buf.into_inner().unwrap(); diff --git a/tket2/src/json/encoder.rs b/tket2/src/json/encoder.rs index 589d871c..0e65efbd 100644 --- a/tket2/src/json/encoder.rs +++ b/tket2/src/json/encoder.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use hugr::extension::prelude::QB_T; use hugr::ops::{NamedOp, OpType}; use hugr::std_extensions::arithmetic::float_types::ConstF64; -use hugr::Wire; +use hugr::{HugrView, Wire}; use itertools::{Either, Itertools}; use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit}; @@ -48,8 +48,9 @@ pub(super) struct JsonEncoder { impl JsonEncoder { /// Create a new [`JsonEncoder`] from a [`Circuit`]. - pub fn new(circ: &impl Circuit) -> Self { + pub fn new(circ: &Circuit) -> Self { let name = circ.name().map(str::to_string); + let hugr = circ.hugr(); let mut qubit_registers = vec![]; let mut bit_registers = vec![]; @@ -58,17 +59,17 @@ impl JsonEncoder { // Recover other parameters stored in the metadata // TODO: Check for invalid encoded metadata - let root = circ.root(); - if let Some(p) = circ.get_metadata(root, METADATA_PHASE) { + let root = circ.parent(); + if let Some(p) = hugr.get_metadata(root, METADATA_PHASE) { phase = p.as_str().unwrap().to_string(); } - if let Some(perm) = circ.get_metadata(root, METADATA_IMPLICIT_PERM) { + if let Some(perm) = hugr.get_metadata(root, METADATA_IMPLICIT_PERM) { implicit_permutation = serde_json::from_value(perm.clone()).unwrap(); } - if let Some(q_regs) = circ.get_metadata(root, METADATA_Q_REGISTERS) { + if let Some(q_regs) = hugr.get_metadata(root, METADATA_Q_REGISTERS) { qubit_registers = serde_json::from_value(q_regs.clone()).unwrap(); } - if let Some(b_regs) = circ.get_metadata(root, METADATA_B_REGISTERS) { + if let Some(b_regs) = hugr.get_metadata(root, METADATA_B_REGISTERS) { bit_registers = serde_json::from_value(b_regs.clone()).unwrap(); } @@ -109,9 +110,9 @@ impl JsonEncoder { } /// Add a circuit command to the serialization. - pub fn add_command( + pub fn add_command( &mut self, - command: Command<'_, C>, + command: Command<'_, T>, optype: &OpType, ) -> Result<(), OpConvertError> { // Register any output of the command that can be used as a TKET1 parameter. @@ -169,7 +170,11 @@ impl JsonEncoder { /// Record any output of the command that can be used as a TKET1 parameter. /// Returns whether parameters were recorded. /// Associates the output wires with the parameter expression. - fn record_parameters(&mut self, command: &Command<'_, C>, optype: &OpType) -> bool { + fn record_parameters( + &mut self, + command: &Command<'_, T>, + optype: &OpType, + ) -> bool { // Only consider commands where all inputs are parameters. let inputs = command .inputs() diff --git a/tket2/src/json/tests.rs b/tket2/src/json/tests.rs index f08ef58e..fde3e641 100644 --- a/tket2/src/json/tests.rs +++ b/tket2/src/json/tests.rs @@ -7,7 +7,6 @@ use hugr::extension::prelude::QB_T; use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use hugr::types::FunctionType; -use hugr::Hugr; use rstest::{fixture, rstest}; use tket_json_rs::circuit_json::{self, SerialCircuit}; use tket_json_rs::optype; @@ -64,7 +63,7 @@ fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num let ser: circuit_json::SerialCircuit = serde_json::from_str(circ_s).unwrap(); assert_eq!(ser.commands.len(), num_commands); - let circ: Hugr = ser.clone().decode().unwrap(); + let circ: Circuit = ser.clone().decode().unwrap(); assert_eq!(circ.qubit_count(), num_qubits); @@ -78,13 +77,13 @@ fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num fn json_file_roundtrip(#[case] circ: impl AsRef) { let reader = BufReader::new(std::fs::File::open(circ).unwrap()); let ser: circuit_json::SerialCircuit = serde_json::from_reader(reader).unwrap(); - let circ: Hugr = ser.clone().decode().unwrap(); + let circ: Circuit = ser.clone().decode().unwrap(); let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap(); compare_serial_circs(&ser, &reser); } #[fixture] -fn circ_add_angles_symbolic() -> Hugr { +fn circ_add_angles_symbolic() -> Circuit { let input_t = vec![QB_T, FLOAT64_TYPE, FLOAT64_TYPE]; let output_t = vec![QB_T]; let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); @@ -99,11 +98,11 @@ fn circ_add_angles_symbolic() -> Hugr { let res = h.add_dataflow_op(Tk2Op::RxF64, [qb, f12]).unwrap(); let qb = res.outputs().next().unwrap(); - h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap().into() } #[fixture] -fn circ_add_angles_constants() -> Hugr { +fn circ_add_angles_constants() -> Circuit { let qb_row = vec![QB_T]; let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row)).unwrap(); @@ -120,13 +119,13 @@ fn circ_add_angles_constants() -> Hugr { .add_dataflow_op(Tk2Op::RxF64, [qb, point5]) .unwrap() .outputs(); - h.finish_hugr_with_outputs(qbs, ®ISTRY).unwrap() + h.finish_hugr_with_outputs(qbs, ®ISTRY).unwrap().into() } #[rstest] #[case::symbolic(circ_add_angles_symbolic(), "f0 + f1")] #[case::constants(circ_add_angles_constants(), "0.2 + 0.3")] -fn test_add_angle_serialise(#[case] circ_add_angles: Hugr, #[case] param_str: &str) { +fn test_add_angle_serialise(#[case] circ_add_angles: Circuit, #[case] param_str: &str) { let ser: SerialCircuit = SerialCircuit::encode(&circ_add_angles).unwrap(); assert_eq!(ser.commands.len(), 1); assert_eq!(ser.commands[0].op.op_type, optype::OpType::Rx); @@ -135,7 +134,7 @@ fn test_add_angle_serialise(#[case] circ_add_angles: Hugr, #[case] param_str: &s // Note: this is not a proper roundtrip as the symbols f0 and f1 are not // converted back to circuit inputs. This would require parsing symbolic // expressions. - let deser: Hugr = ser.clone().decode().unwrap(); + let deser: Circuit = ser.clone().decode().unwrap(); let reser = SerialCircuit::encode(&deser).unwrap(); compare_serial_circs(&ser, &reser); } diff --git a/tket2/src/lib.rs b/tket2/src/lib.rs index 750e43d8..756b52d8 100644 --- a/tket2/src/lib.rs +++ b/tket2/src/lib.rs @@ -20,11 +20,11 @@ //! #![cfg_attr(not(miri), doc = "```")] // this doctest reads from the filesystem, so it fails with miri #![cfg_attr(miri, doc = "```ignore")] -//! use tket2::{Circuit, Hugr}; +//! use tket2::Circuit; //! use hugr::HugrView; //! //! // Load a tket1 circuit. -//! let mut circ: Hugr = tket2::json::load_tk1_json_file("../test_files/barenco_tof_5.json").unwrap(); +//! let mut circ: Circuit = tket2::json::load_tk1_json_file("../test_files/barenco_tof_5.json").unwrap(); //! //! assert_eq!(circ.qubit_count(), 9); //! assert_eq!(circ.num_gates(), 170); diff --git a/tket2/src/ops.rs b/tket2/src/ops.rs index a6e143be..f941c517 100644 --- a/tket2/src/ops.rs +++ b/tket2/src/ops.rs @@ -257,15 +257,16 @@ pub(crate) mod test { use std::sync::Arc; use hugr::extension::simple_op::MakeOpDef; + use hugr::extension::OpDef; use hugr::ops::NamedOp; use hugr::CircuitUnit; - use hugr::{extension::OpDef, Hugr}; use rstest::{fixture, rstest}; use strum::IntoEnumIterator; use super::Tk2Op; + use crate::circuit::Circuit; use crate::extension::{TKET2_EXTENSION as EXTENSION, TKET2_EXTENSION_ID as EXTENSION_ID}; - use crate::{circuit::Circuit, utils::build_simple_circuit}; + use crate::utils::build_simple_circuit; fn get_opdef(op: impl NamedOp) -> Option<&'static Arc> { EXTENSION.get_op(&op.name()) } @@ -279,7 +280,7 @@ pub(crate) mod test { } #[fixture] - pub(crate) fn t2_bell_circuit() -> Hugr { + pub(crate) fn t2_bell_circuit() -> Circuit { let h = build_simple_circuit(2, |circ| { circ.append(Tk2Op::H, [0])?; circ.append(Tk2Op::CX, [0, 1])?; @@ -290,7 +291,7 @@ pub(crate) mod test { } #[rstest] - fn check_t2_bell(t2_bell_circuit: Hugr) { + fn check_t2_bell(t2_bell_circuit: Circuit) { assert_eq!(t2_bell_circuit.commands().count(), 2); } diff --git a/tket2/src/optimiser/badger.rs b/tket2/src/optimiser/badger.rs index 5487095a..9aa6bafa 100644 --- a/tket2/src/optimiser/badger.rs +++ b/tket2/src/optimiser/badger.rs @@ -22,14 +22,13 @@ use crossbeam_channel::select; pub use eq_circ_class::{load_eccs_json_file, EqCircClass}; use fxhash::FxHashSet; use hugr::hugr::HugrError; +use hugr::HugrView; pub use log::BadgerLogger; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; use std::{mem, thread}; -use hugr::Hugr; - use crate::circuit::cost::CircuitCost; use crate::circuit::CircuitHash; use crate::optimiser::badger::hugr_pchannel::{HugrPriorityChannel, PriorityChannelLog}; @@ -37,7 +36,6 @@ use crate::optimiser::badger::hugr_pqueue::{Entry, HugrPQ}; use crate::optimiser::badger::worker::BadgerWorker; use crate::passes::CircuitChunks; use crate::rewrite::strategy::RewriteStrategy; -use crate::rewrite::trace::RewriteTracer; use crate::rewrite::Rewriter; use crate::Circuit; @@ -113,7 +111,7 @@ impl BadgerOptimiser { Self { rewriter, strategy } } - fn cost(&self, circ: &Hugr) -> S::Cost + fn cost(&self, circ: &Circuit) -> S::Cost where S: RewriteStrategy, { @@ -130,7 +128,7 @@ where /// Run the Badger optimiser on a circuit. /// /// A timeout (in seconds) can be provided. - pub fn optimise(&self, circ: &Hugr, options: BadgerOptions) -> Hugr { + pub fn optimise(&self, circ: &Circuit, options: BadgerOptions) -> Circuit { self.optimise_with_log(circ, Default::default(), options) } @@ -139,10 +137,10 @@ where /// A timeout (in seconds) can be provided. pub fn optimise_with_log( &self, - circ: &Hugr, + circ: &Circuit, log_config: BadgerLogger, options: BadgerOptions, - ) -> Hugr { + ) -> Circuit { if options.split_circuit && options.n_threads.get() > 1 { return self.split_run(circ, log_config, options).unwrap(); } @@ -153,12 +151,18 @@ where } #[tracing::instrument(target = "badger::metrics", skip(self, circ, logger))] - fn badger(&self, circ: &Hugr, mut logger: BadgerLogger, opt: BadgerOptions) -> Hugr { + fn badger( + &self, + circ: &Circuit, + mut logger: BadgerLogger, + opt: BadgerOptions, + ) -> Circuit { let start_time = Instant::now(); let mut last_best_time = Instant::now(); + let circ = circ.to_owned(); let mut best_circ = circ.clone(); - let mut best_circ_cost = self.cost(circ); + let mut best_circ_cost = self.cost(&circ); let num_rewrites = best_circ.rewrite_trace().map(|rs| rs.len()); logger.log_best(&best_circ_cost, num_rewrites); @@ -170,12 +174,12 @@ where // The priority queue of circuits to be processed (this should not get big) let cost_fn = { let strategy = self.strategy.clone(); - move |circ: &'_ Hugr| strategy.circuit_cost(circ) + move |circ: &'_ Circuit| strategy.circuit_cost(circ) }; - let cost = (cost_fn)(circ); + let cost = (cost_fn)(&circ); let mut pq = HugrPQ::new(cost_fn, opt.queue_size); - pq.push_unchecked(circ.clone(), hash, cost); + pq.push_unchecked(circ.to_owned(), hash, cost); let mut circ_cnt = 0; let mut timeout_flag = false; @@ -250,16 +254,17 @@ where #[tracing::instrument(target = "badger::metrics", skip(self, circ, logger))] fn badger_multithreaded( &self, - circ: &Hugr, + circ: &Circuit, mut logger: BadgerLogger, opt: BadgerOptions, - ) -> Hugr { + ) -> Circuit { let n_threads: usize = opt.n_threads.get(); + let circ = circ.to_owned(); // multi-consumer priority channel for queuing circuits to be processed by the workers let cost_fn = { let strategy = self.strategy.clone(); - move |circ: &'_ Hugr| strategy.circuit_cost(circ) + move |circ: &'_ Circuit| strategy.circuit_cost(circ) }; let (pq, rx_log) = HugrPriorityChannel::init(cost_fn.clone(), opt.queue_size); @@ -271,7 +276,7 @@ where pq.send(vec![Work { cost: best_circ_cost.clone(), hash: initial_circ_hash, - circ: circ.clone(), + circ, }]) .unwrap(); @@ -381,18 +386,19 @@ where #[tracing::instrument(target = "badger::metrics", skip(self, circ, logger))] fn split_run( &self, - circ: &Hugr, + circ: &Circuit, mut logger: BadgerLogger, opt: BadgerOptions, - ) -> Result { - let circ_cost = self.cost(circ); + ) -> Result { + let circ = circ.to_owned(); + let circ_cost = self.cost(&circ); let max_chunk_cost = circ_cost.clone().div_cost(opt.n_threads); logger.log(format!( "Splitting circuit with cost {:?} into chunks of at most {max_chunk_cost:?}.", circ_cost.clone() )); let mut chunks = - CircuitChunks::split_with_cost(circ, max_chunk_cost, |op| self.strategy.op_cost(op)); + CircuitChunks::split_with_cost(&circ, max_chunk_cost, |op| self.strategy.op_cost(op)); let num_rewrites = circ.rewrite_trace().map(|rs| rs.len()); logger.log_best(circ_cost.clone(), num_rewrites); @@ -495,7 +501,6 @@ mod tests { extension::prelude::QB_T, std_extensions::arithmetic::float_types::FLOAT64_TYPE, types::FunctionType, - Hugr, }; use rstest::{fixture, rstest}; @@ -506,14 +511,14 @@ mod tests { use super::{BadgerOptimiser, DefaultBadgerOptimiser}; /// Simplified description of the circuit's commands. - fn gates(circ: &Hugr) -> Vec { + fn gates(circ: &Circuit) -> Vec { circ.commands() .map(|cmd| cmd.optype().try_into().unwrap()) .collect() } #[fixture] - fn rz_rz() -> Hugr { + fn rz_rz() -> Circuit { let input_t = vec![QB_T, FLOAT64_TYPE, FLOAT64_TYPE]; let output_t = vec![QB_T]; let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); @@ -528,7 +533,7 @@ mod tests { let res = h.add_dataflow_op(Tk2Op::RzF64, [qb, f2]).unwrap(); let qb = res.outputs().next().unwrap(); - h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap().into() } /// This hugr corresponds to the qasm circuit: @@ -549,9 +554,9 @@ mod tests { /// ``` const NON_COMPOSABLE: &str = r#"{"phase":"0.0","commands":[{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[4]],["q",[1]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[1]],["q",[2]]]},{"op":{"type":"U3","params":["0.5","0","0.5"],"signature":["Q"]},"args":[["q",[1]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[3]],["q",[4]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[4]],["q",[0]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[0]],["q",[2]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[0]],["q",[2]]]},{"op":{"type":"CX","n_qb":2,"signature":["Q","Q"]},"args":[["q",[3]],["q",[1]]]}],"qubits":[["q",[0]],["q",[1]],["q",[2]],["q",[3]],["q",[4]]],"bits":[],"implicit_permutation":[[["q",[0]],["q",[0]]],[["q",[1]],["q",[1]]],[["q",[2]],["q",[2]]],[["q",[3]],["q",[3]]],[["q",[4]],["q",[4]]]]}"#; - /// A Hugr that would trigger non-composable rewrites, if we applied them blindly from nam_6_3 matches. + /// A circuit that would trigger non-composable rewrites, if we applied them blindly from nam_6_3 matches. #[fixture] - fn non_composable_rw_hugr() -> Hugr { + fn non_composable_rw_hugr() -> Circuit { load_tk1_json_str(NON_COMPOSABLE).unwrap() } @@ -571,7 +576,7 @@ mod tests { } #[rstest] - fn rz_rz_cancellation(rz_rz: Hugr, badger_opt: DefaultBadgerOptimiser) { + fn rz_rz_cancellation(rz_rz: Circuit, badger_opt: DefaultBadgerOptimiser) { let opt_rz = badger_opt.optimise( &rz_rz, BadgerOptions { @@ -584,7 +589,7 @@ mod tests { } #[rstest] - fn rz_rz_cancellation_parallel(rz_rz: Hugr, badger_opt: DefaultBadgerOptimiser) { + fn rz_rz_cancellation_parallel(rz_rz: Circuit, badger_opt: DefaultBadgerOptimiser) { let mut opt_rz = badger_opt.optimise( &rz_rz, BadgerOptions { @@ -594,13 +599,13 @@ mod tests { ..Default::default() }, ); - opt_rz.update_validate(®ISTRY).unwrap(); + opt_rz.hugr_mut().update_validate(®ISTRY).unwrap(); } #[rstest] #[ignore = "Loading the ECC set is really slow (~5 seconds)"] fn non_composable_rewrites( - non_composable_rw_hugr: Hugr, + non_composable_rw_hugr: Circuit, badger_opt_full: DefaultBadgerOptimiser, ) { let mut opt = badger_opt_full.optimise( @@ -612,7 +617,7 @@ mod tests { }, ); // No rewrites applied. - opt.update_validate(®ISTRY).unwrap(); + opt.hugr_mut().update_validate(®ISTRY).unwrap(); } #[test] diff --git a/tket2/src/optimiser/badger/eq_circ_class.rs b/tket2/src/optimiser/badger/eq_circ_class.rs index f45c2846..c2c7f118 100644 --- a/tket2/src/optimiser/badger/eq_circ_class.rs +++ b/tket2/src/optimiser/badger/eq_circ_class.rs @@ -26,10 +26,10 @@ pub struct EqCircClass { impl EqCircClass { /// Create a new equivalence class with a representative circuit. - pub fn new(rep_circ: Hugr, other_circs: Vec) -> Self { + pub fn new(rep_circ: Circuit, other_circs: impl IntoIterator) -> Self { Self { - rep_circ, - other_circs, + rep_circ: rep_circ.into_hugr(), + other_circs: other_circs.into_iter().map(|c| c.into_hugr()).collect(), } } @@ -64,8 +64,11 @@ impl EqCircClass { /// Create an equivalence class from a set of circuits. /// /// The smallest circuit is chosen as the representative. - pub fn from_circuits(circs: impl Into>) -> Result { - let mut circs: Vec<_> = circs.into(); + pub fn from_circuits( + circs: impl IntoIterator, + ) -> Result { + let mut circs: Vec = circs.into_iter().collect(); + if circs.is_empty() { return Err(EqCircClassError::NoRepresentative); }; diff --git a/tket2/src/optimiser/badger/hugr_pchannel.rs b/tket2/src/optimiser/badger/hugr_pchannel.rs index e50d8980..b69d7cdc 100644 --- a/tket2/src/optimiser/badger/hugr_pchannel.rs +++ b/tket2/src/optimiser/badger/hugr_pchannel.rs @@ -6,19 +6,19 @@ use std::time::Instant; use crossbeam_channel::{select, Receiver, RecvError, SendError, Sender}; use fxhash::FxHashSet; -use hugr::Hugr; use crate::circuit::cost::CircuitCost; +use crate::Circuit; use super::hugr_pqueue::{Entry, HugrPQ}; /// A unit of work for a worker, consisting of a circuit to process, along its /// hash and cost. -pub type Work

= Entry; +pub type Work

= Entry; -/// A priority channel for HUGRs. +/// A priority channel for circuits. /// -/// Queues hugrs using a cost function `C` that produces priority values `P`. +/// Queues circuits using a cost function `C` that produces priority values `P`. /// /// Uses a thread internally to orchestrate the queueing. #[derive(Debug, Clone)] @@ -50,7 +50,7 @@ pub struct HugrPriorityChannel { /// Logging information from the priority channel. #[derive(Debug, Clone)] pub enum PriorityChannelLog

{ - NewBestCircuit(Hugr, P), + NewBestCircuit(Circuit, P), CircuitCount { processed_count: usize, seen_count: usize, @@ -106,12 +106,12 @@ impl PriorityChannelCommunication

{ impl HugrPriorityChannel where - C: Fn(&Hugr) -> P + Send + Sync + 'static, + C: Fn(&Circuit) -> P + Send + Sync + 'static, P: CircuitCost + Send + Sync + 'static, { /// Initialize the queueing system. /// - /// Start the Hugr priority queue in a new thread. + /// Start the circuit priority queue in a new thread. /// /// Get back a [`PriorityChannelCommunication`] for adding and removing circuits to/from the queue, /// and a channel receiver to receive logging information. diff --git a/tket2/src/optimiser/badger/hugr_pqueue.rs b/tket2/src/optimiser/badger/hugr_pqueue.rs index 16569ef6..b0429237 100644 --- a/tket2/src/optimiser/badger/hugr_pqueue.rs +++ b/tket2/src/optimiser/badger/hugr_pqueue.rs @@ -1,9 +1,9 @@ use delegate::delegate; use fxhash::FxHashMap; -use hugr::Hugr; use priority_queue::DoublePriorityQueue; use crate::circuit::CircuitHash; +use crate::Circuit; /// A min-priority queue for Hugrs. /// @@ -12,7 +12,7 @@ use crate::circuit::CircuitHash; #[derive(Debug, Clone, Default)] pub struct HugrPQ { queue: DoublePriorityQueue, - hash_lookup: FxHashMap, + hash_lookup: FxHashMap, cost_fn: C, max_size: usize, } @@ -34,9 +34,9 @@ impl HugrPQ { } } - /// Reference to the minimal Hugr in the queue. + /// Reference to the minimal circuit in the queue. #[allow(unused)] - pub fn peek(&self) -> Option> { + pub fn peek(&self) -> Option> { let (hash, cost) = self.queue.peek_min()?; let circ = self.hash_lookup.get(hash)?; Some(Entry { @@ -46,20 +46,20 @@ impl HugrPQ { }) } - /// Push a Hugr into the queue. + /// Push a circuit into the queue. /// /// If the queue is full, the element with the highest cost will be dropped. #[allow(unused)] - pub fn push(&mut self, hugr: Hugr) + pub fn push(&mut self, circ: Circuit) where - C: Fn(&Hugr) -> P, + C: Fn(&Circuit) -> P, { - let hash = hugr.circuit_hash().unwrap(); - let cost = (self.cost_fn)(&hugr); - self.push_unchecked(hugr, hash, cost); + let hash = circ.circuit_hash().unwrap(); + let cost = (self.cost_fn)(&circ); + self.push_unchecked(circ, hash, cost); } - /// Push a Hugr into the queue with a precomputed hash and cost. + /// Push a circuit into the queue with a precomputed hash and cost. /// /// This is useful to avoid recomputing the hash and cost function in /// [`HugrPQ::push`] when they are already known. @@ -67,9 +67,9 @@ impl HugrPQ { /// This does not check that the hash is valid. /// /// If the queue is full, the most last will be dropped. - pub fn push_unchecked(&mut self, hugr: Hugr, hash: u64, cost: P) + pub fn push_unchecked(&mut self, circ: Circuit, hash: u64, cost: P) where - C: Fn(&Hugr) -> P, + C: Fn(&Circuit) -> P, { if !self.check_accepted(&cost) { return; @@ -78,18 +78,18 @@ impl HugrPQ { self.pop_max(); } self.queue.push(hash, cost); - self.hash_lookup.insert(hash, hugr); + self.hash_lookup.insert(hash, circ); } - /// Pop the minimal Hugr from the queue. - pub fn pop(&mut self) -> Option> { + /// Pop the minimal circuit from the queue. + pub fn pop(&mut self) -> Option> { let (hash, cost) = self.queue.pop_min()?; let circ = self.hash_lookup.remove(&hash)?; Some(Entry { circ, cost, hash }) } - /// Pop the maximal Hugr from the queue. - pub fn pop_max(&mut self) -> Option> { + /// Pop the maximal circuit from the queue. + pub fn pop_max(&mut self) -> Option> { let (hash, cost) = self.queue.pop_max()?; let circ = self.hash_lookup.remove(&hash)?; Some(Entry { circ, cost, hash }) diff --git a/tket2/src/optimiser/badger/log.rs b/tket2/src/optimiser/badger/log.rs index e0fb5eea..e7201d66 100644 --- a/tket2/src/optimiser/badger/log.rs +++ b/tket2/src/optimiser/badger/log.rs @@ -40,7 +40,7 @@ impl<'w> BadgerLogger<'w> { /// /// [`log`]: pub fn new(best_progress_csv_writer: impl io::Write + 'w) -> Self { - let boxed_candidates_writer: Box = Box::new(best_progress_csv_writer); + let boxed_candidates_writer: Box = Box::new(best_progress_csv_writer); Self { circ_candidates_csv: Some(csv::Writer::from_writer(boxed_candidates_writer)), ..Default::default() diff --git a/tket2/src/optimiser/badger/qtz_circuit.rs b/tket2/src/optimiser/badger/qtz_circuit.rs index 81a13922..f9701c76 100644 --- a/tket2/src/optimiser/badger/qtz_circuit.rs +++ b/tket2/src/optimiser/badger/qtz_circuit.rs @@ -7,12 +7,11 @@ use hugr::extension::prelude::QB_T; use hugr::ops::OpType as Op; use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::types::{FunctionType, Type}; -use hugr::CircuitUnit; -use hugr::Hugr as Circuit; +use hugr::{CircuitUnit, Hugr}; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use crate::Tk2Op; +use crate::{Circuit, Tk2Op}; #[derive(Debug, Serialize, Deserialize)] struct RepCircOp { @@ -60,7 +59,7 @@ fn map_op(opstr: &str) -> Op { } // TODO change to TryFrom -impl From for Circuit { +impl From for Circuit { fn from(RepCircData { circ: rc, meta }: RepCircData) -> Self { let qb_types: Vec = vec![QB_T; meta.n_qb]; let param_types: Vec = vec![FLOAT64_TYPE; meta.n_input_param]; @@ -107,10 +106,13 @@ impl From for Circuit { builder .finish_hugr_with_outputs(circ_outputs, &crate::extension::REGISTRY) .unwrap() + .into() } } -pub(super) fn load_ecc_set(path: impl AsRef) -> io::Result>> { +pub(super) fn load_ecc_set( + path: impl AsRef, +) -> io::Result>>> { let jsons = std::fs::read_to_string(path)?; let (_, ecc_map): (Vec<()>, HashMap>) = serde_json::from_str(&jsons).unwrap(); @@ -127,12 +129,9 @@ pub(super) fn load_ecc_set(path: impl AsRef) -> io::Result HashMap { + + fn load_representative_set(path: &str) -> HashMap> { let jsons = std::fs::read_to_string(path).unwrap(); // read_rep_json(&jsons).unwrap(); let st: Vec = serde_json::from_str(&jsons).unwrap(); @@ -144,8 +143,7 @@ mod tests { #[test] #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri fn test_read_rep() { - let rep_map: HashMap = - load_representative_set("../test_files/h_rz_cxrepresentative_set.json"); + let rep_map = load_representative_set("../test_files/h_rz_cxrepresentative_set.json"); for c in rep_map.values().take(1) { println!("{}", c.dot_string()); @@ -155,8 +153,7 @@ mod tests { #[test] #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri fn test_read_complete() { - let _ecc: HashMap> = - load_ecc_set("../test_files/h_rz_cxcomplete_ECC_set.json").unwrap(); + let _ecc = load_ecc_set("../test_files/h_rz_cxcomplete_ECC_set.json").unwrap(); // ecc.values() // .flatten() diff --git a/tket2/src/passes/chunks.rs b/tket2/src/passes/chunks.rs index 41312c05..8f728853 100644 --- a/tket2/src/passes/chunks.rs +++ b/tket2/src/passes/chunks.rs @@ -15,7 +15,7 @@ use hugr::hugr::{HugrError, NodeMetadataMap}; use hugr::ops::handle::DataflowParentID; use hugr::ops::OpType; use hugr::types::FunctionType; -use hugr::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; +use hugr::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; use itertools::Itertools; use portgraph::algorithms::ConvexChecker; @@ -36,7 +36,7 @@ pub struct ChunkConnection(Wire); #[derive(Debug, Clone)] pub struct Chunk { /// The extracted circuit. - pub circ: Hugr, + pub circ: Circuit, /// The original wires connected to the input. inputs: Vec, /// The original wires connected to the output. @@ -47,18 +47,18 @@ impl Chunk { /// Extract a chunk from a circuit. /// /// The chunk is extracted from the input wires to the output wires. - pub(self) fn extract( - circ: &H, + pub(self) fn extract( + circ: &Circuit, nodes: impl IntoIterator, checker: &impl ConvexChecker, ) -> Self { let subgraph = SiblingSubgraph::try_from_nodes_with_checker( nodes.into_iter().collect_vec(), - circ, + circ.hugr(), checker, ) .expect("Failed to define the chunk subgraph"); - let extracted = subgraph.extract_subgraph(circ, "Chunk"); + let extracted = subgraph.extract_subgraph(circ.hugr(), "Chunk").into(); // Transform the subgraph's input/output sets into wires that can be // matched between different chunks. // @@ -70,6 +70,7 @@ impl Chunk { .map(|wires| { let (inp_node, inp_port) = wires[0]; let (out_node, out_port) = circ + .hugr() .linked_outputs(inp_node, inp_port) .exactly_one() .ok() @@ -91,18 +92,20 @@ impl Chunk { /// Insert the chunk back into a circuit. pub(self) fn insert(&self, circ: &mut impl HugrMut, root: Node) -> ChunkInsertResult { - if self.circ.children(self.circ.root()).nth(2).is_none() { + let chunk = self.circ.hugr(); + let chunk_root = chunk.root(); + if chunk.children(self.circ.parent()).nth(2).is_none() { // The chunk is empty. We don't need to insert anything. return self.empty_chunk_insert_result(); } - let [chunk_inp, chunk_out] = self.circ.get_io(self.circ.root()).unwrap(); + let [chunk_inp, chunk_out] = chunk.get_io(chunk_root).unwrap(); let chunk_sg: SiblingGraph<'_, DataflowParentID> = - SiblingGraph::try_new(&self.circ, self.circ.root()).unwrap(); + SiblingGraph::try_new(&chunk, chunk_root).unwrap(); // Insert the chunk circuit into the original circuit. let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&chunk_sg) .unwrap_or_else(|e| panic!("The chunk circuit is no longer a dataflow graph: {e}")); - let node_map = circ.insert_subgraph(root, &self.circ, &subgraph); + let node_map = circ.insert_subgraph(root, &chunk, &subgraph); let mut input_map = HashMap::with_capacity(self.inputs.len()); let mut output_map = HashMap::with_capacity(self.outputs.len()); @@ -111,11 +114,8 @@ impl Chunk { // // Connections to an inserted node are translated into a [`ConnectionTarget::InsertedNode`]. // Connections from the input directly into the output become a [`ConnectionTarget::TransitiveConnection`]. - for (&connection, chunk_inp_port) in - self.inputs.iter().zip(self.circ.node_outputs(chunk_inp)) - { - let connection_targets: Vec = self - .circ + for (&connection, chunk_inp_port) in self.inputs.iter().zip(chunk.node_outputs(chunk_inp)) { + let connection_targets: Vec = chunk .linked_inputs(chunk_inp, chunk_inp_port) .map(|(node, port)| { if node == chunk_out { @@ -131,9 +131,8 @@ impl Chunk { input_map.insert(connection, connection_targets); } - for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) { - let (node, port) = self - .circ + for (&wire, chunk_out_port) in self.outputs.iter().zip(chunk.node_inputs(chunk_out)) { + let (node, port) = chunk .linked_outputs(chunk_out, chunk_out_port) .exactly_one() .ok() @@ -159,15 +158,13 @@ impl Chunk { /// /// TODO: Support empty Subgraphs in Hugr? fn empty_chunk_insert_result(&self) -> ChunkInsertResult { - let [chunk_inp, chunk_out] = self.circ.get_io(self.circ.root()).unwrap(); + let hugr = self.circ.hugr(); + let [chunk_inp, chunk_out] = self.circ.io_nodes(); let mut input_map = HashMap::with_capacity(self.inputs.len()); let mut output_map = HashMap::with_capacity(self.outputs.len()); - for (&connection, chunk_inp_port) in - self.inputs.iter().zip(self.circ.node_outputs(chunk_inp)) - { - let connection_targets: Vec = self - .circ + for (&connection, chunk_inp_port) in self.inputs.iter().zip(hugr.node_outputs(chunk_inp)) { + let connection_targets: Vec = hugr .linked_ports(chunk_inp, chunk_inp_port) .map(|(node, port)| { assert_eq!(node, chunk_out); @@ -178,9 +175,8 @@ impl Chunk { input_map.insert(connection, connection_targets); } - for (&wire, chunk_out_port) in self.outputs.iter().zip(self.circ.node_inputs(chunk_out)) { - let (node, port) = self - .circ + for (&wire, chunk_out_port) in self.outputs.iter().zip(hugr.node_inputs(chunk_out)) { + let (node, port) = hugr .linked_ports(chunk_out, chunk_out_port) .exactly_one() .ok() @@ -251,38 +247,39 @@ impl CircuitChunks { /// Split a circuit into chunks. /// /// The circuit is split into chunks of at most `max_size` gates. - pub fn split(circ: &impl Circuit, max_size: usize) -> Self { + pub fn split(circ: &Circuit, max_size: usize) -> Self { Self::split_with_cost(circ, max_size.saturating_sub(1), |_| 1) } /// Split a circuit into chunks. /// /// The circuit is split into chunks of at most `max_cost`, using the provided cost function. - pub fn split_with_cost( - circ: &H, + pub fn split_with_cost( + circ: &Circuit, max_cost: C, op_cost: impl Fn(&OpType) -> C, ) -> Self { - let root_meta = circ.get_node_metadata(circ.root()).cloned(); + let hugr = circ.hugr(); + let root_meta = hugr.get_node_metadata(circ.parent()).cloned(); let signature = circ.circuit_signature().body().clone(); - let [circ_input, circ_output] = circ.get_io(circ.root()).unwrap(); - let input_connections = circ + let [circ_input, circ_output] = circ.io_nodes(); + let input_connections = hugr .node_outputs(circ_input) .map(|port| Wire::new(circ_input, port).into()) .collect(); - let output_connections = circ + let output_connections = hugr .node_inputs(circ_output) - .flat_map(|p| circ.linked_outputs(circ_output, p)) + .flat_map(|p| hugr.linked_outputs(circ_output, p)) .map(|(n, p)| Wire::new(n, p).into()) .collect(); let mut chunks = Vec::new(); - let convex_checker = TopoConvexChecker::new(circ); + let convex_checker = TopoConvexChecker::new(circ.hugr()); let mut running_cost = C::default(); let mut current_group = 0; for (_, commands) in &circ.commands().map(|cmd| cmd.node()).chunk_by(|&node| { - let new_cost = running_cost.clone() + op_cost(circ.get_optype(node)); + let new_cost = running_cost.clone() + op_cost(hugr.get_optype(node)); if new_cost.sub_cost(&max_cost).as_isize() > 0 { running_cost = C::default(); current_group += 1; @@ -303,7 +300,7 @@ impl CircuitChunks { } /// Reassemble the chunks into a circuit. - pub fn reassemble(self) -> Result { + pub fn reassemble(self) -> Result { let name = self .root_meta .as_ref() @@ -423,16 +420,16 @@ impl CircuitChunks { reassembled.overwrite_node_metadata(root, self.root_meta); - Ok(reassembled) + Ok(reassembled.into()) } /// Returns a list of references to the split circuits. - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.chunks.iter().map(|chunk| &chunk.circ) } /// Returns a list of references to the split circuits. - pub fn iter_mut(&mut self) -> impl Iterator { + pub fn iter_mut(&mut self) -> impl Iterator { self.chunks.iter_mut().map(|chunk| &mut chunk.circ) } @@ -448,7 +445,7 @@ impl CircuitChunks { } impl Index for CircuitChunks { - type Output = Hugr; + type Output = Circuit; fn index(&self, index: usize) -> &Self::Output { &self.chunks[index].circ @@ -490,7 +487,7 @@ mod test { let mut reassembled = chunks.reassemble().unwrap(); - reassembled.update_validate(®ISTRY).unwrap(); + reassembled.hugr_mut().update_validate(®ISTRY).unwrap(); assert_eq!(circ.circuit_hash(), reassembled.circuit_hash()); } @@ -520,14 +517,14 @@ mod test { let mut reassembled = chunks.reassemble().unwrap(); - reassembled.update_validate(®ISTRY).unwrap(); + reassembled.hugr_mut().update_validate(®ISTRY).unwrap(); assert_eq!(reassembled.commands().count(), 1); let h = reassembled.commands().next().unwrap().node(); - let [inp, out] = reassembled.get_io(reassembled.root()).unwrap(); + let [inp, out] = reassembled.io_nodes(); assert_eq!( - &reassembled.output_neighbours(inp).collect_vec(), + &reassembled.hugr().output_neighbours(inp).collect_vec(), &[h, out, out] ); } diff --git a/tket2/src/passes/commutation.rs b/tket2/src/passes/commutation.rs index 4424cb84..564a6b6a 100644 --- a/tket2/src/passes/commutation.rs +++ b/tket2/src/passes/commutation.rs @@ -1,12 +1,13 @@ use std::{collections::HashMap, rc::Rc}; use hugr::hugr::{hugrmut::HugrMut, HugrError, Rewrite}; -use hugr::{CircuitUnit, Direction, Hugr, HugrView, Node, Port, PortIndex}; +use hugr::{CircuitUnit, Direction, HugrView, Node, Port, PortIndex}; use itertools::Itertools; use portgraph::PortOffset; +use crate::Circuit; use crate::{ - circuit::{command::Command, Circuit}, + circuit::command::Command, ops::{Pauli, Tk2Op}, }; @@ -26,11 +27,8 @@ struct ComCommand { inputs: Vec, } -impl<'c, Circ> From> for ComCommand -where - Circ: HugrView, -{ - fn from(com: Command<'c, Circ>) -> Self { +impl<'c, T: HugrView> From> for ComCommand { + fn from(com: Command<'c, T>) -> Self { ComCommand { node: com.node(), inputs: com.inputs().map(|(c, _, _)| c).collect(), @@ -69,13 +67,16 @@ fn add_to_slice(slice: &mut Slice, com: Rc) { } } -fn load_slices(circ: &impl Circuit) -> SliceVec { +fn load_slices(circ: &Circuit) -> SliceVec { let mut slices = vec![]; let n_qbs = circ.qubit_count(); let mut qubit_free_slice = vec![0; n_qbs]; - for command in circ.commands().filter(|c| is_slice_op(circ, c.node())) { + for command in circ + .commands() + .filter(|c| is_slice_op(circ.hugr(), c.node())) + { let command: ComCommand = command.into(); let free_slice = command .qubits() @@ -106,7 +107,7 @@ fn is_slice_op(h: &impl HugrView, node: Node) -> bool { /// Starting from starting_index, work back along slices to check for the /// earliest slice that can accommodate this command, if any. fn available_slice( - circ: &impl HugrView, + circ: &Circuit, slice_vec: &[Slice], starting_index: usize, command: &Rc, @@ -142,7 +143,7 @@ fn available_slice( fn commutes_at_slice( command: &Rc, slice: &Slice, - circ: &impl HugrView, + circ: &Circuit, ) -> Option>> { // map from qubit to node it is connected to immediately after the free slice. let mut prev_nodes: HashMap> = HashMap::new(); @@ -155,12 +156,22 @@ fn commutes_at_slice( let port = command.port_of_qb(q, Direction::Incoming)?; - let op: Tk2Op = circ.get_optype(command.node()).clone().try_into().ok()?; + let op: Tk2Op = circ + .hugr() + .get_optype(command.node()) + .clone() + .try_into() + .ok()?; // TODO: if not tk2op, might still have serialized commutation data we // can use. let pauli = commutation_on_port(&op.qubit_commutation(), port)?; - let other_op: Tk2Op = circ.get_optype(other_com.node()).clone().try_into().ok()?; + let other_op: Tk2Op = circ + .hugr() + .get_optype(other_com.node()) + .clone() + .try_into() + .ok()?; let other_pauli = commutation_on_port( &other_op.qubit_commutation(), other_com.port_of_qb(q, Direction::Outgoing)?, @@ -287,7 +298,7 @@ impl Rewrite for PullForward { } /// Pass which greedily commutes operations forwards in order to reduce depth. -pub fn apply_greedy_commutation(circ: &mut Hugr) -> Result { +pub fn apply_greedy_commutation(circ: &mut Circuit) -> Result { let mut count = 0; let mut slice_vec = load_slices(circ); @@ -301,7 +312,7 @@ pub fn apply_greedy_commutation(circ: &mut Hugr) -> Result Result Hugr { + fn example_cx() -> Circuit { build_simple_circuit(4, |circ| { circ.append(Tk2Op::CX, [0, 2])?; circ.append(Tk2Op::CX, [1, 2])?; @@ -353,7 +362,7 @@ mod test { #[fixture] // example circuit from original task with lower depth - fn example_cx_better() -> Hugr { + fn example_cx_better() -> Circuit { build_simple_circuit(4, |circ| { circ.append(Tk2Op::CX, [0, 2])?; circ.append(Tk2Op::CX, [1, 3])?; @@ -365,7 +374,7 @@ mod test { #[fixture] // can't commute anything here - fn cant_commute() -> Hugr { + fn cant_commute() -> Circuit { build_simple_circuit(4, |circ| { circ.append(Tk2Op::Z, [1])?; circ.append(Tk2Op::CX, [0, 1])?; @@ -376,7 +385,7 @@ mod test { } #[fixture] - fn big_example() -> Hugr { + fn big_example() -> Circuit { build_simple_circuit(4, |circ| { circ.append(Tk2Op::CX, [0, 3])?; circ.append(Tk2Op::CX, [1, 2])?; @@ -395,7 +404,7 @@ mod test { #[fixture] // commute a single qubit gate - fn single_qb_commute() -> Hugr { + fn single_qb_commute() -> Circuit { build_simple_circuit(3, |circ| { circ.append(Tk2Op::H, [1])?; circ.append(Tk2Op::CX, [0, 1])?; @@ -407,7 +416,7 @@ mod test { #[fixture] // commute 2 single qubit gates - fn single_qb_commute_2() -> Hugr { + fn single_qb_commute_2() -> Circuit { build_simple_circuit(4, |circ| { circ.append(Tk2Op::CX, [1, 2])?; circ.append(Tk2Op::CX, [1, 0])?; @@ -421,7 +430,7 @@ mod test { #[fixture] // A commutation forward exists but depth doesn't change - fn commutes_but_same_depth() -> Hugr { + fn commutes_but_same_depth() -> Circuit { build_simple_circuit(3, |circ| { circ.append(Tk2Op::H, [1])?; circ.append(Tk2Op::CX, [0, 1])?; @@ -434,7 +443,7 @@ mod test { #[fixture] // Gate being commuted has a non-linear input - fn non_linear_inputs() -> Hugr { + fn non_linear_inputs() -> Circuit { let build = || { let mut dfg = DFGBuilder::new(FunctionType::new( type_row![QB_T, QB_T, FLOAT64_TYPE], @@ -451,12 +460,12 @@ mod test { let qbs = circ.finish(); dfg.finish_hugr_with_outputs(qbs, ®ISTRY) }; - build().unwrap() + build().unwrap().into() } #[fixture] // Gates being commuted have non-linear outputs - fn non_linear_outputs() -> Hugr { + fn non_linear_outputs() -> Circuit { let build = || { let mut dfg = DFGBuilder::new(FunctionType::new( type_row![QB_T, QB_T], @@ -474,11 +483,11 @@ mod test { outs.extend(measured); dfg.finish_hugr_with_outputs(outs, ®ISTRY) }; - build().unwrap() + build().unwrap().into() } // bug https://github.com/CQCL/tket2/issues/253 - fn cx_commute_bug() -> Hugr { + fn cx_commute_bug() -> Circuit { build_simple_circuit(3, |circ| { circ.append(Tk2Op::H, [2])?; circ.append(Tk2Op::CX, [2, 1])?; @@ -508,7 +517,7 @@ mod test { } #[rstest] - fn test_load_slices_cx(example_cx: Hugr) { + fn test_load_slices_cx(example_cx: Circuit) { let circ = example_cx; let commands: Vec = circ.commands().map_into().collect(); let slices = load_slices(&circ); @@ -518,7 +527,7 @@ mod test { } #[rstest] - fn test_load_slices_cx_better(example_cx_better: Hugr) { + fn test_load_slices_cx_better(example_cx_better: Circuit) { let circ = example_cx_better; let commands: Vec = circ.commands().map_into().collect(); @@ -529,7 +538,7 @@ mod test { } #[rstest] - fn test_load_slices_bell(t2_bell_circuit: Hugr) { + fn test_load_slices_bell(t2_bell_circuit: Circuit) { let circ = t2_bell_circuit; let commands: Vec = circ.commands().map_into().collect(); @@ -540,7 +549,7 @@ mod test { } #[rstest] - fn test_available_slice(example_cx: Hugr) { + fn test_available_slice(example_cx: Circuit) { let circ = example_cx; let slices = load_slices(&circ); let (found, prev_nodes) = @@ -556,7 +565,7 @@ mod test { } #[rstest] - fn big_test(big_example: Hugr) { + fn big_test(big_example: Circuit) { let circ = big_example; let slices = load_slices(&circ); assert_eq!(slices.len(), 6); @@ -578,7 +587,7 @@ mod test { } /// Calculate depth by placing commands in slices. - fn depth(h: &Hugr) -> usize { + fn depth(h: &Circuit) -> usize { load_slices(h).len() } #[rstest] @@ -594,14 +603,14 @@ mod test { #[case(non_linear_outputs(), true, 1)] #[case(cx_commute_bug(), true, 1)] fn commutation_example( - #[case] mut case: Hugr, + #[case] mut case: Circuit, #[case] should_reduce: bool, #[case] expected_moves: u32, ) { - let node_count = case.node_count(); + let node_count = case.hugr().node_count(); let depth_before = depth(&case); let move_count = apply_greedy_commutation(&mut case).unwrap(); - case.update_validate(®ISTRY).unwrap(); + case.hugr_mut().update_validate(®ISTRY).unwrap(); assert_eq!( move_count, expected_moves, @@ -619,7 +628,7 @@ mod test { } assert_eq!( - case.node_count(), + case.hugr().node_count(), node_count, "depth optimisation should not change the number of nodes." ) diff --git a/tket2/src/portmatching.rs b/tket2/src/portmatching.rs index 8e74695e..b1b57d4e 100644 --- a/tket2/src/portmatching.rs +++ b/tket2/src/portmatching.rs @@ -18,7 +18,7 @@ //! let mut dfg = DFGBuilder::new(FunctionType::new(vec![], vec![QB_T]))?; //! let alloc = dfg.add_dataflow_op(Tk2Op::QAlloc, [])?; //! dfg.finish_hugr_with_outputs(alloc.outputs(), &tket2::extension::REGISTRY) -//! }?; +//! }?.into(); //! let pattern = CircuitPattern::try_from_circuit(&circuit_pattern)?; //! //! // Define a circuit that contains a qubit allocation. @@ -39,7 +39,7 @@ //! .append(Tk2Op::CX, [1, 0])?; //! let outputs = circuit.finish(); //! -//! let circuit = dfg.finish_hugr_with_outputs(outputs, &tket2::extension::REGISTRY)?; +//! let circuit = dfg.finish_hugr_with_outputs(outputs, &tket2::extension::REGISTRY)?.into(); //! (circuit, alloc.node()) //! }; //! @@ -56,7 +56,7 @@ pub mod matcher; pub mod pattern; -use hugr::OutgoingPort; +use hugr::{HugrView, OutgoingPort}; use itertools::Itertools; pub use matcher::{PatternMatch, PatternMatcher}; pub use pattern::CircuitPattern; @@ -111,13 +111,10 @@ enum InvalidEdgeProperty { } impl PEdge { - fn try_from_port( - node: Node, - port: Port, - circ: &impl Circuit, - ) -> Result { + fn try_from_port(node: Node, port: Port, circ: &Circuit) -> Result { + let hugr = circ.hugr(); let src = port; - let (dst_node, dst) = circ + let (dst_node, dst) = hugr .linked_ports(node, src) .exactly_one() .map_err(|mut e| { @@ -127,10 +124,10 @@ impl PEdge { InvalidEdgeProperty::NoLinkedEdge(src) } })?; - if circ.get_optype(dst_node).tag() == OpTag::Input { + if hugr.get_optype(dst_node).tag() == OpTag::Input { return Ok(Self::InputEdge { src }); } - let port_type = circ + let port_type = hugr .signature(node) .unwrap() .port_type(src) @@ -202,29 +199,30 @@ impl From for NodeID { #[cfg(test)] mod tests { - use crate::Tk2Op; + use crate::{Circuit, Tk2Op}; use hugr::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, extension::{prelude::QB_T, PRELUDE_REGISTRY}, types::FunctionType, - Hugr, }; use rstest::{fixture, rstest}; use super::{CircuitPattern, PatternMatcher}; #[fixture] - fn lhs() -> Hugr { + fn lhs() -> Circuit { let mut h = DFGBuilder::new(FunctionType::new(vec![], vec![QB_T])).unwrap(); let res = h.add_dataflow_op(Tk2Op::QAlloc, []).unwrap(); let q = res.out_wire(0); - h.finish_hugr_with_outputs([q], &PRELUDE_REGISTRY).unwrap() + h.finish_hugr_with_outputs([q], &PRELUDE_REGISTRY) + .unwrap() + .into() } #[fixture] - pub fn circ() -> Hugr { + pub fn circ() -> Circuit { let mut h = DFGBuilder::new(FunctionType::new(vec![QB_T], vec![QB_T])).unwrap(); let mut inps = h.input_wires(); let q_in = inps.next().unwrap(); @@ -238,10 +236,11 @@ mod tests { h.finish_hugr_with_outputs([q_out], &PRELUDE_REGISTRY) .unwrap() + .into() } #[rstest] - fn simple_match(circ: Hugr, lhs: Hugr) { + fn simple_match(circ: Circuit, lhs: Circuit) { let p = CircuitPattern::try_from_circuit(&lhs).unwrap(); let m = PatternMatcher::from_patterns(vec![p]); diff --git a/tket2/src/portmatching/matcher.rs b/tket2/src/portmatching/matcher.rs index c5423614..f949e5ea 100644 --- a/tket2/src/portmatching/matcher.rs +++ b/tket2/src/portmatching/matcher.rs @@ -13,7 +13,7 @@ use hugr::hugr::views::sibling_subgraph::{ }; use hugr::hugr::views::SiblingSubgraph; use hugr::ops::{NamedOp, OpType}; -use hugr::{Hugr, IncomingPort, Node, OutgoingPort, Port, PortIndex}; +use hugr::{HugrView, IncomingPort, Node, OutgoingPort, Port, PortIndex}; use itertools::Itertools; use portgraph::algorithms::ConvexChecker; use portmatching::{ @@ -115,10 +115,10 @@ impl PatternMatch { pub fn try_from_root_match( root: Node, pattern: PatternID, - circ: &impl Circuit, + circ: &Circuit, matcher: &PatternMatcher, ) -> Result { - let checker = TopoConvexChecker::new(circ); + let checker = TopoConvexChecker::new(circ.hugr()); Self::try_from_root_match_with_checker(root, pattern, circ, matcher, &checker) } @@ -128,10 +128,10 @@ impl PatternMatch { /// checker object to speed up convexity checking. /// /// See [`PatternMatch::try_from_root_match`] for more details. - pub fn try_from_root_match_with_checker( + pub fn try_from_root_match_with_checker( root: Node, pattern: PatternID, - circ: &C, + circ: &Circuit, matcher: &PatternMatcher, checker: &impl ConvexChecker, ) -> Result { @@ -172,11 +172,11 @@ impl PatternMatch { pub fn try_from_io( root: Node, pattern: PatternID, - circ: &impl Circuit, + circ: &Circuit, inputs: Vec>, outputs: Vec<(Node, OutgoingPort)>, ) -> Result { - let checker = TopoConvexChecker::new(circ); + let checker = TopoConvexChecker::new(circ.hugr()); Self::try_from_io_with_checker(root, pattern, circ, inputs, outputs, &checker) } @@ -188,15 +188,16 @@ impl PatternMatch { /// /// This checks at construction time that the match is convex. This will /// have runtime linear in the size of the circuit. - pub fn try_from_io_with_checker( + pub fn try_from_io_with_checker( root: Node, pattern: PatternID, - circ: &C, + circ: &Circuit, inputs: Vec>, outputs: Vec<(Node, OutgoingPort)>, checker: &impl ConvexChecker, ) -> Result { - let subgraph = SiblingSubgraph::try_new_with_checker(inputs, outputs, circ, checker)?; + let subgraph = + SiblingSubgraph::try_new_with_checker(inputs, outputs, circ.hugr(), checker)?; Ok(Self { position: subgraph.into(), pattern, @@ -207,8 +208,8 @@ impl PatternMatch { /// Construct a rewrite to replace `self` with `repl`. pub fn to_rewrite( &self, - source: &Hugr, - target: Hugr, + source: &Circuit, + target: Circuit, ) -> Result { CircuitRewrite::try_new(&self.position, source, target) } @@ -263,25 +264,25 @@ impl PatternMatcher { } /// Find all convex pattern matches in a circuit. - pub fn find_matches_iter<'a, 'c: 'a, C: Circuit + Clone>( + pub fn find_matches_iter<'a, 'c: 'a>( &'a self, - circuit: &'c C, + circuit: &'c Circuit, ) -> impl Iterator + 'a { - let checker = TopoConvexChecker::new(circuit); + let checker = TopoConvexChecker::new(circuit.hugr()); circuit .commands() .flat_map(move |cmd| self.find_rooted_matches(circuit, cmd.node(), &checker)) } /// Find all convex pattern matches in a circuit.and collect in to a vector - pub fn find_matches(&self, circuit: &C) -> Vec { + pub fn find_matches(&self, circuit: &Circuit) -> Vec { self.find_matches_iter(circuit).collect() } /// Find all convex pattern matches in a circuit rooted at a given node. - fn find_rooted_matches( + fn find_rooted_matches( &self, - circ: &C, + circ: &Circuit, root: Node, checker: &impl ConvexChecker, ) -> Vec { @@ -429,23 +430,24 @@ fn compatible_offsets(e1: &PEdge, e2: &PEdge) -> bool { /// Returns a predicate checking that an edge at `src` satisfies `prop` in `circ`. pub(super) fn validate_circuit_edge( - circ: &impl Circuit, + circ: &Circuit, ) -> impl for<'a> Fn(NodeID, &'a PEdge) -> Option + '_ { move |src, &prop| { let NodeID::HugrNode(src) = src else { return None; }; + let hugr = circ.hugr(); match prop { PEdge::InternalEdge { src: src_port, dst: dst_port, .. } => { - let (next_node, next_port) = circ.linked_ports(src, src_port).exactly_one().ok()?; + let (next_node, next_port) = hugr.linked_ports(src, src_port).exactly_one().ok()?; (dst_port == next_port).then_some(NodeID::HugrNode(next_node)) } PEdge::InputEdge { src: src_port } => { - let (next_node, next_port) = circ.linked_ports(src, src_port).exactly_one().ok()?; + let (next_node, next_port) = hugr.linked_ports(src, src_port).exactly_one().ok()?; Some(NodeID::CopyNode(next_node, next_port)) } } @@ -454,13 +456,13 @@ pub(super) fn validate_circuit_edge( /// Returns a predicate checking that `node` satisfies `prop` in `circ`. pub(crate) fn validate_circuit_node( - circ: &impl Circuit, + circ: &Circuit, ) -> impl for<'a> Fn(NodeID, &PNode) -> bool + '_ { move |node, prop| { let NodeID::HugrNode(node) = node else { return false; }; - &MatchOp::from(circ.get_optype(node).clone()) == prop + &MatchOp::from(circ.hugr().get_optype(node).clone()) == prop } } @@ -479,16 +481,15 @@ fn handle_match_error(match_res: Result, root: Node) #[cfg(test)] mod tests { - use hugr::Hugr; use itertools::Itertools; use rstest::{fixture, rstest}; use crate::utils::build_simple_circuit; - use crate::Tk2Op; + use crate::{Circuit, Tk2Op}; use super::{CircuitPattern, PatternMatcher}; - fn h_cx() -> Hugr { + fn h_cx() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::H, [0]).unwrap(); @@ -497,7 +498,7 @@ mod tests { .unwrap() } - fn cx_xc() -> Hugr { + fn cx_xc() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::CX, [1, 0]).unwrap(); @@ -507,7 +508,7 @@ mod tests { } #[fixture] - fn cx_cx_3() -> Hugr { + fn cx_cx_3() -> Circuit { build_simple_circuit(3, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::CX, [2, 1]).unwrap(); @@ -517,7 +518,7 @@ mod tests { } #[fixture] - fn cx_cx() -> Hugr { + fn cx_cx() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::CX, [0, 1]).unwrap(); @@ -558,7 +559,7 @@ mod tests { } #[rstest] - fn cx_cx_replace_to_id(cx_cx: Hugr, cx_cx_3: Hugr) { + fn cx_cx_replace_to_id(cx_cx: Circuit, cx_cx_3: Circuit) { let p = CircuitPattern::try_from_circuit(&cx_cx_3).unwrap(); let m = PatternMatcher::from_patterns(vec![p]); diff --git a/tket2/src/portmatching/pattern.rs b/tket2/src/portmatching/pattern.rs index 4020a6bf..4a460ebd 100644 --- a/tket2/src/portmatching/pattern.rs +++ b/tket2/src/portmatching/pattern.rs @@ -1,6 +1,6 @@ //! Circuit Patterns for pattern matching -use hugr::IncomingPort; +use hugr::{HugrView, IncomingPort}; use hugr::{Node, Port}; use itertools::Itertools; use portmatching::{patterns::NoRootFound, HashMap, Pattern, SinglePatternMatcher}; @@ -30,7 +30,8 @@ impl CircuitPattern { } /// Construct a pattern from a circuit. - pub fn try_from_circuit(circuit: &impl Circuit) -> Result { + pub fn try_from_circuit(circuit: &Circuit) -> Result { + let hugr = circuit.hugr(); if circuit.num_gates() == 0 { return Err(InvalidPattern::EmptyCircuit); } @@ -42,10 +43,9 @@ impl CircuitPattern { let in_offset: IncomingPort = in_offset.into(); let edge_prop = PEdge::try_from_port(cmd.node(), in_offset.into(), circuit) .expect("Invalid HUGR"); - let (prev_node, prev_port) = circuit + let (prev_node, prev_port) = hugr .linked_outputs(cmd.node(), in_offset) .exactly_one() - .ok() .expect("invalid HUGR"); let prev_node = match edge_prop { PEdge::InternalEdge { .. } => NodeID::HugrNode(prev_node), @@ -58,18 +58,16 @@ impl CircuitPattern { if !pattern.is_valid() { return Err(InvalidPattern::NotConnected); } - let (inp, out) = (circuit.input(), circuit.output()); - let inp_ports = circuit.signature(inp).unwrap().output_ports(); - let out_ports = circuit.signature(out).unwrap().input_ports(); + let [inp, out] = circuit.io_nodes(); + let inp_ports = hugr.signature(inp).unwrap().output_ports(); + let out_ports = hugr.signature(out).unwrap().input_ports(); let inputs = inp_ports - .map(|p| circuit.linked_ports(inp, p).collect()) + .map(|p| hugr.linked_ports(inp, p).collect()) .collect_vec(); let outputs = out_ports .map(|p| { - circuit - .linked_ports(out, p) + hugr.linked_ports(out, p) .exactly_one() - .ok() .expect("invalid circuit") }) .collect_vec(); @@ -87,7 +85,11 @@ impl CircuitPattern { } /// Compute the map from pattern nodes to circuit nodes in `circ`. - pub fn get_match_map(&self, root: Node, circ: &impl Circuit) -> Option> { + pub fn get_match_map( + &self, + root: Node, + circ: &Circuit, + ) -> Option> { let single_matcher = SinglePatternMatcher::from_pattern(self.pattern.clone()); single_matcher .get_match_map( @@ -143,7 +145,6 @@ mod tests { use hugr::ops::OpType; use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use hugr::types::FunctionType; - use hugr::Hugr; use crate::extension::REGISTRY; use crate::utils::build_simple_circuit; @@ -151,7 +152,7 @@ mod tests { use super::*; - fn h_cx() -> Hugr { + fn h_cx() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1])?; circ.append(Tk2Op::H, [0])?; @@ -161,7 +162,7 @@ mod tests { } /// A circuit with two rotation gates in sequence, sharing a param - fn circ_with_copy() -> Hugr { + fn circ_with_copy() -> Circuit { let input_t = vec![QB_T, FLOAT64_TYPE]; let output_t = vec![QB_T]; let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); @@ -175,11 +176,11 @@ mod tests { let res = h.add_dataflow_op(Tk2Op::RxF64, [qb, f]).unwrap(); let qb = res.outputs().next().unwrap(); - h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap().into() } /// A circuit with two rotation gates in parallel, sharing a param - fn circ_with_copy_disconnected() -> Hugr { + fn circ_with_copy_disconnected() -> Circuit { let input_t = vec![QB_T, QB_T, FLOAT64_TYPE]; let output_t = vec![QB_T, QB_T]; let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); @@ -194,14 +195,16 @@ mod tests { let res = h.add_dataflow_op(Tk2Op::RxF64, [qb2, f]).unwrap(); let qb2 = res.outputs().next().unwrap(); - h.finish_hugr_with_outputs([qb1, qb2], ®ISTRY).unwrap() + h.finish_hugr_with_outputs([qb1, qb2], ®ISTRY) + .unwrap() + .into() } #[test] fn construct_pattern() { - let hugr = h_cx(); + let circ = h_cx(); - let p = CircuitPattern::try_from_circuit(&hugr).unwrap(); + let p = CircuitPattern::try_from_circuit(&circ).unwrap(); let edges: HashSet<_> = p .pattern @@ -210,9 +213,9 @@ mod tests { .iter() .map(|e| (e.source.unwrap(), e.target.unwrap())) .collect(); - let inp = hugr.input(); - let cx_gate = NodeID::HugrNode(get_nodes_by_tk2op(&hugr, Tk2Op::CX)[0]); - let h_gate = NodeID::HugrNode(get_nodes_by_tk2op(&hugr, Tk2Op::H)[0]); + let inp = circ.input_node(); + let cx_gate = NodeID::HugrNode(get_nodes_by_tk2op(&circ, Tk2Op::CX)[0]); + let h_gate = NodeID::HugrNode(get_nodes_by_tk2op(&circ, Tk2Op::H)[0]); assert_eq!( edges, [ @@ -252,10 +255,11 @@ mod tests { ); } - fn get_nodes_by_tk2op(circ: &impl Circuit, t2_op: Tk2Op) -> Vec { + fn get_nodes_by_tk2op(circ: &Circuit, t2_op: Tk2Op) -> Vec { let t2_op: OpType = t2_op.into(); - circ.nodes() - .filter(|n| circ.get_optype(*n) == &t2_op) + circ.hugr() + .nodes() + .filter(|n| circ.hugr().get_optype(*n) == &t2_op) .collect() } @@ -265,7 +269,7 @@ mod tests { let pattern = CircuitPattern::try_from_circuit(&circ).unwrap(); let edges = pattern.pattern.edges().unwrap(); let rx_ns = get_nodes_by_tk2op(&circ, Tk2Op::RxF64); - let inp = circ.input(); + let inp = circ.input_node(); for rx_n in rx_ns { assert!(edges.iter().any(|e| { e.reverse().is_none() diff --git a/tket2/src/rewrite.rs b/tket2/src/rewrite.rs index 45dfafec..85eeceee 100644 --- a/tket2/src/rewrite.rs +++ b/tket2/src/rewrite.rs @@ -10,17 +10,16 @@ use bytemuck::TransparentWrapper; pub use ecc_rewriter::ECCRewriter; use derive_more::{From, Into}; +use hugr::hugr::hugrmut::HugrMut; use hugr::hugr::views::sibling_subgraph::{InvalidReplacement, InvalidSubgraph}; -use hugr::Node; use hugr::{ - hugr::{hugrmut::HugrMut, views::SiblingSubgraph, Rewrite, SimpleReplacementError}, - Hugr, SimpleReplacement, + hugr::{views::SiblingSubgraph, Rewrite, SimpleReplacementError}, + SimpleReplacement, }; +use hugr::{Hugr, HugrView, Node}; use crate::circuit::Circuit; -use self::trace::RewriteTracer; - /// A subcircuit of a circuit. #[derive(Debug, Clone, From, Into)] #[repr(transparent)] @@ -34,9 +33,9 @@ impl Subcircuit { /// Create a new subcircuit induced from a set of nodes. pub fn try_from_nodes( nodes: impl Into>, - hugr: &Hugr, + circ: &Circuit, ) -> Result { - let subgraph = SiblingSubgraph::try_from_nodes(nodes, hugr)?; + let subgraph = SiblingSubgraph::try_from_nodes(nodes, circ.hugr())?; Ok(Self { subgraph }) } @@ -53,12 +52,13 @@ impl Subcircuit { /// Create a rewrite rule to replace the subcircuit. pub fn create_rewrite( &self, - source: &Hugr, - target: Hugr, + source: &Circuit, + target: Circuit, ) -> Result { - Ok(CircuitRewrite( - self.subgraph.create_simple_replacement(source, target)?, - )) + Ok(CircuitRewrite(self.subgraph.create_simple_replacement( + source.hugr(), + target.into_hugr(), + )?)) } } @@ -70,12 +70,12 @@ impl CircuitRewrite { /// Create a new rewrite rule. pub fn try_new( source_position: &Subcircuit, - source: &Hugr, - target: Hugr, + source: &Circuit, + target: Circuit, ) -> Result { source_position .subgraph - .create_simple_replacement(source, target) + .create_simple_replacement(source.hugr(), target.into_hugr()) .map(Self) } @@ -95,8 +95,8 @@ impl CircuitRewrite { } /// The replacement subcircuit. - pub fn replacement(&self) -> &Hugr { - self.0.replacement() + pub fn replacement(&self) -> Circuit<&Hugr> { + self.0.replacement().into() } /// Returns a set of nodes referenced by the rewrite. Modifying any these @@ -111,20 +111,23 @@ impl CircuitRewrite { /// Apply the rewrite rule to a circuit. #[inline] - pub fn apply(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> { + pub fn apply(self, circ: &mut Circuit) -> Result<(), SimpleReplacementError> { circ.add_rewrite_trace(&self); - self.0.apply(circ) + self.0.apply(circ.hugr_mut()) } /// Apply the rewrite rule to a circuit, without registering it in the rewrite trace. #[inline] - pub fn apply_notrace(self, circ: &mut impl HugrMut) -> Result<(), SimpleReplacementError> { - self.0.apply(circ) + pub fn apply_notrace( + self, + circ: &mut Circuit, + ) -> Result<(), SimpleReplacementError> { + self.0.apply(circ.hugr_mut()) } } /// Generate rewrite rules for circuits. pub trait Rewriter { /// Get the rewrite rules for a circuit. - fn get_rewrites(&self, circ: &C) -> Vec; + fn get_rewrites(&self, circ: &Circuit) -> Vec; } diff --git a/tket2/src/rewrite/ecc_rewriter.rs b/tket2/src/rewrite/ecc_rewriter.rs index 954717b9..1c6be502 100644 --- a/tket2/src/rewrite/ecc_rewriter.rs +++ b/tket2/src/rewrite/ecc_rewriter.rs @@ -13,7 +13,7 @@ //! of the Quartz repository. use derive_more::{From, Into}; -use hugr::PortIndex; +use hugr::{Hugr, HugrView, PortIndex}; use itertools::Itertools; use portmatching::PatternID; use std::{ @@ -24,8 +24,6 @@ use std::{ }; use thiserror::Error; -use hugr::Hugr; - use crate::{ circuit::{remove_empty_wire, Circuit}, optimiser::badger::{load_eccs_json_file, EqCircClass}, @@ -76,7 +74,7 @@ impl ECCRewriter { /// Equivalence classes are represented as [`EqCircClass`]s, lists of /// HUGRs where one of the elements is chosen as the representative. pub fn from_eccs(eccs: impl Into>) -> Self { - let eccs = eccs.into(); + let eccs: Vec = eccs.into(); let rewrite_rules = get_rewrite_rules(&eccs); let patterns = get_patterns(&eccs); let targets = into_targets(eccs); @@ -90,7 +88,7 @@ impl ECCRewriter { let targets = r .into_iter() .filter(|&id| { - let circ = &targets[id.0]; + let circ = (&targets[id.0]).into(); let target_empty_wires: HashSet<_> = empty_wires(&circ).into_iter().collect(); pattern_empty_wires @@ -111,10 +109,10 @@ impl ECCRewriter { } /// Get all targets of rewrite rules given a source pattern. - fn get_targets(&self, pattern: PatternID) -> impl Iterator { + fn get_targets(&self, pattern: PatternID) -> impl Iterator> { self.rewrite_rules[pattern.0] .iter() - .map(|id| &self.targets[id.0]) + .map(|id| (&self.targets[id.0]).into()) } /// Serialise a rewriter to an IO stream. @@ -167,19 +165,18 @@ impl ECCRewriter { } impl Rewriter for ECCRewriter { - fn get_rewrites(&self, circ: &C) -> Vec { + fn get_rewrites(&self, circ: &Circuit) -> Vec { let matches = self.matcher.find_matches(circ); matches .into_iter() .flat_map(|m| { let pattern_id = m.pattern_id(); self.get_targets(pattern_id).map(move |repl| { - let mut repl = repl.clone(); + let mut repl = repl.to_owned(); for &empty_qb in self.empty_wires[pattern_id.0].iter().rev() { remove_empty_wire(&mut repl, empty_qb).unwrap(); } - m.to_rewrite(circ.base_hugr(), repl) - .expect("invalid replacement") + m.to_rewrite(circ, repl).expect("invalid replacement") }) }) .collect() @@ -231,9 +228,9 @@ fn get_patterns(rep_sets: &[EqCircClass]) -> Vec Vec Vec { - let input = circ.input(); - let input_sig = circ.signature(input).unwrap(); - circ.node_outputs(input) +fn empty_wires(circ: &Circuit) -> Vec { + let hugr = circ.hugr(); + let input = circ.input_node(); + let input_sig = hugr.signature(input).unwrap(); + hugr.node_outputs(input) // Only consider dataflow edges .filter(|&p| input_sig.out_port_type(p).is_some()) // Only consider ports linked to at most one other port - .filter_map(|p| Some((p, circ.linked_ports(input, p).at_most_one().ok()?))) + .filter_map(|p| Some((p, hugr.linked_ports(input, p).at_most_one().ok()?))) // Ports are either connected to output or nothing .filter_map(|(from, to)| { if let Some((n, _)) = to { // Wires connected to output - (n == circ.output()).then_some(from.index()) + (n == circ.output_node()).then_some(from.index()) } else { // Wires connected to nothing Some(from.index()) @@ -272,11 +270,11 @@ mod tests { use super::*; - fn empty() -> Hugr { + fn empty() -> Circuit { build_simple_circuit(2, |_| Ok(())).unwrap() } - fn h_h() -> Hugr { + fn h_h() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::H, [0]).unwrap(); circ.append(Tk2Op::H, [0]).unwrap(); @@ -286,7 +284,7 @@ mod tests { .unwrap() } - fn cx_cx() -> Hugr { + fn cx_cx() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::CX, [0, 1]).unwrap(); @@ -295,7 +293,7 @@ mod tests { .unwrap() } - fn cx_x() -> Hugr { + fn cx_x() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::CX, [0, 1]).unwrap(); circ.append(Tk2Op::X, [1]).unwrap(); @@ -304,7 +302,7 @@ mod tests { .unwrap() } - fn x_cx() -> Hugr { + fn x_cx() -> Circuit { build_simple_circuit(2, |circ| { circ.append(Tk2Op::X, [1]).unwrap(); circ.append(Tk2Op::CX, [0, 1]).unwrap(); @@ -328,7 +326,13 @@ mod tests { vec![TargetID(3)], ] ); - assert_eq!(rewriter.get_targets(PatternID(1)).collect_vec(), [&h_h()]); + assert_eq!( + rewriter + .get_targets(PatternID(1)) + .map(|c| c.to_owned()) + .collect_vec(), + [h_h()] + ); } #[test] diff --git a/tket2/src/rewrite/strategy.rs b/tket2/src/rewrite/strategy.rs index e6576942..3e44b638 100644 --- a/tket2/src/rewrite/strategy.rs +++ b/tket2/src/rewrite/strategy.rs @@ -25,13 +25,13 @@ use std::{collections::HashSet, fmt::Debug}; use derive_more::From; use hugr::ops::OpType; -use hugr::Hugr; +use hugr::HugrView; use itertools::Itertools; use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, LexicographicCost}; use crate::Circuit; -use super::trace::{RewriteTrace, RewriteTracer}; +use super::trace::RewriteTrace; use super::CircuitRewrite; /// Rewriting strategies for circuit optimisation. @@ -51,7 +51,7 @@ pub trait RewriteStrategy { fn apply_rewrites( &self, rewrites: impl IntoIterator, - circ: &Hugr, + circ: &Circuit, ) -> impl Iterator>; /// The cost of a single operation for this strategy's cost function. @@ -59,13 +59,13 @@ pub trait RewriteStrategy { /// The cost of a circuit using this strategy's cost function. #[inline] - fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { + fn circuit_cost(&self, circ: &Circuit) -> Self::Cost { circ.circuit_cost(|op| self.op_cost(op)) } /// Returns the cost of a rewrite's matched subcircuit before replacing it. #[inline] - fn pre_rewrite_cost(&self, rw: &CircuitRewrite, circ: &Hugr) -> Self::Cost { + fn pre_rewrite_cost(&self, rw: &CircuitRewrite, circ: &Circuit) -> Self::Cost { circ.nodes_cost(rw.subcircuit().nodes().iter().copied(), |op| { self.op_cost(op) }) @@ -81,15 +81,18 @@ pub trait RewriteStrategy { #[derive(Debug, Clone)] pub struct RewriteResult { /// The rewritten circuit. - pub circ: Hugr, + pub circ: Circuit, /// The cost delta of the rewrite. pub cost_delta: C::CostDelta, } -impl From<(Hugr, C::CostDelta)> for RewriteResult { +impl From<(Circuit, C::CostDelta)> for RewriteResult { #[inline] - fn from((circ, cost_delta): (Hugr, C::CostDelta)) -> Self { - Self { circ, cost_delta } + fn from((circ, cost_delta): (Circuit, C::CostDelta)) -> Self { + Self { + circ: circ.to_owned(), + cost_delta, + } } } @@ -113,7 +116,7 @@ impl RewriteStrategy for GreedyRewriteStrategy { fn apply_rewrites( &self, rewrites: impl IntoIterator, - circ: &Hugr, + circ: &Circuit, ) -> impl Iterator> { let rewrites = rewrites .into_iter() @@ -140,7 +143,7 @@ impl RewriteStrategy for GreedyRewriteStrategy { iter::once((circ, cost_delta).into()) } - fn circuit_cost(&self, circ: &Hugr) -> Self::Cost { + fn circuit_cost(&self, circ: &Circuit) -> Self::Cost { circ.num_gates() } @@ -184,7 +187,7 @@ impl RewriteStrategy for ExhaustiveGreedyStrategy { fn apply_rewrites( &self, rewrites: impl IntoIterator, - circ: &Hugr, + circ: &Circuit, ) -> impl Iterator> { // Check only the rewrites that reduce the size of the circuit. let rewrites = rewrites @@ -262,7 +265,7 @@ impl RewriteStrategy for ExhaustiveThresholdStrategy { fn apply_rewrites( &self, rewrites: impl IntoIterator, - circ: &Hugr, + circ: &Circuit, ) -> impl Iterator> { rewrites.into_iter().filter_map(|rw| { let pattern_cost = self.pre_rewrite_cost(&rw, circ); @@ -429,7 +432,7 @@ impl GammaStrategyCost usize> { #[cfg(test)] mod tests { use super::*; - use hugr::{Hugr, Node}; + use hugr::Node; use itertools::Itertools; use crate::rewrite::trace::REWRITE_TRACING_ENABLED; @@ -440,7 +443,7 @@ mod tests { Tk2Op, }; - fn n_cx(n_gates: usize) -> Hugr { + fn n_cx(n_gates: usize) -> Circuit { let qbs = [0, 1]; build_simple_circuit(2, |circ| { for _ in 0..n_gates { @@ -452,15 +455,15 @@ mod tests { } /// Rewrite cx_nodes -> empty - fn rw_to_empty(hugr: &Hugr, cx_nodes: impl Into>) -> CircuitRewrite { - let subcirc = Subcircuit::try_from_nodes(cx_nodes, hugr).unwrap(); - subcirc.create_rewrite(hugr, n_cx(0)).unwrap() + fn rw_to_empty(circ: &Circuit, cx_nodes: impl Into>) -> CircuitRewrite { + let subcirc = Subcircuit::try_from_nodes(cx_nodes, circ).unwrap(); + subcirc.create_rewrite(circ, n_cx(0)).unwrap() } /// Rewrite cx_nodes -> 10x CX - fn rw_to_full(hugr: &Hugr, cx_nodes: impl Into>) -> CircuitRewrite { - let subcirc = Subcircuit::try_from_nodes(cx_nodes, hugr).unwrap(); - subcirc.create_rewrite(hugr, n_cx(10)).unwrap() + fn rw_to_full(circ: &Circuit, cx_nodes: impl Into>) -> CircuitRewrite { + let subcirc = Subcircuit::try_from_nodes(cx_nodes, circ).unwrap(); + subcirc.create_rewrite(circ, n_cx(10)).unwrap() } #[test] diff --git a/tket2/src/rewrite/trace.rs b/tket2/src/rewrite/trace.rs index f2c17494..d469135f 100644 --- a/tket2/src/rewrite/trace.rs +++ b/tket2/src/rewrite/trace.rs @@ -17,13 +17,13 @@ pub const METADATA_REWRITES: &str = "TKET2.rewrites"; /// Enable it by setting the `rewrite-tracing` feature. /// /// Note that circuits must be explicitly enabled for rewrite tracing by calling -/// [`RewriteTracer::enable_rewrite_tracing`]. +/// [`Circuit::enable_rewrite_tracing`]. pub const REWRITE_TRACING_ENABLED: bool = cfg!(feature = "rewrite-tracing"); /// The trace of a rewrite applied to a circuit. /// /// Traces are only enabled if the `rewrite-tracing` feature is enabled and -/// [`RewriteTracer::enable_rewrite_tracing`] is called on the circuit. +/// [`Circuit::enable_rewrite_tracing`] is called on the circuit. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(transparent)] pub struct RewriteTrace { @@ -67,18 +67,19 @@ impl From for serde_json::Value { } } -/// Extension trait for circuits that can trace rewrites applied to them. +/// Implementation for rewrite tracing in circuits. /// /// This is only tracked if the `rewrite-tracing` feature is enabled and /// `enable_rewrite_tracing` is called on the circuit before. -pub trait RewriteTracer: Circuit + HugrMut + Sized { +impl Circuit { /// Enable rewrite tracing for the circuit. #[inline] - fn enable_rewrite_tracing(&mut self) { + pub fn enable_rewrite_tracing(&mut self) { if !REWRITE_TRACING_ENABLED { return; } - let meta = self.get_metadata_mut(self.root(), METADATA_REWRITES); + let root = self.parent(); + let meta = self.hugr_mut().get_metadata_mut(root, METADATA_REWRITES); if *meta == NodeMetadata::Null { *meta = NodeMetadata::Array(vec![]); } @@ -88,12 +89,14 @@ pub trait RewriteTracer: Circuit + HugrMut + Sized { /// /// Returns `true` if the rewrite was successfully registered, or `false` if it was ignored. #[inline] - fn add_rewrite_trace(&mut self, rewrite: impl Into) -> bool { + pub fn add_rewrite_trace(&mut self, rewrite: impl Into) -> bool { if !REWRITE_TRACING_ENABLED { return false; } + let root = self.parent(); match self - .get_metadata_mut(self.root(), METADATA_REWRITES) + .hugr_mut() + .get_metadata_mut(root, METADATA_REWRITES) .as_array_mut() { Some(meta) => { @@ -112,14 +115,12 @@ pub trait RewriteTracer: Circuit + HugrMut + Sized { // // TODO return an `impl Iterator` once rust 1.75 lands. #[inline] - fn rewrite_trace(&self) -> Option> { + pub fn rewrite_trace(&self) -> Option> { if !REWRITE_TRACING_ENABLED { return None; } - let meta = self.get_metadata(self.root(), METADATA_REWRITES)?; + let meta = self.hugr().get_metadata(self.parent(), METADATA_REWRITES)?; let rewrites = meta.as_array()?; Some(rewrites.iter().map_into().collect_vec()) } } - -impl RewriteTracer for T {} diff --git a/tket2/src/utils.rs b/tket2/src/utils.rs index 1c46692d..ec4daa87 100644 --- a/tket2/src/utils.rs +++ b/tket2/src/utils.rs @@ -9,6 +9,8 @@ use hugr::{ Hugr, }; +use crate::circuit::Circuit; + pub(crate) fn type_is_linear(typ: &Type) -> bool { !TypeBound::Copyable.contains(typ.least_upper_bound()) } @@ -18,7 +20,7 @@ pub(crate) fn type_is_linear(typ: &Type) -> bool { pub(crate) fn build_simple_circuit( num_qubits: usize, f: impl FnOnce(&mut CircuitBuilder>) -> Result<(), BuildError>, -) -> Result { +) -> Result { let qb_row = vec![QB_T; num_qubits]; let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row))?; @@ -29,14 +31,16 @@ pub(crate) fn build_simple_circuit( f(&mut circ)?; let qbs = circ.finish(); - h.finish_hugr_with_outputs(qbs, &PRELUDE_REGISTRY) + let hugr = h.finish_hugr_with_outputs(qbs, &PRELUDE_REGISTRY)?; + Ok(hugr.into()) } // Test only utils #[allow(dead_code)] +#[allow(unused_imports)] #[cfg(test)] pub(crate) mod test { - #[allow(unused_imports)] + use crate::Circuit; use hugr::HugrView; /// Open a browser page to render a dot string graph. @@ -51,6 +55,14 @@ pub(crate) mod test { webbrowser::open(&base).unwrap(); } + /// Open a browser page to render a Circuit's dot string graph. + /// + /// Only for use in local testing. Will fail to compile on CI. + #[cfg(not(ci_run))] + pub(crate) fn viz_circ(circ: &Circuit) { + viz_dotstr(circ.dot_string()); + } + /// Open a browser page to render a HugrView's dot string graph. /// /// Only for use in local testing. Will fail to compile on CI. diff --git a/tket2/tests/badger_termination.rs b/tket2/tests/badger_termination.rs index f899ca9d..b79d46f1 100644 --- a/tket2/tests/badger_termination.rs +++ b/tket2/tests/badger_termination.rs @@ -1,6 +1,5 @@ #![cfg(feature = "portmatching")] -use hugr::Hugr; use rstest::{fixture, rstest}; use tket2::optimiser::badger::BadgerOptions; use tket2::{ @@ -29,7 +28,7 @@ fn nam_4_2() -> DefaultBadgerOptimiser { ///q_2: ┤ X ├───────────────────────────────────────────┤ X ├───────────── /// └───┘ └───┘ #[fixture] -fn simple_circ() -> Hugr { +fn simple_circ() -> Circuit { // The TK1 json of the circuit let json = r#"{ "bits": [], @@ -57,7 +56,7 @@ fn simple_circ() -> Hugr { #[rstest] //#[ignore = "Takes 200ms"] -fn badger_termination(simple_circ: Hugr, nam_4_2: DefaultBadgerOptimiser) { +fn badger_termination(simple_circ: Circuit, nam_4_2: DefaultBadgerOptimiser) { let opt_circ = nam_4_2.optimise( &simple_circ, BadgerOptions {