Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Guppy hugr -> PHIR lowering #215

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ lto = "thin"

[workspace]
resolver = "2"
members = ["tket2", "tket2-py", "compile-rewriter", "taso-optimiser"]
members = ["tket2", "tket2-py", "compile-rewriter", "taso-optimiser", "hugr2phir"]
default-members = ["tket2"]

[workspace.package]
Expand Down
14 changes: 14 additions & 0 deletions hugr2phir/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "hugr2phir"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
clap = { version = "4.4.2", features = ["derive"] }
tket2 = { workspace = true }
quantinuum-hugr = { workspace = true }
rmp-serde = "1.1.2"
serde_json = "1.0.107"
itertools.workspace = true
81 changes: 81 additions & 0 deletions hugr2phir/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
mod normalize;

use std::fs::File;
use std::io::BufReader;
use std::path::PathBuf;

use clap::Parser;

use hugr::{
hugr::views::{DescendantsGraph, HierarchyView},
ops::{OpTag, OpTrait, OpType},
Hugr, HugrView,
};

use tket2::phir::circuit_to_phir;

#[derive(Parser, Debug)]
#[clap(version = "1.0", long_about = None)]
#[clap(about = "Convert from hugr msgpack serialized form to PHIR JSON.")]
#[command(long_about = "Sets the input file to use. It must be serialized HUGR.")]
struct CmdLineArgs {
/// Name of input file/folder
input: PathBuf,
/// Name of output file/folder
#[arg(
short,
long,
value_name = "FILE",
default_value = None,
help = "Sets the output file or folder. Defaults to the same as the input file with a .json extension."
)]
output: Option<PathBuf>,
}

fn main() {
let CmdLineArgs { input, output } = CmdLineArgs::parse();

let reader = BufReader::new(File::open(&input).unwrap());
let output = output.unwrap_or_else(|| {
let mut output = input.clone();
output.set_extension("json");
output
});

let mut hugr: Hugr = rmp_serde::from_read(reader).unwrap();
normalize::remove_identity_tuples(&mut hugr);
// DescendantsGraph::try_new(&hugr, root).unwrap()
let root = hugr.root();
let root_op_tag = hugr.get_optype(root).tag();
let circ: DescendantsGraph = if OpTag::DataflowParent.is_superset(root_op_tag) {
// Some dataflow graph
DescendantsGraph::try_new(&hugr, root).unwrap()
} else if OpTag::ModuleRoot.is_superset(root_op_tag) {
// Assume Guppy generated module

// just take the first function
let main_node = hugr
.children(hugr.root())
.find(|n| matches!(hugr.get_optype(*n), OpType::FuncDefn(_)))
.expect("Module contains no functions.");
// just take the first node again...assume guppy source so always top
// level CFG
let cfg_node = hugr
.children(main_node)
.find(|n| matches!(hugr.get_optype(*n), OpType::CFG(_)))
.expect("Function contains no cfg.");

// Now is a bit sketchy...assume only one basic block in CFG
let block_node = hugr
.children(cfg_node)
.find(|n| matches!(hugr.get_optype(*n), OpType::BasicBlock(_)))
.expect("CFG contains no basic block.");
DescendantsGraph::try_new(&hugr, block_node).unwrap()
} else {
panic!("HUGR Root Op type {root_op_tag:?} not supported");
};

let phir = circuit_to_phir(&circ).unwrap();

serde_json::to_writer(File::create(output).unwrap(), &phir).unwrap();
}
118 changes: 118 additions & 0 deletions hugr2phir/src/normalize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use hugr::builder::Dataflow;
use hugr::builder::DataflowHugr;
use hugr::HugrView;
use hugr::SimpleReplacement;

use hugr::hugr::hugrmut::HugrMut;
use hugr::hugr::views::SiblingSubgraph;
use hugr::ops::OpType;
use itertools::Itertools;
use tket2::extension::REGISTRY;

use hugr::ops::LeafOp;

use hugr::types::FunctionType;

use hugr::builder::DFGBuilder;

use hugr::types::TypeRow;

use hugr::Hugr;

