Skip to content

Commit

Permalink
feat: fully implemented fish segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ccrutchf committed May 28, 2024
1 parent 78a384f commit a61d4f0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 30 deletions.
32 changes: 16 additions & 16 deletions src/cv_convsersion.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
use std::ffi::c_void;
// use std::ffi::c_void;

use opencv::core::{Mat, CV_8UC1};
use ndarray::Array2;
// use opencv::core::{Mat, CV_8UC1};
// use ndarray::Array2;

pub fn as_mat_u8c1_mut(array: &Array2<u8>) -> Result<Mat, opencv::Error> {
// Array must be contiguous and in the standard row-major layout, or the
// conversion to a `Mat` will produce a corrupted result
assert!(array.is_standard_layout());
// pub fn as_mat_u8c1_mut(array: &Array2<u8>) -> Result<Mat, opencv::Error> {
// // Array must be contiguous and in the standard row-major layout, or the
// // conversion to a `Mat` will produce a corrupted result
// assert!(array.is_standard_layout());

let (height, width) = array.dim();
let array_clone = array.clone();
unsafe { Mat::new_rows_cols_with_data_unsafe_def(
height as i32,
width as i32,
CV_8UC1,
array_clone.into_raw_vec().as_ptr() as *mut c_void,
) }
}
// let (height, width) = array.dim();
// let array_clone = array.clone();
// unsafe { Mat::new_rows_cols_with_data_unsafe_def(
// height as i32,
// width as i32,
// CV_8UC1,
// array_clone.into_raw_vec().as_ptr() as *mut c_void,
// ) }
// }
65 changes: 51 additions & 14 deletions src/fish_segmentation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::cv_convsersion::{self, as_mat_u8c1_mut};
// use crate::cv_convsersion::{self, as_mat_u8c1_mut};

use std::cmp::{max, min};
use std::fs::{File, create_dir};
Expand All @@ -9,7 +9,7 @@ use app_dirs2::{AppDataType, AppInfo, app_root};
use bytes::Bytes;
use image::RgbImage;
use image::imageops::{resize, FilterType};
use ndarray::{array, s, stack, Array, Array2, Array3, Array4, ArrayBase, ArrayView2, ArrayViewD, Axis, Dim, IxDynImpl, OwnedRepr};
use ndarray::{array, s, stack, Array, Array2, Array3, Array4, ArrayBase, ArrayView2, ArrayViewD, Axis, Dim, IxDynImpl, OwnedRepr, ViewRepr};
use ndarray_npy::NpzWriter;
use opencv::core::{Mat, MatTraitConstManual, Point2i, Size, VectorToVec, CV_8UC1, CV_8UC3};
use opencv::imgproc::{fill_poly, find_contours_with_hierarchy, resize_def, CHAIN_APPROX_NONE, LINE_8, RETR_CCOMP};
Expand All @@ -18,6 +18,11 @@ use ort::Session;
use reqwest::blocking::get;
use cv_convert::{FromCv, IntoCv, TryFromCv, TryIntoCv};

fn write<T>(arr: &ArrayBase<OwnedRepr<T>, Dim<IxDynImpl>>) where T : ndarray_npy::WritableElement {
let mut npz = NpzWriter::new(File::create("../outputs/rust.npz").unwrap());
npz.add_array("arr", &arr).unwrap();
npz.finish().unwrap();
}

