From fb66d76e9ad2a54b26e25998bb331eb7059be891 Mon Sep 17 00:00:00 2001 From: evilsocket Date: Tue, 29 Oct 2024 12:00:30 +0100 Subject: [PATCH] new: refactored to use trait object --- build.rs | 2 +- src/cli/graph.rs | 18 +- src/cli/inspect.rs | 23 +- src/cli/signing.rs | 34 +- src/core/docker/inspection.rs | 4 +- src/core/gguf/inspect.rs | 125 ------ src/core/gguf/mod.rs | 35 -- src/core/handlers/gguf/mod.rs | 168 ++++++++ src/core/handlers/mod.rs | 160 ++++++++ src/core/handlers/onnx/mod.rs | 372 ++++++++++++++++++ src/core/{ => handlers}/onnx/protos/mod.rs | 0 .../{ => handlers}/onnx/protos/onnx.proto | 0 .../{ => handlers}/pytorch/inspect.Dockerfile | 0 src/core/{ => handlers}/pytorch/inspect.py | 0 .../pytorch/inspect.requirements | 0 src/core/handlers/pytorch/mod.rs | 108 +++++ src/core/handlers/safetensors/mod.rs | 235 +++++++++++ src/core/mod.rs | 6 +- src/core/onnx/graph.rs | 92 ----- src/core/onnx/inspect.rs | 206 ---------- src/core/onnx/mod.rs | 68 ---- src/core/pytorch/mod.rs | 82 ---- src/core/safetensors/mod.rs | 157 -------- 23 files changed, 1064 insertions(+), 831 deletions(-) delete mode 100644 src/core/gguf/inspect.rs delete mode 100644 src/core/gguf/mod.rs create mode 100644 src/core/handlers/gguf/mod.rs create mode 100644 src/core/handlers/mod.rs create mode 100644 src/core/handlers/onnx/mod.rs rename src/core/{ => handlers}/onnx/protos/mod.rs (100%) rename src/core/{ => handlers}/onnx/protos/onnx.proto (100%) rename src/core/{ => handlers}/pytorch/inspect.Dockerfile (100%) rename src/core/{ => handlers}/pytorch/inspect.py (100%) rename src/core/{ => handlers}/pytorch/inspect.requirements (100%) create mode 100644 src/core/handlers/pytorch/mod.rs create mode 100644 src/core/handlers/safetensors/mod.rs delete mode 100644 src/core/onnx/graph.rs delete mode 100644 src/core/onnx/inspect.rs delete mode 100644 src/core/onnx/mod.rs delete mode 100644 src/core/pytorch/mod.rs delete mode 100644 src/core/safetensors/mod.rs diff --git a/build.rs b/build.rs index 0e08101..4609614 100644 --- a/build.rs +++ b/build.rs @@ -3,7 +3,7 @@ fn main() { protobuf_codegen::Codegen::new() .pure() .includes(["src"]) - .input("src/core/onnx/protos/onnx.proto") + .input("src/core/handlers/onnx/protos/onnx.proto") .cargo_out_dir("onnx-protos") .run_from_script(); } diff --git a/src/cli/graph.rs b/src/cli/graph.rs index 1ffc128..1b11040 100644 --- a/src/cli/graph.rs +++ b/src/cli/graph.rs @@ -1,4 +1,4 @@ -use crate::core::FileType; +use crate::core::handlers::Scope; use super::GraphArgs; @@ -9,18 +9,6 @@ pub(crate) fn graph(args: GraphArgs) -> anyhow::Result<()> { args.output.display() ); - let forced_format = args.format.unwrap_or(FileType::Unknown); - let file_ext = args - .file_path - .extension() - .unwrap_or_default() - .to_str() - .unwrap_or("") - .to_ascii_lowercase(); - - if !forced_format.is_onnx() && file_ext != "onnx" { - anyhow::bail!("this format does not embed graph information"); - } - - crate::core::onnx::create_graph(args.file_path, args.output) + crate::core::handlers::handler_for(args.format, &args.file_path, Scope::Inspection)? + .create_graph(&args.file_path, &args.output) } diff --git a/src/cli/inspect.rs b/src/cli/inspect.rs index 3b90517..54d2eba 100644 --- a/src/cli/inspect.rs +++ b/src/cli/inspect.rs @@ -1,12 +1,16 @@ -use crate::core::FileType; +use crate::core::handlers::Scope; use super::InspectArgs; pub(crate) fn inspect(args: InspectArgs) -> anyhow::Result<()> { + let handler = + crate::core::handlers::handler_for(args.format, &args.file_path, Scope::Inspection)?; + if !args.quiet { println!( - "Inspecting {:?} (detail={:?}{}):\n", + "Inspecting {:?} (format={}, detail={:?}{}):\n", args.file_path, + handler.file_type(), args.detail, args.filter .as_ref() @@ -15,20 +19,7 @@ pub(crate) fn inspect(args: InspectArgs) -> anyhow::Result<()> { ); } - let forced_format = args.format.unwrap_or(FileType::Unknown); - let inspection = if forced_format.is_safetensors() - || crate::core::safetensors::is_safetensors(&args.file_path) - { - crate::core::safetensors::inspect(args.file_path, args.detail, args.filter)? - } else if forced_format.is_onnx() || crate::core::onnx::is_onnx(&args.file_path) { - crate::core::onnx::inspect(args.file_path, args.detail, args.filter)? - } else if forced_format.is_gguf() || crate::core::gguf::is_gguf(&args.file_path) { - crate::core::gguf::inspect(args.file_path, args.detail, args.filter)? - } else if forced_format.is_pytorch() || crate::core::pytorch::is_pytorch(&args.file_path) { - crate::core::pytorch::inspect(args.file_path, args.detail, args.filter)? - } else { - anyhow::bail!("unsupported file format") - }; + let inspection = handler.inspect(&args.file_path, args.detail, args.filter)?; if !args.quiet { println!("file type: {}", inspection.file_type); diff --git a/src/cli/signing.rs b/src/cli/signing.rs index 6e11604..9755414 100644 --- a/src/cli/signing.rs +++ b/src/cli/signing.rs @@ -1,4 +1,4 @@ -use crate::core::{signing::Manifest, FileType}; +use crate::core::{handlers::Scope, signing::Manifest}; use super::{CreateKeyArgs, SignArgs, VerifyArgs}; @@ -8,19 +8,9 @@ pub(crate) fn create_key(args: CreateKeyArgs) -> anyhow::Result<()> { pub(crate) fn sign(args: SignArgs) -> anyhow::Result<()> { let signing_key = crate::core::signing::load_key(&args.key_path)?; - - let forced_format = args.format.unwrap_or(FileType::Unknown); - let mut paths_to_sign = if forced_format.is_safetensors() - || crate::core::safetensors::is_safetensors(&args.file_path) - || crate::core::safetensors::is_safetensors_index(&args.file_path) - { - crate::core::safetensors::paths_to_sign(&args.file_path)? - } else if forced_format.is_onnx() || crate::core::onnx::is_onnx(&args.file_path) { - crate::core::onnx::paths_to_sign(&args.file_path)? - } else if forced_format.is_gguf() || crate::core::gguf::is_gguf(&args.file_path) { - crate::core::gguf::paths_to_sign(&args.file_path)? - } else if forced_format.is_pytorch() || crate::core::pytorch::is_pytorch(&args.file_path) { - crate::core::pytorch::paths_to_sign(&args.file_path)? + let handler = crate::core::handlers::handler_for(args.format, &args.file_path, Scope::Signing); + let mut paths_to_sign = if let Ok(handler) = handler { + handler.paths_to_sign(&args.file_path)? } else { println!("Warning: Unrecognized file format. Signing this file does not ensure that the model data will be signed in its entirety."); vec![args.file_path.clone()] @@ -65,22 +55,12 @@ pub(crate) fn verify(args: VerifyArgs) -> anyhow::Result<()> { let raw = std::fs::read_to_string(&manifest_path)?; let ref_manifest: Manifest = serde_json::from_str(&raw)?; - let forced_format = args.format.unwrap_or(FileType::Unknown); - let raw = std::fs::read(&args.key_path)?; let mut manifest = Manifest::for_verifying(raw); - let mut paths_to_verify = if forced_format.is_safetensors() - || crate::core::safetensors::is_safetensors(&args.file_path) - || crate::core::safetensors::is_safetensors_index(&args.file_path) - { - crate::core::safetensors::paths_to_sign(&args.file_path)? - } else if forced_format.is_onnx() || crate::core::onnx::is_onnx(&args.file_path) { - crate::core::onnx::paths_to_sign(&args.file_path)? - } else if forced_format.is_gguf() || crate::core::gguf::is_gguf(&args.file_path) { - crate::core::gguf::paths_to_sign(&args.file_path)? - } else if forced_format.is_pytorch() || crate::core::pytorch::is_pytorch(&args.file_path) { - crate::core::pytorch::paths_to_sign(&args.file_path)? + let handler = crate::core::handlers::handler_for(args.format, &args.file_path, Scope::Signing); + let mut paths_to_verify = if let Ok(handler) = handler { + handler.paths_to_sign(&args.file_path)? } else { println!("Warning: Unrecognized file format. Signing this file does not ensure that the model data will be signed in its entirety."); vec![args.file_path.clone()] diff --git a/src/core/docker/inspection.rs b/src/core/docker/inspection.rs index dcf1ba7..dc4e2d0 100644 --- a/src/core/docker/inspection.rs +++ b/src/core/docker/inspection.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +use std::path::Path; use blake2::{Blake2b512, Digest}; @@ -50,7 +50,7 @@ impl Inspector { pub fn run( &self, - file_path: PathBuf, + file_path: &Path, additional_files: Vec, detail: DetailLevel, filter: Option, diff --git a/src/core/gguf/inspect.rs b/src/core/gguf/inspect.rs deleted file mode 100644 index b8714a6..0000000 --- a/src/core/gguf/inspect.rs +++ /dev/null @@ -1,125 +0,0 @@ -use std::{ - collections::HashSet, - path::{Path, PathBuf}, -}; - -use gguf::GGUFTensorInfo; -use rayon::prelude::*; - -use crate::{ - cli::DetailLevel, - core::{FileType, Inspection, Metadata, TensorDescriptor}, -}; - -use super::data_type_bits; - -pub(crate) fn is_gguf(file_path: &Path) -> bool { - file_path - .extension() - .unwrap_or_default() - .to_str() - .unwrap_or("") - .to_ascii_lowercase() - == "gguf" -} - -fn build_tensor_descriptor(t_info: &GGUFTensorInfo) -> TensorDescriptor { - TensorDescriptor { - id: Some(t_info.name.to_string()), - shape: t_info.dimensions.iter().map(|d| *d as usize).collect(), - dtype: format!("{:?}", t_info.tensor_type), - size: if t_info.dimensions.is_empty() { - 0 - } else { - (data_type_bits(t_info.tensor_type) - * t_info - .dimensions - .iter() - .map(|d| *d as usize) - .product::()) - / 8 - }, - metadata: Metadata::new(), - } -} - -pub(crate) fn inspect( - file_path: PathBuf, - detail: DetailLevel, - filter: Option, -) -> anyhow::Result { - let mut inspection = Inspection::default(); - - let file = std::fs::File::open(&file_path)?; - let buffer = unsafe { - memmap2::MmapOptions::new() - .map(&file) - .unwrap_or_else(|_| panic!("failed to map file {}", file_path.display())) - }; - - inspection.file_path = file_path.canonicalize()?; - inspection.file_size = file.metadata()?.len(); - - let gguf = gguf::GGUFFile::read(&buffer) - .map_err(|e| anyhow::anyhow!("failed to read GGUF file: {}", e))? - .unwrap_or_else(|| panic!("failed to read GGUF file {}", file_path.display())); - - inspection.file_type = FileType::GGUF; - inspection.version = format!("{}", gguf.header.version); - inspection.num_tensors = gguf.header.tensor_count as usize; - inspection.unique_shapes = gguf - .tensors - .par_iter() - .map(|t| t.dimensions.iter().map(|d| *d as usize).collect::>()) - .filter(|shape| !shape.is_empty()) - .collect::>() - .into_iter() - .collect(); - - // sort shapes by volume - inspection.unique_shapes.sort_by(|a, b| { - let size_a: usize = a.iter().product(); - let size_b: usize = b.iter().product(); - size_a.cmp(&size_b) - }); - - inspection.unique_dtypes = gguf - .tensors - .par_iter() - .map(|t| format!("{:?}", t.tensor_type)) - .collect::>() - .into_iter() - .collect(); - - inspection.data_size = gguf - .tensors - .par_iter() - .map(|t| { - if t.dimensions.is_empty() { - 0 - } else { - data_type_bits(t.tensor_type) - * t.dimensions.iter().map(|d| *d as usize).product::() - } - }) - .sum::() - / 8; - - for meta in &gguf.header.metadata { - inspection - .metadata - .insert(meta.key.clone(), format!("{:?}", meta.value)); - } - - if matches!(detail, DetailLevel::Full) { - inspection.tensors = Some( - gguf.tensors - .par_iter() - .filter(|t_info| filter.as_ref().map_or(true, |f| t_info.name.contains(f))) - .map(build_tensor_descriptor) - .collect(), - ); - } - - Ok(inspection) -} diff --git a/src/core/gguf/mod.rs b/src/core/gguf/mod.rs deleted file mode 100644 index 739dc11..0000000 --- a/src/core/gguf/mod.rs +++ /dev/null @@ -1,35 +0,0 @@ -mod inspect; - -use std::path::{Path, PathBuf}; - -use gguf::GGMLType; -pub(crate) use inspect::*; - -#[inline] -fn data_type_bits(dtype: GGMLType) -> usize { - match dtype { - GGMLType::F32 => 32, - GGMLType::F16 => 16, - GGMLType::Q4_0 => 4, - GGMLType::Q4_1 => 4, - GGMLType::Q5_0 => 5, - GGMLType::Q5_1 => 5, - GGMLType::Q8_0 => 8, - GGMLType::Q8_1 => 8, - GGMLType::Q2K => 2, - GGMLType::Q3K => 3, - GGMLType::Q4K => 4, - GGMLType::Q5K => 5, - GGMLType::Q6K => 6, - GGMLType::Q8K => 8, - GGMLType::I8 => 8, - GGMLType::I16 => 16, - GGMLType::I32 => 32, - GGMLType::Count => 32, // Assuming Count is 32-bit, adjust if needed - } -} - -pub(crate) fn paths_to_sign(file_path: &Path) -> anyhow::Result> { - // GGUF are self contained - Ok(vec![file_path.to_path_buf()]) -} diff --git a/src/core/handlers/gguf/mod.rs b/src/core/handlers/gguf/mod.rs new file mode 100644 index 0000000..7450243 --- /dev/null +++ b/src/core/handlers/gguf/mod.rs @@ -0,0 +1,168 @@ +use std::{ + collections::HashSet, + path::{Path, PathBuf}, +}; + +use gguf::{GGMLType, GGUFTensorInfo}; +use rayon::prelude::*; + +use super::{Handler, Scope}; +use crate::{ + cli::DetailLevel, + core::{FileType, Inspection, Metadata, TensorDescriptor}, +}; + +#[inline] +fn data_type_bits(dtype: GGMLType) -> usize { + match dtype { + GGMLType::F32 => 32, + GGMLType::F16 => 16, + GGMLType::Q4_0 => 4, + GGMLType::Q4_1 => 4, + GGMLType::Q5_0 => 5, + GGMLType::Q5_1 => 5, + GGMLType::Q8_0 => 8, + GGMLType::Q8_1 => 8, + GGMLType::Q2K => 2, + GGMLType::Q3K => 3, + GGMLType::Q4K => 4, + GGMLType::Q5K => 5, + GGMLType::Q6K => 6, + GGMLType::Q8K => 8, + GGMLType::I8 => 8, + GGMLType::I16 => 16, + GGMLType::I32 => 32, + GGMLType::Count => 32, // Assuming Count is 32-bit, adjust if needed + } +} + +fn build_tensor_descriptor(t_info: &GGUFTensorInfo) -> TensorDescriptor { + TensorDescriptor { + id: Some(t_info.name.to_string()), + shape: t_info.dimensions.iter().map(|d| *d as usize).collect(), + dtype: format!("{:?}", t_info.tensor_type), + size: if t_info.dimensions.is_empty() { + 0 + } else { + (data_type_bits(t_info.tensor_type) + * t_info + .dimensions + .iter() + .map(|d| *d as usize) + .product::()) + / 8 + }, + metadata: Metadata::new(), + } +} + +pub(crate) struct GGUFHandler {} + +impl GGUFHandler { + pub(crate) fn new() -> Self { + Self {} + } +} + +impl Handler for GGUFHandler { + fn file_type(&self) -> FileType { + FileType::GGUF + } + + fn is_handler_for(&self, file_path: &Path, _scope: &Scope) -> bool { + file_path + .extension() + .unwrap_or_default() + .to_str() + .unwrap_or("") + .to_ascii_lowercase() + == "gguf" + } + + fn paths_to_sign(&self, file_path: &Path) -> anyhow::Result> { + // GGUF are self contained + Ok(vec![file_path.to_path_buf()]) + } + + fn inspect( + &self, + file_path: &Path, + detail: crate::cli::DetailLevel, + filter: Option, + ) -> anyhow::Result { + let mut inspection = Inspection::default(); + + let file = std::fs::File::open(file_path)?; + let buffer = unsafe { + memmap2::MmapOptions::new() + .map(&file) + .unwrap_or_else(|_| panic!("failed to map file {}", file_path.display())) + }; + + inspection.file_path = file_path.canonicalize()?; + inspection.file_size = file.metadata()?.len(); + + let gguf = gguf::GGUFFile::read(&buffer) + .map_err(|e| anyhow::anyhow!("failed to read GGUF file: {}", e))? + .unwrap_or_else(|| panic!("failed to read GGUF file {}", file_path.display())); + + inspection.file_type = FileType::GGUF; + inspection.version = format!("{}", gguf.header.version); + inspection.num_tensors = gguf.header.tensor_count as usize; + inspection.unique_shapes = gguf + .tensors + .par_iter() + .map(|t| t.dimensions.iter().map(|d| *d as usize).collect::>()) + .filter(|shape| !shape.is_empty()) + .collect::>() + .into_iter() + .collect(); + + // sort shapes by volume + inspection.unique_shapes.sort_by(|a, b| { + let size_a: usize = a.iter().product(); + let size_b: usize = b.iter().product(); + size_a.cmp(&size_b) + }); + + inspection.unique_dtypes = gguf + .tensors + .par_iter() + .map(|t| format!("{:?}", t.tensor_type)) + .collect::>() + .into_iter() + .collect(); + + inspection.data_size = gguf + .tensors + .par_iter() + .map(|t| { + if t.dimensions.is_empty() { + 0 + } else { + data_type_bits(t.tensor_type) + * t.dimensions.iter().map(|d| *d as usize).product::() + } + }) + .sum::() + / 8; + + for meta in &gguf.header.metadata { + inspection + .metadata + .insert(meta.key.clone(), format!("{:?}", meta.value)); + } + + if matches!(detail, DetailLevel::Full) { + inspection.tensors = Some( + gguf.tensors + .par_iter() + .filter(|t_info| filter.as_ref().map_or(true, |f| t_info.name.contains(f))) + .map(build_tensor_descriptor) + .collect(), + ); + } + + Ok(inspection) + } +} diff --git a/src/core/handlers/mod.rs b/src/core/handlers/mod.rs new file mode 100644 index 0000000..a7f2dff --- /dev/null +++ b/src/core/handlers/mod.rs @@ -0,0 +1,160 @@ +use std::path::{Path, PathBuf}; + +use crate::cli::DetailLevel; + +use super::{FileType, Inspection}; + +pub(crate) mod gguf; +pub(crate) mod onnx; +pub(crate) mod pytorch; +pub(crate) mod safetensors; + +pub(crate) enum Scope { + Inspection, + Signing, +} + +pub(crate) trait Handler { + fn file_type(&self) -> FileType; + + fn is_handler_for(&self, file_path: &Path, scope: &Scope) -> bool; + fn paths_to_sign(&self, file_path: &Path) -> anyhow::Result>; + fn inspect( + &self, + file_path: &Path, + detail: DetailLevel, + filter: Option, + ) -> anyhow::Result; + + fn create_graph(&self, _file_path: &Path, _output_path: &Path) -> anyhow::Result<()> { + Err(anyhow::anyhow!( + "graph generation not supported for this format" + )) + } +} + +pub(crate) fn handler_for( + format: Option, + file_path: &Path, + scope: Scope, +) -> anyhow::Result> { + let safetensors_handler = safetensors::SafeTensorsHandler::new(); + let onnx_handler = onnx::OnnxHandler::new(); + let gguf_handler = gguf::GGUFHandler::new(); + let pytorch_handler = pytorch::PyTorchHandler::new(); + + match &format { + None => { + if safetensors_handler.is_handler_for(file_path, &scope) { + Ok(Box::new(safetensors_handler)) + } else if onnx_handler.is_handler_for(file_path, &scope) { + Ok(Box::new(onnx_handler)) + } else if gguf_handler.is_handler_for(file_path, &scope) { + Ok(Box::new(gguf_handler)) + } else if pytorch_handler.is_handler_for(file_path, &scope) { + Ok(Box::new(pytorch_handler)) + } else { + anyhow::bail!("unsupported file format") + } + } + Some(forced_format) => { + if forced_format.is_safetensors() { + Ok(Box::new(safetensors_handler)) + } else if forced_format.is_onnx() { + Ok(Box::new(onnx_handler)) + } else if forced_format.is_gguf() { + Ok(Box::new(gguf_handler)) + } else if forced_format.is_pytorch() { + Ok(Box::new(pytorch_handler)) + } else { + anyhow::bail!("unsupported file format") + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::Path; + + #[test] + fn test_handler_for_with_forced_format() { + // Test each forced format + let path = Path::new("test.bin"); + let handler = handler_for(Some(FileType::SafeTensors), path, Scope::Inspection).unwrap(); + assert!(matches!(handler.file_type(), FileType::SafeTensors)); + + let handler = handler_for(Some(FileType::ONNX), path, Scope::Inspection).unwrap(); + assert!(matches!(handler.file_type(), FileType::ONNX)); + + let handler = handler_for(Some(FileType::GGUF), path, Scope::Inspection).unwrap(); + assert!(matches!(handler.file_type(), FileType::GGUF)); + + let handler = handler_for(Some(FileType::PyTorch), path, Scope::Inspection).unwrap(); + assert!(matches!(handler.file_type(), FileType::PyTorch)); + } + + #[test] + fn test_handler_for_with_file_extension() { + // Test auto-detection by file extension + let handler = handler_for(None, Path::new("model.safetensors"), Scope::Inspection).unwrap(); + assert!(matches!(handler.file_type(), FileType::SafeTensors)); + + let handler = handler_for(None, Path::new("model.onnx"), Scope::Inspection).unwrap(); + assert!(matches!(handler.file_type(), FileType::ONNX)); + + let handler = handler_for(None, Path::new("model.gguf"), Scope::Inspection).unwrap(); + assert!(matches!(handler.file_type(), FileType::GGUF)); + + let handler = handler_for(None, Path::new("model.pt"), Scope::Inspection).unwrap(); + assert!(matches!(handler.file_type(), FileType::PyTorch)); + } + + #[test] + fn test_handler_for_unknown_format() { + // Test handling of unknown format + let result = handler_for(None, Path::new("model.unknown"), Scope::Inspection); + assert!(result.is_err()); + } + + #[test] + fn test_handler_for_different_scopes() { + // Test that handlers work with different scopes + let handler = handler_for( + Some(FileType::SafeTensors), + Path::new("test.bin"), + Scope::Signing, + ) + .unwrap(); + assert!(matches!(handler.file_type(), FileType::SafeTensors)); + + let handler = handler_for( + Some(FileType::SafeTensors), + Path::new("test.bin"), + Scope::Inspection, + ) + .unwrap(); + assert!(matches!(handler.file_type(), FileType::SafeTensors)); + } + + #[test] + fn test_format_override() { + // Test that forced format overrides file extension + let handler = handler_for( + Some(FileType::ONNX), + Path::new("model.safetensors"), + Scope::Inspection, + ) + .unwrap(); + assert!(matches!(handler.file_type(), FileType::ONNX)); + + let handler = handler_for( + Some(FileType::GGUF), + Path::new("model.pt"), + Scope::Inspection, + ) + .unwrap(); + assert!(matches!(handler.file_type(), FileType::GGUF)); + } +} diff --git a/src/core/handlers/onnx/mod.rs b/src/core/handlers/onnx/mod.rs new file mode 100644 index 0000000..6e86f50 --- /dev/null +++ b/src/core/handlers/onnx/mod.rs @@ -0,0 +1,372 @@ +use std::{ + collections::{HashMap, HashSet}, + path::{Path, PathBuf}, +}; + +mod protos; + +use dot_graph::Graph; +use protobuf::Message; + +use protos::{tensor_proto::DataLocation, ModelProto, NodeProto, TensorProto}; +use rayon::prelude::*; + +use crate::{ + cli::DetailLevel, + core::{handlers::Handler, FileType, Inspection, Metadata, TensorDescriptor}, +}; + +use super::Scope; + +#[inline] +fn data_type_bits(dtype: i32) -> usize { + match dtype { + 1 => 32, // float + 2 => 8, // uint8_t + 3 => 8, // int8_t + 4 => 16, // uint16_t + 5 => 16, // int16_t + 6 => 32, // int32_t + 7 => 64, // int64_t + 8 => 8, // string (assuming 8 bits per character) + 9 => 8, // bool (typically 8 bits in most systems) + 10 => 16, // FLOAT16 + 11 => 64, // DOUBLE + 12 => 32, // UINT32 + 13 => 64, // UINT64 + 14 => 64, // COMPLEX64 (two 32-bit floats) + 15 => 128, // COMPLEX128 (two 64-bit floats) + 16 => 16, // BFLOAT16 + 17 => 8, // FLOAT8E4M3FN + 18 => 8, // FLOAT8E4M3FNUZ + 19 => 8, // FLOAT8E5M2 + 20 => 8, // FLOAT8E5M2FNUZ + 21 => 4, // UINT4 + 22 => 4, // INT4 + 23 => 4, // FLOAT4E2M1 + _ => panic!("Unsupported data type: {}", dtype), + } +} + +#[inline] +pub(crate) fn data_type_string(dtype: i32) -> &'static str { + match dtype { + 1 => "FLOAT", + 2 => "UINT8", + 3 => "INT8", + 4 => "UINT16", + 5 => "INT16", + 6 => "INT32", + 7 => "INT64", + 8 => "STRING", + 9 => "BOOL", + 10 => "FLOAT16", + 11 => "DOUBLE", + 12 => "UINT32", + 13 => "UINT64", + 14 => "COMPLEX64", + 15 => "COMPLEX128", + 16 => "BFLOAT16", + 17 => "FLOAT8E4M3FN", + 18 => "FLOAT8E4M3FNUZ", + 19 => "FLOAT8E5M2", + 20 => "FLOAT8E5M2FNUZ", + 21 => "UINT4", + 22 => "INT4", + 23 => "FLOAT4E2M1", + _ => "UNKNOWN", + } +} + +fn build_tensor_descriptor(tensor: &TensorProto) -> TensorDescriptor { + let mut metadata = Metadata::new(); + if !tensor.doc_string.is_empty() { + metadata.insert("doc_string".to_string(), tensor.doc_string.clone()); + } + + if tensor.data_location.value() == DataLocation::EXTERNAL as i32 { + metadata.insert("data_location".to_string(), "external".to_string()); + if let Some(external_data) = tensor.external_data.first() { + metadata.insert("location".to_string(), external_data.value.clone()); + } + } + + tensor.metadata_props.iter().for_each(|prop| { + metadata.insert(prop.key.clone(), prop.value.clone()); + }); + + TensorDescriptor { + id: Some(tensor.name.to_string()), + shape: tensor.dims.iter().map(|d| *d as usize).collect(), + dtype: data_type_string(tensor.data_type).to_string(), + size: if tensor.dims.is_empty() { + 0 + } else { + (data_type_bits(tensor.data_type) + * tensor.dims.iter().map(|d| *d as usize).product::()) + / 8 + }, + metadata, + } +} + +#[inline] +fn is_letter_or_underscore_or_dot(c: char) -> bool { + in_range('a', c, 'z') || in_range('A', c, 'Z') || c == '_' || c == '.' +} + +#[inline] +fn is_constituent(c: char) -> bool { + is_letter_or_underscore_or_dot(c) || in_range('0', c, '9') +} + +#[inline] +fn in_range(low: char, c: char, high: char) -> bool { + low as usize <= c as usize && c as usize <= high as usize +} + +fn str_to_node_name(s: &str) -> String { + let mut result = String::new(); + for c in s.chars() { + if is_constituent(c) { + result.push(c); + } else { + result.push('_'); + } + } + result.trim_matches('_').to_string() +} + +fn op_to_dot_node(op: &NodeProto, op_id: usize) -> dot_graph::Node { + let node_label = if !op.name.is_empty() { + format!("{}/{} (op#{})", op.name, op.op_type, op_id) + } else { + format!("{} (op#{})", op.op_type, op_id) + }; + let node_name = str_to_node_name(&node_label); + + dot_graph::Node::new(&node_name).label(&node_label) +} + +pub(crate) struct OnnxHandler; + +impl OnnxHandler { + pub(crate) fn new() -> Self { + Self + } +} + +impl Handler for OnnxHandler { + fn file_type(&self) -> FileType { + FileType::ONNX + } + + fn is_handler_for(&self, file_path: &Path, _scope: &Scope) -> bool { + file_path + .extension() + .unwrap_or_default() + .to_str() + .unwrap_or("") + .to_ascii_lowercase() + == "onnx" + } + + fn paths_to_sign(&self, file_path: &Path) -> anyhow::Result> { + let base_path = file_path + .parent() + .ok_or_else(|| anyhow::anyhow!("no parent path"))?; + let mut file = std::fs::File::open(file_path)?; + let onnx_model: ModelProto = Message::parse_from_reader(&mut file)?; + + // ONNX files can contain external data + let external_paths: HashSet = onnx_model + .graph + .initializer + .par_iter() + .filter(|t| t.data_location.value() == DataLocation::EXTERNAL as i32) + .filter_map(|t| { + t.external_data + .first() + .map(|data| PathBuf::from(&data.value)) + .map(|p| { + if p.is_relative() { + base_path.join(p) + } else { + p + } + }) + }) + .collect(); + + let mut paths = vec![file_path.to_path_buf()]; + paths.extend(external_paths); + + Ok(paths) + } + + fn inspect( + &self, + file_path: &Path, + detail: DetailLevel, + filter: Option, + ) -> anyhow::Result { + let mut inspection = Inspection::default(); + + let mut file = std::fs::File::open(file_path)?; + + inspection.file_path = file_path.canonicalize()?; + inspection.file_size = file.metadata()?.len(); + + let onnx_model: ModelProto = Message::parse_from_reader(&mut file)?; + + inspection.file_type = FileType::ONNX; + + if onnx_model.model_version != 0 { + inspection.version = format!( + "{} (IR v{})", + onnx_model.model_version, onnx_model.ir_version + ); + } else { + inspection.version = format!("IR v{}", onnx_model.ir_version); + } + + // TODO: check the presence of sparse tensors from graph.sparse_initializer + + inspection.num_tensors = onnx_model.graph.initializer.len(); + inspection.data_size = onnx_model + .graph + .initializer + .par_iter() + .map(|t| { + if t.dims.is_empty() { + 0 + } else { + data_type_bits(t.data_type) + * t.dims.iter().map(|d| *d as usize).product::() + } + }) + .sum::() + / 8; + + inspection.unique_shapes = onnx_model + .graph + .initializer + .par_iter() + .map(|t| t.dims.iter().map(|d| *d as usize).collect::>()) + .filter(|shape| !shape.is_empty()) + .collect::>() + .into_iter() + .collect(); + + // sort shapes by volume + inspection.unique_shapes.sort_by(|a, b| { + let size_a: usize = a.iter().product(); + let size_b: usize = b.iter().product(); + size_a.cmp(&size_b) + }); + + inspection.unique_dtypes = onnx_model + .graph + .initializer + .par_iter() + .map(|t| data_type_string(t.data_type).to_string()) + .collect::>() + .into_iter() + .collect(); + + if !onnx_model.producer_name.is_empty() { + inspection.metadata.insert( + "producer_name".to_string(), + onnx_model.producer_name.clone(), + ); + } + + if !onnx_model.producer_version.is_empty() { + inspection.metadata.insert( + "producer_version".to_string(), + onnx_model.producer_version.clone(), + ); + } + + if !onnx_model.domain.is_empty() { + inspection + .metadata + .insert("domain".to_string(), onnx_model.domain.clone()); + } + + if !onnx_model.doc_string.is_empty() { + inspection + .metadata + .insert("doc_string".to_string(), onnx_model.doc_string.clone()); + } + + onnx_model.metadata_props.iter().for_each(|prop| { + inspection + .metadata + .insert(prop.key.clone(), prop.value.clone()); + }); + + if matches!(detail, DetailLevel::Full) { + inspection.tensors = Some( + onnx_model + .graph + .initializer + .par_iter() + .filter(|t_info| filter.as_ref().map_or(true, |f| t_info.name.contains(f))) + .map(build_tensor_descriptor) + .collect(), + ); + } + + Ok(inspection) + } + + // adapted from https://github.com/onnx/onnx/blob/main/onnx/tools/net_drawer.py + fn create_graph(&self, file_path: &Path, output_path: &Path) -> anyhow::Result<()> { + let mut file = std::fs::File::open(file_path)?; + let onnx_model: ModelProto = Message::parse_from_reader(&mut file)?; + let mut dot_graph = Graph::new( + // make sure the name is quoted + &format!( + "{:?}", + file_path.file_stem().unwrap().to_string_lossy().as_ref() + ), + dot_graph::Kind::Digraph, + ); + let mut dot_nodes = HashMap::new(); + let mut dot_node_counts = HashMap::new(); + + for (op_id, op) in onnx_model.graph.node.iter().enumerate() { + let op_node = op_to_dot_node(op, op_id); + dot_graph.add_node(op_node.clone()); + for input_name in &op.input { + let input_node = dot_nodes.entry(input_name.clone()).or_insert_with(|| { + let count = dot_node_counts.entry(input_name.clone()).or_insert(0); + let node = dot_graph::Node::new(&str_to_node_name(&format!( + "{}{}", + input_name, count + ))); + node.label(input_name); + *count += 1; + node + }); + dot_graph.add_node(input_node.clone()); + dot_graph.add_edge(dot_graph::Edge::new(&input_node.name, &op_node.name, "")); + } + for output_name in &op.output { + let count = dot_node_counts.entry(output_name.clone()).or_insert(0); + let output_node = + dot_graph::Node::new(&str_to_node_name(&format!("{}{}", output_name, count))); + output_node.label(output_name); + dot_nodes.insert(output_name.clone(), output_node.clone()); + dot_graph.add_node(output_node.clone()); + dot_graph.add_edge(dot_graph::Edge::new(&op_node.name, &output_node.name, "")); + } + } + + let dot_string = dot_graph.to_dot_string()?; + + std::fs::write(output_path, dot_string) + .map_err(|e| anyhow::anyhow!("failed to write dot string to output path: {:?}", e)) + } +} diff --git a/src/core/onnx/protos/mod.rs b/src/core/handlers/onnx/protos/mod.rs similarity index 100% rename from src/core/onnx/protos/mod.rs rename to src/core/handlers/onnx/protos/mod.rs diff --git a/src/core/onnx/protos/onnx.proto b/src/core/handlers/onnx/protos/onnx.proto similarity index 100% rename from src/core/onnx/protos/onnx.proto rename to src/core/handlers/onnx/protos/onnx.proto diff --git a/src/core/pytorch/inspect.Dockerfile b/src/core/handlers/pytorch/inspect.Dockerfile similarity index 100% rename from src/core/pytorch/inspect.Dockerfile rename to src/core/handlers/pytorch/inspect.Dockerfile diff --git a/src/core/pytorch/inspect.py b/src/core/handlers/pytorch/inspect.py similarity index 100% rename from src/core/pytorch/inspect.py rename to src/core/handlers/pytorch/inspect.py diff --git a/src/core/pytorch/inspect.requirements b/src/core/handlers/pytorch/inspect.requirements similarity index 100% rename from src/core/pytorch/inspect.requirements rename to src/core/handlers/pytorch/inspect.requirements diff --git a/src/core/handlers/pytorch/mod.rs b/src/core/handlers/pytorch/mod.rs new file mode 100644 index 0000000..c17a51e --- /dev/null +++ b/src/core/handlers/pytorch/mod.rs @@ -0,0 +1,108 @@ +use std::path::{Path, PathBuf}; + +use crate::{ + cli::DetailLevel, + core::{docker, FileType, Inspection}, +}; + +use super::{Handler, Scope}; + +pub(crate) struct PyTorchHandler; + +impl PyTorchHandler { + pub(crate) fn new() -> Self { + Self + } +} + +impl Handler for PyTorchHandler { + fn file_type(&self) -> FileType { + FileType::PyTorch + } + + fn is_handler_for(&self, file_path: &Path, _scope: &Scope) -> bool { + let file_ext = file_path + .extension() + .unwrap_or_default() + .to_str() + .unwrap_or("") + .to_ascii_lowercase(); + + let file_name = file_path + .file_name() + .unwrap_or_default() + .to_str() + .unwrap_or_default() + .to_ascii_lowercase(); + + file_ext == "pt" + || file_ext == "pth" + || file_name.ends_with("pytorch_model.bin") + // cases like diffusion_pytorch_model.fp16.bin + || (file_name.contains("pytorch_model") && file_name.ends_with(".bin")) + } + + fn paths_to_sign(&self, file_path: &Path) -> anyhow::Result> { + // TODO: can a pytorch model reference external files? + Ok(vec![file_path.to_path_buf()]) + } + + fn inspect( + &self, + file_path: &Path, + detail: DetailLevel, + filter: Option, + ) -> anyhow::Result { + docker::Inspector::new( + include_str!("inspect.Dockerfile"), + include_str!("inspect.py"), + include_str!("inspect.requirements"), + ) + .run(file_path, vec![], detail, filter) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_pytorch() { + // Standard .pt extension + let handler = PyTorchHandler {}; + + assert!(handler.is_handler_for(Path::new("model.pt"), &Scope::Inspection)); + assert!(handler.is_handler_for(Path::new("path/to/model.pt"), &Scope::Inspection)); + assert!(handler.is_handler_for(Path::new("MODEL.PT"), &Scope::Inspection)); // Case insensitive + + // Standard .pth extension + assert!(handler.is_handler_for(Path::new("model.pth"), &Scope::Inspection)); + assert!(handler.is_handler_for(Path::new("path/to/model.pth"), &Scope::Inspection)); + assert!(handler.is_handler_for(Path::new("MODEL.PTH"), &Scope::Inspection)); // Case insensitive + + // Standard pytorch_model.bin filename + assert!(handler.is_handler_for(Path::new("pytorch_model.bin"), &Scope::Inspection)); + assert!(handler.is_handler_for(Path::new("path/to/pytorch_model.bin"), &Scope::Inspection)); + assert!(handler.is_handler_for(Path::new("PYTORCH_MODEL.BIN"), &Scope::Inspection)); // Case insensitive + + // Variants of pytorch_model.*.bin + assert!( + handler.is_handler_for(Path::new("diffusion_pytorch_model.bin"), &Scope::Inspection) + ); + assert!(handler.is_handler_for( + Path::new("diffusion_pytorch_model.fp16.bin"), + &Scope::Inspection + )); + assert!(handler.is_handler_for( + Path::new("text_encoder_pytorch_model.safetensors.bin"), + &Scope::Inspection + )); + + // Non-matching cases + assert!(!handler.is_handler_for(Path::new("model.onnx"), &Scope::Inspection)); + assert!(!handler.is_handler_for(Path::new("model.safetensors"), &Scope::Inspection)); + assert!(!handler.is_handler_for(Path::new("model.bin"), &Scope::Inspection)); // Just .bin isn't enough + assert!(!handler.is_handler_for(Path::new("pytorch.txt"), &Scope::Inspection)); + assert!(!handler.is_handler_for(Path::new(""), &Scope::Inspection)); + } +} diff --git a/src/core/handlers/safetensors/mod.rs b/src/core/handlers/safetensors/mod.rs new file mode 100644 index 0000000..c36f1e9 --- /dev/null +++ b/src/core/handlers/safetensors/mod.rs @@ -0,0 +1,235 @@ +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + path::{Path, PathBuf}, +}; + +use rayon::prelude::*; + +use safetensors::{tensor::TensorInfo, SafeTensors}; +use serde::Deserialize; + +use crate::{ + cli::DetailLevel, + core::{FileType, Inspection, Metadata, TensorDescriptor}, +}; + +use super::{Handler, Scope}; + +#[derive(Debug, Deserialize)] +struct TensorIndex { + weight_map: HashMap, +} + +pub(crate) struct SafeTensorsHandler; + +impl SafeTensorsHandler { + pub(crate) fn new() -> Self { + Self + } +} + +fn is_safetensors_index(file_path: &Path) -> bool { + file_path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .ends_with(".safetensors.index.json") +} + +fn build_tensor_descriptor(tensor_id: &str, tensor_info: &TensorInfo) -> TensorDescriptor { + TensorDescriptor { + id: Some(tensor_id.to_string()), + shape: tensor_info.shape.clone(), + dtype: format!("{:?}", &tensor_info.dtype), + size: tensor_info.data_offsets.1 - tensor_info.data_offsets.0, + metadata: Metadata::new(), + } +} + +impl Handler for SafeTensorsHandler { + fn file_type(&self) -> FileType { + FileType::SafeTensors + } + + fn is_handler_for(&self, file_path: &Path, scope: &Scope) -> bool { + let is_safetensors = file_path + .extension() + .unwrap_or_default() + .to_str() + .unwrap_or("") + .to_ascii_lowercase() + == "safetensors"; + + match scope { + // can only inspect safetensors files + Scope::Inspection => is_safetensors, + // can sign safetensors files directly or an index referencing multiple files + Scope::Signing => is_safetensors || is_safetensors_index(file_path), + } + } + + fn paths_to_sign(&self, file_path: &Path) -> anyhow::Result> { + if is_safetensors_index(file_path) { + // load unique paths from index + let base_path = file_path + .parent() + .ok_or_else(|| anyhow::anyhow!("no parent path"))?; + + let index = std::fs::read_to_string(file_path)?; + let index: TensorIndex = serde_json::from_str(&index)?; + + let unique: HashSet = index + .weight_map + .values() + .map(PathBuf::from) + .map(|p| { + if p.is_relative() { + base_path.join(p) + } else { + p + } + }) + .collect(); + + let mut paths = vec![file_path.to_path_buf()]; + paths.extend(unique); + Ok(paths) + } else { + // safetensors are self contained + Ok(vec![file_path.to_path_buf()]) + } + } + + fn inspect( + &self, + file_path: &Path, + detail: DetailLevel, + filter: Option, + ) -> anyhow::Result { + let mut inspection = Inspection::default(); + + let file = std::fs::File::open(file_path)?; + let buffer = unsafe { + memmap2::MmapOptions::new() + .map(&file) + .unwrap_or_else(|_| panic!("failed to map file {}", file_path.display())) + }; + + inspection.file_path = file_path.canonicalize()?; + inspection.file_size = file.metadata()?.len(); + + // read header + let (header_size, header) = SafeTensors::read_metadata(&buffer)?; + + inspection.file_type = FileType::SafeTensors; + inspection.header_size = header_size; + inspection.version = "0.x".to_string(); + + let tensors = header.tensors(); + + // transform tensors to a vector + let mut tensors: Vec<_> = tensors.into_iter().collect(); + + inspection.num_tensors = tensors.len(); + inspection.data_size = tensors + .par_iter() + .map(|t| t.1.data_offsets.1 - t.1.data_offsets.0) + .sum::(); + + inspection.unique_shapes = tensors + .par_iter() + .map(|t| t.1.shape.clone()) + .collect::>() + .into_iter() + .collect(); + // sort shapes by volume + inspection.unique_shapes.sort_by(|a, b| { + let size_a: usize = a.iter().product(); + let size_b: usize = b.iter().product(); + size_a.cmp(&size_b) + }); + + inspection.unique_dtypes = tensors + .par_iter() + .map(|t| format!("{:?}", t.1.dtype)) + .collect::>() + .into_iter() + .collect(); + + if let Some(block_metadata) = header.metadata() { + inspection.metadata = BTreeMap::from_iter( + block_metadata + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())), + ); + } + + if matches!(detail, DetailLevel::Full) { + // sort by offset + tensors.sort_by_key(|(_, info)| info.data_offsets.0); + + inspection.tensors = Some( + tensors + .par_iter() + .filter(|(tensor_id, _)| { + filter.as_ref().map_or(true, |f| tensor_id.contains(f)) + }) + .map(|(tensor_id, tensor_info)| build_tensor_descriptor(tensor_id, tensor_info)) + .collect(), + ); + } + + Ok(inspection) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_handler_for() { + let handler = SafeTensorsHandler::new(); + + // Test inspection scope + assert!(handler.is_handler_for(Path::new("model.safetensors"), &Scope::Inspection)); + assert!(handler.is_handler_for(Path::new("path/to/model.safetensors"), &Scope::Inspection)); + assert!(handler.is_handler_for(Path::new("MODEL.SAFETENSORS"), &Scope::Inspection)); // Case insensitive + + // Test signing scope + assert!(handler.is_handler_for(Path::new("model.safetensors"), &Scope::Signing)); + assert!(handler.is_handler_for(Path::new("path/to/model.safetensors"), &Scope::Signing)); + assert!(handler.is_handler_for(Path::new("MODEL.SAFETENSORS"), &Scope::Signing)); // Case insensitive + + // Test non-matching cases for both scopes + for scope in [Scope::Inspection, Scope::Signing] { + assert!(!handler.is_handler_for(Path::new("model.onnx"), &scope)); + assert!(!handler.is_handler_for(Path::new("model.pt"), &scope)); + assert!(!handler.is_handler_for(Path::new("model.bin"), &scope)); + assert!(!handler.is_handler_for(Path::new("safetensors.txt"), &scope)); + assert!(!handler.is_handler_for(Path::new(""), &scope)); + } + } + + #[test] + fn test_is_handler_for_index() { + let handler = SafeTensorsHandler::new(); + + // Index files should only be handled in signing scope + assert!(handler.is_handler_for(Path::new("model.safetensors.index.json"), &Scope::Signing)); + assert!(handler.is_handler_for( + Path::new("path/to/model.safetensors.index.json"), + &Scope::Signing + )); + + // Index files should not be handled in inspection scope + assert!(!handler.is_handler_for( + Path::new("model.safetensors.index.json"), + &Scope::Inspection + )); + assert!(!handler.is_handler_for( + Path::new("path/to/model.safetensors.index.json"), + &Scope::Inspection + )); + } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 7dca992..fe9c756 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -3,12 +3,8 @@ use std::{collections::BTreeMap, fmt, path::PathBuf}; use clap::ValueEnum; use serde::{Deserialize, Serialize}; -pub(crate) mod gguf; -pub(crate) mod onnx; -pub(crate) mod pytorch; -pub(crate) mod safetensors; - pub(crate) mod docker; +pub(crate) mod handlers; pub(crate) mod signing; pub(crate) type Metadata = BTreeMap; diff --git a/src/core/onnx/graph.rs b/src/core/onnx/graph.rs deleted file mode 100644 index 468aaf6..0000000 --- a/src/core/onnx/graph.rs +++ /dev/null @@ -1,92 +0,0 @@ -use std::{collections::HashMap, path::PathBuf}; - -use dot_graph::Graph; -use protobuf::Message; - -use super::protos::{ModelProto, NodeProto}; - -#[inline] -fn is_letter_or_underscore_or_dot(c: char) -> bool { - in_range('a', c, 'z') || in_range('A', c, 'Z') || c == '_' || c == '.' -} - -#[inline] -fn is_constituent(c: char) -> bool { - is_letter_or_underscore_or_dot(c) || in_range('0', c, '9') -} - -#[inline] -fn in_range(low: char, c: char, high: char) -> bool { - low as usize <= c as usize && c as usize <= high as usize -} - -fn str_to_node_name(s: &str) -> String { - let mut result = String::new(); - for c in s.chars() { - if is_constituent(c) { - result.push(c); - } else { - result.push('_'); - } - } - result.trim_matches('_').to_string() -} - -fn op_to_dot_node(op: &NodeProto, op_id: usize) -> dot_graph::Node { - let node_label = if !op.name.is_empty() { - format!("{}/{} (op#{})", op.name, op.op_type, op_id) - } else { - format!("{} (op#{})", op.op_type, op_id) - }; - let node_name = str_to_node_name(&node_label); - - dot_graph::Node::new(&node_name).label(&node_label) -} - -// adapted from https://github.com/onnx/onnx/blob/main/onnx/tools/net_drawer.py - -pub(crate) fn create_graph(file_path: PathBuf, output_path: PathBuf) -> anyhow::Result<()> { - let mut file = std::fs::File::open(&file_path)?; - let onnx_model: ModelProto = Message::parse_from_reader(&mut file)?; - let mut dot_graph = Graph::new( - // make sure the name is quoted - &format!( - "{:?}", - file_path.file_stem().unwrap().to_string_lossy().as_ref() - ), - dot_graph::Kind::Digraph, - ); - let mut dot_nodes = HashMap::new(); - let mut dot_node_counts = HashMap::new(); - - for (op_id, op) in onnx_model.graph.node.iter().enumerate() { - let op_node = op_to_dot_node(op, op_id); - dot_graph.add_node(op_node.clone()); - for input_name in &op.input { - let input_node = dot_nodes.entry(input_name.clone()).or_insert_with(|| { - let count = dot_node_counts.entry(input_name.clone()).or_insert(0); - let node = - dot_graph::Node::new(&str_to_node_name(&format!("{}{}", input_name, count))); - node.label(input_name); - *count += 1; - node - }); - dot_graph.add_node(input_node.clone()); - dot_graph.add_edge(dot_graph::Edge::new(&input_node.name, &op_node.name, "")); - } - for output_name in &op.output { - let count = dot_node_counts.entry(output_name.clone()).or_insert(0); - let output_node = - dot_graph::Node::new(&str_to_node_name(&format!("{}{}", output_name, count))); - output_node.label(output_name); - dot_nodes.insert(output_name.clone(), output_node.clone()); - dot_graph.add_node(output_node.clone()); - dot_graph.add_edge(dot_graph::Edge::new(&op_node.name, &output_node.name, "")); - } - } - - let dot_string = dot_graph.to_dot_string()?; - - std::fs::write(&output_path, dot_string) - .map_err(|e| anyhow::anyhow!("failed to write dot string to output path: {:?}", e)) -} diff --git a/src/core/onnx/inspect.rs b/src/core/onnx/inspect.rs deleted file mode 100644 index b8a0bb3..0000000 --- a/src/core/onnx/inspect.rs +++ /dev/null @@ -1,206 +0,0 @@ -use std::{ - collections::HashSet, - path::{Path, PathBuf}, -}; - -use protobuf::Message; -use protos::{tensor_proto::DataLocation, ModelProto}; -use rayon::prelude::*; - -use crate::{cli::DetailLevel, core::Metadata}; - -use super::{ - data_type_bits, data_type_string, - protos::{self, TensorProto}, - FileType, Inspection, TensorDescriptor, -}; - -pub(crate) fn is_onnx(file_path: &Path) -> bool { - file_path - .extension() - .unwrap_or_default() - .to_str() - .unwrap_or("") - .to_ascii_lowercase() - == "onnx" -} - -pub(crate) fn paths_to_sign(file_path: &PathBuf) -> anyhow::Result> { - let base_path = file_path - .parent() - .ok_or_else(|| anyhow::anyhow!("no parent path"))?; - let mut file = std::fs::File::open(file_path)?; - let onnx_model: ModelProto = Message::parse_from_reader(&mut file)?; - - // ONNX files can contain external data - let external_paths: HashSet = onnx_model - .graph - .initializer - .par_iter() - .filter(|t| t.data_location.value() == DataLocation::EXTERNAL as i32) - .filter_map(|t| { - t.external_data - .first() - .map(|data| PathBuf::from(&data.value)) - .map(|p| { - if p.is_relative() { - base_path.join(p) - } else { - p - } - }) - }) - .collect(); - - let mut paths = vec![file_path.clone()]; - paths.extend(external_paths); - - Ok(paths) -} - -fn build_tensor_descriptor(tensor: &TensorProto) -> TensorDescriptor { - let mut metadata = Metadata::new(); - if !tensor.doc_string.is_empty() { - metadata.insert("doc_string".to_string(), tensor.doc_string.clone()); - } - - if tensor.data_location.value() == DataLocation::EXTERNAL as i32 { - metadata.insert("data_location".to_string(), "external".to_string()); - if let Some(external_data) = tensor.external_data.first() { - metadata.insert("location".to_string(), external_data.value.clone()); - } - } - - tensor.metadata_props.iter().for_each(|prop| { - metadata.insert(prop.key.clone(), prop.value.clone()); - }); - - TensorDescriptor { - id: Some(tensor.name.to_string()), - shape: tensor.dims.iter().map(|d| *d as usize).collect(), - dtype: data_type_string(tensor.data_type).to_string(), - size: if tensor.dims.is_empty() { - 0 - } else { - (data_type_bits(tensor.data_type) - * tensor.dims.iter().map(|d| *d as usize).product::()) - / 8 - }, - metadata, - } -} - -pub(crate) fn inspect( - file_path: PathBuf, - detail: DetailLevel, - filter: Option, -) -> anyhow::Result { - let mut inspection = Inspection::default(); - - let mut file = std::fs::File::open(&file_path)?; - - inspection.file_path = file_path.canonicalize()?; - inspection.file_size = file.metadata()?.len(); - - let onnx_model: ModelProto = Message::parse_from_reader(&mut file)?; - - inspection.file_type = FileType::ONNX; - - if onnx_model.model_version != 0 { - inspection.version = format!( - "{} (IR v{})", - onnx_model.model_version, onnx_model.ir_version - ); - } else { - inspection.version = format!("IR v{}", onnx_model.ir_version); - } - - // TODO: check the presence of sparse tensors from graph.sparse_initializer - - inspection.num_tensors = onnx_model.graph.initializer.len(); - inspection.data_size = onnx_model - .graph - .initializer - .par_iter() - .map(|t| { - if t.dims.is_empty() { - 0 - } else { - data_type_bits(t.data_type) * t.dims.iter().map(|d| *d as usize).product::() - } - }) - .sum::() - / 8; - - inspection.unique_shapes = onnx_model - .graph - .initializer - .par_iter() - .map(|t| t.dims.iter().map(|d| *d as usize).collect::>()) - .filter(|shape| !shape.is_empty()) - .collect::>() - .into_iter() - .collect(); - - // sort shapes by volume - inspection.unique_shapes.sort_by(|a, b| { - let size_a: usize = a.iter().product(); - let size_b: usize = b.iter().product(); - size_a.cmp(&size_b) - }); - - inspection.unique_dtypes = onnx_model - .graph - .initializer - .par_iter() - .map(|t| data_type_string(t.data_type).to_string()) - .collect::>() - .into_iter() - .collect(); - - if !onnx_model.producer_name.is_empty() { - inspection.metadata.insert( - "producer_name".to_string(), - onnx_model.producer_name.clone(), - ); - } - - if !onnx_model.producer_version.is_empty() { - inspection.metadata.insert( - "producer_version".to_string(), - onnx_model.producer_version.clone(), - ); - } - - if !onnx_model.domain.is_empty() { - inspection - .metadata - .insert("domain".to_string(), onnx_model.domain.clone()); - } - - if !onnx_model.doc_string.is_empty() { - inspection - .metadata - .insert("doc_string".to_string(), onnx_model.doc_string.clone()); - } - - onnx_model.metadata_props.iter().for_each(|prop| { - inspection - .metadata - .insert(prop.key.clone(), prop.value.clone()); - }); - - if matches!(detail, DetailLevel::Full) { - inspection.tensors = Some( - onnx_model - .graph - .initializer - .par_iter() - .filter(|t_info| filter.as_ref().map_or(true, |f| t_info.name.contains(f))) - .map(build_tensor_descriptor) - .collect(), - ); - } - - Ok(inspection) -} diff --git a/src/core/onnx/mod.rs b/src/core/onnx/mod.rs deleted file mode 100644 index 4275d93..0000000 --- a/src/core/onnx/mod.rs +++ /dev/null @@ -1,68 +0,0 @@ -use super::{FileType, Inspection, TensorDescriptor}; - -mod graph; -mod inspect; -mod protos; - -pub(crate) use graph::*; -pub(crate) use inspect::*; - -#[inline] -fn data_type_bits(dtype: i32) -> usize { - match dtype { - 1 => 32, // float - 2 => 8, // uint8_t - 3 => 8, // int8_t - 4 => 16, // uint16_t - 5 => 16, // int16_t - 6 => 32, // int32_t - 7 => 64, // int64_t - 8 => 8, // string (assuming 8 bits per character) - 9 => 8, // bool (typically 8 bits in most systems) - 10 => 16, // FLOAT16 - 11 => 64, // DOUBLE - 12 => 32, // UINT32 - 13 => 64, // UINT64 - 14 => 64, // COMPLEX64 (two 32-bit floats) - 15 => 128, // COMPLEX128 (two 64-bit floats) - 16 => 16, // BFLOAT16 - 17 => 8, // FLOAT8E4M3FN - 18 => 8, // FLOAT8E4M3FNUZ - 19 => 8, // FLOAT8E5M2 - 20 => 8, // FLOAT8E5M2FNUZ - 21 => 4, // UINT4 - 22 => 4, // INT4 - 23 => 4, // FLOAT4E2M1 - _ => panic!("Unsupported data type: {}", dtype), - } -} - -#[inline] -pub(crate) fn data_type_string(dtype: i32) -> &'static str { - match dtype { - 1 => "FLOAT", - 2 => "UINT8", - 3 => "INT8", - 4 => "UINT16", - 5 => "INT16", - 6 => "INT32", - 7 => "INT64", - 8 => "STRING", - 9 => "BOOL", - 10 => "FLOAT16", - 11 => "DOUBLE", - 12 => "UINT32", - 13 => "UINT64", - 14 => "COMPLEX64", - 15 => "COMPLEX128", - 16 => "BFLOAT16", - 17 => "FLOAT8E4M3FN", - 18 => "FLOAT8E4M3FNUZ", - 19 => "FLOAT8E5M2", - 20 => "FLOAT8E5M2FNUZ", - 21 => "UINT4", - 22 => "INT4", - 23 => "FLOAT4E2M1", - _ => "UNKNOWN", - } -} diff --git a/src/core/pytorch/mod.rs b/src/core/pytorch/mod.rs deleted file mode 100644 index e8b7876..0000000 --- a/src/core/pytorch/mod.rs +++ /dev/null @@ -1,82 +0,0 @@ -use std::path::{Path, PathBuf}; - -use crate::cli::DetailLevel; - -use super::{docker, Inspection}; - -pub(crate) fn is_pytorch(file_path: &Path) -> bool { - let file_ext = file_path - .extension() - .unwrap_or_default() - .to_str() - .unwrap_or("") - .to_ascii_lowercase(); - - let file_name = file_path - .file_name() - .unwrap_or_default() - .to_str() - .unwrap_or_default() - .to_ascii_lowercase(); - - file_ext == "pt" - || file_ext == "pth" - || file_name.ends_with("pytorch_model.bin") - // cases like diffusion_pytorch_model.fp16.bin - || (file_name.contains("pytorch_model") && file_name.ends_with(".bin")) -} - -pub(crate) fn paths_to_sign(file_path: &Path) -> anyhow::Result> { - // TODO: can a pytorch model reference external files? - Ok(vec![file_path.to_path_buf()]) -} - -pub(crate) fn inspect( - file_path: PathBuf, - detail: DetailLevel, - filter: Option, -) -> anyhow::Result { - docker::Inspector::new( - include_str!("inspect.Dockerfile"), - include_str!("inspect.py"), - include_str!("inspect.requirements"), - ) - .run(file_path, vec![], detail, filter) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_is_pytorch() { - // Standard .pt extension - assert!(is_pytorch(Path::new("model.pt"))); - assert!(is_pytorch(Path::new("path/to/model.pt"))); - assert!(is_pytorch(Path::new("MODEL.PT"))); // Case insensitive - - // Standard .pth extension - assert!(is_pytorch(Path::new("model.pth"))); - assert!(is_pytorch(Path::new("path/to/model.pth"))); - assert!(is_pytorch(Path::new("MODEL.PTH"))); // Case insensitive - - // Standard pytorch_model.bin filename - assert!(is_pytorch(Path::new("pytorch_model.bin"))); - assert!(is_pytorch(Path::new("path/to/pytorch_model.bin"))); - assert!(is_pytorch(Path::new("PYTORCH_MODEL.BIN"))); // Case insensitive - - // Variants of pytorch_model.*.bin - assert!(is_pytorch(Path::new("diffusion_pytorch_model.bin"))); - assert!(is_pytorch(Path::new("diffusion_pytorch_model.fp16.bin"))); - assert!(is_pytorch(Path::new( - "text_encoder_pytorch_model.safetensors.bin" - ))); - - // Non-matching cases - assert!(!is_pytorch(Path::new("model.onnx"))); - assert!(!is_pytorch(Path::new("model.safetensors"))); - assert!(!is_pytorch(Path::new("model.bin"))); // Just .bin isn't enough - assert!(!is_pytorch(Path::new("pytorch.txt"))); - assert!(!is_pytorch(Path::new(""))); - } -} diff --git a/src/core/safetensors/mod.rs b/src/core/safetensors/mod.rs deleted file mode 100644 index 0d2f0df..0000000 --- a/src/core/safetensors/mod.rs +++ /dev/null @@ -1,157 +0,0 @@ -use std::{ - collections::{BTreeMap, HashMap, HashSet}, - path::{Path, PathBuf}, -}; - -use rayon::prelude::*; - -use safetensors::{tensor::TensorInfo, SafeTensors}; -use serde::Deserialize; - -use crate::{cli::DetailLevel, core::TensorDescriptor}; - -use super::{FileType, Inspection}; - -#[derive(Debug, Deserialize)] -struct TensorIndex { - weight_map: HashMap, -} - -pub(crate) fn is_safetensors(file_path: &Path) -> bool { - file_path - .extension() - .unwrap_or_default() - .to_str() - .unwrap_or("") - .to_ascii_lowercase() - == "safetensors" -} - -pub(crate) fn is_safetensors_index(file_path: &Path) -> bool { - file_path - .file_name() - .unwrap() - .to_string_lossy() - .ends_with(".safetensors.index.json") -} - -pub(crate) fn paths_to_sign(file_path: &Path) -> anyhow::Result> { - if is_safetensors_index(file_path) { - // load unique paths from index - let base_path = file_path - .parent() - .ok_or_else(|| anyhow::anyhow!("no parent path"))?; - - let index = std::fs::read_to_string(file_path)?; - let index: TensorIndex = serde_json::from_str(&index)?; - - let unique: HashSet = index - .weight_map - .values() - .map(PathBuf::from) - .map(|p| { - if p.is_relative() { - base_path.join(p) - } else { - p - } - }) - .collect(); - - let mut paths = vec![file_path.to_path_buf()]; - paths.extend(unique); - Ok(paths) - } else { - // safetensors are self contained - Ok(vec![file_path.to_path_buf()]) - } -} - -fn build_tensor_descriptor(tensor_id: &str, tensor_info: &TensorInfo) -> TensorDescriptor { - TensorDescriptor { - id: Some(tensor_id.to_string()), - shape: tensor_info.shape.clone(), - dtype: format!("{:?}", &tensor_info.dtype), - size: tensor_info.data_offsets.1 - tensor_info.data_offsets.0, - metadata: super::Metadata::new(), - } -} - -pub(crate) fn inspect( - file_path: PathBuf, - detail: DetailLevel, - filter: Option, -) -> anyhow::Result { - let mut inspection = Inspection::default(); - - let file = std::fs::File::open(&file_path)?; - let buffer = unsafe { - memmap2::MmapOptions::new() - .map(&file) - .unwrap_or_else(|_| panic!("failed to map file {}", file_path.display())) - }; - - inspection.file_path = file_path.canonicalize()?; - inspection.file_size = file.metadata()?.len(); - - // read header - let (header_size, header) = SafeTensors::read_metadata(&buffer)?; - - inspection.file_type = FileType::SafeTensors; - inspection.header_size = header_size; - inspection.version = "0.x".to_string(); - - let tensors = header.tensors(); - - // transform tensors to a vector - let mut tensors: Vec<_> = tensors.into_iter().collect(); - - inspection.num_tensors = tensors.len(); - inspection.data_size = tensors - .par_iter() - .map(|t| t.1.data_offsets.1 - t.1.data_offsets.0) - .sum::(); - - inspection.unique_shapes = tensors - .par_iter() - .map(|t| t.1.shape.clone()) - .collect::>() - .into_iter() - .collect(); - // sort shapes by volume - inspection.unique_shapes.sort_by(|a, b| { - let size_a: usize = a.iter().product(); - let size_b: usize = b.iter().product(); - size_a.cmp(&size_b) - }); - - inspection.unique_dtypes = tensors - .par_iter() - .map(|t| format!("{:?}", t.1.dtype)) - .collect::>() - .into_iter() - .collect(); - - if let Some(block_metadata) = header.metadata() { - inspection.metadata = BTreeMap::from_iter( - block_metadata - .iter() - .map(|(k, v)| (k.to_string(), v.to_string())), - ); - } - - if matches!(detail, DetailLevel::Full) { - // sort by offset - tensors.sort_by_key(|(_, info)| info.data_offsets.0); - - inspection.tensors = Some( - tensors - .par_iter() - .filter(|(tensor_id, _)| filter.as_ref().map_or(true, |f| tensor_id.contains(f))) - .map(|(tensor_id, tensor_info)| build_tensor_descriptor(tensor_id, tensor_info)) - .collect(), - ); - } - - Ok(inspection) -}