Skip to content

Commit

Permalink
refactor!: Replace Circuit trait with a struct (#370)
Browse files Browse the repository at this point in the history
Replaces the `trait Circuit : HugrView` with a struct containing a `T:
HugrView` and a node id indicating the container node for the circuit in
the hugr.

There are a few design going into this definition:

- #### Struct with checked parent node type.
The previous `Circuit` trait asked the user to only use it for
dataflow-based hugrs. This of course lead to functions assuming things
about the hugr structure and panicking with random errors when passed an
otherwise valid hugr.
  
Bar checking the root node at each method call, this can only be solved
with an opaque struct that checks its preconditions at construction
time.
  
The code here will throw a user-friendly error when trying to use
incompatible hugrs:
  ```
Node(0) cannot be used as a circuit parent. A Module is not a dataflow
container.
  ```

- #### Generic `T` defaulting to `Hugr`
Most uses in this library manipulate circuits as owned structures (e.g.
passing circuits around in badger).
However, sometimes we don't own the hugr so we cannot create a
`Circuit<Hugr>` from it. With the generics definition we can use
`Circuit<&Hugr>` or `Circuit<&mut Hugr>` to respectively view or modify
non-owned structures.
  
- #### Parent node pointer
Circuits represented as hugrs are not necessarily a flat dataflow
structure.
For example, a guppy-defined circuit may include calls to other methods
from the module.
By including a pointer to the entry point of the circuit we can keep all
this structure, and rewrite it latter as necessary.
This is also useful for non-owned hugrs; we can have a `Circuit<&Hugr>`
pointing to a region in the hugr without needing to modify anything
else.

Closes #112.

BREAKING CHANGE: Replaced the `Circuit` trait with a wrapper struct.

---------

Co-authored-by: Seyon Sivarajah <[email protected]>
  • Loading branch information
aborgna-q and ss2165 authored Jun 3, 2024
1 parent 58f0de4 commit ec5dd22
Show file tree
Hide file tree
Showing 38 changed files with 846 additions and 557 deletions.
54 changes: 42 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ missing_docs = "warn"
[workspace.dependencies]

tket2 = { path = "./tket2" }
hugr = "0.4.0"
hugr = "0.5.0"
portgraph = "0.12"
pyo3 = "0.21.2"
itertools = "0.13.0"
Expand Down Expand Up @@ -60,3 +60,4 @@ tracing-subscriber = "0.3.17"
typetag = "0.2.8"
urlencoding = "2.1.2"
webbrowser = "1.0.0"
cool_asserts = "2.0.3"
1 change: 0 additions & 1 deletion badger-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use tket2::json::{load_tk1_json_file, save_tk1_json_file};
use tket2::optimiser::badger::log::BadgerLogger;
use tket2::optimiser::badger::BadgerOptions;
use tket2::optimiser::{BadgerOptimiser, DefaultBadgerOptimiser};
use tket2::rewrite::trace::RewriteTracer;

#[cfg(feature = "peak_alloc")]
#[global_allocator]
Expand Down
12 changes: 7 additions & 5 deletions tket2-py/src/circuit/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ impl CircuitType {
/// Converts a `Hugr` into the format indicated by the flag.
pub fn convert(self, py: Python, hugr: Hugr) -> PyResult<Bound<PyAny>> {
match self {
CircuitType::Tket1 => SerialCircuit::encode(&hugr).convert_pyerrs()?.to_tket1(py),
CircuitType::Tket2 => Ok(Bound::new(py, Tk2Circuit { hugr })?.into_any()),
CircuitType::Tket1 => SerialCircuit::encode(&hugr.into())
.convert_pyerrs()?
.to_tket1(py),
CircuitType::Tket2 => Ok(Bound::new(py, Tk2Circuit { circ: hugr.into() })?.into_any()),
}
}
}
Expand All @@ -58,16 +60,16 @@ where
E: ConvertPyErr<Output = PyErr>,
F: FnOnce(Hugr, CircuitType) -> Result<T, E>,
{
let (hugr, typ) = match Tk2Circuit::extract_bound(circ) {
let (circ, typ) = match Tk2Circuit::extract_bound(circ) {
// hugr circuit
Ok(t2circ) => (t2circ.hugr, CircuitType::Tket2),
Ok(t2circ) => (t2circ.circ, CircuitType::Tket2),
// tket1 circuit
Err(_) => (
SerialCircuit::from_tket1(circ)?.decode().convert_pyerrs()?,
CircuitType::Tket1,
),
};
(f)(hugr, typ).map_err(|e| e.convert_pyerrs())
(f)(circ.into_hugr(), typ).map_err(|e| e.convert_pyerrs())
}

/// Apply a function expecting a hugr on a python circuit.
Expand Down
40 changes: 22 additions & 18 deletions tket2-py/src/circuit/tk2circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ use super::{cost, with_hugr, PyCircuitCost, PyCustom, PyHugrType, PyNode, PyWire
#[derive(Clone, Debug, PartialEq, From)]
pub struct Tk2Circuit {
/// Rust representation of the circuit.
pub hugr: Hugr,
pub circ: Circuit,
}

#[pymethods]
Expand All @@ -67,40 +67,40 @@ impl Tk2Circuit {
#[new]
pub fn new(circ: &Bound<PyAny>) -> PyResult<Self> {
Ok(Self {
hugr: with_hugr(circ, |hugr, _| hugr)?,
circ: with_hugr(circ, |hugr, _| hugr)?.into(),
})
}

/// Convert the [`Tk2Circuit`] to a tket1 circuit.
pub fn to_tket1<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
SerialCircuit::encode(&self.hugr)
SerialCircuit::encode(&self.circ)
.convert_pyerrs()?
.to_tket1(py)
}

/// Apply a rewrite on the circuit.
pub fn apply_rewrite(&mut self, rw: PyCircuitRewrite) {
rw.rewrite.apply(&mut self.hugr).expect("Apply error.");
rw.rewrite.apply(&mut self.circ).expect("Apply error.");
}

/// Encode the circuit as a HUGR json string.
//
// TODO: Bind a messagepack encoder/decoder too.
pub fn to_hugr_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&self.hugr).unwrap())
Ok(serde_json::to_string(self.circ.hugr()).unwrap())
}

