From 9962f97a93b95f1c0d2ad668e6d481c0991684d3 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Thu, 31 Oct 2024 16:06:59 +0000 Subject: [PATCH] feat: Operation and constructor declarations in `hugr-model` (#1605) This PR adds the ability to declare custom operations and constructors (so static types, runtime types, constraints, etc.) to `hugr-model`. In the case of operations this is used for deduplication when exporting. --- hugr-core/src/export.rs | 86 ++++++++++++++++--- hugr-core/src/extension/op_def.rs | 5 ++ hugr-core/src/import.rs | 76 ++++++++++------ hugr-core/tests/model.rs | 28 ++++-- hugr-model/Cargo.toml | 1 + hugr-model/capnp/hugr-v0.capnp | 15 ++++ hugr-model/src/v0/binary/read.rs | 25 ++++++ hugr-model/src/v0/binary/write.rs | 15 ++++ hugr-model/src/v0/mod.rs | 44 +++++++++- hugr-model/src/v0/text/hugr.pest | 38 ++++---- hugr-model/src/v0/text/parse.rs | 64 +++++++++++++- hugr-model/src/v0/text/print.rs | 60 +++++++++++-- hugr-model/tests/binary.rs | 29 +++---- .../tests/fixtures/model-add.edn | 0 .../tests/fixtures/model-alias.edn | 0 .../tests/fixtures/model-call.edn | 0 .../tests/fixtures/model-cfg.edn | 0 .../tests/fixtures/model-cond.edn | 0 hugr-model/tests/fixtures/model-decl-exts.edn | 13 +++ .../tests/fixtures/model-loop.edn | 0 .../tests/fixtures/model-params.edn | 0 .../text__declarative_extensions.snap | 18 ++++ hugr-model/tests/text.rs | 12 +++ 23 files changed, 442 insertions(+), 87 deletions(-) rename {hugr-core => hugr-model}/tests/fixtures/model-add.edn (100%) rename {hugr-core => hugr-model}/tests/fixtures/model-alias.edn (100%) rename {hugr-core => hugr-model}/tests/fixtures/model-call.edn (100%) rename {hugr-core => hugr-model}/tests/fixtures/model-cfg.edn (100%) rename {hugr-core => hugr-model}/tests/fixtures/model-cond.edn (100%) create mode 100644 hugr-model/tests/fixtures/model-decl-exts.edn rename {hugr-core => hugr-model}/tests/fixtures/model-loop.edn (100%) rename {hugr-core => hugr-model}/tests/fixtures/model-params.edn (100%) create mode 100644 hugr-model/tests/snapshots/text__declarative_extensions.snap create mode 100644 hugr-model/tests/text.rs diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 0d266db39..e7a85c98f 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -1,8 +1,8 @@ //! Exporting HUGR graphs to their `hugr-model` representation. use crate::{ - extension::ExtensionSet, + extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc}, hugr::IdentList, - ops::{DataflowBlock, OpTrait, OpType}, + ops::{DataflowBlock, OpName, OpTrait, OpType}, types::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, @@ -38,11 +38,14 @@ struct Context<'a> { /// Mapping from ports to link indices. /// This only includes the minimum port among groups of linked ports. links: FxIndexSet<(Node, Port)>, + /// The arena in which the model is allocated. bump: &'a Bump, /// Stores the terms that we have already seen to avoid duplicates. term_map: FxHashMap, model::TermId>, /// The current scope for local variables. local_scope: Option, + /// Mapping from extension operations to their declarations. + decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>, } impl<'a> Context<'a> { @@ -57,23 +60,26 @@ impl<'a> Context<'a> { links: IndexSet::default(), term_map: FxHashMap::default(), local_scope: None, + decl_operations: FxHashMap::default(), } } /// Exports the root module of the HUGR graph. pub fn export_root(&mut self) { let hugr_children = self.hugr.children(self.hugr.root()); - let mut children = BumpVec::with_capacity_in(hugr_children.len(), self.bump); + let mut children = Vec::with_capacity(hugr_children.len()); for child in self.hugr.children(self.hugr.root()) { children.push(self.export_node(child)); } + children.extend(self.decl_operations.values().copied()); + let root = self.module.insert_region(model::Region { - kind: model::RegionKind::DataFlow, + kind: model::RegionKind::Module, sources: &[], targets: &[], - children: children.into_bump_slice(), + children: self.bump.alloc_slice_copy(&children), meta: &[], // TODO: Export metadata signature: None, }); @@ -123,15 +129,23 @@ impl<'a> Context<'a> { .or_insert_with(|| self.module.insert_term(term)) } - pub fn make_named_global_ref( + pub fn make_qualified_name( &mut self, - extension: &IdentList, + extension: &ExtensionId, name: impl AsRef, - ) -> model::GlobalRef<'a> { + ) -> &'a str { let capacity = extension.len() + name.as_ref().len() + 1; let mut output = BumpString::with_capacity_in(capacity, self.bump); let _ = write!(&mut output, "{}.{}", extension, name.as_ref()); - model::GlobalRef::Named(output.into_bump_str()) + output.into_bump_str() + } + + pub fn make_named_global_ref( + &mut self, + extension: &IdentList, + name: impl AsRef, + ) -> model::GlobalRef<'a> { + model::GlobalRef::Named(self.make_qualified_name(extension, name)) } /// Get the node that declares or defines the function associated with the given @@ -315,7 +329,7 @@ impl<'a> Context<'a> { // regions of potentially different kinds. At the moment, we check if the node has any // children, in which case we create a dataflow region with those children. OpType::ExtensionOp(op) => { - let operation = self.make_named_global_ref(op.def().extension(), op.def().name()); + let operation = self.export_opdef(op.def()); params = self .bump @@ -392,6 +406,58 @@ impl<'a> Context<'a> { node_id } + /// Export an `OpDef` as an operation declaration. + /// + /// Operations that allow a declarative form are exported as a reference to + /// an operation declaration node, and this node is reused for all instances + /// of the operation. The node is added to the `decl_operations` map so that + /// at the end of the export, the operation declaration nodes can be added + /// to the module as children of the module region. + pub fn export_opdef(&mut self, opdef: &OpDef) -> model::GlobalRef<'a> { + use std::collections::hash_map::Entry; + + let poly_func_type = match opdef.signature_func() { + SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type, + _ => return self.make_named_global_ref(opdef.extension(), opdef.name()), + }; + + let key = (opdef.extension().clone(), opdef.name().clone()); + let entry = self.decl_operations.entry(key); + + let node = match entry { + Entry::Occupied(occupied_entry) => { + return model::GlobalRef::Direct(*occupied_entry.get()) + } + Entry::Vacant(vacant_entry) => { + *vacant_entry.insert(self.module.insert_node(model::Node { + operation: model::Operation::Invalid, + inputs: &[], + outputs: &[], + params: &[], + regions: &[], + meta: &[], // TODO: Metadata + signature: None, + })) + } + }; + + let decl = self.with_local_scope(node, |this| { + let name = this.make_qualified_name(opdef.extension(), opdef.name()); + let (params, r#type) = this.export_poly_func_type(poly_func_type); + let decl = this.bump.alloc(model::OperationDecl { + name, + params, + r#type, + }); + decl + }); + + self.module.get_node_mut(node).unwrap().operation = + model::Operation::DeclareOperation { decl }; + + model::GlobalRef::Direct(node) + } + /// Export the signature of a `DataflowBlock`. Here we can't use `OpType::dataflow_signature` /// like for the other nodes since the ports are control flow ports. pub fn export_block_signature(&mut self, block: &DataflowBlock) -> model::TermId { diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index de114fed0..6c1a49d9e 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -450,6 +450,11 @@ impl OpDef { ) -> ConstFoldResult { (self.constant_folder.as_ref())?.fold(type_args, consts) } + + /// Returns a reference to the signature function of this [`OpDef`]. + pub fn signature_func(&self) -> &SignatureFunc { + &self.signature_func + } } impl Extension { diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 347a26fd8..d981049fb 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -284,6 +284,8 @@ impl<'a> Context<'a> { match item { NamedItem::FuncDecl(node) => Ok(*node), NamedItem::FuncDefn(node) => Ok(*node), + NamedItem::CtrDecl(node) => Ok(*node), + NamedItem::OperationDecl(node) => Ok(*node), } } } @@ -299,6 +301,8 @@ impl<'a> Context<'a> { model::Operation::DeclareFunc { decl } => decl.name, model::Operation::DefineAlias { decl, .. } => decl.name, model::Operation::DeclareAlias { decl } => decl.name, + model::Operation::DeclareConstructor { decl } => decl.name, + model::Operation::DeclareOperation { decl } => decl.name, _ => { return Err(model::ModelError::InvalidGlobal(global_ref.to_string()).into()); } @@ -334,7 +338,11 @@ impl<'a> Context<'a> { Ok(()) } - fn import_node(&mut self, node_id: model::NodeId, parent: Node) -> Result { + fn import_node( + &mut self, + node_id: model::NodeId, + parent: Node, + ) -> Result, ImportError> { let node_data = self.get_node(node_id)?; match node_data.operation { @@ -349,7 +357,7 @@ impl<'a> Context<'a> { }; self.import_dfg_region(node_id, *region, node)?; - Ok(node) + Ok(Some(node)) } model::Operation::Cfg => { @@ -362,10 +370,13 @@ impl<'a> Context<'a> { }; self.import_cfg_region(node_id, *region, node)?; - Ok(node) + Ok(Some(node)) } - model::Operation::Block => self.import_cfg_block(node_id, parent), + model::Operation::Block => { + let node = self.import_cfg_block(node_id, parent)?; + Ok(Some(node)) + } model::Operation::DefineFunc { decl } => { self.import_poly_func_type(*decl, |ctx, signature| { @@ -382,7 +393,7 @@ impl<'a> Context<'a> { ctx.import_dfg_region(node_id, *region, node)?; - Ok(node) + Ok(Some(node)) }) } @@ -395,7 +406,7 @@ impl<'a> Context<'a> { let node = ctx.make_node(node_id, optype, parent)?; - Ok(node) + Ok(Some(node)) }) } @@ -415,7 +426,8 @@ impl<'a> Context<'a> { self.static_edges.push((func_node, node_id)); let optype = OpType::Call(Call::try_new(func_sig, type_args, self.extensions)?); - self.make_node(node_id, optype, parent) + let node = self.make_node(node_id, optype, parent)?; + Ok(Some(node)) } model::Operation::LoadFunc { func } => { @@ -439,18 +451,26 @@ impl<'a> Context<'a> { self.extensions, )?); - self.make_node(node_id, optype, parent) + let node = self.make_node(node_id, optype, parent)?; + Ok(Some(node)) } - model::Operation::TailLoop => self.import_tail_loop(node_id, parent), - model::Operation::Conditional => self.import_conditional(node_id, parent), + model::Operation::TailLoop => { + let node = self.import_tail_loop(node_id, parent)?; + Ok(Some(node)) + } + model::Operation::Conditional => { + let node = self.import_conditional(node_id, parent)?; + Ok(Some(node)) + } model::Operation::CustomFull { operation: GlobalRef::Named(name), } if name == OP_FUNC_CALL_INDIRECT => { let signature = self.get_node_signature(node_id)?; let optype = OpType::CallIndirect(CallIndirect { signature }); - self.make_node(node_id, optype, parent) + let node = self.make_node(node_id, optype, parent)?; + Ok(Some(node)) } model::Operation::CustomFull { operation } => { @@ -461,15 +481,7 @@ impl<'a> Context<'a> { .map(|param| self.import_type_arg(*param)) .collect::, _>>()?; - let name = match operation { - GlobalRef::Direct(_) => { - return Err(error_unsupported!( - "custom operation with direct reference to declaring node" - )) - } - GlobalRef::Named(name) => name, - }; - + let name = self.get_global_name(operation)?; let (extension, name) = self.import_custom_name(name)?; // TODO: Currently we do not have the description or any other metadata for @@ -493,7 +505,7 @@ impl<'a> Context<'a> { _ => return Err(error_unsupported!("multiple regions in custom operation")), } - Ok(node) + Ok(Some(node)) } model::Operation::Custom { .. } => Err(error_unsupported!( @@ -512,7 +524,8 @@ impl<'a> Context<'a> { definition: ctx.import_type(value)?, }); - ctx.make_node(node_id, optype, parent) + let node = ctx.make_node(node_id, optype, parent)?; + Ok(Some(node)) }), model::Operation::DeclareAlias { decl } => self.with_local_socpe(|ctx| { @@ -527,7 +540,8 @@ impl<'a> Context<'a> { bound: TypeBound::Copyable, }); - ctx.make_node(node_id, optype, parent) + let node = ctx.make_node(node_id, optype, parent)?; + Ok(Some(node)) }), model::Operation::Tag { tag } => { @@ -536,15 +550,19 @@ impl<'a> Context<'a> { .ok_or_else(|| error_uninferred!("node signature"))?; let (_, outputs, _) = self.get_func_type(signature)?; let (variants, _) = self.import_adt_and_rest(node_id, outputs)?; - self.make_node( + let node = self.make_node( node_id, OpType::Tag(Tag { variants, tag: tag as _, }), parent, - ) + )?; + Ok(Some(node)) } + + model::Operation::DeclareConstructor { .. } => Ok(None), + model::Operation::DeclareOperation { .. } => Ok(None), } } @@ -1188,6 +1206,8 @@ impl<'a> Context<'a> { enum NamedItem { FuncDecl(model::NodeId), FuncDefn(model::NodeId), + CtrDecl(model::NodeId), + OperationDecl(model::NodeId), } struct Names<'a> { @@ -1208,6 +1228,12 @@ impl<'a> Names<'a> { model::Operation::DeclareFunc { decl } => { Some((decl.name, NamedItem::FuncDefn(node_id))) } + model::Operation::DeclareConstructor { decl } => { + Some((decl.name, NamedItem::CtrDecl(node_id))) + } + model::Operation::DeclareOperation { decl } => { + Some((decl.name, NamedItem::OperationDecl(node_id))) + } _ => None, }; diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index f59f463a8..611eda660 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -12,35 +12,49 @@ fn roundtrip(source: &str) -> String { #[test] pub fn test_roundtrip_add() { - insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-add.edn"))); + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-add.edn" + ))); } #[test] pub fn test_roundtrip_call() { - insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-call.edn"))); + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-call.edn" + ))); } #[test] pub fn test_roundtrip_alias() { - insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-alias.edn"))); + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-alias.edn" + ))); } #[test] pub fn test_roundtrip_cfg() { - insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-cfg.edn"))); + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-cfg.edn" + ))); } #[test] pub fn test_roundtrip_cond() { - insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-cond.edn"))); + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-cond.edn" + ))); } #[test] pub fn test_roundtrip_loop() { - insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-loop.edn"))); + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-loop.edn" + ))); } #[test] pub fn test_roundtrip_params() { - insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-params.edn"))); + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-params.edn" + ))); } diff --git a/hugr-model/Cargo.toml b/hugr-model/Cargo.toml index e6ca11801..328fcb7d6 100644 --- a/hugr-model/Cargo.toml +++ b/hugr-model/Cargo.toml @@ -31,4 +31,5 @@ workspace = true capnpc = "0.20.0" [dev-dependencies] +insta.workspace = true pretty_assertions = "1.4.1" diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 296468e8e..95db81205 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -49,6 +49,8 @@ struct Operation { conditional @12 :Void; callFunc @13 :TermId; loadFunc @14 :TermId; + constructorDecl @15 :ConstructorDecl; + operationDecl @16 :OperationDecl; } struct FuncDefn { @@ -75,6 +77,18 @@ struct Operation { params @1 :List(Param); type @2 :TermId; } + + struct ConstructorDecl { + name @0 :Text; + params @1 :List(Param); + type @2 :TermId; + } + + struct OperationDecl { + name @0 :Text; + params @1 :List(Param); + type @2 :TermId; + } } struct Region { @@ -89,6 +103,7 @@ struct Region { enum RegionKind { dataFlow @0; controlFlow @1; + module @2; } struct MetaItem { diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 082d7b3db..681bd4ea9 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -185,6 +185,30 @@ fn read_operation<'a>( }); model::Operation::DeclareAlias { decl } } + Which::ConstructorDecl(reader) => { + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader, get_params, read_param); + let r#type = model::TermId(reader.get_type()); + let decl = bump.alloc(model::ConstructorDecl { + name, + params, + r#type, + }); + model::Operation::DeclareConstructor { decl } + } + Which::OperationDecl(reader) => { + let reader = reader?; + let name = bump.alloc_str(reader.get_name()?.to_str()?); + let params = read_list!(bump, reader, get_params, read_param); + let r#type = model::TermId(reader.get_type()); + let decl = bump.alloc(model::OperationDecl { + name, + params, + r#type, + }); + model::Operation::DeclareOperation { decl } + } Which::Custom(name) => model::Operation::Custom { operation: read_global_ref(bump, name?)?, }, @@ -210,6 +234,7 @@ fn read_region<'a>( let kind = match reader.get_kind()? { hugr_capnp::RegionKind::DataFlow => model::RegionKind::DataFlow, hugr_capnp::RegionKind::ControlFlow => model::RegionKind::ControlFlow, + hugr_capnp::RegionKind::Module => model::RegionKind::Module, }; let sources = read_list!(bump, reader, get_sources, read_link_ref); diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index 4e7ccc0bc..a4b64d646 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -82,6 +82,20 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode write_list!(builder, init_params, write_param, decl.params); builder.set_type(decl.r#type.0); } + + model::Operation::DeclareConstructor { decl } => { + let mut builder = builder.init_constructor_decl(); + builder.set_name(decl.name); + write_list!(builder, init_params, write_param, decl.params); + builder.set_type(decl.r#type.0); + } + model::Operation::DeclareOperation { decl } => { + let mut builder = builder.init_operation_decl(); + builder.set_name(decl.name); + write_list!(builder, init_params, write_param, decl.params); + builder.set_type(decl.r#type.0); + } + model::Operation::Invalid => builder.set_invalid(()), } } @@ -136,6 +150,7 @@ fn write_region(mut builder: hugr_capnp::region::Builder, region: &model::Region builder.set_kind(match region.kind { model::RegionKind::DataFlow => hugr_capnp::RegionKind::DataFlow, model::RegionKind::ControlFlow => hugr_capnp::RegionKind::ControlFlow, + model::RegionKind::Module => hugr_capnp::RegionKind::Module, }); write_list!(builder, init_sources, write_link_ref, region.sources); diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index d29bab5eb..cb8713b32 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -342,6 +342,22 @@ pub enum Operation<'a> { /// The tag of the ADT value. tag: u16, }, + + /// Declaration for a term constructor. + /// + /// Nodes with this operation must be within a module region. + DeclareConstructor { + /// The declaration of the constructor. + decl: &'a ConstructorDecl<'a>, + }, + + /// Declaration for a operation. + /// + /// Nodes with this operation must be within a module region. + DeclareOperation { + /// The declaration of the operation. + decl: &'a OperationDecl<'a>, + }, } /// A region in the hugr. @@ -367,9 +383,11 @@ pub struct Region<'a> { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum RegionKind { /// Data flow region. - DataFlow, + DataFlow = 0, /// Control flow region. - ControlFlow, + ControlFlow = 1, + /// Module region. + Module = 2, } /// A function declaration. @@ -394,6 +412,28 @@ pub struct AliasDecl<'a> { pub r#type: TermId, } +/// A term constructor declaration. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ConstructorDecl<'a> { + /// The name of the constructor to be declared. + pub name: &'a str, + /// The static parameters of the constructor. + pub params: &'a [Param<'a>], + /// The type of the constructed term. + pub r#type: TermId, +} + +/// An operation declaration. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct OperationDecl<'a> { + /// The name of the operation to be declared. + pub name: &'a str, + /// The static parameters of the operation. + pub params: &'a [Param<'a>], + /// The type of the operation. This must be a function type. + pub r#type: TermId, +} + /// A metadata item. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct MetaItem<'a> { diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 5772efb51..33974a76a 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -27,29 +27,35 @@ node = { | node_load_func | node_define_alias | node_declare_alias + | node_declare_ctr + | node_declare_operation | node_tail_loop | node_cond | node_tag | node_custom } -node_dfg = { "(" ~ "dfg" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } -node_cfg = { "(" ~ "cfg" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } -node_block = { "(" ~ "block" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } -node_define_func = { "(" ~ "define-func" ~ func_header ~ meta* ~ region* ~ ")" } -node_declare_func = { "(" ~ "declare-func" ~ func_header ~ meta* ~ ")" } -node_call_func = { "(" ~ "call" ~ term ~ port_lists? ~ signature? ~ meta* ~ ")" } -node_load_func = { "(" ~ "load-func" ~ term ~ port_lists? ~ signature? ~ meta* ~ ")" } -node_define_alias = { "(" ~ "define-alias" ~ alias_header ~ term ~ meta* ~ ")" } -node_declare_alias = { "(" ~ "declare-alias" ~ alias_header ~ meta* ~ ")" } -node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } -node_cond = { "(" ~ "cond" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } -node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } -node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_dfg = { "(" ~ "dfg" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_cfg = { "(" ~ "cfg" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_block = { "(" ~ "block" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_define_func = { "(" ~ "define-func" ~ func_header ~ meta* ~ region* ~ ")" } +node_declare_func = { "(" ~ "declare-func" ~ func_header ~ meta* ~ ")" } +node_call_func = { "(" ~ "call" ~ term ~ port_lists? ~ signature? ~ meta* ~ ")" } +node_load_func = { "(" ~ "load-func" ~ term ~ port_lists? ~ signature? ~ meta* ~ ")" } +node_define_alias = { "(" ~ "define-alias" ~ alias_header ~ term ~ meta* ~ ")" } +node_declare_alias = { "(" ~ "declare-alias" ~ alias_header ~ meta* ~ ")" } +node_declare_ctr = { "(" ~ "declare-ctr" ~ ctr_header ~ meta* ~ ")" } +node_declare_operation = { "(" ~ "declare-operation" ~ operation_header ~ meta* ~ ")" } +node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_cond = { "(" ~ "cond" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } +node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } -signature = { "(" ~ "signature" ~ term ~ ")" } -func_header = { symbol ~ param* ~ term ~ term ~ term } -alias_header = { symbol ~ param* ~ term } +signature = { "(" ~ "signature" ~ term ~ ")" } +func_header = { symbol ~ param* ~ term ~ term ~ term } +alias_header = { symbol ~ param* ~ term } +ctr_header = { symbol ~ param* ~ term } +operation_header = { symbol ~ param* ~ term } param = { param_implicit | param_explicit | param_constraint } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index c139f3a3e..b669ce38c 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -6,8 +6,8 @@ use pest::{ use thiserror::Error; use crate::v0::{ - AliasDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, NodeId, Operation, - Param, Region, RegionId, RegionKind, Term, TermId, + AliasDecl, ConstructorDecl, FuncDecl, GlobalRef, LinkRef, LocalRef, MetaItem, Module, Node, + NodeId, Operation, OperationDecl, Param, Region, RegionId, RegionKind, Term, TermId, }; mod pest_parser { @@ -67,7 +67,7 @@ impl<'a> ParseContext<'a> { let children = self.parse_nodes(&mut inner)?; let root_region = self.module.insert_region(Region { - kind: RegionKind::DataFlow, + kind: RegionKind::Module, sources: &[], targets: &[], children, @@ -458,6 +458,34 @@ impl<'a> ParseContext<'a> { } } + Rule::node_declare_ctr => { + let decl = self.parse_ctr_header(inner.next().unwrap())?; + let meta = self.parse_meta(&mut inner)?; + Node { + operation: Operation::DeclareConstructor { decl }, + inputs: &[], + outputs: &[], + params: &[], + regions: &[], + meta, + signature: None, + } + } + + Rule::node_declare_operation => { + let decl = self.parse_op_header(inner.next().unwrap())?; + let meta = self.parse_meta(&mut inner)?; + Node { + operation: Operation::DeclareOperation { decl }, + inputs: &[], + outputs: &[], + params: &[], + regions: &[], + meta, + signature: None, + } + } + _ => unreachable!(), }; @@ -552,6 +580,36 @@ impl<'a> ParseContext<'a> { })) } + fn parse_ctr_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a ConstructorDecl<'a>> { + debug_assert!(matches!(pair.as_rule(), Rule::ctr_header)); + + let mut inner = pair.into_inner(); + let name = self.parse_symbol(&mut inner)?; + let params = self.parse_params(&mut inner)?; + let r#type = self.parse_term(inner.next().unwrap())?; + + Ok(self.bump.alloc(ConstructorDecl { + name, + params, + r#type, + })) + } + + fn parse_op_header(&mut self, pair: Pair<'a, Rule>) -> ParseResult<&'a OperationDecl<'a>> { + debug_assert!(matches!(pair.as_rule(), Rule::operation_header)); + + let mut inner = pair.into_inner(); + let name = self.parse_symbol(&mut inner)?; + let params = self.parse_params(&mut inner)?; + let r#type = self.parse_term(inner.next().unwrap())?; + + Ok(self.bump.alloc(OperationDecl { + name, + params, + r#type, + })) + } + fn parse_params(&mut self, pairs: &mut Pairs<'a, Rule>) -> ParseResult<&'a [Param<'a>]> { let mut params = Vec::new(); diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index c70dfd401..494c10df2 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -1,4 +1,4 @@ -use pretty::{Arena, DocAllocator, RefDoc}; +use pretty::{docs, Arena, DocAllocator, RefDoc}; use std::borrow::Cow; use crate::v0::{ @@ -327,6 +327,36 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) }), + Operation::DeclareConstructor { decl } => this.with_local_scope(decl.params, |this| { + this.print_group(|this| { + this.print_text("declare-ctr"); + this.print_text(decl.name); + }); + + for param in decl.params { + this.print_param(*param)?; + } + + this.print_term(decl.r#type)?; + this.print_meta(node_data.meta)?; + Ok(()) + }), + + Operation::DeclareOperation { decl } => this.with_local_scope(decl.params, |this| { + this.print_group(|this| { + this.print_text("declare-operation"); + this.print_text(decl.name); + }); + + for param in decl.params { + this.print_param(*param)?; + } + + this.print_term(decl.r#type)?; + this.print_meta(node_data.meta)?; + Ok(()) + }), + Operation::TailLoop => { this.print_text("tail-loop"); this.print_port_lists(node_data.inputs, node_data.outputs)?; @@ -374,6 +404,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { RegionKind::ControlFlow => { this.print_text("cfg"); } + RegionKind::Module => { + this.print_text("module"); + } }; this.print_port_lists(region_data.sources, region_data.targets)?; @@ -513,10 +546,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_term(*item_type) }), Term::Str(str) => { - // TODO: escape - self.print_text("\""); - self.print_text(*str); - self.print_text("\""); + self.print_string(str); Ok(()) } Term::StrType => { @@ -615,8 +645,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { fn print_meta(&mut self, meta: &'a [MetaItem<'a>]) -> PrintResult<()> { for item in meta { self.print_parens(|this| { - this.print_text("meta"); - this.print_text(item.name); + this.print_group(|this| { + this.print_text("meta"); + this.print_text(item.name); + }); this.print_term(item.value) })?; } @@ -634,4 +666,18 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } + + /// Print a string literal. + fn print_string(&mut self, string: &'p str) { + // TODO: escape + self.docs.push( + docs![ + self.arena, + self.arena.text("\""), + self.arena.text(string), + self.arena.text("\"") + ] + .into_doc(), + ); + } } diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs index 57123aa4d..043061677 100644 --- a/hugr-model/tests/binary.rs +++ b/hugr-model/tests/binary.rs @@ -14,45 +14,40 @@ pub fn binary_roundtrip(input: &str) { #[test] pub fn test_add() { - binary_roundtrip(include_str!("../../hugr-core/tests/fixtures/model-add.edn")); + binary_roundtrip(include_str!("fixtures/model-add.edn")); } #[test] pub fn test_alias() { - binary_roundtrip(include_str!( - "../../hugr-core/tests/fixtures/model-alias.edn" - )); + binary_roundtrip(include_str!("fixtures/model-alias.edn")); } #[test] pub fn test_call() { - binary_roundtrip(include_str!( - "../../hugr-core/tests/fixtures/model-call.edn" - )); + binary_roundtrip(include_str!("fixtures/model-call.edn")); } #[test] pub fn test_cfg() { - binary_roundtrip(include_str!("../../hugr-core/tests/fixtures/model-cfg.edn")); + binary_roundtrip(include_str!("fixtures/model-cfg.edn")); } #[test] pub fn test_cond() { - binary_roundtrip(include_str!( - "../../hugr-core/tests/fixtures/model-cond.edn" - )); + binary_roundtrip(include_str!("fixtures/model-cond.edn")); } #[test] pub fn test_loop() { - binary_roundtrip(include_str!( - "../../hugr-core/tests/fixtures/model-loop.edn" - )); + binary_roundtrip(include_str!("fixtures/model-loop.edn")); } #[test] pub fn test_params() { - binary_roundtrip(include_str!( - "../../hugr-core/tests/fixtures/model-params.edn" - )); + binary_roundtrip(include_str!("fixtures/model-params.edn")); +} + +#[test] +pub fn test_decl_exts() { + binary_roundtrip(include_str!("fixtures/model-decl-exts.edn")); } diff --git a/hugr-core/tests/fixtures/model-add.edn b/hugr-model/tests/fixtures/model-add.edn similarity index 100% rename from hugr-core/tests/fixtures/model-add.edn rename to hugr-model/tests/fixtures/model-add.edn diff --git a/hugr-core/tests/fixtures/model-alias.edn b/hugr-model/tests/fixtures/model-alias.edn similarity index 100% rename from hugr-core/tests/fixtures/model-alias.edn rename to hugr-model/tests/fixtures/model-alias.edn diff --git a/hugr-core/tests/fixtures/model-call.edn b/hugr-model/tests/fixtures/model-call.edn similarity index 100% rename from hugr-core/tests/fixtures/model-call.edn rename to hugr-model/tests/fixtures/model-call.edn diff --git a/hugr-core/tests/fixtures/model-cfg.edn b/hugr-model/tests/fixtures/model-cfg.edn similarity index 100% rename from hugr-core/tests/fixtures/model-cfg.edn rename to hugr-model/tests/fixtures/model-cfg.edn diff --git a/hugr-core/tests/fixtures/model-cond.edn b/hugr-model/tests/fixtures/model-cond.edn similarity index 100% rename from hugr-core/tests/fixtures/model-cond.edn rename to hugr-model/tests/fixtures/model-cond.edn diff --git a/hugr-model/tests/fixtures/model-decl-exts.edn b/hugr-model/tests/fixtures/model-decl-exts.edn new file mode 100644 index 000000000..c38c78cdf --- /dev/null +++ b/hugr-model/tests/fixtures/model-decl-exts.edn @@ -0,0 +1,13 @@ +(hugr 0) + +(declare-ctr array.Array + (param ?t type) + (param ?n nat) + type + (meta docs.description "Fixed size array.")) + +(declare-operation array.Init + (param ?t type) + (param ?n nat) + (fn [?t] [(array.Array ?t ?n)] (ext array)) + (meta docs.description "Initialize an array of size ?n with copies of a default value.")) diff --git a/hugr-core/tests/fixtures/model-loop.edn b/hugr-model/tests/fixtures/model-loop.edn similarity index 100% rename from hugr-core/tests/fixtures/model-loop.edn rename to hugr-model/tests/fixtures/model-loop.edn diff --git a/hugr-core/tests/fixtures/model-params.edn b/hugr-model/tests/fixtures/model-params.edn similarity index 100% rename from hugr-core/tests/fixtures/model-params.edn rename to hugr-model/tests/fixtures/model-params.edn diff --git a/hugr-model/tests/snapshots/text__declarative_extensions.snap b/hugr-model/tests/snapshots/text__declarative_extensions.snap new file mode 100644 index 000000000..d26909912 --- /dev/null +++ b/hugr-model/tests/snapshots/text__declarative_extensions.snap @@ -0,0 +1,18 @@ +--- +source: hugr-model/tests/text.rs +expression: "roundtrip(include_str!(\"fixtures/model-decl-exts.edn\"))" +--- +(hugr 0) + +(declare-ctr array.Array + (param ?t type) + (param ?n nat) + type + (meta docs.description "Fixed size array.")) + +(declare-operation array.Init + (param ?t type) + (param ?n nat) + (fn [?t] [(array.Array ?t ?n)] (ext array)) + (meta docs.description + "Initialize an array of size ?n with copies of a default value.")) diff --git a/hugr-model/tests/text.rs b/hugr-model/tests/text.rs new file mode 100644 index 000000000..3e7c7b7a6 --- /dev/null +++ b/hugr-model/tests/text.rs @@ -0,0 +1,12 @@ +use hugr_model::v0 as model; + +fn roundtrip(source: &str) -> String { + let bump = bumpalo::Bump::new(); + let parsed_model = model::text::parse(source, &bump).unwrap(); + model::text::print_to_string(&parsed_model.module, 80).unwrap() +} + +#[test] +pub fn test_declarative_extensions() { + insta::assert_snapshot!(roundtrip(include_str!("fixtures/model-decl-exts.edn"))) +}