Skip to content

Commit

Permalink
feat!: Serialised extensions (#1371)
Browse files Browse the repository at this point in the history
Fear not the diff, mostly schema update.


- Define a Pydantic model and serialised schema for extensions.
- Update the rust `SignatureFunc` serialisation to be compatible with
this.
- Serialized extension "declarations" can state they require binary
compute or validation functions. `SignatureFunc` reports this if they
are missing, and it is up to the caller (typically validation) to decide
how to recover from this.

Closes #1360 
Closes #1361

Addresses parts of #1228 

Best reviewed as individual commits.

BREAKING CHANGE: `TypeDefBound` uses struct-variants for serialization.
`SignatureFunc` now has variants for missing binary functions, and
serializes in to a new format that indicates expected binaries.
  • Loading branch information
ss2165 authored Jul 30, 2024
1 parent 5ee2e04 commit 31be204
Show file tree
Hide file tree
Showing 25 changed files with 1,864 additions and 525 deletions.
10 changes: 10 additions & 0 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ pub enum SignatureError {
cached: Signature,
expected: Signature,
},

/// Extension declaration specifies a binary compute signature function, but none
/// was loaded.
#[error("Binary compute signature function not loaded.")]
MissingComputeFunc,

/// Extension declaration specifies a binary compute signature function, but none
/// was loaded.
#[error("Binary validate signature function not loaded.")]
MissingValidateFunc,
}

/// Concrete instantiations of types and operations defined in extensions.
Expand Down
8 changes: 6 additions & 2 deletions hugr-core/src/extension/declarative/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ enum TypeDefBoundDeclaration {
impl From<TypeDefBoundDeclaration> for TypeDefBound {
fn from(bound: TypeDefBoundDeclaration) -> Self {
match bound {
TypeDefBoundDeclaration::Copyable => Self::Explicit(TypeBound::Copyable),
TypeDefBoundDeclaration::Any => Self::Explicit(TypeBound::Any),
TypeDefBoundDeclaration::Copyable => Self::Explicit {
bound: TypeBound::Copyable,
},
TypeDefBoundDeclaration::Any => Self::Explicit {
bound: TypeBound::Any,
},
}
}
}
Expand Down
121 changes: 59 additions & 62 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
use crate::Hugr;
mod serialize_signature_func;

/// Trait necessary for binary computations of OpDef signature
pub trait CustomSignatureFunc: Send + Sync {
Expand Down Expand Up @@ -114,29 +115,19 @@ pub trait CustomLowerFunc: Send + Sync {
) -> Option<Hugr>;
}

/// Encode a signature as `PolyFuncTypeRV` but optionally allow validating type
/// Encode a signature as [PolyFuncTypeRV] but with additional validation of type
/// arguments via a custom binary. The binary cannot be serialized so will be
/// lost over a serialization round-trip.
#[derive(serde::Deserialize, serde::Serialize)]
pub struct CustomValidator {
#[serde(flatten)]
poly_func: PolyFuncTypeRV,
#[serde(skip)]
/// Custom function for validating type arguments before returning the signature.
pub(crate) validate: Box<dyn ValidateTypeArgs>,
}

impl CustomValidator {
/// Encode a signature using a `PolyFuncTypeRV`
pub fn from_polyfunc(poly_func: impl Into<PolyFuncTypeRV>) -> Self {
Self {
poly_func: poly_func.into(),
validate: Default::default(),
}
}

/// Encode a signature using a `PolyFuncTypeRV`, with a custom function for
/// validating type arguments before returning the signature.
pub fn new_with_validator(
pub fn new(
poly_func: impl Into<PolyFuncTypeRV>,
validate: impl ValidateTypeArgs + 'static,
) -> Self {
Expand All @@ -147,37 +138,19 @@ impl CustomValidator {
}
}

/// The two ways in which an OpDef may compute the Signature of each operation node.
#[derive(serde::Deserialize, serde::Serialize)]
/// The ways in which an OpDef may compute the Signature of each operation node.
pub enum SignatureFunc {
// Note: except for serialization, we could have type schemes just implement the same
// CustomSignatureFunc trait too, and replace this enum with Box<dyn CustomSignatureFunc>.
// However instead we treat all CustomFunc's as non-serializable.
/// A PolyFuncType (polymorphic function type), with optional custom
/// validation for provided type arguments,
#[serde(rename = "signature")]
PolyFuncType(CustomValidator),
#[serde(skip)]
/// An explicit polymorphic function type.
PolyFuncType(PolyFuncTypeRV),
/// A polymorphic function type (like [Self::PolyFuncType] but also with a custom binary for validating type arguments.
CustomValidator(CustomValidator),
/// Serialized declaration specified a custom validate binary but it was not provided.
MissingValidateFunc(PolyFuncTypeRV),
/// A custom binary which computes a polymorphic function type given values
/// for its static type parameters.
CustomFunc(Box<dyn CustomSignatureFunc>),
}
struct NoValidate;
impl ValidateTypeArgs for NoValidate {
fn validate<'o, 'a: 'o>(
&self,
_arg_values: &[TypeArg],
_def: &'o OpDef,
_extension_registry: &ExtensionRegistry,
) -> Result<(), SignatureError> {
Ok(())
}
}

impl Default for Box<dyn ValidateTypeArgs> {
fn default() -> Self {
Box::new(NoValidate)
}
/// Serialized declaration specified a custom compute binary but it was not provided.
MissingComputeFunc,
}

impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {
Expand All @@ -188,39 +161,50 @@ impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {

impl From<PolyFuncType> for SignatureFunc {
fn from(value: PolyFuncType) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(value))
Self::PolyFuncType(value.into())
}
}

impl From<PolyFuncTypeRV> for SignatureFunc {
fn from(v: PolyFuncTypeRV) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(v))
Self::PolyFuncType(v)
}
}

