diff --git a/Cargo.toml b/Cargo.toml index 6c2c565..a8d0afb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "usls" -version = "0.0.17" +version = "0.0.18" edition = "2021" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index 5ae4178..96c51c0 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -160,6 +160,8 @@ fn main() -> Result<()> { // .with_names(&COCO_CLASS_NAMES_80) // .with_names2(&COCO_KEYPOINTS_17) .with_find_contours(!args.no_contours) // find contours or not + .exclude_classes(&[0]) + // .retain_classes(&[0, 5]) .with_profile(args.profile); // build model diff --git a/src/core/options.rs b/src/core/options.rs index 5ea92ea..4e906b5 100644 --- a/src/core/options.rs +++ b/src/core/options.rs @@ -48,6 +48,8 @@ pub struct Options { pub sam_kind: Option, pub use_low_res_mask: Option, pub sapiens_task: Option, + pub classes_excluded: Vec, + pub classes_retained: Vec, } impl Default for Options { @@ -88,6 +90,8 @@ impl Default for Options { use_low_res_mask: None, sapiens_task: None, task: Task::Untitled, + classes_excluded: vec![], + classes_retained: vec![], } } } @@ -276,4 +280,16 @@ impl Options { self.iiixs.push(Iiix::from((i, ii, x))); self } + + pub fn exclude_classes(mut self, xs: &[isize]) -> Self { + self.classes_retained.clear(); + self.classes_excluded.extend_from_slice(xs); + self + } + + pub fn retain_classes(mut self, xs: &[isize]) -> Self { + self.classes_excluded.clear(); + self.classes_retained.extend_from_slice(xs); + self + } } diff --git a/src/models/yolo.rs b/src/models/yolo.rs index 546e35e..3288950 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo.rs @@ -26,6 +26,8 @@ pub struct YOLO { layout: YOLOPreds, find_contours: bool, version: Option, + classes_excluded: Vec, + classes_retained: Vec, } impl Vision for YOLO { @@ -157,6 +159,10 @@ impl Vision for YOLO { let kconfs = DynConf::new(&options.kconfs, nk); let iou = options.iou.unwrap_or(0.45); + // Classes excluded and retained + let classes_excluded = options.classes_excluded; + let classes_retained = options.classes_retained; + // Summary tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version); @@ -179,6 +185,8 @@ impl Vision for YOLO { layout, version, find_contours: options.find_contours, + classes_excluded, + classes_retained, }) } @@ -276,7 +284,19 @@ impl Vision for YOLO { } }; - // filtering + // filtering by class id + if !self.classes_excluded.is_empty() + && self.classes_excluded.contains(&(class_id as isize)) + { + return None; + } + if !self.classes_retained.is_empty() + && !self.classes_retained.contains(&(class_id as isize)) + { + return None; + } + + // filtering by conf if confidence < self.confs[class_id] { return None; }