Skip to content

Commit

Permalink
feat!: Extension requires a version (#1367)
Browse files Browse the repository at this point in the history
Closes #1357 

BREAKING CHANGE: `Extension::new` now requires a `Version` argument.
`Extension::new_with_reqs` replaced with `with_reqs` which takes `self`
(chain with `new`).
  • Loading branch information
ss2165 authored Jul 29, 2024
1 parent c72e288 commit b2d4013
Show file tree
Hide file tree
Showing 19 changed files with 150 additions and 42 deletions.
1 change: 1 addition & 0 deletions hugr-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ delegate = { workspace = true }
paste = { workspace = true }
strum = { workspace = true }
strum_macros = { workspace = true }
semver = { version = "1.0.23", features = ["serde"] }

[dev-dependencies]
rstest = { workspace = true }
Expand Down
118 changes: 100 additions & 18 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! TODO: YAML declaration and parsing. This should be similar to a plugin
//! system (outside the `types` module), which also parses nested [`OpDef`]s.

pub use semver::Version;
use std::collections::btree_map;
use std::collections::hash_map;
use std::collections::{BTreeMap, BTreeSet, HashMap};
Expand Down Expand Up @@ -55,21 +56,17 @@ impl ExtensionRegistry {
pub fn try_new(
value: impl IntoIterator<Item = Extension>,
) -> Result<Self, ExtensionRegistryError> {
let mut exts = BTreeMap::new();
let mut res = ExtensionRegistry(BTreeMap::new());

for ext in value.into_iter() {
let prev = exts.insert(ext.name.clone(), ext);
if let Some(prev) = prev {
return Err(ExtensionRegistryError::AlreadyRegistered(
prev.name().clone(),
));
};
res.register(ext)?;
}

// Note this potentially asks extensions to validate themselves against other extensions that
// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
// or at least to validate the types first - which we don't do at all yet:
// TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be
// cyclically dependent, so there is no perfect solution, and this is at least simple.
let res = ExtensionRegistry(exts);
for ext in res.0.values() {
ext.validate(&res)
.map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
Expand All @@ -82,13 +79,35 @@ impl ExtensionRegistry {
/// Returns a reference to the registered extension if successful.
pub fn register(&mut self, extension: Extension) -> Result<&Extension, ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(_) => Err(ExtensionRegistryError::AlreadyRegistered(
btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
extension.name().clone(),
prev.get().version().clone(),
extension.version().clone(),
)),
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
}
}

/// Registers a new extension to the registry, keeping most up to date if extension exists.
///
/// If extension IDs match, the extension with the higher version is kept.
/// If versions match, the original extension is kept.
/// Returns a reference to the registered extension if successful.
pub fn register_updated(
&mut self,
extension: Extension,
) -> Result<&Extension, ExtensionRegistryError> {
match self.0.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension;
}
Ok(prev.into_mut())
}
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
}
}

