diff --git a/hugr-cli/src/lib.rs b/hugr-cli/src/lib.rs index 76538f034..aabdd1be5 100644 --- a/hugr-cli/src/lib.rs +++ b/hugr-cli/src/lib.rs @@ -3,10 +3,12 @@ use clap::Parser; use clap_verbosity_flag::{InfoLevel, Verbosity}; use clio::Input; +use hugr_core::{Extension, Hugr}; use std::{ffi::OsString, path::PathBuf}; use thiserror::Error; pub mod extensions; +pub mod mermaid; pub mod validate; /// CLI arguments. @@ -20,6 +22,8 @@ pub enum CliArgs { Validate(validate::ValArgs), /// Write standard extensions out in serialized form. GenExtensions(extensions::ExtArgs), + /// Write HUGR as mermaid diagrams. + Mermaid(mermaid::MermaidArgs), /// External commands #[command(external_subcommand)] External(Vec), @@ -30,8 +34,14 @@ pub enum CliArgs { #[error(transparent)] #[non_exhaustive] pub enum CliError { + /// Error reading input. + #[error("Error reading from path: {0}")] + InputFile(#[from] std::io::Error), + /// Error parsing input. + #[error("Error parsing input: {0}")] + Parse(#[from] serde_json::Error), /// Errors produced by the `validate` subcommand. - Validate(#[from] validate::CliError), + Validate(#[from] validate::ValError), } /// Validate and visualise a HUGR file. @@ -50,3 +60,37 @@ pub struct HugrArgs { #[arg(short, long, help = "Skip validation.")] pub extensions: Vec, } + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +/// Package of module HUGRs and extensions. +/// The HUGRs are validated against the extensions. +pub struct Package { + /// Module HUGRs included in the package. + pub modules: Vec, + /// Extensions to validate against. + pub extensions: Vec, +} + +impl Package { + /// Create a new package. + pub fn new(modules: Vec, extensions: Vec) -> Self { + Self { + modules, + extensions, + } + } +} + +impl HugrArgs { + /// Read either a package or a single hugr from the input. + pub fn get_package(&mut self) -> Result { + let val: serde_json::Value = serde_json::from_reader(&mut self.input)?; + // read either a package or a single hugr + if let Ok(p) = serde_json::from_value::(val.clone()) { + Ok(p) + } else { + let hugr: Hugr = serde_json::from_value(val)?; + Ok(Package::new(vec![hugr], vec![])) + } + } +} diff --git a/hugr-cli/src/main.rs b/hugr-cli/src/main.rs index 95474cf6f..c8bb0b56e 100644 --- a/hugr-cli/src/main.rs +++ b/hugr-cli/src/main.rs @@ -10,6 +10,7 @@ fn main() { match CliArgs::parse() { CliArgs::Validate(args) => run_validate(args), CliArgs::GenExtensions(args) => args.run_dump(), + CliArgs::Mermaid(mut args) => args.run_print().unwrap(), CliArgs::External(_) => { // TODO: Implement support for external commands. // Running `hugr COMMAND` would look for `hugr-COMMAND` in the path diff --git a/hugr-cli/src/mermaid.rs b/hugr-cli/src/mermaid.rs new file mode 100644 index 000000000..67f6c5031 --- /dev/null +++ b/hugr-cli/src/mermaid.rs @@ -0,0 +1,32 @@ +//! Render mermaid diagrams. +use std::io::Write; + +use clap::Parser; +use clio::Output; +use hugr_core::HugrView; + +/// Dump the standard extensions. +#[derive(Parser, Debug)] +#[clap(version = "1.0", long_about = None)] +#[clap(about = "Render mermaid diagrams..")] +#[group(id = "hugr")] +#[non_exhaustive] +pub struct MermaidArgs { + /// Common arguments + #[command(flatten)] + pub hugr_args: crate::HugrArgs, + /// Output file '-' for stdout + #[clap(long, short, value_parser, default_value = "-")] + output: Output, +} + +impl MermaidArgs { + /// Write the mermaid diagram to the output. + pub fn run_print(&mut self) -> Result<(), crate::CliError> { + let package = self.hugr_args.get_package()?; + for hugr in package.modules { + write!(self.output, "{}", hugr.mermaid_string())?; + } + Ok(()) + } +} diff --git a/hugr-cli/src/validate.rs b/hugr-cli/src/validate.rs index d73b22ece..62e8df83c 100644 --- a/hugr-cli/src/validate.rs +++ b/hugr-cli/src/validate.rs @@ -2,10 +2,10 @@ use clap::Parser; use clap_verbosity_flag::Level; -use hugr_core::{extension::ExtensionRegistry, Extension, Hugr, HugrView as _}; +use hugr_core::{extension::ExtensionRegistry, Extension, Hugr}; use thiserror::Error; -use crate::HugrArgs; +use crate::{CliError, HugrArgs, Package}; /// Validate and visualise a HUGR file. #[derive(Parser, Debug)] @@ -17,24 +17,12 @@ pub struct ValArgs { #[command(flatten)] /// common arguments pub hugr_args: HugrArgs, - /// Visualise with mermaid. - #[arg(short, long, value_name = "MERMAID", help = "Visualise with mermaid.")] - pub mermaid: bool, - /// Skip validation. - #[arg(short, long, help = "Skip validation.")] - pub no_validate: bool, } /// Error type for the CLI. #[derive(Error, Debug)] #[non_exhaustive] -pub enum CliError { - /// Error reading input. - #[error("Error reading from path: {0}")] - InputFile(#[from] std::io::Error), - /// Error parsing input. - #[error("Error parsing input: {0}")] - Parse(#[from] serde_json::Error), +pub enum ValError { /// Error validating HUGR. #[error("Error validating HUGR: {0}")] Validate(#[from] hugr_core::hugr::ValidationError), @@ -43,45 +31,16 @@ pub enum CliError { ExtReg(#[from] hugr_core::extension::ExtensionRegistryError), } -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -/// Package of module HUGRs and extensions. -/// The HUGRs are validated against the extensions. -pub struct Package { - /// Module HUGRs included in the package. - pub modules: Vec, - /// Extensions to validate against. - pub extensions: Vec, -} - -impl Package { - /// Create a new package. - pub fn new(modules: Vec, extensions: Vec) -> Self { - Self { - modules, - extensions, - } - } -} - /// String to print when validation is successful. pub const VALID_PRINT: &str = "HUGR valid!"; impl ValArgs { /// Run the HUGR cli and validate against an extension registry. pub fn run(&mut self) -> Result, CliError> { - // let rdr = self.input. - let val: serde_json::Value = serde_json::from_reader(&mut self.hugr_args.input)?; - // read either a package or a single hugr - let (mut modules, packed_exts) = if let Ok(Package { - modules, - extensions, - }) = serde_json::from_value::(val.clone()) - { - (modules, extensions) - } else { - let hugr: Hugr = serde_json::from_value(val)?; - (vec![hugr], vec![]) - }; + let Package { + mut modules, + extensions: packed_exts, + } = self.hugr_args.get_package()?; let mut reg: ExtensionRegistry = if self.hugr_args.no_std { hugr_core::extension::PRELUDE_REGISTRY.to_owned() @@ -91,26 +50,20 @@ impl ValArgs { // register packed extensions for ext in packed_exts { - reg.register_updated(ext)?; + reg.register_updated(ext).map_err(ValError::ExtReg)?; } // register external extensions for ext in &self.hugr_args.extensions { let f = std::fs::File::open(ext)?; let ext: Extension = serde_json::from_reader(f)?; - reg.register_updated(ext)?; + reg.register_updated(ext).map_err(ValError::ExtReg)?; } for hugr in modules.iter_mut() { - if self.mermaid { - println!("{}", hugr.mermaid_string()); - } - - if !self.no_validate { - hugr.update_validate(®)?; - if self.verbosity(Level::Info) { - eprintln!("{}", VALID_PRINT); - } + hugr.update_validate(®).map_err(ValError::Validate)?; + if self.verbosity(Level::Info) { + eprintln!("{}", VALID_PRINT); } } Ok(modules) diff --git a/hugr-cli/tests/validate.rs b/hugr-cli/tests/validate.rs index 9f667a858..844a966ee 100644 --- a/hugr-cli/tests/validate.rs +++ b/hugr-cli/tests/validate.rs @@ -6,7 +6,7 @@ use assert_cmd::Command; use assert_fs::{fixture::FileWriteStr, NamedTempFile}; -use hugr_cli::validate::{Package, VALID_PRINT}; +use hugr_cli::{validate::VALID_PRINT, Package}; use hugr_core::builder::DFGBuilder; use hugr_core::types::Type; use hugr_core::{ @@ -22,7 +22,11 @@ use rstest::{fixture, rstest}; #[fixture] fn cmd() -> Command { - let mut cmd = Command::cargo_bin("hugr").unwrap(); + Command::cargo_bin("hugr").unwrap() +} + +#[fixture] +fn val_cmd(mut cmd: Command) -> Command { cmd.arg("validate"); cmd } @@ -48,86 +52,92 @@ fn test_hugr_file(test_hugr_string: String) -> NamedTempFile { } #[rstest] -fn test_doesnt_exist(mut cmd: Command) { - cmd.arg("foobar"); - cmd.assert() +fn test_doesnt_exist(mut val_cmd: Command) { + val_cmd.arg("foobar"); + val_cmd + .assert() .failure() .stderr(contains("No such file or directory")); } #[rstest] -fn test_validate(test_hugr_file: NamedTempFile, mut cmd: Command) { - cmd.arg(test_hugr_file.path()); - cmd.assert().success().stderr(contains(VALID_PRINT)); +fn test_validate(test_hugr_file: NamedTempFile, mut val_cmd: Command) { + val_cmd.arg(test_hugr_file.path()); + val_cmd.assert().success().stderr(contains(VALID_PRINT)); } #[rstest] -fn test_stdin(test_hugr_string: String, mut cmd: Command) { - cmd.write_stdin(test_hugr_string); - cmd.arg("-"); +fn test_stdin(test_hugr_string: String, mut val_cmd: Command) { + val_cmd.write_stdin(test_hugr_string); + val_cmd.arg("-"); - cmd.assert().success().stderr(contains(VALID_PRINT)); + val_cmd.assert().success().stderr(contains(VALID_PRINT)); } #[rstest] -fn test_stdin_silent(test_hugr_string: String, mut cmd: Command) { - cmd.args(["-", "-q"]); - cmd.write_stdin(test_hugr_string); +fn test_stdin_silent(test_hugr_string: String, mut val_cmd: Command) { + val_cmd.args(["-", "-q"]); + val_cmd.write_stdin(test_hugr_string); - cmd.assert().success().stderr(contains(VALID_PRINT).not()); + val_cmd + .assert() + .success() + .stderr(contains(VALID_PRINT).not()); } #[rstest] fn test_mermaid(test_hugr_file: NamedTempFile, mut cmd: Command) { const MERMAID: &str = "graph LR\n subgraph 0 [\"(0) DFG\"]"; + cmd.arg("mermaid"); cmd.arg(test_hugr_file.path()); - cmd.arg("--mermaid"); - cmd.arg("--no-validate"); cmd.assert().success().stdout(contains(MERMAID)); } #[rstest] -fn test_bad_hugr(mut cmd: Command) { +fn test_bad_hugr(mut val_cmd: Command) { let df = DFGBuilder::new(Signature::new_endo(type_row![QB_T])).unwrap(); let bad_hugr = df.hugr().clone(); let bad_hugr_string = serde_json::to_string(&bad_hugr).unwrap(); - cmd.write_stdin(bad_hugr_string); - cmd.arg("-"); + val_cmd.write_stdin(bad_hugr_string); + val_cmd.arg("-"); - cmd.assert() + val_cmd + .assert() .failure() .stderr(contains("Error validating HUGR").and(contains("unconnected port"))); } #[rstest] -fn test_bad_json(mut cmd: Command) { - cmd.write_stdin(r#"{"foo": "bar"}"#); - cmd.arg("-"); +fn test_bad_json(mut val_cmd: Command) { + val_cmd.write_stdin(r#"{"foo": "bar"}"#); + val_cmd.arg("-"); - cmd.assert() + val_cmd + .assert() .failure() .stderr(contains("Error parsing input")); } #[rstest] -fn test_bad_json_silent(mut cmd: Command) { - cmd.write_stdin(r#"{"foo": "bar"}"#); - cmd.args(["-", "-qqq"]); +fn test_bad_json_silent(mut val_cmd: Command) { + val_cmd.write_stdin(r#"{"foo": "bar"}"#); + val_cmd.args(["-", "-qqq"]); - cmd.assert() + val_cmd + .assert() .failure() .stderr(contains("Error parsing input").not()); } #[rstest] -fn test_no_std(test_hugr_string: String, mut cmd: Command) { - cmd.write_stdin(test_hugr_string); - cmd.arg("-"); - cmd.arg("--no-std"); +fn test_no_std(test_hugr_string: String, mut val_cmd: Command) { + val_cmd.write_stdin(test_hugr_string); + val_cmd.arg("-"); + val_cmd.arg("--no-std"); // test hugr doesn't have any standard extensions, so this should succceed - cmd.assert().success().stderr(contains(VALID_PRINT)); + val_cmd.assert().success().stderr(contains(VALID_PRINT)); } #[fixture] @@ -136,12 +146,13 @@ fn float_hugr_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String { } #[rstest] -fn test_no_std_fail(float_hugr_string: String, mut cmd: Command) { - cmd.write_stdin(float_hugr_string); - cmd.arg("-"); - cmd.arg("--no-std"); +fn test_no_std_fail(float_hugr_string: String, mut val_cmd: Command) { + val_cmd.write_stdin(float_hugr_string); + val_cmd.arg("-"); + val_cmd.arg("--no-std"); - cmd.assert() + val_cmd + .assert() .failure() .stderr(contains(" Extension 'arithmetic.float.types' not found")); } @@ -153,14 +164,14 @@ const FLOAT_EXT_FILE: &str = concat!( ); #[rstest] -fn test_float_extension(float_hugr_string: String, mut cmd: Command) { - cmd.write_stdin(float_hugr_string); - cmd.arg("-"); - cmd.arg("--no-std"); - cmd.arg("--extensions"); - cmd.arg(FLOAT_EXT_FILE); +fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) { + val_cmd.write_stdin(float_hugr_string); + val_cmd.arg("-"); + val_cmd.arg("--no-std"); + val_cmd.arg("--extensions"); + val_cmd.arg(FLOAT_EXT_FILE); - cmd.assert().success().stderr(contains(VALID_PRINT)); + val_cmd.assert().success().stderr(contains(VALID_PRINT)); } #[fixture] fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String { @@ -171,11 +182,11 @@ fn package_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String { } #[rstest] -fn test_package(package_string: String, mut cmd: Command) { +fn test_package(package_string: String, mut val_cmd: Command) { // package with float extension and hugr that uses floats can validate - cmd.write_stdin(package_string); - cmd.arg("-"); - cmd.arg("--no-std"); + val_cmd.write_stdin(package_string); + val_cmd.arg("-"); + val_cmd.arg("--no-std"); - cmd.assert().success().stderr(contains(VALID_PRINT)); + val_cmd.assert().success().stderr(contains(VALID_PRINT)); }