Skip to content

Commit

Permalink
feat: FunctionBuilder::add_{in,out}put (#1570)
Browse files Browse the repository at this point in the history
Closes #1562.

It's a chunk of code fiddling with the hugr internals, but it is
something specific to function definitions so I think it should be ok to
have it under `FunctionBuilder`.
  • Loading branch information
aborgna-q authored Oct 14, 2024
1 parent 36548ab commit 3d7405c
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 1 deletion.
143 changes: 142 additions & 1 deletion hugr-core/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use itertools::Itertools;

use super::build_traits::{HugrBuilder, SubContainer};
use super::handle::BuildHandle;
use super::{BuildError, Container, Dataflow, DfgID, FuncID};

use std::marker::PhantomData;

use crate::hugr::internal::HugrMutInternals;
use crate::hugr::{HugrView, ValidationError};
use crate::ops;
use crate::ops::{DataflowParent, Input, Output};
use crate::{Direction, IncomingPort, OutgoingPort, Wire};

use crate::types::{PolyFuncType, Signature};
use crate::types::{PolyFuncType, Signature, Type};

use crate::extension::ExtensionRegistry;
use crate::Node;
Expand Down Expand Up @@ -155,6 +160,109 @@ impl FunctionBuilder<Hugr> {
let db = DFGBuilder::create_with_io(base, root, body)?;
Ok(Self::from_dfg_builder(db))
}

/// Add a new input to the function being constructed.
///
/// Returns the new wire from the input node.
pub fn add_input(&mut self, input_type: Type) -> Wire {
let [inp_node, _] = self.io();

// Update the parent's root type
let new_optype = self.update_fn_signature(|mut s| {
s.input.to_mut().push(input_type);
s
});

// Update the inner input node
let types = new_optype.signature.body().input.clone();
self.hugr_mut()
.replace_op(inp_node, Input { types })
.unwrap();
let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1);
let new_port = new_port.next().unwrap();

// The last port in an input/output node is an order edge port, so we must shift any connections to it.
let new_value_port: OutgoingPort = (new_port - 1).into();
let new_order_port: OutgoingPort = new_port.into();
let order_edge_targets = self
.hugr()
.linked_inputs(inp_node, new_value_port)
.collect_vec();
self.hugr_mut().disconnect(inp_node, new_value_port);
for (tgt_node, tgt_port) in order_edge_targets {
self.hugr_mut()
.connect(inp_node, new_order_port, tgt_node, tgt_port);
}

// Update the builder metadata
self.0.num_in_wires += 1;

self.input_wires().last().unwrap()
}

/// Add a new output to the function being constructed.
pub fn add_output(&mut self, output_type: Type) {
let [_, out_node] = self.io();

// Update the parent's root type
let new_optype = self.update_fn_signature(|mut s| {
s.output.to_mut().push(output_type);
s
});

// Update the inner input node
let types = new_optype.signature.body().output.clone();
self.hugr_mut()
.replace_op(out_node, Output { types })
.unwrap();
let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1);
let new_port = new_port.next().unwrap();

// The last port in an input/output node is an order edge port, so we must shift any connections to it.
let new_value_port: IncomingPort = (new_port - 1).into();
let new_order_port: IncomingPort = new_port.into();
let order_edge_sources = self
.hugr()
.linked_outputs(out_node, new_value_port)
.collect_vec();
self.hugr_mut().disconnect(out_node, new_value_port);
for (src_node, src_port) in order_edge_sources {
self.hugr_mut()
.connect(src_node, src_port, out_node, new_order_port);
}

// Update the builder metadata
self.0.num_out_wires += 1;
}

/// Update the function builder's parent signature.
///
/// Internal function used in [add_input] and [add_output].
///
/// Does not update the input and output nodes.
///
/// Returns a reference to the new optype.
fn update_fn_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {
let parent = self.container_node();
let old_optype = self
.hugr()
.get_optype(parent)
.as_func_defn()
.expect("FunctionBuilder node must be a FuncDefn");
let signature = old_optype.inner_signature();
let name = old_optype.name.clone();
self.hugr_mut()
.replace_op(
parent,
ops::FuncDefn {
signature: f(signature).into(),
name,
},
)
.expect("Could not replace FunctionBuilder operation");

self.hugr().get_optype(parent).as_func_defn().unwrap()
}
}

impl<B: AsMut<Hugr> + AsRef<Hugr>, T> Container for DFGWrapper<B, T> {
Expand Down Expand Up @@ -199,6 +307,7 @@ impl<T> HugrBuilder for DFGWrapper<Hugr, T> {
#[cfg(test)]
pub(crate) mod test {
use cool_asserts::assert_matches;
use ops::OpParent;
use rstest::rstest;
use serde_json::json;

Expand Down Expand Up @@ -344,6 +453,38 @@ pub(crate) mod test {
assert_matches!(builder(), Ok(_));
}

#[test]
fn add_inputs_outputs() {
let builder = || -> Result<(Hugr, Node), BuildError> {
let mut f_build = FunctionBuilder::new(
"main",
Signature::new(type_row![BIT], type_row![BIT]).with_prelude(),
)?;
let f_node = f_build.container_node();

let [i0] = f_build.input_wires_arr();
let noop0 = f_build.add_dataflow_op(Noop(BIT), [i0])?;

// Some some order edges
f_build.set_order(&f_build.io()[0], &noop0.node());
f_build.set_order(&noop0.node(), &f_build.io()[1]);

// Add a new input and output, and connect them with a noop in between
f_build.add_output(QB);
let i1 = f_build.add_input(QB);
let noop1 = f_build.add_dataflow_op(Noop(QB), [i1])?;

let hugr =
f_build.finish_prelude_hugr_with_outputs([noop0.out_wire(0), noop1.out_wire(0)])?;
Ok((hugr, f_node))
};

let (hugr, f_node) = builder().unwrap_or_else(|e| panic!("{e}"));

let func_sig = hugr.get_optype(f_node).inner_function_type().unwrap();
assert_eq!(func_sig.io(), (&type_row![BIT, QB], &type_row![BIT, QB]));
}

#[test]
fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> {
let mut f_build =
Expand Down
5 changes: 5 additions & 0 deletions hugr-core/src/hugr/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ pub trait HugrMutInternals: RootTagged {
/// The `direction` parameter specifies whether to add ports to the incoming
/// or outgoing list.
///
/// Returns the range of newly created ports.
///
/// # Panics
///
/// If the node is not in the graph.
Expand Down Expand Up @@ -129,6 +131,9 @@ pub trait HugrMutInternals: RootTagged {
/// Replace the OpType at node and return the old OpType.
/// In general this invalidates the ports, which may need to be resized to
/// match the OpType signature.
///
/// Returns the old OpType.
///
/// TODO: Add a version which ignores input extensions
///
/// # Errors
Expand Down

0 comments on commit 3d7405c

Please sign in to comment.