Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Allow static inputs to extension operations #1628

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
23 changes: 22 additions & 1 deletion hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,28 @@ pub trait Dataflow: Container {

Ok(outs.into())
}
/// Add a dataflow [`OpType`] to the sibling graph, wiring up the `input_wires` to the
/// incoming ports of the resulting node, and the `static_wires` to the static input ports.
///
/// # Errors
///
/// Returns a [`BuildError::OperationWiring`] error if the `input_wires` or `static_wires` cannot be connected.
fn add_dataflow_op_with_static(
&mut self,
nodetype: impl Into<OpType>,
input_wires: impl IntoIterator<Item = Wire>,
static_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let op: OpType = nodetype.into();
let static_in_ports = op.static_input_ports();
let handle = self.add_dataflow_op(op, input_wires)?;

for (src, in_port) in static_wires.into_iter().zip(static_in_ports) {
self.hugr_mut()
.connect(src.node(), src.source(), handle.node(), in_port);
}
Ok(handle)
Comment on lines +206 to +214
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use self.add_node_with_wires with chained input_ and static_ directly?

}
/// Insert a hugr-defined op to the sibling graph, wiring up the
/// `input_wires` to the incoming ports of the resulting root node.
///
Expand Down Expand Up @@ -672,7 +693,7 @@ pub trait Dataflow: Container {
}
};
let op: OpType = ops::Call::try_new(type_scheme, type_args, exts)?.into();
let const_in_port = op.static_input_port().unwrap();
let const_in_port = op.static_input_ports()[0];
let op_id = self.add_dataflow_op(op, input_wires)?;
let src_port = self.hugr_mut().num_outputs(function.node()) - 1;

Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl<'a> Context<'a> {
/// Get the node that declares or defines the function associated with the given
/// node via the static input. Returns `None` if the node is not connected to a function.
fn connected_function(&self, node: Node) -> Option<Node> {
let func_node = self.hugr.static_source(node)?;
let func_node = *self.hugr.static_sources(node).first()?;

match self.hugr.get_optype(func_node) {
OpType::FuncDecl(_) => Some(func_node),
Expand Down Expand Up @@ -417,7 +417,7 @@ impl<'a> Context<'a> {
use std::collections::hash_map::Entry;

let poly_func_type = match opdef.signature_func() {
SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type,
SignatureFunc::PolyFuncType(op_def_sig) => op_def_sig.poly_func_type(),
_ => return self.make_named_global_ref(opdef.extension(), opdef.name()),
};

Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use crate::types::{Signature, TypeNameRef};

mod op_def;
pub use op_def::{
CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc,
ValidateJustArgs, ValidateTypeArgs,
CustomSignatureFunc, CustomValidator, ExtOpSignature, LowerFunc, OpDef, SignatureFromArgs,
SignatureFunc, ValidateJustArgs, ValidateTypeArgs,
};
mod type_def;
pub use type_def::{TypeDef, TypeDefBound};
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/extension/declarative/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use smol_str::SmolStr;
use crate::extension::prelude::PRELUDE_ID;
use crate::extension::{ExtensionSet, SignatureFunc, TypeDef};
use crate::types::type_param::TypeParam;
use crate::types::{CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeRowRV};
use crate::types::{CustomType, FuncValueType, OpDefSignature, Type, TypeRowRV};
use crate::Extension;

use super::{DeclarationContext, ExtensionDeclarationError};
Expand Down Expand Up @@ -56,7 +56,7 @@ impl SignatureDeclaration {
extension_reqs: self.extensions.clone(),
};

let poly_func = PolyFuncTypeRV::new(op_params, body);
let poly_func = OpDefSignature::new(op_params, body);
Ok(poly_func.into())
}
}
Expand Down
Loading
Loading