Skip to content

Commit

Permalink
resize and padding using libtorch, no opencv dep needed
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Conos committed Sep 1, 2024
1 parent 885c005 commit a1f495f
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 188 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
opencv = "0.92.2"
tch = "0.17.0"

[[example]]
Expand Down
2 changes: 1 addition & 1 deletion code
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ export LIBTORCH_LIB=$(pwd)/libtorch/

export LD_LIBRARY_PATH="$LIBTORCH/lib/:$LD_LIBRARY_PATH"

code
code .
16 changes: 8 additions & 8 deletions examples/predict/main.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use tch::TchError;
use yolo_v8::{Image, YoloV8Classifier, YoloV8ObjectDetection, YoloV8Segmentation};

fn object_detection() {
fn object_detection(path: &str) {
// Load image to perform object detection, note that YOLOv8 resolution must match
// scaling width and height here
let mut image = Image::new("images/bus.jpg", YoloV8ObjectDetection::input_dimension());
let mut image = Image::new(path, YoloV8ObjectDetection::input_dimension());

// Load exported torchscript for object detection
let yolo = YoloV8ObjectDetection::new();

// Predict with non-max-suppression in the end
let bboxes = yolo.predict(&image, 0.15, 0.35);
let bboxes = yolo.predict(&image, 0.25, 0.7);
println!("bboxes={:?}", bboxes);

// Draw rectangles around detected objects
Expand All @@ -19,9 +19,9 @@ fn object_detection() {
image.save("images/result2.jpg");
}

fn image_classification() {
fn image_classification(path: &str) {
// Load image to perform image classification
let image = Image::new("images/test.jpg", YoloV8Classifier::input_dimension());
let image = Image::new(path, YoloV8Classifier::input_dimension());

// Load exported torchscript for object detection
let yolo = YoloV8Classifier::new();
Expand All @@ -39,10 +39,10 @@ fn image_segmentation() {
let classes = yolo.predict(&image);
}

// YOLOv8n (nano model) for object detection in image
// YOLOv8n for object detection in image
fn main() -> Result<(), TchError> {
object_detection();
// image_classification();
object_detection("images/katri.jpg");
// image_classification("images/katri.jpg");
// image_segmentation();
Ok(())
}
Binary file added images/katri.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
51 changes: 32 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ impl YoloV8ObjectDetection {
}

pub fn predict(&self, image: &Image, conf_thresh: f64, iou_thresh: f64) -> Vec<BBox> {
println!("predict(): image={:?}", image.scaled_image);
// println!("predict(): image={:?}", image.scaled_image);
let pred = self.yolo.predict(image);
println!("pred={:?}", pred);
self.non_max_suppression(&pred.get(0), conf_thresh, iou_thresh)
// println!("pred={:?}", pred);
self.non_max_suppression(image, &pred.get(0), conf_thresh, iou_thresh)
}

fn iou(&self, b1: &BBox, b2: &BBox) -> f64 {
Expand All @@ -117,6 +117,7 @@ impl YoloV8ObjectDetection {

fn non_max_suppression(
&self,
image: &Image,
prediction: &tch::Tensor,
conf_thresh: f64,
iou_thresh: f64,
Expand All @@ -143,11 +144,33 @@ impl YoloV8ObjectDetection {
// CLASSES[class_index]
// );

let (_, orig_h, orig_w) = image.image.size3().unwrap();
let (_, sh, sw) = image.scaled_image.size3().unwrap();
let cx = sw as f64 / 2.0;
let cy = sh as f64 / 2.0;
let mut dx = pred[0] - cx;
let mut dy = pred[1] - cy;
let mut w = pred[2];
let mut h = pred[3];

let aspect = orig_w as f64 / orig_h as f64;

if orig_w > orig_h {
dy *= aspect;
h *= aspect;
} else {
dx /= aspect;
w /= aspect;
}

let x = cx + dx;
let y = cy + dy;

let bbox = BBox {
xmin: pred[0] - pred[2] / 2.,
ymin: pred[1] - pred[3] / 2.,
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
xmin: x - w / 2.,
ymin: y - h / 2.,
xmax: x + w / 2.,
ymax: y + h / 2.,
conf: confidence,
cls: class_index,
name: DETECT_CLASSES[class_index],
Expand Down Expand Up @@ -253,16 +276,12 @@ impl YOLOv8 {
pub fn predict(&self, image: &Image) -> Tensor {
let img = &image.scaled_image;

println!("img={:?}", img);

let img = img
.unsqueeze(0)
.to_kind(tch::Kind::Float)
.to_device(self.device)
.g_div_scalar(255.);

println!("img_float={:?}", img);

let pred = self
.model
.forward_ts(&[img])
Expand Down Expand Up @@ -304,15 +323,9 @@ impl Image {
let width = dimension.0;
let height = dimension.1;
let image = tch::vision::image::load(path).expect("can't load image");
let scaled_image =
tch::vision::image::resize(&image, width, height).expect("can't resize image");
utils::print_tensor(&scaled_image);
println!("---------------------------------");
// let scaled_image =
// utils::preprocess(path, dimension.0 as i32, true).expect("image preprocess");
let scaled_image = utils::plain_resize(path).expect("XXXXXXXXXXXXXXXXXXXXXx");
println!("AHOJ");
utils::print_tensor(&scaled_image);
// tch::vision::image::resize(&image, width, height).expect("can't resize image");
let scaled_image = utils::preprocess_torch(path, dimension.0 as i32);
Self {
width,
height,
Expand Down
Loading

0 comments on commit a1f495f

Please sign in to comment.