diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index bf70268d8..04ec38b4b 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -1,13 +1,18 @@ +use itertools::Itertools; + use super::build_traits::{HugrBuilder, SubContainer}; use super::handle::BuildHandle; use super::{BuildError, Container, Dataflow, DfgID, FuncID}; use std::marker::PhantomData; +use crate::hugr::internal::HugrMutInternals; use crate::hugr::{HugrView, ValidationError}; use crate::ops; +use crate::ops::{DataflowParent, Input, Output}; +use crate::{Direction, IncomingPort, OutgoingPort, Wire}; -use crate::types::{PolyFuncType, Signature}; +use crate::types::{PolyFuncType, Signature, Type}; use crate::extension::ExtensionRegistry; use crate::Node; @@ -155,6 +160,109 @@ impl FunctionBuilder { let db = DFGBuilder::create_with_io(base, root, body)?; Ok(Self::from_dfg_builder(db)) } + + /// Add a new input to the function being constructed. + /// + /// Returns the new wire from the input node. + pub fn add_input(&mut self, input_type: Type) -> Wire { + let [inp_node, _] = self.io(); + + // Update the parent's root type + let new_optype = self.update_fn_signature(|mut s| { + s.input.to_mut().push(input_type); + s + }); + + // Update the inner input node + let types = new_optype.signature.body().input.clone(); + self.hugr_mut() + .replace_op(inp_node, Input { types }) + .unwrap(); + let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1); + let new_port = new_port.next().unwrap(); + + // The last port in an input/output node is an order edge port, so we must shift any connections to it. + let new_value_port: OutgoingPort = (new_port - 1).into(); + let new_order_port: OutgoingPort = new_port.into(); + let order_edge_targets = self + .hugr() + .linked_inputs(inp_node, new_value_port) + .collect_vec(); + self.hugr_mut().disconnect(inp_node, new_value_port); + for (tgt_node, tgt_port) in order_edge_targets { + self.hugr_mut() + .connect(inp_node, new_order_port, tgt_node, tgt_port); + } + + // Update the builder metadata + self.0.num_in_wires += 1; + + self.input_wires().last().unwrap() + } + + /// Add a new output to the function being constructed. + pub fn add_output(&mut self, output_type: Type) { + let [_, out_node] = self.io(); + + // Update the parent's root type + let new_optype = self.update_fn_signature(|mut s| { + s.output.to_mut().push(output_type); + s + }); + + // Update the inner input node + let types = new_optype.signature.body().output.clone(); + self.hugr_mut() + .replace_op(out_node, Output { types }) + .unwrap(); + let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1); + let new_port = new_port.next().unwrap(); + + // The last port in an input/output node is an order edge port, so we must shift any connections to it. + let new_value_port: IncomingPort = (new_port - 1).into(); + let new_order_port: IncomingPort = new_port.into(); + let order_edge_sources = self + .hugr() + .linked_outputs(out_node, new_value_port) + .collect_vec(); + self.hugr_mut().disconnect(out_node, new_value_port); + for (src_node, src_port) in order_edge_sources { + self.hugr_mut() + .connect(src_node, src_port, out_node, new_order_port); + } + + // Update the builder metadata + self.0.num_out_wires += 1; + } + + /// Update the function builder's parent signature. + /// + /// Internal function used in [add_input] and [add_output]. + /// + /// Does not update the input and output nodes. + /// + /// Returns a reference to the new optype. + fn update_fn_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn { + let parent = self.container_node(); + let old_optype = self + .hugr() + .get_optype(parent) + .as_func_defn() + .expect("FunctionBuilder node must be a FuncDefn"); + let signature = old_optype.inner_signature(); + let name = old_optype.name.clone(); + self.hugr_mut() + .replace_op( + parent, + ops::FuncDefn { + signature: f(signature).into(), + name, + }, + ) + .expect("Could not replace FunctionBuilder operation"); + + self.hugr().get_optype(parent).as_func_defn().unwrap() + } } impl + AsRef, T> Container for DFGWrapper { @@ -199,6 +307,7 @@ impl HugrBuilder for DFGWrapper { #[cfg(test)] pub(crate) mod test { use cool_asserts::assert_matches; + use ops::OpParent; use rstest::rstest; use serde_json::json; @@ -344,6 +453,38 @@ pub(crate) mod test { assert_matches!(builder(), Ok(_)); } + #[test] + fn add_inputs_outputs() { + let builder = || -> Result<(Hugr, Node), BuildError> { + let mut f_build = FunctionBuilder::new( + "main", + Signature::new(type_row![BIT], type_row![BIT]).with_prelude(), + )?; + let f_node = f_build.container_node(); + + let [i0] = f_build.input_wires_arr(); + let noop0 = f_build.add_dataflow_op(Noop(BIT), [i0])?; + + // Some some order edges + f_build.set_order(&f_build.io()[0], &noop0.node()); + f_build.set_order(&noop0.node(), &f_build.io()[1]); + + // Add a new input and output, and connect them with a noop in between + f_build.add_output(QB); + let i1 = f_build.add_input(QB); + let noop1 = f_build.add_dataflow_op(Noop(QB), [i1])?; + + let hugr = + f_build.finish_prelude_hugr_with_outputs([noop0.out_wire(0), noop1.out_wire(0)])?; + Ok((hugr, f_node)) + }; + + let (hugr, f_node) = builder().unwrap_or_else(|e| panic!("{e}")); + + let func_sig = hugr.get_optype(f_node).inner_function_type().unwrap(); + assert_eq!(func_sig.io(), (&type_row![BIT, QB], &type_row![BIT, QB])); + } + #[test] fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> { let mut f_build = diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index b8ac050c3..6062e4084 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -74,6 +74,8 @@ pub trait HugrMutInternals: RootTagged { /// The `direction` parameter specifies whether to add ports to the incoming /// or outgoing list. /// + /// Returns the range of newly created ports. + /// /// # Panics /// /// If the node is not in the graph. @@ -129,6 +131,9 @@ pub trait HugrMutInternals: RootTagged { /// Replace the OpType at node and return the old OpType. /// In general this invalidates the ports, which may need to be resized to /// match the OpType signature. + /// + /// Returns the old OpType. + /// /// TODO: Add a version which ignores input extensions /// /// # Errors