From 45d652c5c51ec13c2dbf1bb190ed409f576e4942 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 30 Oct 2024 17:18:35 +0000 Subject: [PATCH 01/14] feat: static_input returns vector --- hugr-core/src/ops.rs | 13 ++++++++++--- hugr-core/src/ops/dataflow.rs | 18 +++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 24ce8492e..14ceb29c4 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -159,7 +159,14 @@ impl OpType { #[inline] pub fn static_port_kind(&self, dir: Direction) -> Option { match dir { - Direction::Incoming => self.static_input(), + Direction::Incoming => { + let mut v = self.static_input(); + if v.is_empty() { + None + } else { + Some(v.remove(0)) + } + } Direction::Outgoing => self.static_output(), } } @@ -384,8 +391,8 @@ pub trait OpTrait { /// /// If not None, an extra input port of that kind will be present after the /// dataflow input ports and before any [`OpTrait::other_input`] ports. - fn static_input(&self) -> Option { - None + fn static_input(&self) -> Vec { + vec![] } /// The edge kind for a single constant output of the operation, not diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 364429784..f359026bf 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -46,8 +46,8 @@ pub trait DataflowOpTrait { /// If not None, an extra input port of that kind will be present after the /// dataflow input ports and before any [`DataflowOpTrait::other_input`] ports. #[inline] - fn static_input(&self) -> Option { - None + fn static_input(&self) -> Vec { + vec![] } } @@ -148,7 +148,7 @@ impl OpTrait for T { DataflowOpTrait::other_output(self) } - fn static_input(&self) -> Option { + fn static_input(&self) ->Vec { DataflowOpTrait::static_input(self) } } @@ -184,8 +184,8 @@ impl DataflowOpTrait for Call { self.instantiation.clone() } - fn static_input(&self) -> Option { - Some(EdgeKind::Function(self.called_function_type().clone())) + fn static_input(&self) -> Vec { + vec![EdgeKind::Function(self.called_function_type().clone())] } } impl Call { @@ -300,8 +300,8 @@ impl DataflowOpTrait for LoadConstant { Signature::new(TypeRow::new(), vec![self.datatype.clone()]) } - fn static_input(&self) -> Option { - Some(EdgeKind::Const(self.constant_type().clone())) + fn static_input(&self) -> Vec { + vec![EdgeKind::Const(self.constant_type().clone())] } } impl LoadConstant { @@ -355,8 +355,8 @@ impl DataflowOpTrait for LoadFunction { self.signature.clone() } - fn static_input(&self) -> Option { - Some(EdgeKind::Function(self.func_sig.clone())) + fn static_input(&self) -> Vec { + vec![EdgeKind::Function(self.func_sig.clone())] } } impl LoadFunction { From 093c594effa98443e6e210ace1acf58ac245b72c Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 30 Oct 2024 17:47:48 +0000 Subject: [PATCH 02/14] feat!: use vector interface for static inputs --- hugr-core/src/builder/build_traits.rs | 2 +- hugr-core/src/export.rs | 2 +- hugr-core/src/hugr/serialize.rs | 7 +++- hugr-core/src/hugr/validate/test.rs | 2 +- hugr-core/src/hugr/views.rs | 12 +++--- hugr-core/src/hugr/views/tests.rs | 2 +- hugr-core/src/import.rs | 7 +++- hugr-core/src/ops.rs | 55 +++++++++++++-------------- hugr-core/src/ops/dataflow.rs | 18 ++++----- 9 files changed, 58 insertions(+), 49 deletions(-) diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 2950bdc47..13ef1db4b 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -672,7 +672,7 @@ pub trait Dataflow: Container { } }; let op: OpType = ops::Call::try_new(type_scheme, type_args, exts)?.into(); - let const_in_port = op.static_input_port().unwrap(); + let const_in_port = op.static_input_ports()[0]; let op_id = self.add_dataflow_op(op, input_wires)?; let src_port = self.hugr_mut().num_outputs(function.node()) - 1; diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index e7a85c98f..1488e851e 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -151,7 +151,7 @@ impl<'a> Context<'a> { /// Get the node that declares or defines the function associated with the given /// node via the static input. Returns `None` if the node is not connected to a function. fn connected_function(&self, node: Node) -> Option { - let func_node = self.hugr.static_source(node)?; + let func_node = *self.hugr.static_sources(node).first()?; match self.hugr.get_optype(func_node) { OpType::FuncDecl(_) => Some(func_node), diff --git a/hugr-core/src/hugr/serialize.rs b/hugr-core/src/hugr/serialize.rs index 0e8f44985..9a213f8ef 100644 --- a/hugr-core/src/hugr/serialize.rs +++ b/hugr-core/src/hugr/serialize.rs @@ -181,8 +181,11 @@ impl TryFrom<&Hugr> for SerHugrLatest { let find_offset = |node: Node, offset: usize, dir: Direction, hugr: &Hugr| { let op = hugr.get_optype(node); - let is_value_port = offset < op.value_port_count(dir); - let is_static_input = op.static_port(dir).map_or(false, |p| p.index() == offset); + let value_count = op.value_port_count(dir); + let is_value_port = offset < value_count; + let static_len = op.static_ports(dir).len(); + let is_static_input = offset < (value_count + static_len); + // let is_static_input = op.static_port(dir).map_or(false, |p| p.index() == offset); let offset = (is_value_port || is_static_input).then_some(offset as u16); (node_rekey[&node], offset) }; diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 705649b36..9190713dd 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -322,7 +322,7 @@ fn test_local_const() { h.connect(cst, 0, lcst, 0); h.connect(lcst, 0, and, 1); - assert_eq!(h.static_source(lcst), Some(cst)); + assert_eq!(h.static_sources(lcst), vec![cst]); // There is no edge from Input to LoadConstant, but that's OK: h.update_validate(&EMPTY_REG).unwrap(); } diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index d17eaf44f..40f65b07b 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -416,11 +416,13 @@ pub trait HugrView: HugrInternals { .finish() } - /// If a node has a static input, return the source node. - fn static_source(&self, node: Node) -> Option { - self.linked_outputs(node, self.get_optype(node).static_input_port()?) - .next() - .map(|(n, _)| n) + /// If a node has static inputs, return the source nodes. + fn static_sources(&self, node: Node) -> Vec { + self.get_optype(node) + .static_input_ports() + .into_iter() + .filter_map(|port| self.linked_outputs(node, port).next().map(|(n, _)| n)) + .collect() } /// If a node has a static output, return the targets. diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index c52957b2c..70f2e296d 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -160,7 +160,7 @@ fn static_targets() { let h = dfg.finish_prelude_hugr_with_outputs([load]).unwrap(); - assert_eq!(h.static_source(load.node()), Some(c.node())); + assert_eq!(h.static_sources(load.node()), vec![c.node()]); assert_eq!( &h.static_targets(c.node()).unwrap().collect_vec()[..], diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index d981049fb..2d36ec873 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -251,7 +251,12 @@ impl<'a> Context<'a> { let src = self.nodes[&src_id]; let dst = self.nodes[&dst_id]; let src_port = self.hugr.get_optype(src).static_output_port().unwrap(); - let dst_port = self.hugr.get_optype(dst).static_input_port().unwrap(); + let dst_port = *self + .hugr + .get_optype(dst) + .static_input_ports() + .first() + .unwrap(); self.hugr.connect(src, src_port, dst, dst_port); } diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 14ceb29c4..c6fa12e89 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -157,17 +157,10 @@ impl OpType { /// given direction after any dataflow ports and before any /// [`OpType::other_port_kind`] ports. #[inline] - pub fn static_port_kind(&self, dir: Direction) -> Option { + pub fn static_port_kind(&self, dir: Direction) -> Vec { match dir { - Direction::Incoming => { - let mut v = self.static_input(); - if v.is_empty() { - None - } else { - Some(v.remove(0)) - } - } - Direction::Outgoing => self.static_output(), + Direction::Incoming => self.static_input(), + Direction::Outgoing => self.static_output().map(|k| vec![k]).unwrap_or_default(), } } @@ -188,11 +181,10 @@ impl OpType { } // Constant port - let static_kind = self.static_port_kind(dir); - if port.index() == port_count { - if let Some(kind) = static_kind { - return Some(kind); - } + let mut static_kind = self.static_port_kind(dir); + let static_offset = port.index() - port_count; + if static_offset < static_kind.len() { + return Some(static_kind.remove(static_offset)); } // Non-dataflow ports @@ -232,28 +224,35 @@ impl OpType { .map(|p| p.as_outgoing().unwrap()) } - /// If the op has a static port, the port of that input. + /// The static ports of the operation. /// - /// See [`OpType::static_input_port`] and [`OpType::static_output_port`]. + /// See [`OpType::static_input_ports`] and [`OpType::static_output_port`]. #[inline] - pub fn static_port(&self, dir: Direction) -> Option { - self.static_port_kind(dir)?; - Some(Port::new(dir, self.value_port_count(dir))) + pub fn static_ports(&self, dir: Direction) -> Vec { + let static_len = self.static_port_kind(dir).len(); + (0..static_len) + .map(|i| Port::new(dir, self.value_port_count(dir) + i)) + .collect() } - /// If the op has a static input ([`Call`], [`LoadConstant`], and [`LoadFunction`]), the port of - /// that input. + /// If the op has static inputs, the port of those inputs. #[inline] - pub fn static_input_port(&self) -> Option { - self.static_port(Direction::Incoming) + pub fn static_input_ports(&self) -> Vec { + self.static_ports(Direction::Incoming) + .into_iter() .map(|p| p.as_incoming().unwrap()) + .collect() } /// If the op has a static output ([`Const`], [`FuncDefn`], [`FuncDecl`]), the port of that output. #[inline] pub fn static_output_port(&self) -> Option { - self.static_port(Direction::Outgoing) - .map(|p| p.as_outgoing().unwrap()) + if let [p] = self.static_ports(Direction::Outgoing).as_slice() { + Some(p.as_outgoing().unwrap()) + } else { + None + } + // .map(|p| p.as_outgoing().unwrap()).collect() } /// The number of Value ports in given direction. @@ -279,9 +278,9 @@ impl OpType { /// Returns the number of ports for the given direction. #[inline] pub fn port_count(&self, dir: Direction) -> usize { - let has_static_port = self.static_port_kind(dir).is_some(); + let static_len = self.static_port_kind(dir).len(); let non_df_count = self.non_df_port_count(dir); - self.value_port_count(dir) + has_static_port as usize + non_df_count + self.value_port_count(dir) + static_len + non_df_count } /// Returns the number of inputs ports for the operation. diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index f359026bf..1a2c2fc06 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -148,7 +148,7 @@ impl OpTrait for T { DataflowOpTrait::other_output(self) } - fn static_input(&self) ->Vec { + fn static_input(&self) -> Vec { DataflowOpTrait::static_input(self) } } @@ -215,7 +215,7 @@ impl Call { /// The IncomingPort which links to the function being called. /// - /// This matches [`OpType::static_input_port`]. + /// This matches [`OpType::static_input_ports`]. /// /// ``` /// # use hugr::ops::dataflow::Call; @@ -226,10 +226,10 @@ impl Call { /// let signature = Signature::new(vec![QB_T, QB_T], vec![QB_T, QB_T]); /// let call = Call::try_new(signature.into(), &[], &PRELUDE_REGISTRY).unwrap(); /// let op = OpType::Call(call.clone()); - /// assert_eq!(op.static_input_port(), Some(call.called_function_port())); + /// assert_eq!(op.static_input_ports(), vec![call.called_function_port()]); /// ``` /// - /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port + /// [`OpType::static_input_ports`]: crate::ops::OpType::static_input_ports #[inline] pub fn called_function_port(&self) -> IncomingPort { self.instantiation.input_count().into() @@ -313,7 +313,7 @@ impl LoadConstant { /// The IncomingPort which links to the loaded constant. /// - /// This matches [`OpType::static_input_port`]. + /// This matches [`OpType::static_input_ports`]. /// /// ``` /// # use hugr::ops::dataflow::LoadConstant; @@ -322,10 +322,10 @@ impl LoadConstant { /// let datatype = Type::UNIT; /// let load_constant = LoadConstant { datatype }; /// let op = OpType::LoadConstant(load_constant.clone()); - /// assert_eq!(op.static_input_port(), Some(load_constant.constant_port())); + /// assert_eq!(op.static_input_ports(), vec![load_constant.constant_port()]); /// ``` /// - /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port + /// [`OpType::static_input_ports`]: crate::ops::OpType::static_input_ports #[inline] pub fn constant_port(&self) -> IncomingPort { 0.into() @@ -387,9 +387,9 @@ impl LoadFunction { /// The IncomingPort which links to the loaded function. /// - /// This matches [`OpType::static_input_port`]. + /// This matches [`OpType::static_input_ports`]. /// - /// [`OpType::static_input_port`]: crate::ops::OpType::static_input_port + /// [`OpType::static_input_ports`]: crate::ops::OpType::static_input_ports #[inline] pub fn function_port(&self) -> IncomingPort { 0.into() From d6a20946048945b1b9b98ce3119f6b7ca5a160aa Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 1 Nov 2024 13:53:29 +0000 Subject: [PATCH 03/14] feat!: allow static inputs to custom ops BREAKING CHANGE: `PolyFuncTypeRV` replaced with new type `OpDefSignature` that also holds static input types. --- .../src/extension/declarative/signature.rs | 4 +- hugr-core/src/extension/op_def.rs | 54 +++---- .../op_def/serialize_signature_func.rs | 8 +- hugr-core/src/extension/prelude.rs | 10 +- hugr-core/src/extension/prelude/array.rs | 16 +- hugr-core/src/hugr/serialize/test.rs | 28 ++-- hugr-core/src/hugr/validate/test.rs | 8 +- hugr-core/src/import.rs | 1 + hugr-core/src/ops/custom.rs | 23 ++- .../src/std_extensions/arithmetic/int_ops.rs | 10 +- hugr-core/src/std_extensions/collections.rs | 6 +- hugr-core/src/types.rs | 2 +- hugr-core/src/types/poly_func.rs | 137 +++++++++++++++++- hugr-core/src/types/type_param.rs | 4 +- hugr-core/src/utils.rs | 6 +- hugr/src/lib.rs | 6 +- 16 files changed, 236 insertions(+), 87 deletions(-) diff --git a/hugr-core/src/extension/declarative/signature.rs b/hugr-core/src/extension/declarative/signature.rs index d1479ae13..fc94577a8 100644 --- a/hugr-core/src/extension/declarative/signature.rs +++ b/hugr-core/src/extension/declarative/signature.rs @@ -14,7 +14,7 @@ use smol_str::SmolStr; use crate::extension::prelude::PRELUDE_ID; use crate::extension::{ExtensionSet, SignatureFunc, TypeDef}; use crate::types::type_param::TypeParam; -use crate::types::{CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeRowRV}; +use crate::types::{CustomType, FuncValueType, OpDefSignature, Type, TypeRowRV}; use crate::Extension; use super::{DeclarationContext, ExtensionDeclarationError}; @@ -56,7 +56,7 @@ impl SignatureDeclaration { extension_reqs: self.extensions.clone(), }; - let poly_func = PolyFuncTypeRV::new(op_params, body); + let poly_func = OpDefSignature::new(op_params, body); Ok(poly_func.into()) } } diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 6c1a49d9e..5b9ecaab2 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -11,7 +11,7 @@ use super::{ use crate::ops::{OpName, OpNameRef}; use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; -use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; +use crate::types::{FuncValueType, OpDefSignature, PolyFuncType, Signature}; use crate::Hugr; mod serialize_signature_func; @@ -25,7 +25,7 @@ pub trait CustomSignatureFunc: Send + Sync { arg_values: &[TypeArg], def: &'o OpDef, extension_registry: &ExtensionRegistry, - ) -> Result; + ) -> Result; /// The declared type parameters which require values in order for signature to /// be computed. fn static_params(&self) -> &[TypeParam]; @@ -35,7 +35,7 @@ pub trait CustomSignatureFunc: Send + Sync { pub trait SignatureFromArgs: Send + Sync { /// Compute signature of node given /// values for the type parameters. - fn compute_signature(&self, arg_values: &[TypeArg]) -> Result; + fn compute_signature(&self, arg_values: &[TypeArg]) -> Result; /// The declared type parameters which require values in order for signature to /// be computed. fn static_params(&self) -> &[TypeParam]; @@ -48,7 +48,7 @@ impl CustomSignatureFunc for T { arg_values: &[TypeArg], _def: &'o OpDef, _extension_registry: &ExtensionRegistry, - ) -> Result { + ) -> Result { SignatureFromArgs::compute_signature(self, arg_values) } @@ -58,7 +58,7 @@ impl CustomSignatureFunc for T { } } -/// Trait for validating type arguments to a PolyFuncTypeRV beyond conformation to +/// Trait for validating type arguments to a OpDefSignature beyond conformation to /// declared type parameter (which should have been checked beforehand). pub trait ValidateTypeArgs: Send + Sync { /// Validate the type arguments of node given @@ -72,7 +72,7 @@ pub trait ValidateTypeArgs: Send + Sync { ) -> Result<(), SignatureError>; } -/// Trait for validating type arguments to a PolyFuncTypeRV beyond conformation to +/// Trait for validating type arguments to a OpDefSignature beyond conformation to /// declared type parameter (which should have been checked beforehand), given just the arguments. pub trait ValidateJustArgs: Send + Sync { /// Validate the type arguments of node given @@ -115,20 +115,20 @@ pub trait CustomLowerFunc: Send + Sync { ) -> Option; } -/// Encode a signature as [PolyFuncTypeRV] but with additional validation of type +/// Encode a signature as [OpDefSignature] but with additional validation of type /// arguments via a custom binary. The binary cannot be serialized so will be /// lost over a serialization round-trip. pub struct CustomValidator { - poly_func: PolyFuncTypeRV, + poly_func: OpDefSignature, /// Custom function for validating type arguments before returning the signature. pub(crate) validate: Box, } impl CustomValidator { - /// Encode a signature using a `PolyFuncTypeRV`, with a custom function for + /// Encode a signature using a `OpDefSignature`, with a custom function for /// validating type arguments before returning the signature. pub fn new( - poly_func: impl Into, + poly_func: impl Into, validate: impl ValidateTypeArgs + 'static, ) -> Self { Self { @@ -141,11 +141,11 @@ impl CustomValidator { /// The ways in which an OpDef may compute the Signature of each operation node. pub enum SignatureFunc { /// An explicit polymorphic function type. - PolyFuncType(PolyFuncTypeRV), + PolyFuncType(OpDefSignature), /// A polymorphic function type (like [Self::PolyFuncType] but also with a custom binary for validating type arguments. CustomValidator(CustomValidator), /// Serialized declaration specified a custom validate binary but it was not provided. - MissingValidateFunc(PolyFuncTypeRV), + MissingValidateFunc(OpDefSignature), /// A custom binary which computes a polymorphic function type given values /// for its static type parameters. CustomFunc(Box), @@ -165,8 +165,8 @@ impl From for SignatureFunc { } } -impl From for SignatureFunc { - fn from(v: PolyFuncTypeRV) -> Self { +impl From for SignatureFunc { + fn from(v: OpDefSignature) -> Self { Self::PolyFuncType(v) } } @@ -225,7 +225,7 @@ impl SignatureFunc { args: &[TypeArg], exts: &ExtensionRegistry, ) -> Result { - let temp: PolyFuncTypeRV; // to keep alive + let temp: OpDefSignature; // to keep alive let (pf, args) = match &self { SignatureFunc::CustomValidator(custom) => { custom.validate.validate(args, def, exts)?; @@ -333,7 +333,7 @@ impl OpDef { exts: &ExtensionRegistry, var_decls: &[TypeParam], ) -> Result<(), SignatureError> { - let temp: PolyFuncTypeRV; // to keep alive + let temp: OpDefSignature; // to keep alive let (pf, args) = match &self.signature_func { SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args), SignatureFunc::PolyFuncType(ts) => (ts, args), @@ -459,7 +459,7 @@ impl OpDef { impl Extension { /// Add an operation definition to the extension. Must be a type scheme - /// (defined by a [`PolyFuncTypeRV`]), a type scheme along with binary + /// (defined by a [`OpDefSignature`]), a type scheme along with binary /// validation for type arguments ([`CustomValidator`]), or a custom binary /// function for computing the signature given type arguments (`impl [CustomSignatureFunc]`). pub fn add_op( @@ -501,7 +501,7 @@ pub(super) mod test { use crate::ops::OpName; use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; use crate::types::type_param::{TypeArgError, TypeParam}; - use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV}; + use crate::types::{OpDefSignature, Signature, Type, TypeArg, TypeBound, TypeRV}; use crate::{const_extension_ids, Extension}; const_extension_ids! { @@ -598,7 +598,7 @@ pub(super) mod test { let list_of_var = Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); const OP_NAME: OpName = OpName::new_inline("Reverse"); - let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var])); + let type_scheme = OpDefSignature::new(vec![TP], Signature::new_endo(vec![list_of_var])); let def = e.add_op(OP_NAME, "desc".into(), type_scheme)?; def.add_lower_func(LowerFunc::FixedHugr { @@ -629,7 +629,7 @@ pub(super) mod test { #[test] fn binary_polyfunc() -> Result<(), Box> { - // Test a custom binary `compute_signature` that returns a PolyFuncTypeRV + // Test a custom binary `compute_signature` that returns a OpDefSignature // where the latter declares more type params itself. In particular, // we should be able to substitute (external) type variables into the latter, // but not pass them into the former (custom binary function). @@ -638,7 +638,7 @@ pub(super) mod test { fn compute_signature( &self, arg_values: &[TypeArg], - ) -> Result { + ) -> Result { const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; let [TypeArg::BoundedNat { n }] = arg_values else { return Err(SignatureError::InvalidTypeArgs); @@ -647,7 +647,7 @@ pub(super) mod test { let tvs: Vec = (0..n) .map(|_| Type::new_var_use(0, TypeBound::Any)) .collect(); - Ok(PolyFuncTypeRV::new( + Ok(OpDefSignature::new( vec![TP.to_owned()], Signature::new(tvs.clone(), vec![Type::new_tuple(tvs)]), )) @@ -718,13 +718,13 @@ pub(super) mod test { #[test] fn type_scheme_instantiate_var() -> Result<(), Box> { - // Check that we can instantiate a PolyFuncTypeRV-scheme with an (external) + // Check that we can instantiate a OpDefSignature-scheme with an (external) // type variable let mut e = Extension::new_test(EXT_ID); let def = e.add_op( "SimpleOp".into(), "".into(), - PolyFuncTypeRV::new( + OpDefSignature::new( vec![TypeBound::Any.into()], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), ), @@ -764,7 +764,7 @@ pub(super) mod test { let def = e.add_op( "SimpleOp".into(), "".into(), - PolyFuncTypeRV::new(params.clone(), fun_ty), + OpDefSignature::new(params.clone(), fun_ty), )?; // Concrete extension set @@ -788,7 +788,7 @@ pub(super) mod test { use crate::{ builder::test::simple_dfg_hugr, extension::{op_def::LowerFunc, ExtensionId, ExtensionSet, OpDef, SignatureFunc}, - types::PolyFuncTypeRV, + types::OpDefSignature, }; impl Arbitrary for SignatureFunc { @@ -798,7 +798,7 @@ pub(super) mod test { // TODO there is also SignatureFunc::CustomFunc, but for now // this is not serialized. When it is, we should generate // examples here . - any::() + any::() .prop_map(SignatureFunc::PolyFuncType) .boxed() } diff --git a/hugr-core/src/extension/op_def/serialize_signature_func.rs b/hugr-core/src/extension/op_def/serialize_signature_func.rs index 88c8c30de..0877e808f 100644 --- a/hugr-core/src/extension/op_def/serialize_signature_func.rs +++ b/hugr-core/src/extension/op_def/serialize_signature_func.rs @@ -1,10 +1,10 @@ use serde::{Deserialize, Serialize}; -use super::{CustomValidator, PolyFuncTypeRV, SignatureFunc}; +use super::{CustomValidator, OpDefSignature, SignatureFunc}; #[derive(serde::Deserialize, serde::Serialize, PartialEq, Debug, Clone)] struct SerSignatureFunc { /// If the type scheme is available explicitly, store it. - signature: Option, + signature: Option, /// Whether an associated binary function is expected. /// If `signature` is `None`, a true value here indicates a custom compute function. /// If `signature` is not `None`, a true value here indicates a custom validation function. @@ -97,7 +97,7 @@ mod test { _arg_values: &[TypeArg], _def: &'o crate::extension::op_def::OpDef, _extension_registry: &crate::extension::ExtensionRegistry, - ) -> Result { + ) -> Result { Ok(Default::default()) } @@ -146,7 +146,7 @@ mod test { let mut deser = SignatureFunc::try_from(ser.clone()).unwrap(); assert_matches!(&deser, SignatureFunc::MissingValidateFunc(poly_func) => { - assert_eq!(poly_func, &PolyFuncTypeRV::from(sig.clone())); + assert_eq!(poly_func, &OpDefSignature::from(sig.clone())); } ); diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index ca338eae3..4b60e7a89 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -15,7 +15,7 @@ use crate::ops::OpName; use crate::ops::{NamedOp, Value}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{ - CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeBound, + CustomType, FuncValueType, OpDefSignature, PolyFuncType, Signature, SumType, Type, TypeBound, TypeName, TypeRV, TypeRow, TypeRowRV, }; use crate::utils::sorted_consts; @@ -87,7 +87,7 @@ lazy_static! { .add_op( PANIC_OP_ID, "Panic with input error".to_string(), - PolyFuncTypeRV::new( + OpDefSignature::new( [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], FuncValueType::new( vec![TypeRV::new_extension(ERROR_CUSTOM_TYPE), TypeRV::new_row_var_use(0, TypeBound::Any)], @@ -502,10 +502,10 @@ impl MakeOpDef for TupleOpDef { let param = TypeParam::new_list(TypeBound::Any); match self { TupleOpDef::MakeTuple => { - PolyFuncTypeRV::new([param], FuncValueType::new(rv, tuple_type)) + OpDefSignature::new([param], FuncValueType::new(rv, tuple_type)) } TupleOpDef::UnpackTuple => { - PolyFuncTypeRV::new([param], FuncValueType::new(tuple_type, rv)) + OpDefSignature::new([param], FuncValueType::new(tuple_type, rv)) } } .into() @@ -784,7 +784,7 @@ impl std::str::FromStr for LiftDef { impl MakeOpDef for LiftDef { fn signature(&self) -> SignatureFunc { - PolyFuncTypeRV::new( + OpDefSignature::new( vec![TypeParam::Extensions, TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(TypeRV::new_row_var_use(1, TypeBound::Any)) .with_extension_delta(ExtensionSet::type_var(0)), diff --git a/hugr-core/src/extension/prelude/array.rs b/hugr-core/src/extension/prelude/array.rs index a15bf23cc..9baf24933 100644 --- a/hugr-core/src/extension/prelude/array.rs +++ b/hugr-core/src/extension/prelude/array.rs @@ -25,7 +25,7 @@ use crate::types::Type; use crate::extension::SignatureError; -use crate::types::PolyFuncTypeRV; +use crate::types::OpDefSignature; use crate::types::type_param::TypeArg; use crate::Extension; @@ -52,7 +52,7 @@ pub enum ArrayOpDef { const STATIC_SIZE_PARAM: &[TypeParam; 1] = &[TypeParam::max_nat()]; impl SignatureFromArgs for ArrayOpDef { - fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { + fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { let [TypeArg::BoundedNat { n }] = *arg_values else { return Err(SignatureError::InvalidTypeArgs); }; @@ -60,14 +60,14 @@ impl SignatureFromArgs for ArrayOpDef { let array_ty = array_type(TypeArg::BoundedNat { n }, elem_ty_var.clone()); let params = vec![TypeBound::Any.into()]; let poly_func_ty = match self { - ArrayOpDef::new_array => PolyFuncTypeRV::new( + ArrayOpDef::new_array => OpDefSignature::new( params, FuncValueType::new(vec![elem_ty_var.clone(); n as usize], array_ty), ), ArrayOpDef::pop_left | ArrayOpDef::pop_right => { let popped_array_ty = array_type(TypeArg::BoundedNat { n: n - 1 }, elem_ty_var.clone()); - PolyFuncTypeRV::new( + OpDefSignature::new( params, FuncValueType::new( array_ty, @@ -123,7 +123,7 @@ impl ArrayOpDef { let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable); let copy_array_ty = instantiate(array_def, size_var, copy_elem_ty.clone()); let option_type: Type = option_type(copy_elem_ty).into(); - PolyFuncTypeRV::new( + OpDefSignature::new( params, FuncValueType::new(vec![copy_array_ty, USIZE_T], option_type), ) @@ -131,7 +131,7 @@ impl ArrayOpDef { set => { let result_row = vec![elem_ty_var.clone(), array_ty.clone()]; let result_type: Type = either_type(result_row.clone(), result_row).into(); - PolyFuncTypeRV::new( + OpDefSignature::new( standard_params, FuncValueType::new( vec![array_ty.clone(), USIZE_T, elem_ty_var], @@ -141,12 +141,12 @@ impl ArrayOpDef { } swap => { let result_type: Type = either_type(array_ty.clone(), array_ty.clone()).into(); - PolyFuncTypeRV::new( + OpDefSignature::new( standard_params, FuncValueType::new(vec![array_ty, USIZE_T, USIZE_T], result_type), ) } - discard_empty => PolyFuncTypeRV::new( + discard_empty => OpDefSignature::new( vec![TypeBound::Any.into()], FuncValueType::new( instantiate(array_def, 0, Type::new_var_use(0, TypeBound::Any)), diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 4da2197ef..e2f809c6c 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -17,7 +17,7 @@ use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use crate::std_extensions::logic::LogicOp; use crate::types::type_param::TypeParam; use crate::types::{ - FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg, TypeBound, + FuncValueType, OpDefSignature, PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRV, }; use crate::{type_row, OutgoingPort}; @@ -37,7 +37,7 @@ const QB: Type = crate::extension::prelude::QB_T; struct SerTestingLatest { typ: Option, sum_type: Option, - poly_func_type: Option, + poly_func_type: Option, value: Option, optype: Option, op_def: Option, @@ -124,14 +124,14 @@ macro_rules! impl_sertesting_from { impl_sertesting_from!(crate::types::TypeRV, typ); impl_sertesting_from!(crate::types::SumType, sum_type); -impl_sertesting_from!(crate::types::PolyFuncTypeRV, poly_func_type); +impl_sertesting_from!(crate::types::OpDefSignature, poly_func_type); impl_sertesting_from!(crate::ops::Value, value); impl_sertesting_from!(NodeSer, optype); impl_sertesting_from!(SimpleOpDef, op_def); impl From for SerTestingLatest { fn from(v: PolyFuncType) -> Self { - let v: PolyFuncTypeRV = v.into(); + let v: OpDefSignature = v.into(); v.into() } } @@ -484,7 +484,7 @@ fn polyfunctype1() -> PolyFuncType { PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type) } -fn polyfunctype2() -> PolyFuncTypeRV { +fn polyfunctype2() -> OpDefSignature { let tv0 = TypeRV::new_row_var_use(0, TypeBound::Any); let tv1 = TypeRV::new_row_var_use(1, TypeBound::Copyable); let params = [TypeBound::Any, TypeBound::Copyable].map(TypeParam::new_list); @@ -492,7 +492,7 @@ fn polyfunctype2() -> PolyFuncTypeRV { TypeRV::new_function(FuncValueType::new(tv0.clone(), tv1.clone())), tv0, ]; - let res = PolyFuncTypeRV::new(params, FuncValueType::new(inputs, tv1)); + let res = OpDefSignature::new(params, FuncValueType::new(inputs, tv1)); // Just check we've got the arguments the right way round // (not that it really matters for the serialization schema we have) res.validate(&EMPTY_REG).unwrap(); @@ -515,15 +515,15 @@ fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { #[rstest] #[case(FuncValueType::new_endo(type_row![]).into())] -#[case(PolyFuncTypeRV::new([TypeParam::String], FuncValueType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(PolyFuncTypeRV::new([TypeBound::Copyable.into()], FuncValueType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(PolyFuncTypeRV::new([TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(type_row![])))] -#[case(PolyFuncTypeRV::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FuncValueType::new_endo(type_row![])))] -#[case(PolyFuncTypeRV::new( +#[case(OpDefSignature::new([TypeParam::String], FuncValueType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(OpDefSignature::new([TypeBound::Copyable.into()], FuncValueType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(OpDefSignature::new([TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(type_row![])))] +#[case(OpDefSignature::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FuncValueType::new_endo(type_row![])))] +#[case(OpDefSignature::new( [TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Any))))] #[case(polyfunctype2())] -fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { +fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: OpDefSignature) { check_testing_roundtrip(poly_func_type) } @@ -563,7 +563,7 @@ mod proptest { use super::check_testing_roundtrip; use super::{NodeSer, SimpleOpDef}; use crate::ops::{OpType, OpaqueOp, Value}; - use crate::types::{PolyFuncTypeRV, Type}; + use crate::types::{OpDefSignature, Type}; use proptest::prelude::*; impl Arbitrary for NodeSer { @@ -596,7 +596,7 @@ mod proptest { } #[test] - fn prop_roundtrip_poly_func_type(t: PolyFuncTypeRV) { + fn prop_roundtrip_poly_func_type(t: OpDefSignature) { check_testing_roundtrip(t) } diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 9190713dd..5cf44b90a 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -22,7 +22,7 @@ use crate::std_extensions::logic::LogicOp; use crate::std_extensions::logic::{self}; use crate::types::type_param::{TypeArg, TypeArgError}; use crate::types::{ - CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, + CustomType, FuncValueType, OpDefSignature, PolyFuncType, Signature, Type, TypeBound, TypeRV, TypeRow, }; use crate::{ @@ -594,14 +594,14 @@ pub(crate) fn extension_with_eval_parallel() -> Extension { let inputs = TypeRV::new_row_var_use(0, TypeBound::Any); let outputs = TypeRV::new_row_var_use(1, TypeBound::Any); let evaled_fn = TypeRV::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); - let pf = PolyFuncTypeRV::new( + let pf = OpDefSignature::new( [rowp.clone(), rowp.clone()], FuncValueType::new(vec![evaled_fn, inputs], outputs), ); e.add_op("eval".into(), "".into(), pf).unwrap(); let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Any); - let pf = PolyFuncTypeRV::new( + let pf = OpDefSignature::new( [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], Signature::new( vec![ @@ -709,7 +709,7 @@ fn test_polymorphic_call() -> Result<(), Box> { e.add_op( "eval".into(), "".into(), - PolyFuncTypeRV::new( + OpDefSignature::new( params.clone(), Signature::new( vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 2d36ec873..db1dcfefd 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -500,6 +500,7 @@ impl<'a> Context<'a> { String::default(), args, signature, + vec![], )); let node = self.make_node(node_id, optype, parent)?; diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 6d9a3b2b2..048554795 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -10,10 +10,13 @@ use { ::proptest_derive::Arbitrary, }; -use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::HugrView; use crate::types::{type_param::TypeArg, Signature}; +use crate::{ + extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}, + types::Type, +}; use crate::{ops, Hugr, IncomingPort, Node}; use super::dataflow::DataflowOpTrait; @@ -109,6 +112,7 @@ impl ExtensionOp { description: self.def.description().into(), args: self.args.clone(), signature: self.signature.clone(), + static_inputs: vec![], // Initialize with empty vector } } } @@ -126,6 +130,7 @@ impl From for OpaqueOp { description: def.description().into(), args, signature, + static_inputs: vec![], // Initialize with empty vector } } } @@ -179,6 +184,8 @@ pub struct OpaqueOp { // remain private, and should be accessed through // `DataflowOpTrait::signature`. signature: Signature, + #[serde(default)] + static_inputs: Vec, } fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName { @@ -193,6 +200,7 @@ impl OpaqueOp { description: String, args: impl Into>, signature: Signature, + static_inputs: Vec, // New parameter added ) -> Self { Self { extension, @@ -200,6 +208,7 @@ impl OpaqueOp { description, args: args.into(), signature, + static_inputs, // Initialize new field } } } @@ -373,6 +382,7 @@ mod test { "desc".into(), vec![TypeArg::Type { ty: USIZE_T }], sig.clone(), + vec![], // Initialize with empty vector ); assert_eq!(op.name(), "res.op"); assert_eq!(DataflowOpTrait::description(&op), "desc"); @@ -393,6 +403,7 @@ mod test { "description".into(), vec![], Signature::new(i0.clone(), BOOL_T), + vec![], // Initialize with empty vector ); let resolved = super::resolve_opaque_op(Node::from(portgraph::NodeIndex::new(1)), &opaque, registry) @@ -428,8 +439,16 @@ mod test { "".into(), vec![], endo_sig.clone(), + vec![], // Initialize with empty vector + ); + let opaque_comp = OpaqueOp::new( + ext_id.clone(), + comp_name, + "".into(), + vec![], + endo_sig, + vec![], // Initialize with empty vector ); - let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig); let resolved_val = super::resolve_opaque_op( Node::from(portgraph::NodeIndex::new(1)), &opaque_val, diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 51d3e3885..b28dcf338 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -11,7 +11,7 @@ use crate::extension::{ use crate::ops::custom::ExtensionOp; use crate::ops::{NamedOp, OpName}; use crate::type_row; -use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV}; +use crate::types::{FuncValueType, OpDefSignature, TypeRowRV}; use crate::utils::collect_array; use crate::{ @@ -227,20 +227,20 @@ pub(in crate::std_extensions::arithmetic) fn int_polytype( n_vars: usize, input: impl Into, output: impl Into, -) -> PolyFuncTypeRV { - PolyFuncTypeRV::new( +) -> OpDefSignature { + OpDefSignature::new( vec![LOG_WIDTH_TYPE_PARAM; n_vars], FuncValueType::new(input, output), ) } -fn ibinop_sig() -> PolyFuncTypeRV { +fn ibinop_sig() -> OpDefSignature { let int_type_var = int_tv(0); int_polytype(1, vec![int_type_var.clone(); 2], vec![int_type_var]) } -fn iunop_sig() -> PolyFuncTypeRV { +fn iunop_sig() -> OpDefSignature { let int_type_var = int_tv(0); int_polytype(1, vec![int_type_var.clone()], vec![int_type_var]) } diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index d00a6f21d..f15687e58 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -26,7 +26,7 @@ use crate::{ ops::{custom::ExtensionOp, NamedOp}, types::{ type_param::{TypeArg, TypeParam}, - CustomCheckFailure, CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeBound, + CustomCheckFailure, CustomType, FuncValueType, OpDefSignature, Type, TypeBound, }, Extension, }; @@ -187,8 +187,8 @@ impl ListOp { self, input: impl Into, output: impl Into, - ) -> PolyFuncTypeRV { - PolyFuncTypeRV::new(vec![Self::TP], FuncValueType::new(input, output)) + ) -> OpDefSignature { + OpDefSignature::new(vec![Self::TP], FuncValueType::new(input, output)) } /// Returns the type of a generic list, associated with the element type parameter at index `idx`. diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index be45c1985..228e407b4 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -16,7 +16,7 @@ use crate::types::type_param::check_type_arg; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; -pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; +pub use poly_func::{OpDefSignature, PolyFuncType}; pub use signature::{FuncValueType, Signature}; use smol_str::SmolStr; pub use type_param::TypeArg; diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index b168556bf..1ab198693 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -1,8 +1,12 @@ //! Polymorphic Function Types +use delegate::delegate; use itertools::Itertools; -use crate::extension::{ExtensionRegistry, SignatureError}; +use crate::{ + extension::{ExtensionRegistry, SignatureError}, + type_row, +}; #[cfg(test)] use { crate::proptest::RecursionDepth, @@ -10,9 +14,12 @@ use { proptest_derive::Arbitrary, }; -use super::type_param::{check_type_args, TypeArg, TypeParam}; use super::Substitution; use super::{signature::FuncTypeBase, MaybeRV, NoRV, RowVariable}; +use super::{ + type_param::{check_type_args, TypeArg, TypeParam}, + TypeRow, +}; /// A polymorphic type scheme, i.e. of a [FuncDecl], [FuncDefn] or [OpDef]. /// (Nodes/operations in the Hugr are not polymorphic.) @@ -47,9 +54,105 @@ pub type PolyFuncType = PolyFuncTypeBase; /// The polymorphic type of an [OpDef], whose number of input and outputs /// may vary according to how [RowVariable]s therein are instantiated. /// +/// It may also have a row of static inputs. +/// /// [OpDef]: crate::extension::OpDef -pub type PolyFuncTypeRV = PolyFuncTypeBase; +#[derive( + Clone, + PartialEq, + Debug, + Default, + Eq, + Hash, + derive_more::Display, + serde::Serialize, + serde::Deserialize, +)] +#[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] +#[display("{}{}{}", self.display_params(), self.static_inputs, self.body())] +pub struct OpDefSignature { + #[serde(flatten)] + signature: PolyFuncTypeBase, + #[serde(default, skip_serializing_if = "TypeRow::is_empty")] + static_inputs: TypeRow, +} + +impl OpDefSignature { + /// Creates a new `OpDefSignature` with the given parameters and body. + /// + /// + /// Use `OpDefSignature::with_static_inputs` to set the static inputs. + /// + /// # Arguments + /// + /// * `params` - A collection of type parameters. + /// * `body` - The function type base. + pub fn new( + params: impl Into>, + body: impl Into>, + ) -> Self { + Self { + signature: PolyFuncTypeBase::new(params, body), + static_inputs: type_row![], + } + } + + /// Sets the static inputs for the `OpDefSignature`. + /// + /// # Arguments + /// + /// * `static_inputs` - A `TypeRow` representing the static input types. + pub fn with_static_inputs(mut self, static_inputs: TypeRow) -> Self { + self.static_inputs = static_inputs; + self + } + + /// Returns a reference to the static inputs of the `OpDefSignature`. + pub fn static_inputs(&self) -> &TypeRow { + &self.static_inputs + } + /// Returns a reference to the polymorphic function type of the `OpDefSignature`. + pub fn poly_func_type(&self) -> &PolyFuncTypeBase { + &self.signature + } + + delegate! { + to self.signature { + /// The type parameters. + pub fn params(&self) -> &[TypeParam]; + + /// Returns a reference to the function body type. + pub fn body(&self) -> &FuncTypeBase; + + /// Validates the signature against the given extension registry. + /// + /// # Arguments + /// + /// * `reg` - A reference to the `ExtensionRegistry`. + /// + /// # Raises + /// + /// A `SignatureError` if validation fails. + pub fn validate(&self, reg: &ExtensionRegistry) -> Result<(), SignatureError>; + + /// Displays the type parameters as a string. + pub fn display_params(&self) -> String; + + /// Instantiates the function type base with the given arguments and extension registry. + /// + /// # Arguments + /// + /// * `args` - A slice of `TypeArg`. + /// * `ext_reg` - A reference to the `ExtensionRegistry`. + /// + /// # Raises + /// + /// A `SignatureError` if instantiation fails. + pub fn instantiate(&self, args: &[TypeArg], ext_reg: &ExtensionRegistry) -> Result, SignatureError>; + } + } +} // deriving Default leads to an impl that only applies for RV: Default impl Default for PolyFuncTypeBase { fn default() -> Self { @@ -69,7 +172,7 @@ impl From> for PolyFuncTypeBase { } } -impl From for PolyFuncTypeRV { +impl From for PolyFuncTypeBase { fn from(value: PolyFuncType) -> Self { Self { params: value.params, @@ -77,6 +180,32 @@ impl From for PolyFuncTypeRV { } } } +impl From for OpDefSignature { + fn from(value: PolyFuncType) -> Self { + Self { + signature: value.into(), + static_inputs: type_row![], + } + } +} + +impl From> for OpDefSignature { + fn from(value: PolyFuncTypeBase) -> Self { + Self { + signature: value, + static_inputs: type_row![], + } + } +} + +impl From> for OpDefSignature { + fn from(body: FuncTypeBase) -> Self { + Self { + signature: body.into(), + static_inputs: type_row![], + } + } +} impl TryFrom> for FuncTypeBase { /// If the PolyFuncTypeBase is not monomorphic, fail with its binders diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index 4ffeaecf4..a2a6d1059 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -45,11 +45,11 @@ impl UpperBound { } } -/// A *kind* of [TypeArg]. Thus, a parameter declared by a [PolyFuncType] or [PolyFuncTypeRV], +/// A *kind* of [TypeArg]. Thus, a parameter declared by a [PolyFuncType] or [OpDefSignature], /// specifying a value that must be provided statically in order to instantiate it. /// /// [PolyFuncType]: super::PolyFuncType -/// [PolyFuncTypeRV]: super::PolyFuncTypeRV +/// [OpDefSignature]: super::OpDefSignature #[derive( Clone, Debug, PartialEq, Eq, Hash, derive_more::Display, serde::Deserialize, serde::Serialize, )] diff --git a/hugr-core/src/utils.rs b/hugr-core/src/utils.rs index 72b99fd76..352f64fa3 100644 --- a/hugr-core/src/utils.rs +++ b/hugr-core/src/utils.rs @@ -113,17 +113,17 @@ pub(crate) mod test_quantum_extension { ops::ExtensionOp, std_extensions::arithmetic::float_types, type_row, - types::{PolyFuncTypeRV, Signature}, + types::{OpDefSignature, Signature}, Extension, }; use lazy_static::lazy_static; - fn one_qb_func() -> PolyFuncTypeRV { + fn one_qb_func() -> OpDefSignature { FuncValueType::new_endo(QB_T).into() } - fn two_qb_func() -> PolyFuncTypeRV { + fn two_qb_func() -> OpDefSignature { FuncValueType::new_endo(type_row![QB_T, QB_T]).into() } /// The extension identifier. diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index e4fa3ee99..d18b4e28c 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -43,17 +43,17 @@ //! }, //! ops::{ExtensionOp, OpName}, //! type_row, -//! types::{FuncValueType, PolyFuncTypeRV}, +//! types::{FuncValueType, OpDefSignature}, //! Extension, //! }; //! //! use lazy_static::lazy_static; //! -//! fn one_qb_func() -> PolyFuncTypeRV { +//! fn one_qb_func() -> OpDefSignature { //! FuncValueType::new_endo(type_row![QB_T]).into() //! } //! -//! fn two_qb_func() -> PolyFuncTypeRV { +//! fn two_qb_func() -> OpDefSignature { //! FuncValueType::new_endo(type_row![QB_T, QB_T]).into() //! } //! /// The extension identifier. From 0c40deb6ec2c9d9e6d185c3e03542ffcffc78508 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 1 Nov 2024 14:00:04 +0000 Subject: [PATCH 04/14] refactor!: static_input -> static_inputs --- hugr-core/src/ops.rs | 10 +++++----- hugr-core/src/ops/dataflow.rs | 16 ++++++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index c6fa12e89..b79fd679a 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -159,7 +159,7 @@ impl OpType { #[inline] pub fn static_port_kind(&self, dir: Direction) -> Vec { match dir { - Direction::Incoming => self.static_input(), + Direction::Incoming => self.static_inputs(), Direction::Outgoing => self.static_output().map(|k| vec![k]).unwrap_or_default(), } } @@ -385,12 +385,12 @@ pub trait OpTrait { None } - /// The edge kind for a single constant input of the operation, not + /// The edge kinds for static inputs to the operation, not /// described by the dataflow signature. /// - /// If not None, an extra input port of that kind will be present after the - /// dataflow input ports and before any [`OpTrait::other_input`] ports. - fn static_input(&self) -> Vec { + /// If not empty, extra input ports of those kinds will be present after the + /// dataflow input ports and before any [`DataflowOpTrait::other_input`] ports. + fn static_inputs(&self) -> Vec { vec![] } diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 1a2c2fc06..7947dd8d2 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -40,13 +40,13 @@ pub trait DataflowOpTrait { Some(EdgeKind::StateOrder) } - /// The edge kind for a single constant input of the operation, not + /// The edge kinds for static inputs to the operation, not /// described by the dataflow signature. /// - /// If not None, an extra input port of that kind will be present after the + /// If not empty, extra input ports of those kinds will be present after the /// dataflow input ports and before any [`DataflowOpTrait::other_input`] ports. #[inline] - fn static_input(&self) -> Vec { + fn static_inputs(&self) -> Vec { vec![] } } @@ -148,8 +148,8 @@ impl OpTrait for T { DataflowOpTrait::other_output(self) } - fn static_input(&self) -> Vec { - DataflowOpTrait::static_input(self) + fn static_inputs(&self) -> Vec { + DataflowOpTrait::static_inputs(self) } } impl StaticTag for T { @@ -184,7 +184,7 @@ impl DataflowOpTrait for Call { self.instantiation.clone() } - fn static_input(&self) -> Vec { + fn static_inputs(&self) -> Vec { vec![EdgeKind::Function(self.called_function_type().clone())] } } @@ -300,7 +300,7 @@ impl DataflowOpTrait for LoadConstant { Signature::new(TypeRow::new(), vec![self.datatype.clone()]) } - fn static_input(&self) -> Vec { + fn static_inputs(&self) -> Vec { vec![EdgeKind::Const(self.constant_type().clone())] } } @@ -355,7 +355,7 @@ impl DataflowOpTrait for LoadFunction { self.signature.clone() } - fn static_input(&self) -> Vec { + fn static_inputs(&self) -> Vec { vec![EdgeKind::Function(self.func_sig.clone())] } } From 73c344a50a969ffd76bb1bfe2c7b5d5bbee20800 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 1 Nov 2024 15:29:46 +0000 Subject: [PATCH 05/14] feat!: include static inputs in return when computing signature --- hugr-core/src/extension.rs | 4 +- hugr-core/src/extension/op_def.rs | 71 ++++++++++++++++++++++++++++--- hugr-core/src/import.rs | 1 - hugr-core/src/ops/custom.rs | 66 ++++++++++++++-------------- hugr-core/src/types/type_row.rs | 2 + 5 files changed, 101 insertions(+), 43 deletions(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index fe30eb5f5..ea92052ae 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -22,8 +22,8 @@ use crate::types::{Signature, TypeNameRef}; mod op_def; pub use op_def::{ - CustomSignatureFunc, CustomValidator, LowerFunc, OpDef, SignatureFromArgs, SignatureFunc, - ValidateJustArgs, ValidateTypeArgs, + CustomSignatureFunc, CustomValidator, ExtOpSignature, LowerFunc, OpDef, SignatureFromArgs, + SignatureFunc, ValidateJustArgs, ValidateTypeArgs, }; mod type_def; pub use type_def::{TypeDef, TypeDefBound}; diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 5b9ecaab2..78d42ee43 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -4,6 +4,9 @@ use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::sync::Arc; +#[cfg(test)] +use {crate::proptest::RecursionDepth, ::proptest::prelude::*, proptest_derive::Arbitrary}; + use super::{ ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, @@ -11,7 +14,7 @@ use super::{ use crate::ops::{OpName, OpNameRef}; use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; -use crate::types::{FuncValueType, OpDefSignature, PolyFuncType, Signature}; +use crate::types::{FuncValueType, OpDefSignature, PolyFuncType, Signature, TypeRow}; use crate::Hugr; mod serialize_signature_func; @@ -224,7 +227,7 @@ impl SignatureFunc { def: &OpDef, args: &[TypeArg], exts: &ExtensionRegistry, - ) -> Result { + ) -> Result { let temp: OpDefSignature; // to keep alive let (pf, args) = match &self { SignatureFunc::CustomValidator(custom) => { @@ -244,11 +247,65 @@ impl SignatureFunc { // TODO raise warning: https://github.com/CQCL/hugr/issues/1432 SignatureFunc::MissingValidateFunc(ts) => (ts, args), }; + let static_inputs = pf.static_inputs().clone(); let mut res = pf.instantiate(args, exts)?; res.extension_reqs.insert(&def.extension); // If there are any row variables left, this will fail with an error: - res.try_into() + let func_type = res.try_into()?; + + Ok(ExtOpSignature { + func_type, + static_inputs, + }) + } +} + +/// Instantiated [OpDef] signature. +#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] +#[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] +pub struct ExtOpSignature { + #[serde(flatten)] + /// The dataflow function type of the signature. + pub func_type: Signature, + #[serde(default, skip_serializing_if = "TypeRow::is_empty")] + /// The static inputs of the signature. + pub static_inputs: TypeRow, +} + +impl ExtOpSignature { + /// Returns the function type of the signature. + pub fn func_type(&self) -> &Signature { + &self.func_type + } + + /// Returns the static inputs of the signature. + pub fn static_inputs(&self) -> &TypeRow { + &self.static_inputs + } +} +impl From for Signature { + fn from(v: ExtOpSignature) -> Self { + v.func_type + } +} + +impl From for ExtOpSignature { + fn from(v: Signature) -> Self { + Self { + func_type: v, + static_inputs: TypeRow::new(), + } + } +} + +impl std::fmt::Display for ExtOpSignature { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.static_inputs.is_empty() { + write!(f, "{}", self.func_type) + } else { + write!(f, "<{}> {}", self.static_inputs, self.func_type) + } } } @@ -364,7 +421,7 @@ impl OpDef { &self, args: &[TypeArg], exts: &ExtensionRegistry, - ) -> Result { + ) -> Result { self.signature_func.compute_signature(self, args, exts) } @@ -669,6 +726,7 @@ pub(super) mod test { Ok( Signature::new(vec![USIZE_T; 3], vec![Type::new_tuple(vec![USIZE_T; 3])]) .with_extension_delta(EXT_ID) + .into() ) ); assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); @@ -682,6 +740,7 @@ pub(super) mod test { Ok( Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) .with_extension_delta(EXT_ID) + .into() ) ); def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Copyable.into()]) @@ -735,7 +794,7 @@ pub(super) mod test { def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); assert_eq!( def.compute_signature(&args, &EMPTY_REG), - Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) + Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID).into()) ); // But not with an external row variable let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); @@ -776,7 +835,7 @@ pub(super) mod test { .unwrap(); assert_eq!( def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok(exp_fun_ty) + Ok(exp_fun_ty.into()) ); Ok(()) } diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index db1dcfefd..2d36ec873 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -500,7 +500,6 @@ impl<'a> Context<'a> { String::default(), args, signature, - vec![], )); let node = self.make_node(node_id, optype, parent)?; diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 048554795..da8742fe6 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -10,13 +10,10 @@ use { ::proptest_derive::Arbitrary, }; -use crate::hugr::internal::HugrMutInternals; +use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::HugrView; use crate::types::{type_param::TypeArg, Signature}; -use crate::{ - extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}, - types::Type, -}; +use crate::{extension::ExtOpSignature, hugr::internal::HugrMutInternals, types::EdgeKind}; use crate::{ops, Hugr, IncomingPort, Node}; use super::dataflow::DataflowOpTrait; @@ -39,7 +36,7 @@ pub struct ExtensionOp { )] def: Arc, args: Vec, - signature: Signature, // Cache + signature: ExtOpSignature, // Cache } impl ExtensionOp { @@ -72,7 +69,7 @@ impl ExtensionOp { Ok(sig) => sig, Err(SignatureError::MissingComputeFunc) => { // TODO raise warning: https://github.com/CQCL/hugr/issues/1432 - opaque.signature() + opaque.ext_op_signature() } Err(e) => return Err(e), }; @@ -112,7 +109,6 @@ impl ExtensionOp { description: self.def.description().into(), args: self.args.clone(), signature: self.signature.clone(), - static_inputs: vec![], // Initialize with empty vector } } } @@ -130,7 +126,6 @@ impl From for OpaqueOp { description: def.description().into(), args, signature, - static_inputs: vec![], // Initialize with empty vector } } } @@ -158,7 +153,16 @@ impl DataflowOpTrait for ExtensionOp { } fn signature(&self) -> Signature { - self.signature.clone() + self.signature.func_type().clone() + } + + fn static_inputs(&self) -> Vec { + self.signature + .static_inputs() + .iter() + .cloned() + .map(EdgeKind::Const) + .collect() } } @@ -183,9 +187,7 @@ pub struct OpaqueOp { // note that the `signature` field might not include `extension`. Thus this must // remain private, and should be accessed through // `DataflowOpTrait::signature`. - signature: Signature, - #[serde(default)] - static_inputs: Vec, + signature: ExtOpSignature, } fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName { @@ -199,16 +201,14 @@ impl OpaqueOp { name: impl Into, description: String, args: impl Into>, - signature: Signature, - static_inputs: Vec, // New parameter added + signature: impl Into, ) -> Self { Self { extension, name: name.into(), description, args: args.into(), - signature, - static_inputs, // Initialize new field + signature: signature.into(), } } } @@ -234,6 +234,13 @@ impl OpaqueOp { pub fn extension(&self) -> &ExtensionId { &self.extension } + + /// Instantiated signature of the operation. + pub fn ext_op_signature(&self) -> ExtOpSignature { + let mut sig = self.signature.clone(); + sig.func_type = sig.func_type.with_extension_delta(self.extension.clone()); + sig + } } impl DataflowOpTrait for OpaqueOp { @@ -244,9 +251,7 @@ impl DataflowOpTrait for OpaqueOp { } fn signature(&self) -> Signature { - self.signature - .clone() - .with_extension_delta(self.extension().clone()) + self.ext_op_signature().func_type } } @@ -295,6 +300,7 @@ pub fn resolve_opaque_op( r.name().clone(), )); }; + dbg!(opaque.signature().extension_reqs); let ext_op = ExtensionOp::new_with_cached( def.clone(), opaque.args.clone(), @@ -307,12 +313,14 @@ pub fn resolve_opaque_op( cause: e, })?; if opaque.signature() != ext_op.signature() { + dbg!(opaque.signature().extension_reqs); + dbg!(ext_op.signature().extension_reqs); return Err(OpaqueOpError::SignatureMismatch { node, extension: opaque.extension.clone(), op: def.name().clone(), - computed: ext_op.signature.clone(), - stored: opaque.signature.clone(), + computed: ext_op.signature(), + stored: opaque.signature(), }); }; Ok(ext_op) @@ -382,7 +390,6 @@ mod test { "desc".into(), vec![TypeArg::Type { ty: USIZE_T }], sig.clone(), - vec![], // Initialize with empty vector ); assert_eq!(op.name(), "res.op"); assert_eq!(DataflowOpTrait::description(&op), "desc"); @@ -403,7 +410,6 @@ mod test { "description".into(), vec![], Signature::new(i0.clone(), BOOL_T), - vec![], // Initialize with empty vector ); let resolved = super::resolve_opaque_op(Node::from(portgraph::NodeIndex::new(1)), &opaque, registry) @@ -439,16 +445,8 @@ mod test { "".into(), vec![], endo_sig.clone(), - vec![], // Initialize with empty vector - ); - let opaque_comp = OpaqueOp::new( - ext_id.clone(), - comp_name, - "".into(), - vec![], - endo_sig, - vec![], // Initialize with empty vector ); + let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig); let resolved_val = super::resolve_opaque_op( Node::from(portgraph::NodeIndex::new(1)), &opaque_val, @@ -462,7 +460,7 @@ mod test { &opaque_comp, ®istry, ) - .unwrap(); + .unwrap_or_else(|e| panic!("{}", e)); assert_eq!(resolved_comp.def().name(), comp_name); } } diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index b8cb6d116..ac50bf151 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -102,6 +102,8 @@ impl TypeRowBase { self.iter().try_for_each(|t| t.validate(exts, var_decls)) } } +/// Empty row of types. +pub static EMPTY_ROW: TypeRow = TypeRow::new(); impl TypeRow { delegate! { From adc6cb440f578677c2ea7ff41f7984fa5529aed8 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 1 Nov 2024 15:55:55 +0000 Subject: [PATCH 06/14] test: add static op instantiation test --- hugr-core/src/extension/op_def.rs | 42 +++++++++++++++++++++++++++++-- hugr-core/src/ops/custom.rs | 5 ++++ hugr-core/src/types/poly_func.rs | 4 +-- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 78d42ee43..af1c1a145 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -274,6 +274,14 @@ pub struct ExtOpSignature { } impl ExtOpSignature { + /// Create a new [ExtOpSignature] from a [Signature] and a [TypeRow]. + pub fn new(func_type: Signature, static_inputs: impl Into) -> Self { + Self { + func_type, + static_inputs: static_inputs.into(), + } + } + /// Returns the function type of the signature. pub fn func_type(&self) -> &Signature { &self.func_type @@ -549,10 +557,10 @@ pub(super) mod test { use itertools::Itertools; - use super::SignatureFromArgs; + use super::{ExtOpSignature, SignatureFromArgs}; use crate::builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr}; use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc}; - use crate::extension::prelude::USIZE_T; + use crate::extension::prelude::{QB_T, USIZE_T}; use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; use crate::ops::OpName; @@ -840,6 +848,36 @@ pub(super) mod test { Ok(()) } + pub(crate) fn static_input_op() -> (Extension, OpName) { + let mut e = Extension::new_test(EXT_ID); + let def = e + .add_op( + "StaticInOp".into(), + "".into(), + OpDefSignature::new(vec![], Signature::new(vec![], vec![QB_T])) + .with_static_inputs(USIZE_T), + ) + .unwrap(); + let op_name = def.name().clone(); + (e, op_name) + } + + #[test] + fn test_static_input_op() -> Result<(), Box> { + let (e, op_name) = static_input_op(); + let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), e])?; + let e = reg.get(&EXT_ID).unwrap(); + + let ext_op = e.instantiate_extension_op(&op_name, vec![], ®).unwrap(); + let sig = ext_op.ext_op_signature(); + let expected = ExtOpSignature::new( + Signature::new(vec![], vec![QB_T]).with_extension_delta(e.name().clone()), + USIZE_T, + ); + assert_eq!(sig, &expected); + Ok(()) + } + mod proptest { use super::SimpleOpDef; use ::proptest::prelude::*; diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index da8742fe6..3eb0f9524 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -111,6 +111,11 @@ impl ExtensionOp { signature: self.signature.clone(), } } + + /// Returns the [`ExtOpSignature`] of this [`ExtensionOp`]. + pub fn ext_op_signature(&self) -> &ExtOpSignature { + &self.signature + } } impl From for OpaqueOp { diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 1ab198693..d760a854b 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -102,8 +102,8 @@ impl OpDefSignature { /// # Arguments /// /// * `static_inputs` - A `TypeRow` representing the static input types. - pub fn with_static_inputs(mut self, static_inputs: TypeRow) -> Self { - self.static_inputs = static_inputs; + pub fn with_static_inputs(mut self, static_inputs: impl Into) -> Self { + self.static_inputs = static_inputs.into(); self } From ecb5295a53e3a18216c35859e6df511eb9a2289b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 4 Nov 2024 11:05:20 +0000 Subject: [PATCH 07/14] feat!: allow static input ext ops to be built + static port counting trait methods --- hugr-core/src/builder/build_traits.rs | 21 +++++++++++++++++ hugr-core/src/extension/op_def.rs | 19 ++++++++++----- hugr-core/src/hugr/serialize.rs | 5 ++-- hugr-core/src/ops.rs | 27 ++++++++++++++-------- hugr-core/src/ops/custom.rs | 33 +++++++++++++++++++++++---- hugr-core/src/ops/dataflow.rs | 15 +++++++++++- hugr-core/src/ops/handle.rs | 8 ++++++- hugr-core/src/ops/tag.rs | 10 +++----- 8 files changed, 107 insertions(+), 31 deletions(-) diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 13ef1db4b..ace7aad27 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -191,7 +191,28 @@ pub trait Dataflow: Container { Ok(outs.into()) } + /// Add a dataflow [`OpType`] to the sibling graph, wiring up the `input_wires` to the + /// incoming ports of the resulting node, and the `static_wires` to the static input ports. + /// + /// # Errors + /// + /// Returns a [`BuildError::OperationWiring`] error if the `input_wires` or `static_wires` cannot be connected. + fn add_dataflow_op_with_static( + &mut self, + nodetype: impl Into, + input_wires: impl IntoIterator, + static_wires: impl IntoIterator, + ) -> Result, BuildError> { + let op: OpType = nodetype.into(); + let static_in_ports = op.static_input_ports(); + let handle = self.add_dataflow_op(op, input_wires)?; + for (src, in_port) in static_wires.into_iter().zip(static_in_ports) { + self.hugr_mut() + .connect(src.node(), src.source(), handle.node(), in_port); + } + Ok(handle) + } /// Insert a hugr-defined op to the sibling graph, wiring up the /// `input_wires` to the incoming ports of the resulting root node. /// diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index af1c1a145..0fe68db88 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -558,12 +558,12 @@ pub(super) mod test { use itertools::Itertools; use super::{ExtOpSignature, SignatureFromArgs}; - use crate::builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr}; + use crate::builder::{endo_sig, Container, DFGBuilder, Dataflow, DataflowHugr}; use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc}; use crate::extension::prelude::{QB_T, USIZE_T}; use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; - use crate::ops::OpName; + use crate::ops::{OpName, Value}; use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; use crate::types::type_param::{TypeArgError, TypeParam}; use crate::types::{OpDefSignature, Signature, Type, TypeArg, TypeBound, TypeRV}; @@ -848,14 +848,13 @@ pub(super) mod test { Ok(()) } - pub(crate) fn static_input_op() -> (Extension, OpName) { + fn static_input_op() -> (Extension, OpName) { let mut e = Extension::new_test(EXT_ID); let def = e .add_op( "StaticInOp".into(), "".into(), - OpDefSignature::new(vec![], Signature::new(vec![], vec![QB_T])) - .with_static_inputs(USIZE_T), + OpDefSignature::new(vec![], Signature::new_endo(QB_T)).with_static_inputs(USIZE_T), ) .unwrap(); let op_name = def.name().clone(); @@ -871,10 +870,18 @@ pub(super) mod test { let ext_op = e.instantiate_extension_op(&op_name, vec![], ®).unwrap(); let sig = ext_op.ext_op_signature(); let expected = ExtOpSignature::new( - Signature::new(vec![], vec![QB_T]).with_extension_delta(e.name().clone()), + Signature::new_endo(QB_T).with_extension_delta(e.name().clone()), USIZE_T, ); assert_eq!(sig, &expected); + + let mut dfg = DFGBuilder::new(expected.func_type.with_prelude())?; + let cnst = dfg.add_constant(Value::extension( + crate::extension::prelude::ConstUsize::new(42), + )); + let [inq] = dfg.input_wires_arr(); + let ext_op_node = dfg.add_dataflow_op_with_static(ext_op, vec![inq], [cnst.wire()])?; + dfg.finish_hugr_with_outputs(ext_op_node.outputs(), ®)?; Ok(()) } diff --git a/hugr-core/src/hugr/serialize.rs b/hugr-core/src/hugr/serialize.rs index 9a213f8ef..721f7da51 100644 --- a/hugr-core/src/hugr/serialize.rs +++ b/hugr-core/src/hugr/serialize.rs @@ -184,9 +184,8 @@ impl TryFrom<&Hugr> for SerHugrLatest { let value_count = op.value_port_count(dir); let is_value_port = offset < value_count; let static_len = op.static_ports(dir).len(); - let is_static_input = offset < (value_count + static_len); - // let is_static_input = op.static_port(dir).map_or(false, |p| p.index() == offset); - let offset = (is_value_port || is_static_input).then_some(offset as u16); + let value_or_static = offset < (value_count + static_len); + let offset = (is_value_port || value_or_static).then_some(offset as u16); (node_rekey[&node], offset) }; diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index b79fd679a..3d8647490 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -181,9 +181,10 @@ impl OpType { } // Constant port - let mut static_kind = self.static_port_kind(dir); let static_offset = port.index() - port_count; - if static_offset < static_kind.len() { + let static_port_count = self.static_port_count(dir); + if static_offset < static_port_count { + let mut static_kind = self.static_port_kind(dir); return Some(static_kind.remove(static_offset)); } @@ -197,12 +198,11 @@ impl OpType { /// Returns None if there is no such port, or if the operation defines multiple non-dataflow ports. pub fn other_port(&self, dir: Direction) -> Option { let df_count = self.value_port_count(dir); + let static_count = self.static_port_count(dir); let non_df_count = self.non_df_port_count(dir); - // if there is a static input it comes before the non_df_ports - let static_input = - (dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize; + if self.other_port_kind(dir).is_some() && non_df_count >= 1 { - Some(Port::new(dir, df_count + static_input)) + Some(Port::new(dir, df_count + static_count)) } else { None } @@ -229,9 +229,10 @@ impl OpType { /// See [`OpType::static_input_ports`] and [`OpType::static_output_port`]. #[inline] pub fn static_ports(&self, dir: Direction) -> Vec { - let static_len = self.static_port_kind(dir).len(); + let static_len = self.static_port_count(dir); + let value_port_count = self.value_port_count(dir); (0..static_len) - .map(|i| Port::new(dir, self.value_port_count(dir) + i)) + .map(|i| Port::new(dir, value_port_count + i)) .collect() } @@ -278,7 +279,7 @@ impl OpType { /// Returns the number of ports for the given direction. #[inline] pub fn port_count(&self, dir: Direction) -> usize { - let static_len = self.static_port_kind(dir).len(); + let static_len = self.static_port_count(dir); let non_df_count = self.non_df_port_count(dir); self.value_port_count(dir) + static_len + non_df_count } @@ -411,6 +412,14 @@ pub trait OpTrait { } .is_some() as usize } + + /// Get the number of static multiports. + fn static_port_count(&self, dir: Direction) -> usize { + match dir { + Direction::Incoming => self.static_inputs().len(), + Direction::Outgoing => self.static_output().is_some() as usize, + } + } } /// Properties of child graphs of ops, if the op has children. diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 3eb0f9524..4bc3e0cc0 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -10,10 +10,13 @@ use { ::proptest_derive::Arbitrary, }; -use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::HugrView; use crate::types::{type_param::TypeArg, Signature}; use crate::{extension::ExtOpSignature, hugr::internal::HugrMutInternals, types::EdgeKind}; +use crate::{ + extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}, + Direction, +}; use crate::{ops, Hugr, IncomingPort, Node}; use super::dataflow::DataflowOpTrait; @@ -169,6 +172,14 @@ impl DataflowOpTrait for ExtensionOp { .map(EdgeKind::Const) .collect() } + + fn static_port_count(&self, dir: Direction) -> usize { + // specialise as we can count without allocating + match dir { + Direction::Incoming => self.signature.static_inputs().len(), + Direction::Outgoing => 0, + } + } } /// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`]. @@ -258,6 +269,23 @@ impl DataflowOpTrait for OpaqueOp { fn signature(&self) -> Signature { self.ext_op_signature().func_type } + + fn static_inputs(&self) -> Vec { + self.signature + .static_inputs() + .iter() + .cloned() + .map(EdgeKind::Const) + .collect() + } + + fn static_port_count(&self, dir: Direction) -> usize { + // specialise as we can count without allocating + match dir { + Direction::Incoming => self.signature.static_inputs().len(), + Direction::Outgoing => 0, + } + } } /// Resolve serialized names of operations into concrete implementation (OpDefs) where possible @@ -305,7 +333,6 @@ pub fn resolve_opaque_op( r.name().clone(), )); }; - dbg!(opaque.signature().extension_reqs); let ext_op = ExtensionOp::new_with_cached( def.clone(), opaque.args.clone(), @@ -318,8 +345,6 @@ pub fn resolve_opaque_op( cause: e, })?; if opaque.signature() != ext_op.signature() { - dbg!(opaque.signature().extension_reqs); - dbg!(ext_op.signature().extension_reqs); return Err(OpaqueOpError::SignatureMismatch { node, extension: opaque.extension.clone(), diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 7947dd8d2..119f58c6f 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -5,7 +5,7 @@ use super::{impl_op_name, OpTag, OpTrait}; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; use crate::ops::StaticTag; use crate::types::{EdgeKind, PolyFuncType, Signature, Type, TypeArg, TypeRow}; -use crate::IncomingPort; +use crate::{Direction, IncomingPort}; #[cfg(test)] use ::proptest_derive::Arbitrary; @@ -49,6 +49,15 @@ pub trait DataflowOpTrait { fn static_inputs(&self) -> Vec { vec![] } + + /// The number of static input ports in the given direction. + #[inline] + fn static_port_count(&self, dir: Direction) -> usize { + match dir { + Direction::Incoming => self.static_inputs().len(), + Direction::Outgoing => 0, + } + } } /// Helpers to construct input and output nodes @@ -151,6 +160,10 @@ impl OpTrait for T { fn static_inputs(&self) -> Vec { DataflowOpTrait::static_inputs(self) } + + fn static_port_count(&self, dir: Direction) -> usize { + DataflowOpTrait::static_port_count(self, dir) + } } impl StaticTag for T { const TAG: OpTag = T::TAG; diff --git a/hugr-core/src/ops/handle.rs b/hugr-core/src/ops/handle.rs index d7fe16419..428cd04e5 100644 --- a/hugr-core/src/ops/handle.rs +++ b/hugr-core/src/ops/handle.rs @@ -1,6 +1,6 @@ //! Handles to nodes in HUGR. use crate::types::{Type, TypeBound}; -use crate::Node; +use crate::{Node, OutgoingPort, Wire}; use derive_more::From as DerFrom; use smol_str::SmolStr; @@ -100,6 +100,12 @@ impl AliasID { #[derive(DerFrom, Debug, Clone, PartialEq, Eq)] /// Handle to a [Const](crate::ops::OpType::Const) node. pub struct ConstID(Node); +impl ConstID { + /// Retrieve the outgoing wire for the constant node. + pub fn wire(&self) -> Wire { + Wire::new(self.node(), OutgoingPort::from(0)) + } +} #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowBlock](crate::ops::DataflowBlock) or [Exit](crate::ops::ExitBlock) node. diff --git a/hugr-core/src/ops/tag.rs b/hugr-core/src/ops/tag.rs index b9edcf84e..1dd4bb022 100644 --- a/hugr-core/src/ops/tag.rs +++ b/hugr-core/src/ops/tag.rs @@ -46,8 +46,6 @@ pub enum OpTag { Input, /// A dataflow output. Output, - /// Dataflow node that has a static input - StaticInput, /// Node that has a static output StaticOutput, /// A function call. @@ -127,11 +125,10 @@ impl OpTag { ], OpTag::TailLoop => &[OpTag::DataflowChild, OpTag::DataflowParent], OpTag::Conditional => &[OpTag::DataflowChild], - OpTag::StaticInput => &[OpTag::Any], OpTag::StaticOutput => &[OpTag::Any], - OpTag::FnCall => &[OpTag::StaticInput, OpTag::DataflowChild], - OpTag::LoadConst => &[OpTag::StaticInput, OpTag::DataflowChild], - OpTag::LoadFunc => &[OpTag::StaticInput, OpTag::DataflowChild], + OpTag::FnCall => &[OpTag::DataflowChild], + OpTag::LoadConst => &[OpTag::DataflowChild], + OpTag::LoadFunc => &[OpTag::DataflowChild], OpTag::Leaf => &[OpTag::DataflowChild], OpTag::DataflowParent => &[OpTag::Any], } @@ -159,7 +156,6 @@ impl OpTag { OpTag::Cfg => "Nested control-flow operation", OpTag::TailLoop => "Tail-recursive loop", OpTag::Conditional => "Conditional operation", - OpTag::StaticInput => "Node with static input (LoadConst, LoadFunc, or FnCall)", OpTag::StaticOutput => "Node with static output (FuncDefn, FuncDecl, Const)", OpTag::FnCall => "Function call", OpTag::LoadConst => "Constant load operation", From d3fbc858995b887a12088c7ebce78f3f368fb556 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 4 Nov 2024 12:14:29 +0000 Subject: [PATCH 08/14] refactor!: serialize statics next to `binary` --- .../op_def/serialize_signature_func.rs | 54 +++++++++++++------ hugr-core/src/hugr/serialize/test.rs | 33 ++++++------ hugr-core/src/types/poly_func.rs | 14 +---- 3 files changed, 56 insertions(+), 45 deletions(-) diff --git a/hugr-core/src/extension/op_def/serialize_signature_func.rs b/hugr-core/src/extension/op_def/serialize_signature_func.rs index 0877e808f..2daa3545b 100644 --- a/hugr-core/src/extension/op_def/serialize_signature_func.rs +++ b/hugr-core/src/extension/op_def/serialize_signature_func.rs @@ -1,14 +1,18 @@ use serde::{Deserialize, Serialize}; +use crate::types::{PolyFuncTypeBase, RowVariable, TypeRow}; + use super::{CustomValidator, OpDefSignature, SignatureFunc}; #[derive(serde::Deserialize, serde::Serialize, PartialEq, Debug, Clone)] struct SerSignatureFunc { /// If the type scheme is available explicitly, store it. - signature: Option, + signature: Option>, /// Whether an associated binary function is expected. /// If `signature` is `None`, a true value here indicates a custom compute function. /// If `signature` is not `None`, a true value here indicates a custom validation function. binary: bool, + #[serde(default, skip_serializing_if = "TypeRow::is_empty")] + static_inputs: TypeRow, } pub(super) fn serialize(value: &super::SignatureFunc, serializer: S) -> Result @@ -16,17 +20,20 @@ where S: serde::Serializer, { match value { - SignatureFunc::PolyFuncType(poly) => SerSignatureFunc { - signature: Some(poly.clone()), + SignatureFunc::PolyFuncType(op_sig) => SerSignatureFunc { + signature: Some(op_sig.poly_func_type().clone()), + static_inputs: op_sig.static_inputs().clone(), binary: false, }, SignatureFunc::CustomValidator(CustomValidator { poly_func, .. }) | SignatureFunc::MissingValidateFunc(poly_func) => SerSignatureFunc { - signature: Some(poly_func.clone()), + signature: Some(poly_func.poly_func_type().clone()), + static_inputs: poly_func.static_inputs().clone(), binary: true, }, SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => SerSignatureFunc { signature: None, + static_inputs: TypeRow::new(), binary: true, }, } @@ -37,11 +44,19 @@ pub(super) fn deserialize<'de, D>(deserializer: D) -> Result, { - let SerSignatureFunc { signature, binary } = SerSignatureFunc::deserialize(deserializer)?; + let SerSignatureFunc { + signature, + binary, + static_inputs, + } = SerSignatureFunc::deserialize(deserializer)?; match (signature, binary) { - (Some(sig), false) => Ok(sig.into()), - (Some(sig), true) => Ok(SignatureFunc::MissingValidateFunc(sig)), + (Some(sig), false) => Ok(OpDefSignature::from(sig) + .with_static_inputs(static_inputs) + .into()), + (Some(sig), true) => Ok(SignatureFunc::MissingValidateFunc( + OpDefSignature::from(sig).with_static_inputs(static_inputs), + )), (None, true) => Ok(SignatureFunc::MissingComputeFunc), (None, false) => Err(serde::de::Error::custom( "No signature provided and custom computation not expected.", @@ -57,10 +72,12 @@ mod test { use super::*; use crate::{ extension::{ - prelude::USIZE_T, CustomSignatureFunc, CustomValidator, ExtensionRegistry, OpDef, - SignatureError, ValidateTypeArgs, + prelude::{BOOL_T, USIZE_T}, + CustomSignatureFunc, CustomValidator, ExtensionRegistry, OpDef, SignatureError, + ValidateTypeArgs, }, - types::{FuncValueType, Signature, TypeArg}, + type_row, + types::{Signature, TypeArg}, }; #[derive(serde::Deserialize, serde::Serialize, Debug)] @@ -121,32 +138,35 @@ mod test { #[test] fn test_serial_sig_func() { // test round-trip - let sig: FuncValueType = Signature::new_endo(USIZE_T.clone()).into(); + let sig = OpDefSignature::new([], Signature::new_endo(USIZE_T.clone())) + .with_static_inputs(BOOL_T); let simple: SignatureFunc = sig.clone().into(); let ser: SerSignatureFunc = simple.into(); let expected_ser = SerSignatureFunc { - signature: Some(sig.clone().into()), + signature: Some(sig.poly_func_type().clone()), binary: false, + static_inputs: type_row![BOOL_T], }; assert_eq!(ser, expected_ser); let deser = SignatureFunc::try_from(ser).unwrap(); assert_matches!( deser, - SignatureFunc::PolyFuncType(poly_func) => { - assert_eq!(poly_func, sig.clone().into()); + SignatureFunc::PolyFuncType(op_def_sig) => { + assert_eq!(op_def_sig, sig.clone()); }); let with_custom: SignatureFunc = CustomValidator::new(sig.clone(), NoValidate).into(); let ser: SerSignatureFunc = with_custom.into(); let expected_ser = SerSignatureFunc { - signature: Some(sig.clone().into()), + signature: Some(sig.poly_func_type().clone()), + static_inputs: type_row![BOOL_T], binary: true, }; assert_eq!(ser, expected_ser); let mut deser = SignatureFunc::try_from(ser.clone()).unwrap(); assert_matches!(&deser, SignatureFunc::MissingValidateFunc(poly_func) => { - assert_eq!(poly_func, &OpDefSignature::from(sig.clone())); + assert_eq!(poly_func, &sig.clone()); } ); @@ -162,6 +182,7 @@ mod test { let custom: SignatureFunc = CustomSig.into(); let ser: SerSignatureFunc = custom.into(); let expected_ser = SerSignatureFunc { + static_inputs: type_row![], signature: None, binary: true, }; @@ -174,6 +195,7 @@ mod test { let bad_ser = SerSignatureFunc { signature: None, + static_inputs: type_row![], binary: false, }; diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index e2f809c6c..2250c0458 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -17,8 +17,8 @@ use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use crate::std_extensions::logic::LogicOp; use crate::types::type_param::TypeParam; use crate::types::{ - FuncValueType, OpDefSignature, PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, - TypeRV, + FuncValueType, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, SumType, Type, TypeArg, + TypeBound, TypeRV, }; use crate::{type_row, OutgoingPort}; @@ -32,12 +32,13 @@ use rstest::rstest; const NAT: Type = crate::extension::prelude::USIZE_T; const QB: Type = crate::extension::prelude::QB_T; +type PolyFuncTypeRV = PolyFuncTypeBase; /// Version 1 of the Testing HUGR serialization format, see `testing_hugr.py`. #[derive(Serialize, Deserialize, PartialEq, Debug, Default)] struct SerTestingLatest { typ: Option, sum_type: Option, - poly_func_type: Option, + poly_func_type: Option, value: Option, optype: Option, op_def: Option, @@ -124,14 +125,14 @@ macro_rules! impl_sertesting_from { impl_sertesting_from!(crate::types::TypeRV, typ); impl_sertesting_from!(crate::types::SumType, sum_type); -impl_sertesting_from!(crate::types::OpDefSignature, poly_func_type); +impl_sertesting_from!(PolyFuncTypeRV, poly_func_type); impl_sertesting_from!(crate::ops::Value, value); impl_sertesting_from!(NodeSer, optype); impl_sertesting_from!(SimpleOpDef, op_def); impl From for SerTestingLatest { fn from(v: PolyFuncType) -> Self { - let v: OpDefSignature = v.into(); + let v: PolyFuncTypeRV = v.into(); v.into() } } @@ -484,7 +485,7 @@ fn polyfunctype1() -> PolyFuncType { PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type) } -fn polyfunctype2() -> OpDefSignature { +fn polyfunctype2() -> PolyFuncTypeRV { let tv0 = TypeRV::new_row_var_use(0, TypeBound::Any); let tv1 = TypeRV::new_row_var_use(1, TypeBound::Copyable); let params = [TypeBound::Any, TypeBound::Copyable].map(TypeParam::new_list); @@ -492,7 +493,7 @@ fn polyfunctype2() -> OpDefSignature { TypeRV::new_function(FuncValueType::new(tv0.clone(), tv1.clone())), tv0, ]; - let res = OpDefSignature::new(params, FuncValueType::new(inputs, tv1)); + let res = PolyFuncTypeRV::new(params, FuncValueType::new(inputs, tv1)); // Just check we've got the arguments the right way round // (not that it really matters for the serialization schema we have) res.validate(&EMPTY_REG).unwrap(); @@ -515,15 +516,15 @@ fn roundtrip_polyfunctype_fixedlen(#[case] poly_func_type: PolyFuncType) { #[rstest] #[case(FuncValueType::new_endo(type_row![]).into())] -#[case(OpDefSignature::new([TypeParam::String], FuncValueType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(OpDefSignature::new([TypeBound::Copyable.into()], FuncValueType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] -#[case(OpDefSignature::new([TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(type_row![])))] -#[case(OpDefSignature::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FuncValueType::new_endo(type_row![])))] -#[case(OpDefSignature::new( +#[case(PolyFuncTypeRV::new([TypeParam::String], FuncValueType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncTypeRV::new([TypeBound::Copyable.into()], FuncValueType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] +#[case(PolyFuncTypeRV::new([TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(type_row![])))] +#[case(PolyFuncTypeRV::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FuncValueType::new_endo(type_row![])))] +#[case(PolyFuncTypeRV::new( [TypeParam::new_list(TypeBound::Any)], FuncValueType::new_endo(TypeRV::new_row_var_use(0, TypeBound::Any))))] #[case(polyfunctype2())] -fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: OpDefSignature) { +fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { check_testing_roundtrip(poly_func_type) } @@ -561,9 +562,9 @@ fn std_extensions_valid() { mod proptest { use super::check_testing_roundtrip; - use super::{NodeSer, SimpleOpDef}; + use super::{NodeSer, PolyFuncTypeRV, SimpleOpDef}; use crate::ops::{OpType, OpaqueOp, Value}; - use crate::types::{OpDefSignature, Type}; + use crate::types::Type; use proptest::prelude::*; impl Arbitrary for NodeSer { @@ -596,7 +597,7 @@ mod proptest { } #[test] - fn prop_roundtrip_poly_func_type(t: OpDefSignature) { + fn prop_roundtrip_poly_func_type(t: PolyFuncTypeRV) { check_testing_roundtrip(t) } diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index d760a854b..628453524 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -57,23 +57,11 @@ pub type PolyFuncType = PolyFuncTypeBase; /// It may also have a row of static inputs. /// /// [OpDef]: crate::extension::OpDef -#[derive( - Clone, - PartialEq, - Debug, - Default, - Eq, - Hash, - derive_more::Display, - serde::Serialize, - serde::Deserialize, -)] +#[derive(Clone, PartialEq, Debug, Default, Eq, Hash, derive_more::Display)] #[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] #[display("{}{}{}", self.display_params(), self.static_inputs, self.body())] pub struct OpDefSignature { - #[serde(flatten)] signature: PolyFuncTypeBase, - #[serde(default, skip_serializing_if = "TypeRow::is_empty")] static_inputs: TypeRow, } From 416416b0193d0f020e63ba9fe92266299e3e78e5 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 4 Nov 2024 14:38:12 +0000 Subject: [PATCH 09/14] refactor: move static inputs outside signature in serialised extop --- hugr-core/src/extension/op_def.rs | 8 ++--- hugr-core/src/import.rs | 2 ++ hugr-core/src/ops/custom.rs | 56 +++++++++++++++++++++++-------- 3 files changed, 48 insertions(+), 18 deletions(-) diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 0fe68db88..2ad4b424d 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -262,13 +262,12 @@ impl SignatureFunc { } /// Instantiated [OpDef] signature. -#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] pub struct ExtOpSignature { - #[serde(flatten)] + // #[serde(flatten)] /// The dataflow function type of the signature. pub func_type: Signature, - #[serde(default, skip_serializing_if = "TypeRow::is_empty")] /// The static inputs of the signature. pub static_inputs: TypeRow, } @@ -881,7 +880,8 @@ pub(super) mod test { )); let [inq] = dfg.input_wires_arr(); let ext_op_node = dfg.add_dataflow_op_with_static(ext_op, vec![inq], [cnst.wire()])?; - dfg.finish_hugr_with_outputs(ext_op_node.outputs(), ®)?; + let h = dfg.finish_hugr_with_outputs(ext_op_node.outputs(), ®)?; + println!("{}", serde_json::to_string_pretty(&h).unwrap()); Ok(()) } diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 2d36ec873..36f9e17c1 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -12,6 +12,7 @@ use crate::{ FuncDecl, FuncDefn, Input, LoadFunction, Module, OpType, OpaqueOp, Output, Tag, TailLoop, CFG, DFG, }, + type_row, types::{ type_param::TypeParam, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, NoRV, PolyFuncType, PolyFuncTypeBase, RowVariable, Signature, Type, TypeArg, TypeBase, TypeBound, @@ -500,6 +501,7 @@ impl<'a> Context<'a> { String::default(), args, signature, + type_row![], )); let node = self.make_node(node_id, optype, parent)?; diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 4bc3e0cc0..8d2bcddb7 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -10,13 +10,13 @@ use { ::proptest_derive::Arbitrary, }; -use crate::hugr::HugrView; -use crate::types::{type_param::TypeArg, Signature}; +use crate::types::{type_param::TypeArg, Signature, Type}; use crate::{extension::ExtOpSignature, hugr::internal::HugrMutInternals, types::EdgeKind}; use crate::{ extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}, Direction, }; +use crate::{hugr::HugrView, types::TypeRow}; use crate::{ops, Hugr, IncomingPort, Node}; use super::dataflow::DataflowOpTrait; @@ -111,7 +111,8 @@ impl ExtensionOp { name: self.def.name().clone(), description: self.def.description().into(), args: self.args.clone(), - signature: self.signature.clone(), + signature: self.signature.func_type.clone(), + static_inputs: self.signature.static_inputs.clone(), } } @@ -126,7 +127,11 @@ impl From for OpaqueOp { let ExtensionOp { def, args, - signature, + signature: + ExtOpSignature { + func_type: signature, + static_inputs, + }, } = op; OpaqueOp { extension: def.extension().clone(), @@ -134,6 +139,7 @@ impl From for OpaqueOp { description: def.description().into(), args, signature, + static_inputs, } } } @@ -203,7 +209,9 @@ pub struct OpaqueOp { // note that the `signature` field might not include `extension`. Thus this must // remain private, and should be accessed through // `DataflowOpTrait::signature`. - signature: ExtOpSignature, + signature: Signature, + #[serde(default, skip_serializing_if = "TypeRow::is_empty")] + static_inputs: TypeRow, } fn qualify_name(res_id: &ExtensionId, name: &OpNameRef) -> OpName { @@ -217,7 +225,8 @@ impl OpaqueOp { name: impl Into, description: String, args: impl Into>, - signature: impl Into, + signature: impl Into, + static_inputs: impl Into, ) -> Self { Self { extension, @@ -225,6 +234,7 @@ impl OpaqueOp { description, args: args.into(), signature: signature.into(), + static_inputs: static_inputs.into(), } } } @@ -251,11 +261,17 @@ impl OpaqueOp { &self.extension } + /// Static inputs. + pub fn static_inputs(&self) -> &[Type] { + &self.static_inputs + } + /// Instantiated signature of the operation. pub fn ext_op_signature(&self) -> ExtOpSignature { - let mut sig = self.signature.clone(); - sig.func_type = sig.func_type.with_extension_delta(self.extension.clone()); - sig + ExtOpSignature { + func_type: self.signature().clone(), + static_inputs: self.static_inputs.clone(), + } } } @@ -267,12 +283,13 @@ impl DataflowOpTrait for OpaqueOp { } fn signature(&self) -> Signature { - self.ext_op_signature().func_type + self.signature + .clone() + .with_extension_delta(self.extension.clone()) } fn static_inputs(&self) -> Vec { - self.signature - .static_inputs() + self.static_inputs .iter() .cloned() .map(EdgeKind::Const) @@ -282,7 +299,7 @@ impl DataflowOpTrait for OpaqueOp { fn static_port_count(&self, dir: Direction) -> usize { // specialise as we can count without allocating match dir { - Direction::Incoming => self.signature.static_inputs().len(), + Direction::Incoming => self.static_inputs.len(), Direction::Outgoing => 0, } } @@ -399,6 +416,7 @@ pub enum OpaqueOpError { mod test { use crate::std_extensions::arithmetic::conversions::{self, CONVERT_OPS_REGISTRY}; + use crate::type_row; use crate::{ extension::{ prelude::{BOOL_T, QB_T, USIZE_T}, @@ -420,6 +438,7 @@ mod test { "desc".into(), vec![TypeArg::Type { ty: USIZE_T }], sig.clone(), + type_row![], ); assert_eq!(op.name(), "res.op"); assert_eq!(DataflowOpTrait::description(&op), "desc"); @@ -440,6 +459,7 @@ mod test { "description".into(), vec![], Signature::new(i0.clone(), BOOL_T), + type_row![], ); let resolved = super::resolve_opaque_op(Node::from(portgraph::NodeIndex::new(1)), &opaque, registry) @@ -475,8 +495,16 @@ mod test { "".into(), vec![], endo_sig.clone(), + type_row![], + ); + let opaque_comp = OpaqueOp::new( + ext_id.clone(), + comp_name, + "".into(), + vec![], + endo_sig, + type_row![], ); - let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig); let resolved_val = super::resolve_opaque_op( Node::from(portgraph::NodeIndex::new(1)), &opaque_val, From 218bb0b5daa234c908e3c2cfe9b40f3089d4b84d Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 4 Nov 2024 14:41:09 +0000 Subject: [PATCH 10/14] feat: add static inputs to python --- hugr-py/src/hugr/_serialization/extension.py | 5 ++- hugr-py/src/hugr/_serialization/ops.py | 2 + hugr-py/src/hugr/build/dfg.py | 19 +++++++++- hugr-py/src/hugr/build/tracked_dfg.py | 9 ++++- hugr-py/src/hugr/ext.py | 7 +++- hugr-py/src/hugr/ops.py | 4 +- hugr-py/tests/test_custom.py | 39 ++++++++++++++++---- 7 files changed, 71 insertions(+), 14 deletions(-) diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 6420bffff..96c5f6c74 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -15,6 +15,7 @@ ExtensionId, ExtensionSet, PolyFuncType, + Type, TypeBound, TypeParam, ) @@ -98,6 +99,7 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): misc: dict[str, Any] | None = None signature: PolyFuncType | None = None binary: bool = False + static_inputs: list[Type] = pd.Field(default_factory=list) lower_funcs: list[FixedHugr] = pd.Field(default_factory=list) def deserialize(self, extension: ext.Extension) -> ext.OpDef: @@ -108,7 +110,8 @@ def deserialize(self, extension: ext.Extension) -> ext.OpDef: misc=self.misc or {}, signature=ext.OpDefSig( self.signature.deserialize() if self.signature else None, - self.binary, + static_inputs=[t.deserialize() for t in self.static_inputs], + binary=self.binary, ), lower_funcs=[f.deserialize() for f in self.lower_funcs], ) diff --git a/hugr-py/src/hugr/_serialization/ops.py b/hugr-py/src/hugr/_serialization/ops.py index aa573cb64..49dca79a2 100644 --- a/hugr-py/src/hugr/_serialization/ops.py +++ b/hugr-py/src/hugr/_serialization/ops.py @@ -512,6 +512,7 @@ class ExtensionOp(DataflowOp): signature: stys.FunctionType = Field(default_factory=stys.FunctionType.empty) description: str = "" args: list[stys.TypeArg] = Field(default_factory=list) + static_inputs: list[Type] = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.signature = stys.FunctionType(input=list(in_types), output=list(out_types)) @@ -525,6 +526,7 @@ def deserialize(self) -> ops.Custom: op_name=self.name, signature=self.signature.deserialize(), args=deser_it(self.args), + static_inputs=deser_it(self.static_inputs), ) model_config = ConfigDict( diff --git a/hugr-py/src/hugr/build/dfg.py b/hugr-py/src/hugr/build/dfg.py index 8af4ae1ad..69d458690 100644 --- a/hugr-py/src/hugr/build/dfg.py +++ b/hugr-py/src/hugr/build/dfg.py @@ -195,12 +195,19 @@ def add_op( return replace(new_n, _num_out_ports=op.num_out) - def add(self, com: ops.Command, *, metadata: dict[str, Any] | None = None) -> Node: + def add( + self, + com: ops.Command, + *, + static_in: Iterable[Wire] | None = None, + metadata: dict[str, Any] | None = None, + ) -> Node: """Add a command (holding a dataflow operation and the incoming wires) to the graph. Args: com: The command to add. + static_in: Any static input wires to the command. metadata: Metadata to attach to the function definition. Defaults to None. Example: @@ -218,7 +225,15 @@ def raise_no_ints(): wires = ( (w if not isinstance(w, int) else raise_no_ints()) for w in com.incoming ) - return self.add_op(com.op, *wires, metadata=metadata) + node = self.add_op(com.op, *wires, metadata=metadata) + + # wire up static inputs + static_inputs = list(static_in or []) + dataflow_in = self.hugr.num_incoming(node) + for i, w in enumerate(static_inputs): + # static inputs always come after dataflow inputs + self.hugr.add_link(w.out_port(), node.inp(dataflow_in + i)) + return node def extend(self, *coms: ops.Command) -> list[Node]: """Add a series of commands to the DFG. diff --git a/hugr-py/src/hugr/build/tracked_dfg.py b/hugr-py/src/hugr/build/tracked_dfg.py index 439d5680c..43f918131 100644 --- a/hugr-py/src/hugr/build/tracked_dfg.py +++ b/hugr-py/src/hugr/build/tracked_dfg.py @@ -124,7 +124,13 @@ def tracked_wire(self, index: int) -> Wire: raise IndexError(msg) return tracked - def add(self, com: Command, *, metadata: dict[str, Any] | None = None) -> Node: + def add( + self, + com: Command, + *, + static_in: Iterable[Wire] | None = None, + metadata: dict[str, Any] | None = None, + ) -> Node: """Add a command to the DFG. Overrides :meth:`Dfg.add ` to allow Command inputs @@ -139,6 +145,7 @@ def add(self, com: Command, *, metadata: dict[str, Any] | None = None) -> Node: Args: com: Command to append. + static_in: Any static input wires to the command. metadata: Metadata to attach to the function definition. Defaults to None. Returns: diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 533e55cd7..23145116d 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -164,12 +164,15 @@ class OpDefSig: #: The polymorphic function type of the operation (type scheme). poly_func: tys.PolyFuncType | None + #: Static input types of the operation. + static_inputs: tys.TypeRow #: If no static type scheme known, flag indicates a computation of the signature. binary: bool def __init__( self, poly_func: tys.PolyFuncType | tys.FunctionType | None, + static_inputs: tys.TypeRow | None = None, binary: bool = False, ) -> None: if poly_func is None and not binary: @@ -182,6 +185,7 @@ def __init__( poly_func = tys.PolyFuncType([], poly_func) self.poly_func = poly_func self.binary = binary + self.static_inputs = static_inputs or tys.TypeRow() @dataclass @@ -209,6 +213,7 @@ def _to_serial(self) -> ext_s.OpDef: if self.signature.poly_func else None, binary=self.signature.binary, + static_inputs=[t._to_serial_root() for t in self.signature.static_inputs], lower_funcs=[f._to_serial() for f in self.lower_funcs], ) @@ -413,7 +418,7 @@ def register_op( """ if not isinstance(signature, OpDefSig): binary = signature is None - signature = OpDefSig(signature, binary) + signature = OpDefSig(signature, binary=binary) def _inner(cls: type[T]) -> type[T]: new_description = cls.__doc__ if description is None and cls.__doc__ else "" diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index a97470adf..3768772f1 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -313,6 +313,7 @@ class Custom(DataflowOp): description: str = "" extension: tys.ExtensionId = "" args: list[tys.TypeArg] = field(default_factory=list) + static_inputs: tys.TypeRow = field(default_factory=tys.TypeRow) def _to_serial(self, parent: Node) -> sops.ExtensionOp: return sops.ExtensionOp( @@ -322,6 +323,7 @@ def _to_serial(self, parent: Node) -> sops.ExtensionOp: signature=self.signature._to_serial(), description=self.description, args=ser_it(self.args), + static_inputs=ser_it(self.static_inputs), ) def outer_signature(self) -> tys.FunctionType: @@ -378,12 +380,12 @@ def to_custom_op(self) -> Custom: sig = poly_func.body else: sig = self.signature - return Custom( op_name=self._op_def.name, signature=sig, extension=ext.name if ext else "", args=self.args, + static_inputs=self._op_def.signature.static_inputs, ) def _to_serial(self, parent: Node) -> sops.ExtensionOp: diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 48f57de7a..8ec9574e6 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -2,7 +2,7 @@ import pytest -from hugr import ext, ops, tys +from hugr import ext, ops, tys, val from hugr.build.dfg import Dfg from hugr.hugr import Hugr, Node from hugr.ops import AsExtOp, Custom, ExtOp @@ -15,8 +15,8 @@ from .conftest import CX, QUANTUM_EXT, H, Measure, Rz, validate -STRINGLY_EXT = ext.Extension("my_extension", ext.Version(0, 0, 0)) -_STRINGLY_DEF = STRINGLY_EXT.add_op_def( +TEST_EXT = ext.Extension("my_extension", ext.Version(0, 0, 0)) +_STRINGLY_DEF = TEST_EXT.add_op_def( ext.OpDef( "StringlyOp", signature=ext.OpDefSig( @@ -24,6 +24,14 @@ ), ) ) +_STATIC_DEF = TEST_EXT.add_op_def( + ext.OpDef( + "StaticInOp", + signature=ext.OpDefSig( + tys.FunctionType.endo([FLOAT_T]), static_inputs=[tys.Bool] + ), + ) +) @dataclass @@ -31,7 +39,7 @@ class StringlyOp(AsExtOp): tag: str def op_def(self) -> ext.OpDef: - return STRINGLY_EXT.get_op("StringlyOp") + return TEST_EXT.get_op("StringlyOp") def type_args(self) -> list[tys.TypeArg]: return [tys.StringArg(self.tag)] @@ -57,7 +65,7 @@ def test_stringly_typed(): n = dfg.add(StringlyOp("world")()) dfg.set_outputs() assert dfg.hugr[n].op == StringlyOp("world") - validate(Package([dfg.hugr], [STRINGLY_EXT])) + validate(Package([dfg.hugr], [TEST_EXT])) new_h = Hugr._from_serial(dfg.hugr._to_serial()) @@ -69,12 +77,27 @@ def test_stringly_typed(): # doesn't resolve without extension assert isinstance(new_h[n].op, Custom) - registry.add_extension(STRINGLY_EXT) + registry.add_extension(TEST_EXT) new_h.resolve_extensions(registry) assert isinstance(new_h[n].op, ExtOp) +@dataclass +class StaticInOp(AsExtOp): + def op_def(self) -> ext.OpDef: + return TEST_EXT.get_op("StaticInOp") + + +def test_static_in_op(): + dfg = Dfg(FLOAT_T) + tr = dfg.add_const(val.TRUE) + n = dfg.add(StaticInOp()(dfg.inputs()[0]), static_in=[tr]) + dfg.set_outputs(n) + assert dfg.hugr[n].op == StaticInOp() + validate(Package([dfg.hugr], [TEST_EXT])) + + def test_registry(): reg = ext.ExtensionRegistry() reg.add_extension(LOGIC_EXT) @@ -92,7 +115,7 @@ def registry() -> ext.ExtensionRegistry: reg = ext.ExtensionRegistry() reg.add_extension(LOGIC_EXT) reg.add_extension(QUANTUM_EXT) - reg.add_extension(STRINGLY_EXT) + reg.add_extension(TEST_EXT) reg.add_extension(INT_TYPES_EXTENSION) reg.add_extension(INT_OPS_EXTENSION) reg.add_extension(FLOAT_EXT) @@ -135,7 +158,7 @@ def test_custom_bad_eq(): assert Not != bad_custom_args -_LIST_T = STRINGLY_EXT.add_type_def( +_LIST_T = TEST_EXT.add_type_def( ext.TypeDef( "List", description="A list of elements.", From 3c7574a253f412e76c4ebda81036fc060c300187 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 4 Nov 2024 14:43:12 +0000 Subject: [PATCH 11/14] update schema --- specification/schema/hugr_schema_live.json | 14 ++++++++++++++ specification/schema/hugr_schema_strict_live.json | 14 ++++++++++++++ specification/schema/testing_hugr_schema_live.json | 14 ++++++++++++++ .../schema/testing_hugr_schema_strict_live.json | 14 ++++++++++++++ 4 files changed, 56 insertions(+) diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index 26c2f3bfc..743f6c7f9 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -657,6 +657,13 @@ }, "title": "Args", "type": "array" + }, + "static_inputs": { + "items": { + "$ref": "#/$defs/Type" + }, + "title": "Static Inputs", + "type": "array" } }, "required": [ @@ -1134,6 +1141,13 @@ "title": "Binary", "type": "boolean" }, + "static_inputs": { + "items": { + "$ref": "#/$defs/Type" + }, + "title": "Static Inputs", + "type": "array" + }, "lower_funcs": { "items": { "$ref": "#/$defs/FixedHugr" diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index 3ae0e3c28..23b589090 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -657,6 +657,13 @@ }, "title": "Args", "type": "array" + }, + "static_inputs": { + "items": { + "$ref": "#/$defs/Type" + }, + "title": "Static Inputs", + "type": "array" } }, "required": [ @@ -1134,6 +1141,13 @@ "title": "Binary", "type": "boolean" }, + "static_inputs": { + "items": { + "$ref": "#/$defs/Type" + }, + "title": "Static Inputs", + "type": "array" + }, "lower_funcs": { "items": { "$ref": "#/$defs/FixedHugr" diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index 7696790dc..080ba2402 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -657,6 +657,13 @@ }, "title": "Args", "type": "array" + }, + "static_inputs": { + "items": { + "$ref": "#/$defs/Type" + }, + "title": "Static Inputs", + "type": "array" } }, "required": [ @@ -1134,6 +1141,13 @@ "title": "Binary", "type": "boolean" }, + "static_inputs": { + "items": { + "$ref": "#/$defs/Type" + }, + "title": "Static Inputs", + "type": "array" + }, "lower_funcs": { "items": { "$ref": "#/$defs/FixedHugr" diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index 0490b157d..f70a2533f 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -657,6 +657,13 @@ }, "title": "Args", "type": "array" + }, + "static_inputs": { + "items": { + "$ref": "#/$defs/Type" + }, + "title": "Static Inputs", + "type": "array" } }, "required": [ @@ -1134,6 +1141,13 @@ "title": "Binary", "type": "boolean" }, + "static_inputs": { + "items": { + "$ref": "#/$defs/Type" + }, + "title": "Static Inputs", + "type": "array" + }, "lower_funcs": { "items": { "$ref": "#/$defs/FixedHugr" From 4f992546b9b9c075a3b23acc47d2ad39770dcc68 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 4 Nov 2024 14:52:35 +0000 Subject: [PATCH 12/14] main merge fixup --- hugr-core/src/export.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 1488e851e..f2c7fb2ef 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -417,7 +417,7 @@ impl<'a> Context<'a> { use std::collections::hash_map::Entry; let poly_func_type = match opdef.signature_func() { - SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type, + SignatureFunc::PolyFuncType(op_def_sig) => op_def_sig.poly_func_type(), _ => return self.make_named_global_ref(opdef.extension(), opdef.name()), }; From 8a13c64e880b047b37a214871fe1e545e3faff88 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 4 Nov 2024 17:32:43 +0000 Subject: [PATCH 13/14] common up static input wiring --- hugr-py/src/hugr/build/dfg.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/hugr-py/src/hugr/build/dfg.py b/hugr-py/src/hugr/build/dfg.py index 69d458690..ae1692f40 100644 --- a/hugr-py/src/hugr/build/dfg.py +++ b/hugr-py/src/hugr/build/dfg.py @@ -173,13 +173,19 @@ def inputs(self) -> list[OutPort]: return [self.input_node.out(i) for i in range(len(self._input_op().types))] def add_op( - self, op: ops.DataflowOp, /, *args: Wire, metadata: dict[str, Any] | None = None + self, + op: ops.DataflowOp, + /, + *args: Wire, + static_in: Iterable[Wire] | None = None, + metadata: dict[str, Any] | None = None, ) -> Node: """Add a dataflow operation to the graph, wiring in input ports. Args: op: The operation to add. args: The input wires to the operation. + static_in: Any static input wires to the command. metadata: Metadata to attach to the function definition. Defaults to None. Returns: @@ -191,7 +197,7 @@ def add_op( Node(3) """ new_n = self.hugr.add_node(op, self.parent_node, metadata=metadata) - self._wire_up(new_n, args) + self._wire_up(new_n, args, static_in=static_in) return replace(new_n, _num_out_ports=op.num_out) @@ -225,15 +231,9 @@ def raise_no_ints(): wires = ( (w if not isinstance(w, int) else raise_no_ints()) for w in com.incoming ) - node = self.add_op(com.op, *wires, metadata=metadata) + return self.add_op(com.op, *wires, metadata=metadata, static_in=static_in) + - # wire up static inputs - static_inputs = list(static_in or []) - dataflow_in = self.hugr.num_incoming(node) - for i, w in enumerate(static_inputs): - # static inputs always come after dataflow inputs - self.hugr.add_link(w.out_port(), node.inp(dataflow_in + i)) - return node def extend(self, *coms: ops.Command) -> list[Node]: """Add a series of commands to the DFG. @@ -627,10 +627,19 @@ def _fn_sig(self, func: ToNode) -> tys.PolyFuncType: raise ValueError(msg) return signature - def _wire_up(self, node: Node, ports: Iterable[Wire]) -> tys.TypeRow: + def _wire_up( + self, node: Node, ports: Iterable[Wire], static_in: Iterable[Wire] | None = None + ) -> tys.TypeRow: tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)] if isinstance(op := self.hugr[node].op, ops._PartialOp): op._set_in_types(tys) + + # wire up static inputs + static_inputs = list(static_in or []) + dataflow_in = self.hugr.num_incoming(node) + for i, w in enumerate(static_inputs): + # static inputs always come after dataflow inputs + self.hugr.add_link(w.out_port(), node.inp(dataflow_in + i)) return tys def _get_dataflow_type(self, wire: Wire) -> tys.Type: From 6b3085dcfe56d32361dae872c8dc0336938c27d8 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 5 Nov 2024 13:57:08 +0000 Subject: [PATCH 14/14] Update hugr-core/src/extension/op_def.rs Co-authored-by: Alan Lawrence --- hugr-core/src/extension/op_def.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 2ad4b424d..87f7161ce 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -145,7 +145,7 @@ impl CustomValidator { pub enum SignatureFunc { /// An explicit polymorphic function type. PolyFuncType(OpDefSignature), - /// A polymorphic function type (like [Self::PolyFuncType] but also with a custom binary for validating type arguments. + /// A polymorphic function type (like [Self::PolyFuncType]) but also with a custom binary for validating type arguments. CustomValidator(CustomValidator), /// Serialized declaration specified a custom validate binary but it was not provided. MissingValidateFunc(OpDefSignature),