Skip to content

Commit

Permalink
Replace ft1/ft2 with impl From<(TypeRow,[TypeRow])> for FunctionType
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Jul 8, 2024
1 parent cdc3739 commit 557e7db
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 36 deletions.
16 changes: 8 additions & 8 deletions hugr-core/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,16 @@ pub use conditional::{CaseBuilder, ConditionalBuilder};
mod circuit;
pub use circuit::{CircuitBuildError, CircuitBuilder};

/// Return a FunctionType with the same input and output types (specified)
/// whose extension delta, when used in a non-FuncDefn container, will be inferred.
pub fn ft1(types: impl Into<TypeRow>) -> FunctionType {
FunctionType::new_endo(types).with_extension_delta(TO_BE_INFERRED)
impl<T: Into<TypeRow>> From<(T,)> for FunctionType {
fn from(value: (T,)) -> Self {
FunctionType::new_endo(value.0).with_extension_delta(TO_BE_INFERRED)
}
}

/// Return a FunctionType with the specified input and output types
/// whose extension delta, when used in a non-FuncDefn container, will be inferred.
pub fn ft2(inputs: impl Into<TypeRow>, outputs: impl Into<TypeRow>) -> FunctionType {
FunctionType::new(inputs, outputs).with_extension_delta(TO_BE_INFERRED)
impl<T1: Into<TypeRow>, T2: Into<TypeRow>> From<(T1, T2)> for FunctionType {
fn from(value: (T1, T2)) -> Self {
FunctionType::new(value.0, value.1).with_extension_delta(TO_BE_INFERRED)
}
}

#[derive(Debug, Clone, PartialEq, Error)]
Expand Down
3 changes: 2 additions & 1 deletion hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,10 @@ pub trait Dataflow: Container {
// TODO: Should this be one function, or should there be a temporary "op" one like with the others?
fn dfg_builder(
&mut self,
signature: FunctionType,
signature: impl Into<FunctionType>,
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
let signature = signature.into();
let op = ops::DFG {
signature: signature.clone(),
};
Expand Down
11 changes: 6 additions & 5 deletions hugr-core/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ impl DFGBuilder<Hugr> {
/// # Errors
///
/// Error in adding DFG child nodes.
pub fn new(signature: FunctionType) -> Result<DFGBuilder<Hugr>, BuildError> {
pub fn new(signature: impl Into<FunctionType>) -> Result<DFGBuilder<Hugr>, BuildError> {
let signature = signature.into();
let dfg_op = ops::DFG {
signature: signature.clone(),
};
Expand Down Expand Up @@ -203,7 +204,7 @@ pub(crate) mod test {
use serde_json::json;

use crate::builder::build_traits::DataflowHugr;
use crate::builder::{ft1, BuilderWiringError, DataflowSubContainer, ModuleBuilder};
use crate::builder::{BuilderWiringError, DataflowSubContainer, ModuleBuilder};
use crate::extension::prelude::{BOOL_T, USIZE_T};
use crate::extension::{ExtensionId, SignatureError, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::validate::InterGraphEdgeError;
Expand Down Expand Up @@ -420,12 +421,12 @@ pub(crate) mod test {
let xb: ExtensionId = "B".try_into().unwrap();
let xc: ExtensionId = "C".try_into().unwrap();

let mut parent = DFGBuilder::new(ft1(BIT))?;
let mut parent = DFGBuilder::new((BIT,))?;

let [w] = parent.input_wires_arr();

// A box which adds extensions A and B, via child Lift nodes
let mut add_ab = parent.dfg_builder(ft1(BIT), [w])?;
let mut add_ab = parent.dfg_builder((BIT,), [w])?;
let [w] = add_ab.input_wires_arr();

let lift_a = add_ab.add_dataflow_op(
Expand All @@ -451,7 +452,7 @@ pub(crate) mod test {

// Add another node (a sibling to add_ab) which adds extension C
// via a child lift node
let mut add_c = parent.dfg_builder(ft1(BIT), [w])?;
let mut add_c = parent.dfg_builder((BIT,), [w])?;
let [w] = add_c.input_wires_arr();
let lift_c = add_c.add_dataflow_op(
Lift {
Expand Down
8 changes: 4 additions & 4 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ impl CustomConst for ConstExternalSymbol {
#[cfg(test)]
mod test {
use crate::{
builder::{ft1, DFGBuilder, Dataflow, DataflowHugr},
builder::{DFGBuilder, Dataflow, DataflowHugr},
utils::test_quantum_extension::cx_gate,
Hugr, Wire,
};
Expand Down Expand Up @@ -452,7 +452,7 @@ mod test {
assert!(error_val.equal_consts(&ConstError::new(2, "my message")));
assert!(!error_val.equal_consts(&ConstError::new(3, "my message")));

let mut b = DFGBuilder::new(ft1(type_row![])).unwrap();
let mut b = DFGBuilder::new((type_row![],)).unwrap();

let err = b.add_load_value(error_val);

Expand Down Expand Up @@ -486,7 +486,7 @@ mod test {
)
.unwrap();

let mut b = DFGBuilder::new(ft1(type_row![QB_T, QB_T])).unwrap();
let mut b = DFGBuilder::new((type_row![QB_T, QB_T],)).unwrap();
let [q0, q1] = b.input_wires_arr();
let [q0, q1] = b
.add_dataflow_op(cx_gate(), [q0, q1])
Expand Down Expand Up @@ -524,7 +524,7 @@ mod test {
#[test]
/// Test print operation
fn test_print() {
let mut b: DFGBuilder<Hugr> = DFGBuilder::new(ft1(vec![])).unwrap();
let mut b: DFGBuilder<Hugr> = DFGBuilder::new((vec![],)).unwrap();
let greeting: ConstString = ConstString::new("Hello, world!".into());
let greeting_out: Wire = b.add_load_value(greeting);
let print_op = PRELUDE
Expand Down
10 changes: 5 additions & 5 deletions hugr-core/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ mod test {
use rstest::rstest;

use crate::builder::{
ft1, ft2, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
};
use crate::extension::prelude::QB_T;
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
Expand Down Expand Up @@ -174,7 +174,7 @@ mod test {
.unwrap();
let int_ty = &int_types::INT_TYPES[6];

let mut outer = DFGBuilder::new(ft2(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
let mut outer = DFGBuilder::new((vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
let [a, b] = outer.input_wires_arr();
fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
d: &mut DFGBuilder<T>,
Expand Down Expand Up @@ -244,7 +244,7 @@ mod test {

#[test]
fn permutation() -> Result<(), Box<dyn std::error::Error>> {
let mut h = DFGBuilder::new(ft1(type_row![QB_T, QB_T]))?;
let mut h = DFGBuilder::new((type_row![QB_T, QB_T],))?;
let [p, q] = h.input_wires_arr();
let [p_h] = h
.add_dataflow_op(test_quantum_extension::h_gate(), [p])?
Expand Down Expand Up @@ -339,11 +339,11 @@ mod test {
PRELUDE.to_owned(),
])
.unwrap();
let mut outer = DFGBuilder::new(ft1(type_row![QB_T, QB_T]))?;
let mut outer = DFGBuilder::new((type_row![QB_T, QB_T],))?;
let [a, b] = outer.input_wires_arr();
let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
let mut inner = outer.dfg_builder(ft1(QB_T), h_b.outputs())?;
let mut inner = outer.dfg_builder((QB_T,), h_b.outputs())?;
let [i] = inner.input_wires_arr();
let f = inner.add_load_value(float_types::ConstF64::new(1.0));
inner.add_other_wire(inner.input().node(), f.node());
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::*;
use crate::builder::{
ft2, test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr,
test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr,
DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::{BOOL_T, PRELUDE_ID, QB_T, USIZE_T};
Expand Down Expand Up @@ -351,7 +351,7 @@ fn hierarchy_order() -> Result<(), Box<dyn std::error::Error>> {

#[test]
fn constants_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
let mut builder = DFGBuilder::new(ft2(vec![], vec![INT_TYPES[4].clone()])).unwrap();
let mut builder = DFGBuilder::new((vec![], vec![INT_TYPES[4].clone()])).unwrap();
let w = builder.add_load_value(ConstInt::new_s(4, -2).unwrap());
let hugr = builder.finish_hugr_with_outputs([w], &INT_OPS_REGISTRY)?;

Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use rstest::rstest;
use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
ft2, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer,
};
use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, QB_T, USIZE_T};
Expand Down Expand Up @@ -769,7 +769,7 @@ fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {

let int_pair = Type::new_tuple(type_row![USIZE_T; 2]);
// Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints
let mut d = DFGBuilder::new(ft2(
let mut d = DFGBuilder::new((
vec![utou(PRELUDE_ID), int_pair.clone()],
vec![int_pair.clone()],
))?;
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use portgraph::PortOffset;
use rstest::{fixture, rstest};

use crate::{
builder::{ft2, BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr},
builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::QB_T,
ops::{
handle::{DataflowOpID, NodeHandle},
Expand Down Expand Up @@ -152,7 +152,7 @@ fn value_types() {
fn static_targets() {
use crate::extension::prelude::{ConstUsize, USIZE_T};
use itertools::Itertools;
let mut dfg = DFGBuilder::new(ft2(type_row![], type_row![USIZE_T])).unwrap();
let mut dfg = DFGBuilder::new((type_row![], type_row![USIZE_T])).unwrap();

let c = dfg.add_constant(Value::extension(ConstUsize::new(1)));

Expand Down
6 changes: 1 addition & 5 deletions hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,6 @@ pub type ValueNameRef = str;
#[cfg(test)]
mod test {
use super::Value;
use crate::builder::ft2;
use crate::builder::test::simple_dfg_hugr;
use crate::std_extensions::arithmetic::int_types::ConstInt;
use crate::{
Expand Down Expand Up @@ -587,10 +586,7 @@ mod test {
let pred_rows = vec![type_row![USIZE_T, FLOAT64_TYPE], Type::EMPTY_TYPEROW];
let pred_ty = SumType::new(pred_rows.clone());

let mut b = DFGBuilder::new(ft2(
type_row![],
TypeRow::from(vec![pred_ty.clone().into()]),
))?;
let mut b = DFGBuilder::new((type_row![], TypeRow::from(vec![pred_ty.clone().into()])))?;
let c = b.add_constant(Value::sum(
0,
[
Expand Down
3 changes: 1 addition & 2 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use std::collections::{BTreeSet, HashMap};

use hugr_core::builder::ft2;
use itertools::Itertools;
use thiserror::Error;

Expand Down Expand Up @@ -136,7 +135,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR
/// against `reg`.
fn const_graph(consts: Vec<Value>, reg: &ExtensionRegistry) -> Hugr {
let const_types = consts.iter().map(Value::get_type).collect_vec();
let mut b = DFGBuilder::new(ft2(type_row![], const_types)).unwrap();
let mut b = DFGBuilder::new((type_row![], const_types)).unwrap();

let outputs = consts
.into_iter()
Expand Down

0 comments on commit 557e7db

Please sign in to comment.