From c2c0bbcf0bc952441b08a39f8732b2f96f466b65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Tue, 12 Nov 2024 14:18:21 +0000 Subject: [PATCH] feat: Share `Extension`s under `Arc`s --- .github/workflows/python-wheels.yml | 2 +- hugr-cli/tests/validate.rs | 4 +- hugr-core/src/extension.rs | 61 +++++++++++-------- hugr-core/src/extension/declarative.rs | 7 ++- hugr-core/src/extension/op_def.rs | 2 +- hugr-core/src/extension/prelude.rs | 11 ++-- hugr-core/src/extension/simple_op.rs | 8 ++- hugr-core/src/hugr/rewrite/inline_dfg.rs | 6 +- hugr-core/src/hugr/validate/test.rs | 8 +-- hugr-core/src/ops/custom.rs | 6 +- hugr-core/src/package.rs | 9 +-- hugr-core/src/std_extensions.rs | 2 +- .../std_extensions/arithmetic/conversions.rs | 14 +++-- .../std_extensions/arithmetic/float_ops.rs | 12 ++-- .../std_extensions/arithmetic/float_types.rs | 6 +- .../src/std_extensions/arithmetic/int_ops.rs | 12 ++-- .../std_extensions/arithmetic/int_types.rs | 7 ++- hugr-core/src/std_extensions/collections.rs | 9 +-- hugr-core/src/std_extensions/logic.rs | 16 +++-- hugr-core/src/std_extensions/ptr.rs | 10 +-- hugr-core/src/types/poly_func.rs | 2 +- hugr-core/src/utils.rs | 10 +-- hugr-passes/src/merge_bbs.rs | 4 +- hugr/src/lib.rs | 10 +-- 24 files changed, 139 insertions(+), 99 deletions(-) diff --git a/.github/workflows/python-wheels.yml b/.github/workflows/python-wheels.yml index ec9e1b92e..c1f946e2b 100644 --- a/.github/workflows/python-wheels.yml +++ b/.github/workflows/python-wheels.yml @@ -74,4 +74,4 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1 with: verbose: true - skip-existing: true \ No newline at end of file + skip-existing: true diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index 32b81010e..cd9a3cb79 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -4,6 +4,8 @@ //! calling the CLI binary, which Miri doesn't support. #![cfg(all(test, not(miri)))] +use std::sync::Arc; + use assert_cmd::Command; use assert_fs::{fixture::FileWriteStr, NamedTempFile}; use hugr::builder::{DFGBuilder, DataflowSubContainer, ModuleBuilder}; @@ -49,7 +51,7 @@ fn test_package(#[default(BOOL_T)] id_type: Type) -> Package { let hugr = module.hugr().clone(); // unvalidated let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap(); - let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap(); + let float_ext: Arc = serde_json::from_reader(rdr).unwrap(); Package::new(vec![hugr], vec![float_ext]).unwrap() } diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index fe30eb5f5..817637493 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -38,11 +38,11 @@ pub mod declarative; /// Extension Registries store extensions to be looked up e.g. during validation. #[derive(Clone, Debug, PartialEq)] -pub struct ExtensionRegistry(BTreeMap); +pub struct ExtensionRegistry(BTreeMap>); impl ExtensionRegistry { /// Gets the Extension with the given name - pub fn get(&self, name: &str) -> Option<&Extension> { + pub fn get(&self, name: &str) -> Option<&Arc> { self.0.get(name) } @@ -51,9 +51,9 @@ impl ExtensionRegistry { self.0.contains_key(name) } - /// Makes a new ExtensionRegistry, validating all the extensions in it + /// Makes a new [ExtensionRegistry], validating all the extensions in it. pub fn try_new( - value: impl IntoIterator, + value: impl IntoIterator>, ) -> Result { let mut res = ExtensionRegistry(BTreeMap::new()); @@ -70,20 +70,28 @@ impl ExtensionRegistry { ext.validate(&res) .map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?; } + Ok(res) } /// Registers a new extension to the registry. /// /// Returns a reference to the registered extension if successful. - pub fn register(&mut self, extension: Extension) -> Result<&Extension, ExtensionRegistryError> { + pub fn register( + &mut self, + extension: impl Into>, + ) -> Result<(), ExtensionRegistryError> { + let extension = extension.into(); match self.0.entry(extension.name().clone()) { btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered( extension.name().clone(), prev.get().version().clone(), extension.version().clone(), )), - btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)), + btree_map::Entry::Vacant(ve) => { + ve.insert(extension); + Ok(()) + } } } @@ -93,21 +101,24 @@ impl ExtensionRegistry { /// If versions match, the original extension is kept. /// Returns a reference to the registered extension if successful. /// - /// Avoids cloning the extension unless required. For a reference version see + /// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, see /// [`ExtensionRegistry::register_updated_ref`]. pub fn register_updated( &mut self, - extension: Extension, - ) -> Result<&Extension, ExtensionRegistryError> { + extension: impl Into>, + ) -> Result<(), ExtensionRegistryError> { + let extension = extension.into(); match self.0.entry(extension.name().clone()) { btree_map::Entry::Occupied(mut prev) => { if prev.get().version() < extension.version() { *prev.get_mut() = extension; } - Ok(prev.into_mut()) } - btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)), + btree_map::Entry::Vacant(ve) => { + ve.insert(extension); + } } + Ok(()) } /// Registers a new extension to the registry, keeping most up to date if @@ -117,21 +128,23 @@ impl ExtensionRegistry { /// If versions match, the original extension is kept. Returns a reference /// to the registered extension if successful. /// - /// Clones the extension if required. For no-cloning version see + /// Clones the Arc only when required. For no-cloning version see /// [`ExtensionRegistry::register_updated`]. pub fn register_updated_ref( &mut self, - extension: &Extension, - ) -> Result<&Extension, ExtensionRegistryError> { + extension: &Arc, + ) -> Result<(), ExtensionRegistryError> { match self.0.entry(extension.name().clone()) { btree_map::Entry::Occupied(mut prev) => { if prev.get().version() < extension.version() { *prev.get_mut() = extension.clone(); } - Ok(prev.into_mut()) } - btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension.clone())), + btree_map::Entry::Vacant(ve) => { + ve.insert(extension.clone()); + } } + Ok(()) } /// Returns the number of extensions in the registry. @@ -145,20 +158,20 @@ impl ExtensionRegistry { } /// Returns an iterator over the extensions in the registry. - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator)> { self.0.iter() } /// Delete an extension from the registry and return it if it was present. - pub fn remove_extension(&mut self, name: &ExtensionId) -> Option { + pub fn remove_extension(&mut self, name: &ExtensionId) -> Option> { self.0.remove(name) } } impl IntoIterator for ExtensionRegistry { - type Item = (ExtensionId, Extension); + type Item = (ExtensionId, Arc); - type IntoIter = as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() @@ -646,10 +659,10 @@ pub mod test { let ext_1_id = ExtensionId::new("ext1").unwrap(); let ext_2_id = ExtensionId::new("ext2").unwrap(); - let ext1 = Extension::new(ext_1_id.clone(), Version::new(1, 0, 0)); - let ext1_1 = Extension::new(ext_1_id.clone(), Version::new(1, 1, 0)); - let ext1_2 = Extension::new(ext_1_id.clone(), Version::new(0, 2, 0)); - let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0)); + let ext1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 0, 0))); + let ext1_1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 1, 0))); + let ext1_2 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(0, 2, 0))); + let ext2 = Arc::new(Extension::new(ext_2_id, Version::new(1, 0, 0))); reg.register(ext1.clone()).unwrap(); reg_ref.register(ext1.clone()).unwrap(); diff --git a/hugr-core/src/extension/declarative.rs b/hugr-core/src/extension/declarative.rs index cb98e8215..c81414c9f 100644 --- a/hugr-core/src/extension/declarative.rs +++ b/hugr-core/src/extension/declarative.rs @@ -136,8 +136,8 @@ impl ExtensionSetDeclaration { registry, }; let ext = decl.make_extension(&self.imports, ctx)?; - let ext = registry.register(ext)?; - scope.insert(ext.name()) + scope.insert(ext.name()); + registry.register(ext)?; } Ok(()) @@ -272,6 +272,7 @@ mod test { use itertools::Itertools; use rstest::rstest; use std::path::PathBuf; + use std::sync::Arc; use crate::extension::PRELUDE_REGISTRY; use crate::std_extensions; @@ -406,7 +407,7 @@ extensions: fn new_extensions<'a>( reg: &'a ExtensionRegistry, dependencies: &'a ExtensionRegistry, - ) -> impl Iterator { + ) -> impl Iterator)> { reg.iter() .filter(move |(id, _)| !dependencies.contains(id) && *id != &PRELUDE_ID) } diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 6c1a49d9e..d3b52eed2 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -611,7 +611,7 @@ pub(super) mod test { assert_eq!(def.misc.len(), 1); let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned(), e]).unwrap(); + ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), e.into()]).unwrap(); let e = reg.get(&EXT_ID).unwrap(); let list_usize = diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index ca338eae3..e649c2686 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -1,5 +1,7 @@ //! Prelude extension - available in all contexts, defining common types, //! operations and constants. +use std::sync::Arc; + use itertools::Itertools; use lazy_static::lazy_static; @@ -34,7 +36,7 @@ pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude"); /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); lazy_static! { - static ref PRELUDE_DEF: Extension = { + static ref PRELUDE_DEF: Arc = { let mut prelude = Extension::new(PRELUDE_ID, VERSION); prelude .add_type( @@ -101,14 +103,15 @@ lazy_static! { NoopDef.add_to_extension(&mut prelude).unwrap(); LiftDef.add_to_extension(&mut prelude).unwrap(); array::ArrayOpDef::load_all_ops(&mut prelude).unwrap(); - prelude + + Arc::new(prelude) }; /// An extension registry containing only the prelude pub static ref PRELUDE_REGISTRY: ExtensionRegistry = - ExtensionRegistry::try_new([PRELUDE_DEF.to_owned()]).unwrap(); + ExtensionRegistry::try_new([PRELUDE_DEF.clone()]).unwrap(); /// Prelude extension - pub static ref PRELUDE: &'static Extension = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap(); + pub static ref PRELUDE: Arc = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap().clone(); } diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index a21f68224..c338a693d 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -272,6 +272,8 @@ impl From for OpType { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::{const_extension_ids, type_row, types::Signature}; use super::*; @@ -313,13 +315,13 @@ mod test { } lazy_static! { - static ref EXT: Extension = { + static ref EXT: Arc = { let mut e = Extension::new_test(EXT_ID.clone()); DummyEnum::Dumb.add_to_extension(&mut e).unwrap(); - e + Arc::new(e) }; static ref DUMMY_REG: ExtensionRegistry = - ExtensionRegistry::try_new([EXT.to_owned()]).unwrap(); + ExtensionRegistry::try_new([EXT.clone()]).unwrap(); } impl MakeRegisteredOp for DummyEnum { fn extension_id(&self) -> ExtensionId { diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/rewrite/inline_dfg.rs index ca3f39cd3..1ef5adf98 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/rewrite/inline_dfg.rs @@ -257,9 +257,9 @@ mod test { let [q, p] = swap.outputs_arr(); let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?; let reg = ExtensionRegistry::try_new([ - test_quantum_extension::EXTENSION.to_owned(), - PRELUDE.to_owned(), - float_types::EXTENSION.to_owned(), + test_quantum_extension::EXTENSION.clone(), + PRELUDE.clone(), + float_types::EXTENSION.clone(), ]) .unwrap(); diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 705649b36..cf934e18b 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -386,7 +386,7 @@ fn invalid_types() { TypeDefBound::any(), ) .unwrap(); - let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()]).unwrap(); + let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()]).unwrap(); let validate_to_sig_error = |t: CustomType| { let (h, def) = identity_hugr_with_type(Type::new_extension(t)); @@ -643,7 +643,7 @@ fn instantiate_row_variables() -> Result<(), Box> { let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?; dfb.finish_hugr_with_outputs( eval2.outputs(), - &ExtensionRegistry::try_new([PRELUDE.to_owned(), e]).unwrap(), + &ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(), )?; Ok(()) } @@ -683,7 +683,7 @@ fn row_variables() -> Result<(), Box> { let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs( par_func.outputs(), - &ExtensionRegistry::try_new([PRELUDE.to_owned(), e]).unwrap(), + &ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(), )?; Ok(()) } @@ -763,7 +763,7 @@ fn test_polymorphic_call() -> Result<(), Box> { f.finish_with_outputs([tup])? }; - let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()])?; + let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()])?; let [func, tup] = d.input_wires_arr(); let call = d.call( f.handle(), diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 6d9a3b2b2..eec5f4d34 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -46,7 +46,7 @@ impl ExtensionOp { args: impl Into>, exts: &ExtensionRegistry, ) -> Result { - let args = args.into(); + let args: Vec = args.into(); let signature = def.compute_signature(&args, exts)?; Ok(Self { def, @@ -62,7 +62,7 @@ impl ExtensionOp { opaque: &OpaqueOp, exts: &ExtensionRegistry, ) -> Result { - let args = args.into(); + let args: Vec = args.into(); // TODO skip computation depending on config // see https://github.com/CQCL/hugr/issues/1363 let signature = match def.compute_signature(&args, exts) { @@ -421,7 +421,7 @@ mod test { SignatureFunc::MissingComputeFunc, ) .unwrap(); - let registry = ExtensionRegistry::try_new([ext]).unwrap(); + let registry = ExtensionRegistry::try_new([ext.into()]).unwrap(); let opaque_val = OpaqueOp::new( ext_id.clone(), val_name, diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index 291b61ebe..cec1b2c85 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -3,6 +3,7 @@ use derive_more::{Display, Error, From}; use std::collections::HashMap; use std::path::Path; +use std::sync::Arc; use std::{fs, io, mem}; use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder}; @@ -19,7 +20,7 @@ pub struct Package { /// Module HUGRs included in the package. pub modules: Vec, /// Extensions to validate against. - pub extensions: Vec, + pub extensions: Vec>, } impl Package { @@ -32,7 +33,7 @@ impl Package { /// Returns an error if any of the HUGRs does not have a `Module` root. pub fn new( modules: impl IntoIterator, - extensions: impl IntoIterator, + extensions: impl IntoIterator>, ) -> Result { let modules: Vec = modules.into_iter().collect(); for (idx, module) in modules.iter().enumerate() { @@ -62,7 +63,7 @@ impl Package { /// Returns an error if any of the HUGRs cannot be wrapped in a module. pub fn from_hugrs( modules: impl IntoIterator, - extensions: impl IntoIterator, + extensions: impl IntoIterator>, ) -> Result { let modules: Vec = modules .into_iter() @@ -378,7 +379,7 @@ mod test { Package { modules: vec![hugr0, hugr1], - extensions: vec![ext1, ext2], + extensions: vec![ext1.into(), ext2.into()], } } diff --git a/hugr-core/src/std_extensions.rs b/hugr-core/src/std_extensions.rs index a6738ddff..2896b5b7a 100644 --- a/hugr-core/src/std_extensions.rs +++ b/hugr-core/src/std_extensions.rs @@ -12,7 +12,7 @@ pub mod ptr; /// Extension registry with all standard extensions and prelude. pub fn std_reg() -> ExtensionRegistry { ExtensionRegistry::try_new([ - crate::extension::prelude::PRELUDE.to_owned(), + crate::extension::prelude::PRELUDE.clone(), arithmetic::int_ops::EXTENSION.to_owned(), arithmetic::int_types::EXTENSION.to_owned(), arithmetic::conversions::EXTENSION.to_owned(), diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index 2db324312..deb93f8c2 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -1,5 +1,7 @@ //! Conversions between integer and floating-point values. +use std::sync::Arc; + use strum_macros::{EnumIter, EnumString, IntoStaticStr}; use crate::extension::prelude::{BOOL_T, STRING_TYPE, USIZE_T}; @@ -155,7 +157,7 @@ impl MakeExtensionOp for ConvertOpType { lazy_static! { /// Extension for conversions between integers and floats. - pub static ref EXTENSION: Extension = { + pub static ref EXTENSION: Arc = { let mut extension = Extension::new( EXTENSION_ID, VERSION).with_reqs( @@ -167,15 +169,15 @@ lazy_static! { ConvertOpDef::load_all_ops(&mut extension).unwrap(); - extension + Arc::new(extension) }; /// Registry of extensions required to validate integer operations. pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - super::int_types::EXTENSION.to_owned(), - super::float_types::EXTENSION.to_owned(), - EXTENSION.to_owned(), + PRELUDE.clone(), + super::int_types::EXTENSION.clone(), + super::float_types::EXTENSION.clone(), + EXTENSION.clone(), ]) .unwrap(); } diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index 8ef8850a8..7d353e71a 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -1,5 +1,7 @@ //! Basic floating-point operations. +use std::sync::Arc; + use strum_macros::{EnumIter, EnumString, IntoStaticStr}; use super::float_types::FLOAT64_TYPE; @@ -104,7 +106,7 @@ impl MakeOpDef for FloatOps { lazy_static! { /// Extension for basic float operations. - pub static ref EXTENSION: Extension = { + pub static ref EXTENSION: Arc = { let mut extension = Extension::new( EXTENSION_ID, VERSION).with_reqs( @@ -113,14 +115,14 @@ lazy_static! { FloatOps::load_all_ops(&mut extension).unwrap(); - extension + Arc::new(extension) }; /// Registry of extensions required to validate float operations. pub static ref FLOAT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - super::float_types::EXTENSION.to_owned(), - EXTENSION.to_owned(), + PRELUDE.clone(), + super::float_types::EXTENSION.clone(), + EXTENSION.clone(), ]) .unwrap(); } diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index a046ebe0e..ec145008f 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -1,5 +1,7 @@ //! Basic floating-point types +use std::sync::Arc; + use crate::ops::constant::{TryHash, ValueName}; use crate::types::TypeName; use crate::{ @@ -79,7 +81,7 @@ impl CustomConst for ConstF64 { lazy_static! { /// Extension defining the float type. - pub static ref EXTENSION: Extension = { + pub static ref EXTENSION: Arc = { let mut extension = Extension::new(EXTENSION_ID, VERSION); extension @@ -91,7 +93,7 @@ lazy_static! { ) .unwrap(); - extension + Arc::new(extension) }; } #[cfg(test)] diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 51d3e3885..97bb247a2 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -1,5 +1,7 @@ //! Basic integer operations. +use std::sync::Arc; + use super::int_types::{get_log_width, int_tv, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{sum_with_error, BOOL_T}; use crate::extension::simple_op::{ @@ -247,7 +249,7 @@ fn iunop_sig() -> PolyFuncTypeRV { lazy_static! { /// Extension for basic integer operations. - pub static ref EXTENSION: Extension = { + pub static ref EXTENSION: Arc = { let mut extension = Extension::new( EXTENSION_ID, VERSION).with_reqs( @@ -256,14 +258,14 @@ lazy_static! { IntOpDef::load_all_ops(&mut extension).unwrap(); - extension + Arc::new(extension) }; /// Registry of extensions required to validate integer operations. pub static ref INT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - super::int_types::EXTENSION.to_owned(), - EXTENSION.to_owned(), + PRELUDE.clone(), + super::int_types::EXTENSION.clone(), + EXTENSION.clone(), ]) .unwrap(); } diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 522f8b2b9..3d257b9d0 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -1,6 +1,7 @@ //! Basic integer types use std::num::NonZeroU64; +use std::sync::Arc; use crate::ops::constant::ValueName; use crate::types::TypeName; @@ -186,7 +187,7 @@ impl CustomConst for ConstInt { } /// Extension for basic integer types. -pub fn extension() -> Extension { +pub fn extension() -> Arc { let mut extension = Extension::new(EXTENSION_ID, VERSION); extension @@ -198,12 +199,12 @@ pub fn extension() -> Extension { ) .unwrap(); - extension + Arc::new(extension) } lazy_static! { /// Lazy reference to int types extension. - pub static ref EXTENSION: Extension = extension(); + pub static ref EXTENSION: Arc = extension(); } /// get an integer type with width corresponding to a type variable with id `var_id` diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index d00a6f21d..17a1b0d03 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -5,6 +5,7 @@ use std::hash::{Hash, Hasher}; mod list_fold; use std::str::FromStr; +use std::sync::Arc; use itertools::Itertools; use lazy_static::lazy_static; @@ -249,7 +250,7 @@ impl MakeOpDef for ListOp { lazy_static! { /// Extension for list operations. - pub static ref EXTENSION: Extension = { + pub static ref EXTENSION: Arc = { let mut extension = Extension::new(EXTENSION_ID, VERSION); // The list type must be defined before the operations are added. @@ -263,13 +264,13 @@ lazy_static! { ListOp::load_all_ops(&mut extension).unwrap(); - extension + Arc::new(extension) }; /// Registry of extensions required to validate list operations. pub static ref COLLECTIONS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - EXTENSION.to_owned(), + PRELUDE.clone(), + EXTENSION.clone(), ]) .unwrap(); } diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index fbc068672..89f9dfa8b 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -1,5 +1,7 @@ //! Basic logical operations. +use std::sync::Arc; + use strum_macros::{EnumIter, EnumString, IntoStaticStr}; use crate::extension::{ConstFold, ConstFoldResult}; @@ -107,7 +109,7 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic"); pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); /// Extension for basic logical operations. -fn extension() -> Extension { +fn extension() -> Arc { let mut extension = Extension::new(EXTENSION_ID, VERSION); LogicOp::load_all_ops(&mut extension).unwrap(); @@ -117,15 +119,15 @@ fn extension() -> Extension { extension .add_value(TRUE_NAME, ops::Value::true_val()) .unwrap(); - extension + Arc::new(extension) } lazy_static! { /// Reference to the logic Extension. - pub static ref EXTENSION: Extension = extension(); + pub static ref EXTENSION: Arc = extension(); /// Registry required to validate logic extension. pub static ref LOGIC_REG: ExtensionRegistry = - ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap(); + ExtensionRegistry::try_new([EXTENSION.clone()]).unwrap(); } impl MakeRegisteredOp for LogicOp { @@ -159,6 +161,8 @@ fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option> { #[cfg(test)] pub(crate) mod test { + use std::sync::Arc; + use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME}; use crate::{ extension::{ @@ -174,7 +178,7 @@ pub(crate) mod test { #[test] fn test_logic_extension() { - let r: Extension = extension(); + let r: Arc = extension(); assert_eq!(r.name() as &str, "logic"); assert_eq!(r.operations().count(), 4); @@ -196,7 +200,7 @@ pub(crate) mod test { #[test] fn test_values() { - let r: Extension = extension(); + let r: Arc = extension(); let false_val = r.get_value(&FALSE_NAME).unwrap(); let true_val = r.get_value(&TRUE_NAME).unwrap(); diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index afa840699..e3023e4b5 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -1,5 +1,7 @@ //! Pointer type and operations. +use std::sync::Arc; + use strum_macros::{EnumIter, EnumString, IntoStaticStr}; use crate::builder::{BuildError, Dataflow}; @@ -84,7 +86,7 @@ const TYPE_PARAMS: [TypeParam; 1] = [TypeParam::Type { pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); /// Extension for pointer operations. -fn extension() -> Extension { +fn extension() -> Arc { let mut extension = Extension::new(EXTENSION_ID, VERSION); extension .add_type( @@ -95,15 +97,15 @@ fn extension() -> Extension { ) .unwrap(); PtrOpDef::load_all_ops(&mut extension).unwrap(); - extension + Arc::new(extension) } lazy_static! { /// Reference to the pointer Extension. - pub static ref EXTENSION: Extension = extension(); + pub static ref EXTENSION: Arc = extension(); /// Registry required to validate pointer extension. pub static ref PTR_REG: ExtensionRegistry = - ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap(); + ExtensionRegistry::try_new([EXTENSION.clone()]).unwrap(); } /// Integer type of a given bit width (specified by the TypeArg). Depending on diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index b168556bf..7e3f4f664 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -330,7 +330,7 @@ pub(crate) mod test { ) .unwrap(); - let reg = ExtensionRegistry::try_new([e]).unwrap(); + let reg = ExtensionRegistry::try_new([e.into()]).unwrap(); let make_scheme = |tp: TypeParam| { PolyFuncTypeBase::new_validated( diff --git a/hugr-core/src/utils.rs b/hugr-core/src/utils.rs index 72b99fd76..f1f97cc5a 100644 --- a/hugr-core/src/utils.rs +++ b/hugr-core/src/utils.rs @@ -103,6 +103,8 @@ pub(crate) fn is_default(t: &T) -> bool { #[cfg(test)] pub(crate) mod test_quantum_extension { + use std::sync::Arc; + use crate::ops::{OpName, OpNameRef}; use crate::types::FuncValueType; use crate::{ @@ -128,7 +130,7 @@ pub(crate) mod test_quantum_extension { } /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum"); - fn extension() -> Extension { + fn extension() -> Arc { let mut extension = Extension::new_test(EXTENSION_ID); extension @@ -170,13 +172,13 @@ pub(crate) mod test_quantum_extension { ) .unwrap(); - extension + Arc::new(extension) } lazy_static! { /// Quantum extension definition. - pub static ref EXTENSION: Extension = extension(); - static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap(); + pub static ref EXTENSION: Arc = extension(); + static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.clone(), PRELUDE.clone(), float_types::EXTENSION.clone()]).unwrap(); } diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 51fd07d9b..5df428671 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -228,7 +228,7 @@ mod test { let exit_types = type_row![USIZE_T]; let e = extension(); let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?; - let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), e])?; + let reg = ExtensionRegistry::try_new([PRELUDE.clone(), e.into()])?; let mut h = CFGBuilder::new(inout_sig(loop_variants.clone(), exit_types.clone()))?; let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; let n = no_b1.add_dataflow_op(Noop::new(QB_T), no_b1.input_wires())?; @@ -355,7 +355,7 @@ mod test { h.branch(&bb2, 0, &bb3)?; h.branch(&bb3, 0, &h.exit_block())?; - let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()])?; + let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()])?; let mut h = h.finish_hugr(®)?; let root = h.root(); merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index e4fa3ee99..8c59220ce 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -47,6 +47,7 @@ //! Extension, //! }; //! +//! use std::sync::Arc; //! use lazy_static::lazy_static; //! //! fn one_qb_func() -> PolyFuncTypeRV { @@ -59,7 +60,7 @@ //! /// The extension identifier. //! pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("mini.quantum"); //! pub const VERSION: Version = Version::new(0, 1, 0); -//! fn extension() -> Extension { +//! fn extension() -> Arc { //! let mut extension = Extension::new(EXTENSION_ID, VERSION); //! //! extension @@ -78,15 +79,14 @@ //! ) //! .unwrap(); //! -//! extension +//! Arc::new(extension) //! } //! //! lazy_static! { //! /// Quantum extension definition. -//! pub static ref EXTENSION: Extension = extension(); +//! pub static ref EXTENSION: Arc = extension(); //! static ref REG: ExtensionRegistry = -//! ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned()]).unwrap(); -//! +//! ExtensionRegistry::try_new([EXTENSION.clone(), PRELUDE.clone()]).unwrap(); //! } //! fn get_gate(gate_name: impl Into) -> ExtensionOp { //! EXTENSION