Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: FunctionBuilder::add_{in,out}put #1570

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 142 additions & 1 deletion hugr-core/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use itertools::Itertools;

use super::build_traits::{HugrBuilder, SubContainer};
use super::handle::BuildHandle;
use super::{BuildError, Container, Dataflow, DfgID, FuncID};

use std::marker::PhantomData;

use crate::hugr::internal::HugrMutInternals;
use crate::hugr::{HugrView, ValidationError};
use crate::ops;
use crate::ops::{DataflowParent, Input, Output};
use crate::{Direction, IncomingPort, OutgoingPort, Wire};

use crate::types::{PolyFuncType, Signature};
use crate::types::{PolyFuncType, Signature, Type};

use crate::extension::ExtensionRegistry;
use crate::Node;
Expand Down Expand Up @@ -155,6 +160,109 @@ impl FunctionBuilder<Hugr> {
let db = DFGBuilder::create_with_io(base, root, body)?;
Ok(Self::from_dfg_builder(db))
}

/// Add a new input to the function being constructed.
///
/// Returns the new wire from the input node.
pub fn add_input(&mut self, input_type: Type) -> Wire {
let [inp_node, _] = self.io();

// Update the parent's root type
let new_optype = self.update_fn_signature(|mut s| {
s.input.to_mut().push(input_type);
s
});

// Update the inner input node
let types = new_optype.signature.body().input.clone();
self.hugr_mut()
.replace_op(inp_node, Input { types })
.unwrap();
let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1);
let new_port = new_port.next().unwrap();

// The last port in an input/output node is an order edge port, so we must shift any connections to it.
let new_value_port: OutgoingPort = (new_port - 1).into();
let new_order_port: OutgoingPort = new_port.into();
let order_edge_targets = self
.hugr()
.linked_inputs(inp_node, new_value_port)
.collect_vec();
self.hugr_mut().disconnect(inp_node, new_value_port);
for (tgt_node, tgt_port) in order_edge_targets {
self.hugr_mut()
.connect(inp_node, new_order_port, tgt_node, tgt_port);
}

// Update the builder metadata
self.0.num_in_wires += 1;

self.input_wires().last().unwrap()
}

/// Add a new output to the function being constructed.
pub fn add_output(&mut self, output_type: Type) {
let [_, out_node] = self.io();

// Update the parent's root type
let new_optype = self.update_fn_signature(|mut s| {
s.output.to_mut().push(output_type);
s
});

// Update the inner input node
let types = new_optype.signature.body().output.clone();
self.hugr_mut()
.replace_op(out_node, Output { types })
.unwrap();
let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1);
let new_port = new_port.next().unwrap();

// The last port in an input/output node is an order edge port, so we must shift any connections to it.
let new_value_port: IncomingPort = (new_port - 1).into();
let new_order_port: IncomingPort = new_port.into();
let order_edge_sources = self
.hugr()
.linked_outputs(out_node, new_value_port)
.collect_vec();
self.hugr_mut().disconnect(out_node, new_value_port);
for (src_node, src_port) in order_edge_sources {
self.hugr_mut()
.connect(src_node, src_port, out_node, new_order_port);
}

// Update the builder metadata
self.0.num_out_wires += 1;
}

/// Update the function builder's parent signature.
///
/// Internal function used in [add_input] and [add_output].
///
/// Does not update the input and output nodes.
///
/// Returns a reference to the new optype.
fn update_fn_signature(&mut self, f: impl FnOnce(Signature) -> Signature) -> &ops::FuncDefn {
let parent = self.container_node();
let old_optype = self
.hugr()
.get_optype(parent)
.as_func_defn()
.expect("FunctionBuilder node must be a FuncDefn");
let signature = old_optype.inner_signature();
let name = old_optype.name.clone();
self.hugr_mut()
.replace_op(
parent,
ops::FuncDefn {
signature: f(signature).into(),
name,
},
)
.expect("Could not replace FunctionBuilder operation");

self.hugr().get_optype(parent).as_func_defn().unwrap()
}
}

impl<B: AsMut<Hugr> + AsRef<Hugr>, T> Container for DFGWrapper<B, T> {
Expand Down Expand Up @@ -199,6 +307,7 @@ impl<T> HugrBuilder for DFGWrapper<Hugr, T> {
#[cfg(test)]
pub(crate) mod test {
use cool_asserts::assert_matches;
use ops::OpParent;
use rstest::rstest;
use serde_json::json;

Expand Down Expand Up @@ -344,6 +453,38 @@ pub(crate) mod test {
assert_matches!(builder(), Ok(_));
}

#[test]
fn add_inputs_outputs() {
let builder = || -> Result<(Hugr, Node), BuildError> {
let mut f_build = FunctionBuilder::new(
"main",
Signature::new(type_row![BIT], type_row![BIT]).with_prelude(),
)?;
let f_node = f_build.container_node();

let [i0] = f_build.input_wires_arr();
let noop0 = f_build.add_dataflow_op(Noop(BIT), [i0])?;

// Some some order edges
f_build.set_order(&f_build.io()[0], &noop0.node());
f_build.set_order(&noop0.node(), &f_build.io()[1]);

// Add a new input and output, and connect them with a noop in between
f_build.add_output(QB);
let i1 = f_build.add_input(QB);
let noop1 = f_build.add_dataflow_op(Noop(QB), [i1])?;

let hugr =
f_build.finish_prelude_hugr_with_outputs([noop0.out_wire(0), noop1.out_wire(0)])?;
Ok((hugr, f_node))
};

let (hugr, f_node) = builder().unwrap_or_else(|e| panic!("{e}"));

let func_sig = hugr.get_optype(f_node).inner_function_type().unwrap();
assert_eq!(func_sig.io(), (&type_row![BIT, QB], &type_row![BIT, QB]));
}

#[test]
fn error_on_linear_inter_graph_edge() -> Result<(), BuildError> {
let mut f_build =
Expand Down
5 changes: 5 additions & 0 deletions hugr-core/src/hugr/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ pub trait HugrMutInternals: RootTagged {
/// The `direction` parameter specifies whether to add ports to the incoming
/// or outgoing list.
///
/// Returns the range of newly created ports.
///
/// # Panics
///
/// If the node is not in the graph.
Expand Down Expand Up @@ -129,6 +131,9 @@ pub trait HugrMutInternals: RootTagged {
/// Replace the OpType at node and return the old OpType.
/// In general this invalidates the ports, which may need to be resized to
/// match the OpType signature.
///
/// Returns the old OpType.
///
/// TODO: Add a version which ignores input extensions
///
/// # Errors
Expand Down
Loading