Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Emulate TypeBounds on parameters via constraints. #1624

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 50 additions & 14 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg,
TypeBase, TypeEnum,
TypeBase, TypeBound, TypeEnum,
},
Direction, Hugr, HugrView, IncomingPort, Node, Port,
};
Expand Down Expand Up @@ -46,6 +46,8 @@ struct Context<'a> {
term_map: FxHashMap<model::Term<'a>, model::TermId>,
/// The current scope for local variables.
local_scope: Option<model::NodeId>,
/// Constraints to be added to the local scope.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is local scope == [Self::local_scope]? Do we actually plan to modify the local scope by adding these? Or are these "Constraints to be added to those from the local scope" (i.e., extra constraints from somewhere else)? "Constraints in addition to those from the local scope" would be clearer still that we don't intend to mutate some list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When exporting a PolyFuncType, we gather constraints in local_constraints. PolyFuncType itself can not express constraints, but if one of the parameters takes a copyable type, local_constraints is where we record the corresponding constraint for that parameter.

local_constraints: Vec<model::TermId>,
/// Mapping from extension operations to their declarations.
decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>,
}
Expand All @@ -63,6 +65,7 @@ impl<'a> Context<'a> {
term_map: FxHashMap::default(),
local_scope: None,
decl_operations: FxHashMap::default(),
local_constraints: Vec::new(),
}
}

Expand Down Expand Up @@ -173,9 +176,11 @@ impl<'a> Context<'a> {
}

