diff --git a/src/core/pytorch/mod.rs b/src/core/pytorch/mod.rs index a6abb61..e8b7876 100644 --- a/src/core/pytorch/mod.rs +++ b/src/core/pytorch/mod.rs @@ -20,6 +20,7 @@ pub(crate) fn is_pytorch(file_path: &Path) -> bool { .to_ascii_lowercase(); file_ext == "pt" + || file_ext == "pth" || file_name.ends_with("pytorch_model.bin") // cases like diffusion_pytorch_model.fp16.bin || (file_name.contains("pytorch_model") && file_name.ends_with(".bin")) @@ -54,6 +55,11 @@ mod tests { assert!(is_pytorch(Path::new("path/to/model.pt"))); assert!(is_pytorch(Path::new("MODEL.PT"))); // Case insensitive + // Standard .pth extension + assert!(is_pytorch(Path::new("model.pth"))); + assert!(is_pytorch(Path::new("path/to/model.pth"))); + assert!(is_pytorch(Path::new("MODEL.PTH"))); // Case insensitive + // Standard pytorch_model.bin filename assert!(is_pytorch(Path::new("pytorch_model.bin"))); assert!(is_pytorch(Path::new("path/to/pytorch_model.bin")));