fn identity_dfg(type_combination: TypeRow) -> Hugr {
let identity_build = DFGBuilder::new(FunctionType::new(
type_combination.clone(),
type_combination,
))
.unwrap();
let inputs = identity_build.input_wires();
identity_build
.finish_hugr_with_outputs(inputs, &REGISTRY)
.unwrap()
}

fn find_make_unmake(hugr: &impl HugrView) -> impl Iterator<Item = SimpleReplacement> + '_ {
hugr.nodes().filter_map(|n| {
let op = hugr.get_optype(n);

let OpType::LeafOp(LeafOp::MakeTuple { tys }) = op else {
return None;
};

let Ok(neighbour) = hugr.output_neighbours(n).exactly_one() else {
return None;
};

let OpType::LeafOp(LeafOp::UnpackTuple { .. }) = hugr.get_optype(neighbour) else {
return None;
};

let sibling_graph = SiblingSubgraph::try_from_nodes([n, neighbour], hugr)
.expect("Make unmake should be valid subgraph.");

let replacement = identity_dfg(tys.clone());
sibling_graph
.create_simple_replacement(hugr, replacement)
.ok()
})
}

/// Remove any pairs of MakeTuple immediately followed by UnpackTuple (an
/// identity operation)
pub(crate) fn remove_identity_tuples(circ: &mut impl HugrMut) {
let rewrites: Vec<_> = find_make_unmake(circ).collect();
// should be able to apply all in parallel unless there are copies...

for rw in rewrites {
circ.apply_rewrite(rw).unwrap();
}
}

#[cfg(test)]
mod test {
use super::*;
use hugr::extension::prelude::BOOL_T;
use hugr::extension::prelude::QB_T;
use hugr::type_row;
use hugr::HugrView;

fn make_unmake_tuple(type_combination: TypeRow) -> Hugr {
let mut b = DFGBuilder::new(FunctionType::new(
type_combination.clone(),
type_combination.clone(),
))
.unwrap();
let input_wires = b.input_wires();

let tuple = b
.add_dataflow_op(
LeafOp::MakeTuple {
tys: type_combination.clone(),
},
input_wires,
)
.unwrap();

let unpacked = b
.add_dataflow_op(
LeafOp::UnpackTuple {
tys: type_combination,
},
tuple.outputs(),
)
.unwrap();

b.finish_hugr_with_outputs(unpacked.outputs(), &REGISTRY)
.unwrap()
}
#[test]
fn test_remove_id_tuple() {
let mut h = make_unmake_tuple(type_row![QB_T, BOOL_T]);

assert_eq!(h.node_count(), 5);

remove_identity_tuples(&mut h);

assert_eq!(h.node_count(), 3);
}
}
9 changes: 8 additions & 1 deletion tket2/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ use hugr::extension::{ExtensionId, ExtensionRegistry, SignatureError};
use hugr::hugr::IdentList;
use hugr::ops::custom::{ExternalOp, OpaqueOp};
use hugr::ops::OpName;
use hugr::std_extensions::arithmetic::float_types::extension as float_extension;
use hugr::std_extensions::arithmetic::{
float_types::extension as float_extension, int_ops::extension as int_ops_extension,
int_types::extension as int_types_extension,
};

use hugr::types::type_param::{CustomTypeArg, TypeArg, TypeParam};
use hugr::types::{CustomType, FunctionType, Type, TypeBound};
use hugr::Extension;
Expand Down Expand Up @@ -70,6 +74,9 @@ pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::from([
PRELUDE.clone(),
T2EXTENSION.clone(),
float_extension(),
int_ops_extension(),
int_types_extension(),

]);


Expand Down
1 change: 1 addition & 0 deletions tket2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod rewrite;
#[cfg(feature = "portmatching")]
pub mod portmatching;

pub mod phir;
mod utils;

pub use circuit::Circuit;
Expand Down
7 changes: 7 additions & 0 deletions tket2/src/phir.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//! Rust struct for PHIR and conversion from HUGR.

mod convert;
mod model;

pub use convert::circuit_to_phir;
pub use model::PHIRModel;
Loading