/// Decode a HUGR json string to a circuit.
#[staticmethod]
pub fn from_hugr_json(json: &str) -> PyResult<Self> {
let hugr = serde_json::from_str(json)
let hugr: Hugr = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit { hugr })
Ok(Tk2Circuit { circ: hugr.into() })
}

/// Encode the circuit as a tket1 json string.
pub fn to_tket1_json(&self) -> PyResult<String> {
Ok(serde_json::to_string(&SerialCircuit::encode(&self.hugr).convert_pyerrs()?).unwrap())
Ok(serde_json::to_string(&SerialCircuit::encode(&self.circ).convert_pyerrs()?).unwrap())
}

/// Decode a tket1 json string to a circuit.
Expand All @@ -109,7 +109,7 @@ impl Tk2Circuit {
let tk1: SerialCircuit = serde_json::from_str(json)
.map_err(|e| PyErr::new::<PyAttributeError, _>(format!("Invalid encoded HUGR: {e}")))?;
Ok(Tk2Circuit {
hugr: tk1.decode().convert_pyerrs()?,
circ: tk1.decode().convert_pyerrs()?,
})
}

Expand All @@ -134,13 +134,13 @@ impl Tk2Circuit {
cost: cost.to_object(py),
})
};
let circ_cost = self.hugr.circuit_cost(cost_fn)?;
let circ_cost = self.circ.circuit_cost(cost_fn)?;
Ok(circ_cost.cost.into_bound(py))
}

/// Returns a hash of the circuit.
pub fn hash(&self) -> u64 {
self.hugr.circuit_hash().unwrap()
self.circ.circuit_hash().unwrap()
}

