diff --git a/tket2/src/static_circ.rs b/tket2/src/static_circ.rs index d34692d4..9cae5286 100644 --- a/tket2/src/static_circ.rs +++ b/tket2/src/static_circ.rs @@ -2,17 +2,36 @@ mod match_op; -use hugr::{Direction, HugrView}; +use std::{collections::BTreeMap, fmt, rc::Rc}; + +use hugr::{Direction, HugrView, Port, PortIndex}; pub(crate) use match_op::MatchOp; use derive_more::{From, Into}; +use thiserror::Error; use crate::{circuit::units::filter, Circuit}; /// A circuit with a fixed number of qubits numbered from 0 to `num_qubits - 1`. -pub(crate) struct StaticSizeCircuit { +#[derive(Clone, Default)] +pub struct StaticSizeCircuit { /// All quantum operations on qubits. - qubit_ops: Vec>, + qubit_ops: Vec>>, + /// Map operations to their locations in `qubit_ops`. + op_locations: BTreeMap>, +} + +type MatchOpPtr = *const MatchOp; + +/// The location of an operation in a `StaticSizeCircuit`. +/// +/// Given by the qubit index and the position within that qubit's op list. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct OpLocation { + /// The index of the qubit the operation acts on. + qubit: StaticQubitIndex, + /// The index of the operation in the qubit's operation list. + op_idx: usize, } impl StaticSizeCircuit { @@ -22,35 +41,96 @@ impl StaticSizeCircuit { self.qubit_ops.len() } + /// Returns an iterator over the qubits in the circuit. + pub fn qubits_iter(&self) -> impl ExactSizeIterator + '_ { + (0..self.qubit_count()).map(StaticQubitIndex) + } + /// Returns the operations on a given qubit. + pub fn qubit_ops( + &self, + qubit: StaticQubitIndex, + ) -> impl ExactSizeIterator + '_ { + self.qubit_ops[qubit.0].iter().map(|op| op.as_ref()) + } + + fn get(&self, loc: OpLocation) -> Option<&Rc> { + self.qubit_ops.get(loc.qubit.0)?.get(loc.op_idx) + } + + fn exists(&self, loc: OpLocation) -> bool { + self.qubit_ops + .get(loc.qubit.0) + .map_or(false, |ops| ops.get(loc.op_idx).is_some()) + } + + /// Returns the location of the operation linked to the given operation at + /// the given port. + pub fn linked_op(&self, loc: OpLocation, port: Port) -> Option { + let op = self.get(loc)?; + let loc = self.op_location(op).get(port.index())?; + match port.direction() { + Direction::Outgoing => { + let next_loc = OpLocation { + qubit: loc.qubit, + op_idx: loc.op_idx + 1, + }; + if self.exists(next_loc) { + Some(next_loc) + } else { + None + } + } + Direction::Incoming => { + if loc.op_idx == 0 { + None + } else { + Some(OpLocation { + qubit: loc.qubit, + op_idx: loc.op_idx - 1, + }) + } + } + } + } + + fn op_location(&self, op: &Rc) -> &[OpLocation] { + self.op_locations[&Rc::as_ptr(op)].as_slice() + } + + fn append_op(&mut self, op: MatchOp, qubits: impl IntoIterator) { + let qubits = qubits.into_iter(); + let op = Rc::new(op); + let op_ptr = Rc::as_ptr(&op); + for qubit in qubits { + if qubit.0 >= self.qubit_count() { + self.qubit_ops.resize(qubit.0 + 1, Vec::new()); + } + let op_idx = self.qubit_ops[qubit.0].len(); + self.qubit_ops[qubit.0].push(op.clone()); + self.op_locations + .entry(op_ptr) + .or_default() + .push(OpLocation { qubit, op_idx }); + } + } + #[allow(unused)] - pub fn qubit_ops(&self, qubit: usize) -> &[StaticOp] { - &self.qubit_ops[qubit] + fn all_ops_iter(&self) -> impl Iterator> { + self.qubit_ops.iter().flat_map(|ops| ops.iter()) } } /// A qubit index within a `StaticSizeCircuit`. #[repr(transparent)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, From, Into)] -pub(crate) struct StaticQubitIndex(usize); - -/// An operation in a `StaticSizeCircuit`. -/// -/// Currently only support quantum operations without any classical IO. -#[derive(Debug, Clone)] -pub(crate) struct StaticOp { - #[allow(unused)] - op: MatchOp, - #[allow(unused)] - qubits: Vec, - // TODO: clbits -} +pub struct StaticQubitIndex(usize); impl TryFrom<&Circuit> for StaticSizeCircuit { type Error = StaticSizeCircuitError; fn try_from(circuit: &Circuit) -> Result { - let mut qubit_ops = vec![Vec::new(); circuit.qubit_count()]; + let mut res = Self::default(); for cmd in circuit.commands() { let qubits = cmd .units(Direction::Incoming) @@ -64,24 +144,13 @@ impl TryFrom<&Circuit> for StaticSizeCircuit { if cmd.units(Direction::Outgoing).count() != qubits.len() { return Err(StaticSizeCircuitError::InvalidCircuit); } - let op = StaticOp { - op: cmd.optype().clone().into(), - qubits: qubits - .iter() - .copied() - .map(|u| StaticQubitIndex(u.index())) - .collect(), - }; - for qb in qubits { - qubit_ops[qb.index()].push(op.clone()); - } + let op = cmd.optype().clone().into(); + res.append_op(op, qubits.into_iter().map(|u| StaticQubitIndex(u.index()))); } - Ok(Self { qubit_ops }) + Ok(res) } } -use thiserror::Error; - /// Errors that can occur when converting a `Circuit` to a `StaticSizeCircuit`. #[derive(Debug, Error)] pub enum StaticSizeCircuitError { @@ -94,10 +163,30 @@ pub enum StaticSizeCircuitError { InvalidCircuit, } +impl fmt::Debug for StaticSizeCircuit { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("StaticSizeCircuit") + .field("qubit_ops", &self.qubit_ops) + .finish() + } +} + +impl PartialEq for StaticSizeCircuit { + fn eq(&self, other: &Self) -> bool { + self.qubit_ops == other.qubit_ops + } +} + +impl Eq for StaticSizeCircuit {} + #[cfg(test)] mod tests { + use portgraph::PortOffset; + use rstest::rstest; + use super::StaticSizeCircuit; use crate::ops::Tk2Op; + use crate::static_circ::OpLocation; use crate::utils::build_simple_circuit; #[test] @@ -116,8 +205,51 @@ mod tests { // Check the conversion assert_eq!(static_circuit.qubit_count(), 2); - assert_eq!(static_circuit.qubit_ops(0).len(), 2); // H gate on qubit 0 - dbg!(static_circuit.qubit_ops(0)); - assert_eq!(static_circuit.qubit_ops(1).len(), 2); // CX and H gate on qubit 1 + assert_eq!(static_circuit.qubit_ops(0.into()).len(), 2); // H gate on qubit 0 + assert_eq!(static_circuit.qubit_ops(1.into()).len(), 2); // CX and H gate on qubit 1 + } + + #[rstest] + #[case(PortOffset::Outgoing(0), None)] + #[case(PortOffset::Incoming(1), None)] + #[case( + PortOffset::Outgoing(1), + Some(OpLocation { + qubit: 1.into(), + op_idx: 1, + }) + )] + #[case( + PortOffset::Incoming(0), + Some(OpLocation { + qubit: 0.into(), + op_idx: 0, + }) + )] + fn test_linked_op(#[case] port: PortOffset, #[case] expected_loc: Option) { + let circuit = build_simple_circuit(2, |circ| { + circ.append(Tk2Op::H, [0])?; + circ.append(Tk2Op::CX, [0, 1])?; + circ.append(Tk2Op::H, [1])?; + Ok(()) + }) + .unwrap(); + // Convert the circuit to StaticSizeCircuit + let static_circuit: StaticSizeCircuit = (&circuit).try_into().unwrap(); + + // Define the location of the CX gate + let cx_location = OpLocation { + qubit: 0.into(), + op_idx: 1, + }; + + // Define the port for the CX gate + let cx_port = port.into(); + + // Get the linked operation for the CX gate + let linked_op_location = static_circuit.linked_op(cx_location, cx_port); + + // Check if the linked operation is correct + assert_eq!(linked_op_location, expected_loc); } } diff --git a/tket2/src/static_circ/match_op.rs b/tket2/src/static_circ/match_op.rs index c52f8345..5ffaa713 100644 --- a/tket2/src/static_circ/match_op.rs +++ b/tket2/src/static_circ/match_op.rs @@ -5,7 +5,7 @@ use smol_str::SmolStr; #[derive( Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize, )] -pub(crate) struct MatchOp { +pub struct MatchOp { /// The operation identifier op_name: SmolStr, /// The encoded operation, if necessary for comparisons.