impl From<FuncValueType> for SignatureFunc {
fn from(v: FuncValueType) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(v))
Self::PolyFuncType(v.into())
}
}

impl From<Signature> for SignatureFunc {
fn from(v: Signature) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(FuncValueType::from(v)))
Self::PolyFuncType(FuncValueType::from(v).into())
}
}

impl From<CustomValidator> for SignatureFunc {
fn from(v: CustomValidator) -> Self {
Self::PolyFuncType(v)
Self::CustomValidator(v)
}
}

impl SignatureFunc {
fn static_params(&self) -> &[TypeParam] {
match self {
SignatureFunc::PolyFuncType(ts) => ts.poly_func.params(),
fn static_params(&self) -> Result<&[TypeParam], SignatureError> {
Ok(match self {
SignatureFunc::PolyFuncType(ts)
| SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. })
| SignatureFunc::MissingValidateFunc(ts) => ts.params(),
SignatureFunc::CustomFunc(func) => func.static_params(),
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
})
}

/// If the signature is missing a custom validation function, ignore and treat as
/// self-contained type scheme (with no custom validation).
pub fn ignore_missing_validation(&mut self) {
if let SignatureFunc::MissingValidateFunc(ts) = self {
*self = SignatureFunc::PolyFuncType(ts.clone());
}
}

Expand All @@ -243,10 +227,11 @@ impl SignatureFunc {
) -> Result<Signature, SignatureError> {
let temp: PolyFuncTypeRV; // to keep alive
let (pf, args) = match &self {
SignatureFunc::PolyFuncType(custom) => {
SignatureFunc::CustomValidator(custom) => {
custom.validate.validate(args, def, exts)?;
(&custom.poly_func, args)
}
SignatureFunc::PolyFuncType(ts) => (ts, args),
SignatureFunc::CustomFunc(func) => {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));
Expand All @@ -255,6 +240,10 @@ impl SignatureFunc {
temp = func.compute_signature(static_args, def, exts)?;
(&temp, other_args)
}
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
SignatureFunc::MissingValidateFunc(_) => {
return Err(SignatureError::MissingValidateFunc)
}
};

let mut res = pf.instantiate(args, exts)?;
Expand All @@ -268,8 +257,11 @@ impl SignatureFunc {
impl Debug for SignatureFunc {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::PolyFuncType(ts) => ts.poly_func.fmt(f),
Self::CustomValidator(ts) => ts.poly_func.fmt(f),
Self::PolyFuncType(ts) => ts.fmt(f),
Self::CustomFunc { .. } => f.write_str("<custom sig>"),
Self::MissingComputeFunc => f.write_str("<missing custom sig>"),
Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"),
}
}
}
Expand Down Expand Up @@ -321,10 +313,11 @@ pub struct OpDef {
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
misc: HashMap<String, serde_json::Value>,

#[serde(flatten)]
#[serde(with = "serialize_signature_func", flatten)]
signature_func: SignatureFunc,
// Some operations cannot lower themselves and tools that do not understand them
// can only treat them as opaque/black-box ops.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub(crate) lower_funcs: Vec<LowerFunc>,

/// Operations can optionally implement [`ConstFold`] to implement constant folding.
Expand All @@ -344,7 +337,8 @@ impl OpDef {
) -> Result<(), SignatureError> {
let temp: PolyFuncTypeRV; // to keep alive
let (pf, args) = match &self.signature_func {
SignatureFunc::PolyFuncType(ts) => (&ts.poly_func, args),
SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args),
SignatureFunc::PolyFuncType(ts) => (ts, args),
SignatureFunc::CustomFunc(custom) => {
let (static_args, other_args) =
args.split_at(min(custom.static_params().len(), args.len()));
Expand All @@ -355,6 +349,10 @@ impl OpDef {
temp = custom.compute_signature(static_args, self, exts)?;
(&temp, other_args)
}
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
SignatureFunc::MissingValidateFunc(_) => {
return Err(SignatureError::MissingValidateFunc)
}
};
args.iter()
.try_for_each(|ta| ta.validate(exts, var_decls))?;
Expand Down Expand Up @@ -409,14 +407,14 @@ impl OpDef {
}

/// Returns a reference to the params of this [`OpDef`].
pub fn params(&self) -> &[TypeParam] {
pub fn params(&self) -> Result<&[TypeParam], SignatureError> {
self.signature_func.static_params()
}

pub(super) fn validate(&self, exts: &ExtensionRegistry) -> Result<(), SignatureError> {
// TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams
// for both type scheme and custom binary
if let SignatureFunc::PolyFuncType(ts) = &self.signature_func {
if let SignatureFunc::CustomValidator(ts) = &self.signature_func {
// The type scheme may contain row variables so be of variable length;
// these will have to be substituted to fixed-length concrete types when
// the OpDef is instantiated into an actual OpType.
Expand Down Expand Up @@ -557,12 +555,13 @@ pub(super) mod test {
// a compile error here. To fix: modify the fields matched on here,
// maintaining the lack of `..` and, for each part that is
// serializable, ensure we are checking it for equality below.
SignatureFunc::PolyFuncType(CustomValidator {
SignatureFunc::CustomValidator(CustomValidator {
poly_func,
validate: _,
}) => Some(poly_func.clone()),
// This is ruled out by `new()` but leave it here for later.
SignatureFunc::CustomFunc(_) => None,
})
| SignatureFunc::PolyFuncType(poly_func)
| SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()),
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
};

let get_lower_funcs = |lfs: &Vec<LowerFunc>| {
Expand Down Expand Up @@ -787,9 +786,7 @@ pub(super) mod test {

use crate::{
builder::test::simple_dfg_hugr,
extension::{
op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc,
},
extension::{op_def::LowerFunc, ExtensionId, ExtensionSet, OpDef, SignatureFunc},
types::PolyFuncTypeRV,
};

Expand All @@ -801,7 +798,7 @@ pub(super) mod test {
// this is not serialized. When it is, we should generate
// examples here .
any::<PolyFuncTypeRV>()
.prop_map(|x| SignatureFunc::PolyFuncType(CustomValidator::from_polyfunc(x)))
.prop_map(SignatureFunc::PolyFuncType)
.boxed()
}
}
Expand Down
Loading

0 comments on commit 31be204

Please sign in to comment.