Skip to content

Commit

Permalink
WIP: segmentation masks, lots of refactoring needed
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Conos committed Sep 1, 2024
1 parent a1f495f commit 12c68d7
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 31 deletions.
31 changes: 24 additions & 7 deletions examples/predict/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use tch::TchError;
use tch::{TchError, Tensor};
use yolo_v8::{Image, YoloV8Classifier, YoloV8ObjectDetection, YoloV8Segmentation};

fn object_detection(path: &str) {
Expand Down Expand Up @@ -30,19 +30,36 @@ fn image_classification(path: &str) {
println!("classes={:?}", classes);
}

fn image_segmentation() {
let image = Image::new("images/test.jpg", YoloV8Segmentation::input_dimension());
fn image_segmentation(path: &str) {
let image = Image::new(path, YoloV8Segmentation::input_dimension());

// Load exported torchscript for object detection
let yolo = YoloV8Segmentation::new();

let classes = yolo.predict(&image);
let segmentation = yolo.predict(&image, 0.25, 0.7);
println!("segmentation={:?}", segmentation);
let mut mask_no = 0;
for seg in segmentation {
let mask = seg.mask.reshape([-1]);
let name = seg.segbox.name;
let mut rgb = Vec::new();
let mut vec = Vec::<f64>::try_from(&mask).unwrap();
rgb.append(&mut vec.clone());
rgb.append(&mut vec.clone());
rgb.append(&mut vec);
let im = Tensor::from_slice(&rgb)
.reshape([3, 160, 160])
.g_mul_scalar(255.);
let imgname = format!("mask-{name}-{mask_no}.jpg");
tch::vision::image::save(&im, imgname).expect("can't save image");
mask_no += 1;
}
}

// YOLOv8n for object detection in image
fn main() -> Result<(), TchError> {
object_detection("images/katri.jpg");
// image_classification("images/katri.jpg");
// image_segmentation();
// object_detection("images/bus.jpg");
// image_classification("images/bus.jpg");
image_segmentation("images/test.jpg");
Ok(())
}
196 changes: 172 additions & 24 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ pub struct BBox {
pub name: &'static str,
}

#[derive(Debug, Clone, Copy)]
pub struct SegBBox {
pub xmin: f32,
pub ymin: f32,
pub xmax: f32,
pub ymax: f32,
pub conf: f32,
pub cls: usize,
pub cls_weight: [f32; 32],
pub name: &'static str,
}

#[derive(Debug)]
pub struct SegmentationResult {
pub segbox: SegBBox,
pub mask: Tensor,
}

#[derive(Debug)]
pub struct ClassConfidence {
pub name: &'static str,
Expand All @@ -44,7 +62,7 @@ pub struct YoloV8Classifier {
impl YoloV8Classifier {
pub fn new() -> Self {
Self {
yolo: YOLOv8::new("models/yolov8x-cls.torchscript").expect("can't load model"),
yolo: YOLOv8::new("models/yolov8n-cls.torchscript").expect("can't load model"),
}
}

Expand Down Expand Up @@ -104,17 +122,6 @@ impl YoloV8ObjectDetection {
self.non_max_suppression(image, &pred.get(0), conf_thresh, iou_thresh)
}

fn iou(&self, b1: &BBox, b2: &BBox) -> f64 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
let i_xmax = b1.xmax.min(b2.xmax);
let i_ymin = b1.ymin.max(b2.ymin);
let i_ymax = b1.ymax.min(b2.ymax);
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
i_area / (b1_area + b2_area - i_area)
}

fn non_max_suppression(
&self,
image: &Image,
Expand Down Expand Up @@ -187,7 +194,7 @@ impl YoloV8ObjectDetection {
for index in 0..bboxes_for_class.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = self.iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
let iou = YOLOv8::iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);

if iou > iou_thresh {
drop = true;
Expand Down Expand Up @@ -225,8 +232,14 @@ impl YoloV8Segmentation {
}
}

pub fn predict(&self, image: &Image) {
pub fn predict(
&self,
image: &Image,
conf_threshold: f32,
iou_threshold: f32,
) -> Vec<SegmentationResult> {
let img = &image.scaled_image;
let mut result = Vec::new();

// println!("img={:?}", img);

Expand All @@ -239,21 +252,133 @@ impl YoloV8Segmentation {
let t = tch::IValue::Tensor(img);
let pred = self.yolo.model.forward_is(&[t]).unwrap();
println!("pred={:?}", pred);

// https://github.com/ultralytics/ultralytics/issues/2953
if let IValue::Tuple(iv) = pred {
let mut segboxes = Vec::new();
if let IValue::Tensor(bboxes) = &iv[0] {
let t = bboxes.get(0);
println!("bboxes={:?}", t);
segboxes = self.non_max_suppression(image, &t, conf_threshold, iou_threshold);
println!("r={:?}", segboxes);
}

if let IValue::Tensor(seg) = &iv[1] {
let t = seg.get(0);
println!("seg={:?}", t);
let (nclass, w, h) = t.size3().unwrap();
for i in 0..nclass {
let img = t.get(i);
let mut vec: Vec<f32> = vec![0.0; (img.size()[0] * img.size()[1]) as usize];
let l = vec.len();
img.copy_data(&mut vec, l);
println!("i={i}, v={:?}", vec);
for segbox in segboxes {
let weights = Tensor::from_slice(&segbox.cls_weight).reshape([1, 32]);
println!("weights={:?}", weights);

let t = seg.get(0).reshape([32, 160 * 160]);
println!("seg={:?}", t);
let mask = weights.matmul(&t).reshape([1, 160, 160]).gt_(0.0);
println!("r={}", mask);
result.push(SegmentationResult { segbox, mask });
}
}
}
result
}

fn non_max_suppression(
&self,
image: &Image,
prediction: &tch::Tensor,
conf_thresh: f32,
iou_thresh: f32,
) -> Vec<SegBBox> {
let prediction = prediction.transpose(1, 0);
let (anchors, classes_no) = prediction.size2().unwrap();

let nclasses = (classes_no - 4) as usize;
println!("classes_no={classes_no}, anchors={anchors}");

let mut bboxes: Vec<Vec<SegBBox>> = (0..nclasses).map(|_| vec![]).collect();

for index in 0..anchors {
let pred = Vec::<f32>::try_from(prediction.get(index)).expect("wrong type of tensor");

// println!("index={index}, pred={}", pred.len());

//FIXME
let weights: [f32; 32] = pred[84..116].try_into().expect("cccc");

for i in 4..84 as usize {
let confidence = pred[i];
if confidence > conf_thresh {
let class_index = i - 4;
// println!(
// "confidence={confidence}, class_index={class_index} class_name={}",
// CLASSES[class_index]
// );

let (_, orig_h, orig_w) = image.image.size3().unwrap();
let (_, sh, sw) = image.scaled_image.size3().unwrap();
let cx = sw as f32 / 2.0;
let cy = sh as f32 / 2.0;
let mut dx = pred[0] - cx;
let mut dy = pred[1] - cy;
let mut w = pred[2];
let mut h = pred[3];

let aspect = orig_w as f32 / orig_h as f32;

if orig_w > orig_h {
dy *= aspect;
h *= aspect;
} else {
dx /= aspect;
w /= aspect;
}

let x = cx + dx;
let y = cy + dy;

let bbox = SegBBox {
xmin: x - w / 2.,
ymin: y - h / 2.,
xmax: x + w / 2.,
ymax: y + h / 2.,
conf: confidence,
cls: class_index,
name: DETECT_CLASSES[class_index],
cls_weight: weights,
};
bboxes[class_index].push(bbox)
}
}
}

for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.conf.partial_cmp(&b1.conf).unwrap());

let mut current_index = 0;
for index in 0..bboxes_for_class.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou =
YOLOv8::iou_seg(&bboxes_for_class[prev_index], &bboxes_for_class[index]);

if iou > iou_thresh {
drop = true;
break;
}
}
if !drop {
bboxes_for_class.swap(current_index, index);
current_index += 1;
}
}
bboxes_for_class.truncate(current_index);
}

let mut result = vec![];

for bboxes_for_class in bboxes.iter() {
for bbox in bboxes_for_class.iter() {
result.push(*bbox);
}
}

return result;
}

pub fn input_dimension() -> (i64, i64) {
Expand Down Expand Up @@ -290,6 +415,29 @@ impl YOLOv8 {

pred
}

fn iou(b1: &BBox, b2: &BBox) -> f64 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
let i_xmax = b1.xmax.min(b2.xmax);
let i_ymin = b1.ymin.max(b2.ymin);
let i_ymax = b1.ymax.min(b2.ymax);
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
i_area / (b1_area + b2_area - i_area)
}

//FIXME !!!
fn iou_seg(b1: &SegBBox, b2: &SegBBox) -> f32 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
let i_xmax = b1.xmax.min(b2.xmax);
let i_ymin = b1.ymin.max(b2.ymin);
let i_ymax = b1.ymax.min(b2.ymax);
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
i_area / (b1_area + b2_area - i_area)
}
}

pub struct Image {
Expand Down
15 changes: 15 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,18 @@ fn square(size: i32, w: i32, h: i32) -> (i32, i32) {
(tw, th)
}
}

#[cfg(test)]
mod test {
use tch::Tensor;

#[test]
fn matmul() {
let a = Tensor::from_slice(&[1, 1]).reshape([1, 2]);
let b = Tensor::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]).reshape([2, 4]);
println!("a={}", a);
println!("b={}", b);
let c = a.matmul(&b);
println!("c={}", c);
}
}

0 comments on commit 12c68d7

Please sign in to comment.