fn with_local_scope<T>(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T {
let old_scope = self.local_scope.replace(node);
let prev_local_scope = self.local_scope.replace(node);
let prev_local_constraints = std::mem::take(&mut self.local_constraints);
let result = f(self);
self.local_scope = old_scope;
self.local_scope = prev_local_scope;
self.local_constraints = prev_local_constraints;
result
}

Expand Down Expand Up @@ -232,10 +237,11 @@ impl<'a> Context<'a> {

OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| {
let name = this.get_func_name(node).unwrap();
let (params, signature) = this.export_poly_func_type(&func.signature);
let (params, constraints, signature) = this.export_poly_func_type(&func.signature);
let decl = this.bump.alloc(model::FuncDecl {
name,
params,
constraints,
signature,
});
let extensions = this.export_ext_set(&func.signature.body().extension_reqs);
Expand All @@ -247,10 +253,11 @@ impl<'a> Context<'a> {

OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| {
let name = this.get_func_name(node).unwrap();
let (params, func) = this.export_poly_func_type(&func.signature);
let (params, constraints, func) = this.export_poly_func_type(&func.signature);
let decl = this.bump.alloc(model::FuncDecl {
name,
params,
constraints,
signature: func,
});
model::Operation::DeclareFunc { decl }
Expand All @@ -262,6 +269,7 @@ impl<'a> Context<'a> {
let decl = this.bump.alloc(model::AliasDecl {
name: &alias.name,
params: &[],
constraints: &[],
r#type,
});
model::Operation::DeclareAlias { decl }
Expand All @@ -274,6 +282,7 @@ impl<'a> Context<'a> {
let decl = this.bump.alloc(model::AliasDecl {
name: &alias.name,
params: &[],
constraints: &[],
r#type,
});
model::Operation::DefineAlias { decl, value }
Expand Down Expand Up @@ -450,10 +459,11 @@ impl<'a> Context<'a> {

let decl = self.with_local_scope(node, |this| {
let name = this.make_qualified_name(opdef.extension(), opdef.name());
let (params, r#type) = this.export_poly_func_type(poly_func_type);
let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type);
let decl = this.bump.alloc(model::OperationDecl {
name,
params,
constraints,
r#type,
});
decl
Expand Down Expand Up @@ -674,19 +684,30 @@ impl<'a> Context<'a> {
pub fn export_poly_func_type<RV: MaybeRV>(
&mut self,
t: &PolyFuncTypeBase<RV>,
) -> (&'a [model::Param<'a>], model::TermId) {
) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) {
let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump);
let scope = self
.local_scope
.expect("exporting poly func type outside of local scope");

for (i, param) in t.params().iter().enumerate() {
let name = self.bump.alloc_str(&i.to_string());
let r#type = self.export_type_param(param);
let param = model::Param::Implicit { name, r#type };
let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _)));
let param = model::Param {
name,
r#type,
sort: model::ParamSort::Implicit,
};
params.push(param)
}

let constraints = self
.bump
.alloc_slice_fill_iter(self.local_constraints.drain(..));

let body = self.export_func_type(t.body());

(params.into_bump_slice(), body)
(params.into_bump_slice(), constraints, body)
}

pub fn export_type<RV: MaybeRV>(&mut self, t: &TypeBase<RV>) -> model::TermId {
Expand Down Expand Up @@ -794,20 +815,35 @@ impl<'a> Context<'a> {
self.make_term(model::Term::List { items, tail: None })
}

pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId {
pub fn export_type_param(
&mut self,
t: &TypeParam,
var: Option<model::LocalRef<'static>>,
) -> model::TermId {
match t {
// This ignores the type bound for now.
TypeParam::Type { .. } => self.make_term(model::Term::Type),
TypeParam::Type { b } => {
if let (Some(var), TypeBound::Copyable) = (var, b) {
let term = self.make_term(model::Term::Var(var));
let copy = self.make_term(model::Term::CopyConstraint { term });
let discard = self.make_term(model::Term::DiscardConstraint { term });
self.local_constraints.extend([copy, discard]);
}

self.make_term(model::Term::Type)
}
// This ignores the type bound for now.
TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType),
TypeParam::String => self.make_term(model::Term::StrType),
TypeParam::List { param } => {
let item_type = self.export_type_param(param);
let item_type = self.export_type_param(param, None);
self.make_term(model::Term::ListType { item_type })
}
TypeParam::Tuple { params } => {
let items = self.bump.alloc_slice_fill_iter(
params.iter().map(|param| self.export_type_param(param)),
params
.iter()
.map(|param| self.export_type_param(param, None)),
);
let types = self.make_term(model::Term::List { items, tail: None });
self.make_term(model::Term::ApplyFull {
Expand Down
125 changes: 93 additions & 32 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ struct Context<'a> {
nodes: FxHashMap<model::NodeId, Node>,

/// The types of the local variables that are currently in scope.
local_variables: FxIndexMap<&'a str, model::TermId>,
local_variables: FxIndexMap<&'a str, LocalVar>,

custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>,
}
Expand Down Expand Up @@ -159,16 +159,16 @@ impl<'a> Context<'a> {
fn resolve_local_ref(
&self,
local_ref: &model::LocalRef,
) -> Result<(usize, model::TermId), ImportError> {
) -> Result<(usize, LocalVar), ImportError> {
let term = match local_ref {
model::LocalRef::Index(_, index) => self
.local_variables
.get_index(*index as usize)
.map(|(_, term)| (*index as usize, *term)),
.map(|(_, v)| (*index as usize, *v)),
model::LocalRef::Named(name) => self
.local_variables
.get_full(name)
.map(|(index, _, term)| (index, *term)),
.map(|(index, _, v)| (index, *v)),
};

term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into())
Expand Down Expand Up @@ -892,41 +892,59 @@ impl<'a> Context<'a> {
self.with_local_socpe(|ctx| {
let mut imported_params = Vec::with_capacity(decl.params.len());

for param in decl.params {
// TODO: `PolyFuncType` should be able to handle constraints
// and distinguish between implicit and explicit parameters.
match param {
model::Param::Implicit { name, r#type } => {
imported_params.push(ctx.import_type_param(*r#type)?);
ctx.local_variables.insert(name, *r#type);
}
model::Param::Explicit { name, r#type } => {
imported_params.push(ctx.import_type_param(*r#type)?);
ctx.local_variables.insert(name, *r#type);
ctx.local_variables.extend(
decl.params
.iter()
.map(|param| (param.name, LocalVar::new(param.r#type))),
);

for constraint in decl.constraints {
match ctx.get_term(*constraint)? {
model::Term::CopyConstraint { term } => {
let model::Term::Var(var) = ctx.get_term(*term)? else {
return Err(error_unsupported!(
"constraint on term that is not a variable"
));
};

let var = ctx.resolve_local_ref(var)?.0;
ctx.local_variables.get_index_mut(var).unwrap().1.copy = true;
}
model::Param::Constraint { constraint: _ } => {
return Err(error_unsupported!("constraints"));
model::Term::DiscardConstraint { term } => {
let model::Term::Var(var) = ctx.get_term(*term)? else {
return Err(error_unsupported!(
"constraint on term that is not a variable"
));
};

let var = ctx.resolve_local_ref(var)?.0;
ctx.local_variables.get_index_mut(var).unwrap().1.discard = true;
}
_ => return Err(error_unsupported!("constraint other than copy or discard")),
}
}

for (index, param) in decl.params.iter().enumerate() {
// TODO: `PolyFuncType` should be able to distinguish between implicit and explicit parameters.
let bound = ctx.local_variables.get_index(index).unwrap().1.bound()?;
imported_params.push(ctx.import_type_param(param.r#type, bound)?);
}

let body = ctx.import_func_type::<RV>(decl.signature)?;
in_scope(ctx, PolyFuncTypeBase::new(imported_params, body))
})
}

/// Import a [`TypeParam`] from a term that represents a static type.
fn import_type_param(&mut self, term_id: model::TermId) -> Result<TypeParam, ImportError> {
fn import_type_param(
&mut self,
term_id: model::TermId,
bound: TypeBound,
) -> Result<TypeParam, ImportError> {
match self.get_term(term_id)? {
model::Term::Wildcard => Err(error_uninferred!("wildcard")),

model::Term::Type => {
// As part of the migration from `TypeBound`s to constraints, we pretend that all
// `TypeBound`s are copyable.
Ok(TypeParam::Type {
b: TypeBound::Copyable,
})
}
model::Term::Type => Ok(TypeParam::Type { b: bound }),

model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")),
model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")),
Expand All @@ -938,7 +956,7 @@ impl<'a> Context<'a> {
model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")),

model::Term::ListType { item_type } => {
let param = Box::new(self.import_type_param(*item_type)?);
let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this is any (also not covered by tests), is this a todo waiting for a list constraint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The item type is nested within the list and since we only constrain parameters directly here, we can not express that the items of the list should be copyable. Therefore the TypeBound::Any.

Ok(TypeParam::List { param })
}

Expand All @@ -952,15 +970,19 @@ impl<'a> Context<'a> {
| model::Term::List { .. }
| model::Term::ExtSet { .. }
| model::Term::Adt { .. }
| model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()),
| model::Term::Control { .. }
| model::Term::CopyConstraint { .. }
| model::Term::DiscardConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}

model::Term::ControlType => {
Err(error_unsupported!("type of control types as `TypeParam`"))
}
}
}

/// Import a `TypeArg` froma term that represents a static type or value.
/// Import a `TypeArg` from a term that represents a static type or value.
fn import_type_arg(&mut self, term_id: model::TermId) -> Result<TypeArg, ImportError> {
match self.get_term(term_id)? {
model::Term::Wildcard => Err(error_uninferred!("wildcard")),
Expand All @@ -969,8 +991,9 @@ impl<'a> Context<'a> {
}

model::Term::Var(var) => {
let (index, var_type) = self.resolve_local_ref(var)?;
let decl = self.import_type_param(var_type)?;
let (index, var) = self.resolve_local_ref(var)?;
let bound = var.bound()?;
let decl = self.import_type_param(var.r#type, bound)?;
Ok(TypeArg::new_var_use(index, decl))
}

Expand Down Expand Up @@ -1008,7 +1031,11 @@ impl<'a> Context<'a> {

model::Term::FuncType { .. }
| model::Term::Adt { .. }
| model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()),
| model::Term::Control { .. }
| model::Term::CopyConstraint { .. }
| model::Term::DiscardConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
}
}

Expand Down Expand Up @@ -1109,7 +1136,11 @@ impl<'a> Context<'a> {
| model::Term::List { .. }
| model::Term::Control { .. }
| model::Term::ControlType
| model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()),
| model::Term::Nat(_)
| model::Term::DiscardConstraint { .. }
| model::Term::CopyConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
}
}

Expand Down Expand Up @@ -1285,3 +1316,33 @@ impl<'a> Names<'a> {
Ok(Self { items })
}
}

#[derive(Debug, Clone, Copy)]
struct LocalVar {
r#type: model::TermId,
copy: bool,
discard: bool,
}

impl LocalVar {
pub fn new(r#type: model::TermId) -> Self {
Self {
r#type,
copy: false,
discard: false,
}
}

pub fn bound(&self) -> Result<TypeBound, ImportError> {
match (self.copy, self.discard) {
(true, true) => Ok(TypeBound::Copyable),
(false, false) => Ok(TypeBound::Any),
(true, false) => Err(error_unsupported!(
"type that is copyable but not discardable"
)),
(false, true) => Err(error_unsupported!(
"type that is discardable but not copyable"
)),
}
}
}
7 changes: 7 additions & 0 deletions hugr-core/tests/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,10 @@ pub fn test_roundtrip_params() {
"../../hugr-model/tests/fixtures/model-params.edn"
)));
}

#[test]
pub fn test_roundtrip_constraints() {
insta::assert_snapshot!(roundtrip(include_str!(
"../../hugr-model/tests/fixtures/model-constraints.edn"
)));
}
Loading
Loading