/// Hash the circuit
Expand All @@ -160,7 +160,8 @@ impl Tk2Circuit {

fn node_op(&self, node: PyNode) -> PyResult<PyCustom> {
let custom: CustomOp = self
.hugr
.circ
.hugr()
.get_optype(node.node)
.clone()
.try_into()
Expand All @@ -174,25 +175,27 @@ impl Tk2Circuit {
}

fn node_inputs(&self, node: PyNode) -> Vec<PyWire> {
self.hugr
self.circ
.hugr()
.all_linked_outputs(node.node)
.map(|(n, p)| Wire::new(n, p).into())
.collect()
}

fn node_outputs(&self, node: PyNode) -> Vec<PyWire> {
self.hugr
self.circ
.hugr()
.node_outputs(node.node)
.map(|p| Wire::new(node.node, p).into())
.collect()
}

fn input_node(&self) -> PyNode {
self.hugr.input().into()
self.circ.input_node().into()
}

fn output_node(&self) -> PyNode {
self.hugr.output().into()
self.circ.output_node().into()
}
}
impl Tk2Circuit {
Expand Down Expand Up @@ -236,11 +239,12 @@ impl Dfg {

fn finish(&mut self, outputs: Vec<PyWire>) -> PyResult<Tk2Circuit> {
Ok(Tk2Circuit {
hugr: self
circ: self
.builder
.clone()
.finish_hugr_with_outputs(outputs.into_iter().map_into(), &REGISTRY)
.convert_pyerrs()?,
.convert_pyerrs()?
.into(),
})
}
}
Expand Down
11 changes: 7 additions & 4 deletions tket2-py/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
use std::io::BufWriter;
use std::{fs, num::NonZeroUsize, path::PathBuf};

use hugr::Hugr;
use pyo3::prelude::*;
use tket2::optimiser::badger::BadgerOptions;
use tket2::optimiser::{BadgerLogger, DefaultBadgerOptimiser};
use tket2::Circuit;

use crate::circuit::update_hugr;

Expand Down Expand Up @@ -96,18 +96,21 @@ impl PyBadgerOptimiser {
split_circuit: split_circ.unwrap_or(false),
queue_size: queue_size.unwrap_or(100),
};
update_hugr(circ, |circ, _| self.optimise(circ, log_progress, options))
update_hugr(circ, |circ, _| {
self.optimise(circ.into(), log_progress, options)
.into_hugr()
})
}
}

impl PyBadgerOptimiser {
/// The Python optimise method, but on Hugrs.
pub(super) fn optimise(
&self,
circ: Hugr,
circ: Circuit,
log_progress: Option<PathBuf>,
options: BadgerOptions,
) -> Hugr {
) -> Circuit {
let badger_logger = log_progress
.map(|file_name| {
let log_file = fs::File::create(file_name).unwrap();
Expand Down
12 changes: 7 additions & 5 deletions tket2-py/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ create_py_exception!(
#[pyfunction]
fn greedy_depth_reduce<'py>(circ: &Bound<'py, PyAny>) -> PyResult<(Bound<'py, PyAny>, u32)> {
let py = circ.py();
try_with_hugr(circ, |mut h, typ| {
let n_moves = apply_greedy_commutation(&mut h).convert_pyerrs()?;
let circ = typ.convert(py, h)?;
try_with_hugr(circ, |h, typ| {
let mut circ: Circuit = h.into();
let n_moves = apply_greedy_commutation(&mut circ).convert_pyerrs()?;
let circ = typ.convert(py, circ.into_hugr())?;
PyResult::Ok((circ, n_moves))
})
}
Expand Down Expand Up @@ -117,7 +118,8 @@ fn badger_optimise<'py>(
_ => unreachable!(),
};
// Optimise
try_update_hugr(circ, |mut circ, _| {
try_update_hugr(circ, |hugr, _| {
let mut circ: Circuit = hugr.into();
let n_cx = circ
.commands()
.filter(|c| op_matches(c.optype(), Tk2Op::CX))
Expand All @@ -142,6 +144,6 @@ fn badger_optimise<'py>(
};
circ = optimiser.optimise(circ, log_file, options);
}
PyResult::Ok(circ)
PyResult::Ok(circ.into_hugr())
})
}
Loading

0 comments on commit ec5dd22

Please sign in to comment.