fn write2f(arr: &Array2<f32>) {
let mut npz = NpzWriter::new(File::create("../outputs/rust.npz").unwrap());
Expand Down Expand Up @@ -73,6 +78,26 @@ fn write4u(arr: &Array4<u8>) {
npz.finish().unwrap();
}

fn write_vec_vec_p(vec: &Vec<VectorOfPoint>) {
let mut npz = NpzWriter::new(File::create("../outputs/rust.npz").unwrap());

let mut idx = -1;
for e in vec {
idx += 1;
let mut arr = Array2::<i32>::zeros((e.len() as usize, 2));
let mut p_idx = 0;
for p in e {
arr[[p_idx, 0]] = p.x;
arr[[p_idx, 1]] = p.y;
p_idx += 1;

}
npz.add_array(idx.to_string().as_str(), &arr).unwrap();
}

npz.finish().unwrap();
}



#[derive(Debug)]
Expand Down Expand Up @@ -365,10 +390,10 @@ impl FishSegmentation {
output
}

fn do_paste_mask(&self, masks: &Array4<f32>, img_h: u32, img_w: u32) -> Result<Array2<f32>, SegmentationError> {
let masks_squeezed = masks.clone().remove_axis(Axis(3));
fn do_paste_mask(&self, masks: &Array2<f32>, img_h: u32, img_w: u32) -> Result<Array2<f32>, SegmentationError> {
let masks_unsqueezed = masks.clone().insert_axis(Axis(2));

let conversion_result: Result<Mat, _> = masks_squeezed.try_into_cv();
let conversion_result: Result<Mat, _> = masks_unsqueezed.try_into_cv();
match conversion_result {
Ok(masks_cv) => {
let mut resized_cv = Mat::default();
Expand Down Expand Up @@ -418,7 +443,6 @@ impl FishSegmentation {
// // write4f(&grid);

// let resized_mask = self.grid_sample(masks, &grid);
// write2f(&resized_mask);

// Ok(resized_mask)
// },
Expand All @@ -430,7 +454,9 @@ impl FishSegmentation {
}

fn bitmap_to_polygon(&self, bitmap: &Array2<u8>) -> Result<Vec<VectorOfPoint>, SegmentationError> {
match as_mat_u8c1_mut(bitmap) {
let bitmap3 = bitmap.clone().insert_axis(Axis(2));
let conversion_result: Result<Mat, _> = bitmap3.try_into_cv();
match conversion_result {
Ok(bitmap_cv) => {
let mut contours_cv = VectorOfVectorOfPoint::new();
let mut hierarchy_cv = VectorOfVec4i::new();
Expand All @@ -442,6 +468,9 @@ impl FishSegmentation {
// cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points.
match find_contours_with_hierarchy(&bitmap_cv, &mut contours_cv, &mut hierarchy_cv, RETR_CCOMP, CHAIN_APPROX_NONE, Point2i::new(0, 0)) {
Ok(_) => {
let c_len = contours_cv.len();
let h_len = hierarchy_cv.len();

if hierarchy_cv.is_empty() {
Err(SegmentationError::FishNotFound)
}
Expand All @@ -463,8 +492,8 @@ impl FishSegmentation {
let res = VectorOfPoint::from_iter(poly
.iter()
.map(|point| Point2i::new(
((start_x as f32 + point.x as f32) * width_scale) as i32,
((start_y as f32 + point.y as f32) * height_scale) as i32
((start_x as f32 + point.x as f32).ceil() * width_scale) as i32,
((start_y as f32 + point.y as f32).ceil() * height_scale) as i32
))
.collect::<Vec<_>>());

Expand All @@ -480,12 +509,20 @@ impl FishSegmentation {
width_scale: f32, height_scale: f32,
shape: (usize, usize, usize)) -> Result<Array2<u8>, SegmentationError> {

let mut masks_clone = masks.clone();

masks_clone.swap_axes(3, 2);
masks_clone.swap_axes(2, 1);
masks_clone.swap_axes(1, 0);
masks_clone.swap_axes(1, 2);

// let mut complete_mask = Array2::<u8>::zeros((shape.0, shape.1));
match Mat::new_rows_cols_with_default(shape.0 as i32, shape.1 as i32, CV_8UC1, 0.into()) {
Ok(mut complete_mask_cv) => {
let mask_count = scores.len();

for ind in 0..mask_count {
println!("{}", scores[ind]);
if scores[ind] <= FishSegmentation::SCORE_THRESHOLD {
println!("scores below thresh, {}", scores[ind]);
continue;
Expand All @@ -497,17 +534,18 @@ impl FishSegmentation {
let y2 = boxes[[3, ind]].floor() as u32;
let (mask_h, mask_w) = (y2 - y1 + 1, x2 - x1 + 1);

let mut mask: Array4<f32> = masks.slice(s![.., .., .., ind])
.insert_axis(Axis(3))
let mask = masks_clone.slice(s![ind, .., .., 0])
.mapv(|v| v.to_owned());
mask.swap_axes(0, 1);

// Threshold the mask converting to uint8 casuse opencv diesn't allow other type!
let np_mask = self.do_paste_mask(&mask, mask_h, mask_w)?
.mapv(|v| if v > FishSegmentation::MASK_THRESHOLD {255 as u8} else {0});

// Find contours in the binary mask
let contours = self.bitmap_to_polygon(&np_mask)?;

write_vec_vec_p(&contours);

// Ignore empty contpurs
if contours.is_empty() {
println!("contours empty");
Expand Down Expand Up @@ -543,7 +581,6 @@ impl FishSegmentation {
match conversion_result {
Ok(complete_mask3) => {
let complete_mask = complete_mask3.remove_axis(Axis(2));
write2u(&complete_mask);
Ok(complete_mask)
}
Err(_) => {
Expand Down Expand Up @@ -571,6 +608,7 @@ impl FishSegmentation {
let (boxes, masks, scores) = result;
let masks = self.convert_output_to_mask_and_polygons(&boxes, &masks, &scores, width_scale, height_scale, img.dim())?;

write2u(&masks);
Ok(masks)
}
Err(error) => Err(SegmentationError::OrtErr(error))
Expand All @@ -597,7 +635,6 @@ mod tests {
img_bgr[[y, x, 2]] = img_rgb[[y, x, 0]];
}
}

let rust_segmentations = image::io::Reader::open("./data/segmentations.png").unwrap().decode().unwrap().as_luma8().unwrap().clone();
let truth = Array2::from_shape_vec((rust_segmentations.height() as usize, rust_segmentations.width() as usize), rust_segmentations.as_raw().clone()).unwrap()
.mapv(|v| v as i32);
Expand Down

0 comments on commit a61d4f0

Please sign in to comment.