/// Returns the number of extensions in the registry.
pub fn len(&self) -> usize {
self.0.len()
Expand All @@ -103,6 +122,11 @@ impl ExtensionRegistry {
pub fn iter(&self) -> impl Iterator<Item = (&ExtensionId, &Extension)> {
self.0.iter()
}

/// Delete an extension from the registry and return it if it was present.
pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Extension> {
self.0.remove(name)
}
}

impl IntoIterator for ExtensionRegistry {
Expand Down Expand Up @@ -264,6 +288,8 @@ pub type ExtensionId = IdentList;
/// A extension is a set of capabilities required to execute a graph.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Extension {
/// Extension version, follows semver.
pub version: Version,
/// Unique identifier for the extension.
pub name: ExtensionId,
/// Other extensions defining types used by this extension.
Expand All @@ -286,21 +312,25 @@ pub struct Extension {

impl Extension {
/// Creates a new extension with the given name.
pub fn new(name: ExtensionId) -> Self {
Self::new_with_reqs(name, ExtensionSet::default())
}

/// Creates a new extension with the given name and requirements.
pub fn new_with_reqs(name: ExtensionId, extension_reqs: impl Into<ExtensionSet>) -> Self {
pub fn new(name: ExtensionId, version: Version) -> Self {
Self {
name,
extension_reqs: extension_reqs.into(),
version,
extension_reqs: Default::default(),
types: Default::default(),
values: Default::default(),
operations: Default::default(),
}
}

/// Extend the requirements of this extension with another set of extensions.
pub fn with_reqs(self, extension_reqs: impl Into<ExtensionSet>) -> Self {
Self {
extension_reqs: self.extension_reqs.union(extension_reqs.into()),
..self
}
}

/// Allows read-only access to the operations in this Extension
pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc<op_def::OpDef>> {
self.operations.get(name)
Expand All @@ -321,6 +351,11 @@ impl Extension {
&self.name
}

/// Returns the version of the extension.
pub fn version(&self) -> &Version {
&self.version
}

/// Iterator over the operations of this [`Extension`].
pub fn operations(&self) -> impl Iterator<Item = (&OpName, &Arc<OpDef>)> {
self.operations.iter()
Expand Down Expand Up @@ -382,8 +417,8 @@ impl PartialEq for Extension {
#[derive(Debug, Clone, Error, PartialEq, Eq)]
pub enum ExtensionRegistryError {
/// Extension already defined.
#[error("The registry already contains an extension with id {0}.")]
AlreadyRegistered(ExtensionId),
#[error("The registry already contains an extension with id {0} and version {1}. New extension has version {2}.")]
AlreadyRegistered(ExtensionId, Version, Version),
/// A registered extension has invalid signatures.
#[error("The extension {0} contains an invalid signature, {1}.")]
InvalidSignature(ExtensionId, #[source] SignatureError),
Expand Down Expand Up @@ -544,6 +579,53 @@ pub mod test {
// We re-export this here because mod op_def is private.
pub use super::op_def::test::SimpleOpDef;

use super::*;

impl Extension {
/// Create a new extension for testing, with a 0 version.
pub(crate) fn new_test(name: ExtensionId) -> Self {
Self::new(name, Version::new(0, 0, 0))
}
}

#[test]
fn test_register_update() {
let mut reg = ExtensionRegistry::try_new([]).unwrap();
let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
let ext1 = Extension::new(ext_1_id.clone(), Version::new(1, 0, 0));
let ext1_1 = Extension::new(ext_1_id.clone(), Version::new(1, 1, 0));
let ext1_2 = Extension::new(ext_1_id.clone(), Version::new(0, 2, 0));
let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0));

reg.register(ext1.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 0, 0));

// normal registration fails
assert_eq!(
reg.register(ext1_1.clone()),
Err(ExtensionRegistryError::AlreadyRegistered(
ext_1_id.clone(),
Version::new(1, 0, 0),
Version::new(1, 1, 0)
))
);

// register with update works
reg.register_updated(ext1_1.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));

// register with lower version does not change version
reg.register_updated(ext1_2.clone()).unwrap();
assert_eq!(reg.get("ext1").unwrap().version(), &Version::new(1, 1, 0));

reg.register(ext2.clone()).unwrap();
assert_eq!(reg.get("ext2").unwrap().version(), &Version::new(1, 0, 0));
assert_eq!(reg.len(), 2);

assert!(reg.remove_extension(&ext_1_id).unwrap().version() == &Version::new(1, 1, 0));
assert_eq!(reg.len(), 1);
}
mod proptest {

use ::proptest::{collection::hash_set, prelude::*};
Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/extension/declarative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ struct ExtensionSetDeclaration {
/// A declarative extension definition.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
struct ExtensionDeclaration {
// TODO add version
/// The name of the extension.
name: ExtensionId,
/// A list of types that this extension provides.
Expand Down Expand Up @@ -150,7 +151,8 @@ impl ExtensionDeclaration {
imports: &ExtensionSet,
ctx: DeclarationContext<'_>,
) -> Result<Extension, ExtensionDeclarationError> {
let mut ext = Extension::new_with_reqs(self.name.clone(), imports.clone());
let mut ext = Extension::new(self.name.clone(), crate::extension::Version::new(0, 0, 0))
.with_reqs(imports.clone());

for t in &self.types {
t.register(&mut ext, ctx)?;
Expand Down
8 changes: 4 additions & 4 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ pub(super) mod test {
#[test]
fn op_def_with_type_scheme() -> Result<(), Box<dyn std::error::Error>> {
let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
let list_of_var =
Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
Expand Down Expand Up @@ -658,7 +658,7 @@ pub(super) mod test {
MAX_NAT
}
}
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);
let def: &mut crate::extension::OpDef =
e.add_op("MyOp".into(), "".to_string(), SigFun())?;

Expand Down Expand Up @@ -720,7 +720,7 @@ pub(super) mod test {
fn type_scheme_instantiate_var() -> Result<(), Box<dyn std::error::Error>> {
// Check that we can instantiate a PolyFuncTypeRV-scheme with an (external)
// type variable
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);
let def = e.add_op(
"SimpleOp".into(),
"".into(),
Expand Down Expand Up @@ -755,7 +755,7 @@ pub(super) mod test {
fn instantiate_extension_delta() -> Result<(), Box<dyn std::error::Error>> {
use crate::extension::prelude::{BOOL_T, PRELUDE_REGISTRY};

let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);

let params: Vec<TypeParam> = vec![TypeParam::Extensions];
let db_set = ExtensionSet::type_var(0);
Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ impl SignatureFromArgs for GenericOpCustom {

/// Name of prelude extension.
pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
lazy_static! {
static ref PRELUDE_DEF: Extension = {
let mut prelude = Extension::new(PRELUDE_ID);
let mut prelude = Extension::new(PRELUDE_ID, VERSION);
prelude
.add_type(
TypeName::new_inline("usize"),
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/extension/simple_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ mod test {

lazy_static! {
static ref EXT: Extension = {
let mut e = Extension::new(EXT_ID.clone());
let mut e = Extension::new_test(EXT_ID.clone());
DummyEnum::Dumb.add_to_extension(&mut e).unwrap();
e
};
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ const_extension_ids! {
}
#[test]
fn invalid_types() {
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);
e.add_type(
"MyContainer".into(),
vec![TypeBound::Copyable.into()],
Expand Down Expand Up @@ -570,7 +570,7 @@ fn no_polymorphic_consts() -> Result<(), Box<dyn std::error::Error>> {

pub(crate) fn extension_with_eval_parallel() -> Extension {
let rowp = TypeParam::new_list(TypeBound::Any);
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);

let inputs = TypeRV::new_row_var_use(0, TypeBound::Any);
let outputs = TypeRV::new_row_var_use(1, TypeBound::Any);
Expand Down Expand Up @@ -671,7 +671,7 @@ fn row_variables() -> Result<(), Box<dyn std::error::Error>> {

#[test]
fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
let mut e = Extension::new(EXT_ID);
let mut e = Extension::new_test(EXT_ID);

let params: Vec<TypeParam> = vec![
TypeBound::Any.into(),
Expand Down
5 changes: 4 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use lazy_static::lazy_static;
mod const_fold;
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);

/// Extension for conversions between floats and integers.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
Expand Down Expand Up @@ -121,8 +123,9 @@ impl MakeExtensionOp for ConvertOpType {
lazy_static! {
/// Extension for conversions between integers and floats.
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new_with_reqs(
let mut extension = Extension::new(
EXTENSION_ID,
VERSION).with_reqs(
ExtensionSet::from_iter(vec![
super::int_types::EXTENSION_ID,
super::float_types::EXTENSION_ID,
Expand Down
5 changes: 4 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use lazy_static::lazy_static;
mod const_fold;
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);

/// Integer extension operation definitions.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
Expand Down Expand Up @@ -99,8 +101,9 @@ impl MakeOpDef for FloatOps {
lazy_static! {
/// Extension for basic float operations.
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new_with_reqs(
let mut extension = Extension::new(
EXTENSION_ID,
VERSION).with_reqs(
ExtensionSet::singleton(&super::int_types::EXTENSION_ID),
);

Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use lazy_static::lazy_static;

/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float.types");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);

/// Identifier for the 64-bit IEEE 754-2019 floating-point type.
const FLOAT_TYPE_ID: TypeName = TypeName::new_inline("float64");
Expand Down Expand Up @@ -76,7 +78,7 @@ impl CustomConst for ConstF64 {
lazy_static! {
/// Extension defining the float type.
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new(EXTENSION_ID);
let mut extension = Extension::new(EXTENSION_ID, VERSION);

extension
.add_type(
Expand Down
7 changes: 5 additions & 2 deletions hugr-core/src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ mod const_fold;

/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int");
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);

struct IOValidator {
// whether the first type argument should be greater than or equal to the second
Expand Down Expand Up @@ -261,9 +263,10 @@ fn iunop_sig() -> PolyFuncTypeRV {
lazy_static! {
/// Extension for basic integer operations.
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new_with_reqs(
let mut extension = Extension::new(
EXTENSION_ID,
ExtensionSet::singleton(&super::int_types::EXTENSION_ID),
VERSION).with_reqs(
ExtensionSet::singleton(&super::int_types::EXTENSION_ID)
);

IntOpDef::load_all_ops(&mut extension).unwrap();
Expand Down
Loading

0 comments on commit b2d4013

Please sign in to comment.