diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index 586d31c3e..d6fbc49ff 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -9,7 +9,7 @@ use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder}; use crate::extension::{ExtensionRegistry, ExtensionRegistryError}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::{HugrView, ValidationError}; -use crate::ops::{Module, NamedOp, OpTag, OpTrait, OpType}; +use crate::ops::{FuncDefn, Module, NamedOp, OpTag, OpTrait, OpType}; use crate::{Extension, Hugr}; #[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)] @@ -211,7 +211,7 @@ impl AsRef<[Hugr]> for Package { /// Returns [PackageError::] fn to_module_hugr(mut hugr: Hugr) -> Result { let root = hugr.root(); - let root_op = hugr.get_optype(root); + let root_op = hugr.get_optype(root).clone(); let tag = root_op.tag(); // Modules can be returned as is. @@ -225,6 +225,29 @@ fn to_module_hugr(mut hugr: Hugr) -> Result { hugr.set_parent(root, new_root); return Ok(hugr); } + // If it is a DFG, make it into a "main" function definition and insert it into a module. + if OpTag::Dfg.is_superset(tag) { + let signature = root_op + .dataflow_signature() + .unwrap_or_else(|| panic!("Dataflow child {} without signature", root_op.name())); + + // Convert the DFG into a `FuncDefn` + hugr.set_num_ports(root, 0, 1); + hugr.replace_op( + root, + FuncDefn { + name: "main".to_string(), + signature: signature.into(), + }, + ) + .expect("Hugr accepts any root node"); + + // Wrap it in a module. + let new_root = hugr.add_node(Module::new().into()); + hugr.set_root(new_root); + hugr.set_parent(root, new_root); + return Ok(hugr); + } // Wrap it in a function definition named "main" inside the module otherwise. if OpTag::DataflowChild.is_superset(tag) && !root_op.is_input() && !root_op.is_output() { let signature = root_op @@ -374,17 +397,20 @@ mod test { } #[rstest] - #[case::module(simple_module_hugr(), false)] - #[case::funcdef(simple_funcdef_hugr(), false)] - #[case::dfg(simple_dfg_hugr(), false)] - #[case::cfg(simple_cfg_hugr(), false)] - #[case::unsupported_input(simple_input_node(), true)] - fn hugr_to_package(#[case] hugr: Hugr, #[case] errors: bool) { + #[case::module("module", simple_module_hugr(), false)] + #[case::funcdef("funcdef", simple_funcdef_hugr(), false)] + #[case::dfg("dfg", simple_dfg_hugr(), false)] + #[case::cfg("cfg", simple_cfg_hugr(), false)] + #[case::unsupported_input("input", simple_input_node(), true)] + fn hugr_to_package(#[case] test_name: &str, #[case] hugr: Hugr, #[case] errors: bool) { match (&Package::from_hugr(hugr), errors) { (Ok(package), false) => { assert_eq!(package.modules.len(), 1); - let root_op = package.modules[0].get_optype(package.modules[0].root()); + let hugr = &package.modules[0]; + let root_op = hugr.get_optype(hugr.root()); assert!(root_op.is_module()); + + insta::assert_snapshot!(test_name, hugr.mermaid_string()); } (Err(_), true) => {} (p, _) => panic!("Unexpected result {:?}", p), diff --git a/hugr-core/src/snapshots/hugr_core__package__test__cfg.snap b/hugr-core/src/snapshots/hugr_core__package__test__cfg.snap new file mode 100644 index 000000000..3262c52fc --- /dev/null +++ b/hugr-core/src/snapshots/hugr_core__package__test__cfg.snap @@ -0,0 +1,40 @@ +--- +source: hugr-core/src/package.rs +expression: hugr.mermaid_string() +--- +graph LR + subgraph 0 ["(0) Module"] + direction LR + subgraph 1 ["(1) FuncDefn: #quot;main#quot;"] + direction LR + 2["(2) Input"] + 3["(3) Output"] + subgraph 4 ["(4) CFG"] + direction LR + subgraph 6 ["(6) DataflowBlock"] + direction LR + 7["(7) Input"] + 8["(8) Output"] + 9["(9) Tag"] + 7--"0:0
usize"-->9 + 9--"0:0
[usize]+[usize]"-->8 + end + 5["(5) ExitBlock"] + subgraph 10 ["(10) DataflowBlock"] + direction LR + 12["(12) Input"] + 13["(13) Output"] + 14["(14) const:seq:{}"] + 15["(15) LoadConstant"] + 12--"0:1
usize"-->13 + 14--"0:0
[]"-->15 + 15--"0:0
[]"-->13 + end + 6-."0:0".->10 + 6-."1:0".->5 + 10-."0:0".->5 + end + 2--"0:0
usize"-->4 + 4--"0:0
usize"-->3 + end + end diff --git a/hugr-core/src/snapshots/hugr_core__package__test__dfg.snap b/hugr-core/src/snapshots/hugr_core__package__test__dfg.snap new file mode 100644 index 000000000..fd1a0c0ef --- /dev/null +++ b/hugr-core/src/snapshots/hugr_core__package__test__dfg.snap @@ -0,0 +1,14 @@ +--- +source: hugr-core/src/package.rs +expression: hugr.mermaid_string() +--- +graph LR + subgraph 3 ["(3) Module"] + direction LR + subgraph 0 ["(0) FuncDefn: #quot;main#quot;"] + direction LR + 1["(1) Input"] + 2["(2) Output"] + 1--"0:0
[]+[]"-->2 + end + end diff --git a/hugr-core/src/snapshots/hugr_core__package__test__funcdef.snap b/hugr-core/src/snapshots/hugr_core__package__test__funcdef.snap new file mode 100644 index 000000000..edeb76dbb --- /dev/null +++ b/hugr-core/src/snapshots/hugr_core__package__test__funcdef.snap @@ -0,0 +1,14 @@ +--- +source: hugr-core/src/package.rs +expression: hugr.mermaid_string() +--- +graph LR + subgraph 3 ["(3) Module"] + direction LR + subgraph 0 ["(0) FuncDefn: #quot;test#quot;"] + direction LR + 1["(1) Input"] + 2["(2) Output"] + 1--"0:0
[]+[]"-->2 + end + end diff --git a/hugr-core/src/snapshots/hugr_core__package__test__module.snap b/hugr-core/src/snapshots/hugr_core__package__test__module.snap new file mode 100644 index 000000000..62d54bd96 --- /dev/null +++ b/hugr-core/src/snapshots/hugr_core__package__test__module.snap @@ -0,0 +1,9 @@ +--- +source: hugr-core/src/package.rs +expression: hugr.mermaid_string() +--- +graph LR + subgraph 0 ["(0) Module"] + direction LR + 1["(1) FuncDecl: #quot;test#quot;"] + end