From 6670ca11948765430f7f638121bfd1d5700fec39 Mon Sep 17 00:00:00 2001 From: RichardYang Date: Sat, 9 Mar 2024 23:34:08 +0800 Subject: [PATCH] modify method of creating label --- yolo_v1/Cargo.toml | 1 + yolo_v1/src/data.rs | 35 +++++++++++++++++++++++++---------- yolo_v1/src/main.rs | 2 +- yolo_v1/src/model.rs | 4 ++-- 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/yolo_v1/Cargo.toml b/yolo_v1/Cargo.toml index 6190380..182e38f 100644 --- a/yolo_v1/Cargo.toml +++ b/yolo_v1/Cargo.toml @@ -13,3 +13,4 @@ image = "0.23" itertools = "0.12.1" rand = "0.8" xml = "0.8.10" +libm = "0.2.8" diff --git a/yolo_v1/src/data.rs b/yolo_v1/src/data.rs index 71b86c2..5773c6d 100644 --- a/yolo_v1/src/data.rs +++ b/yolo_v1/src/data.rs @@ -1,4 +1,4 @@ -use std::{cmp::min, fs::{read_dir, File}, io::BufReader, path::Path, usize, vec}; +use std::{fs::{read_dir, File}, io::BufReader, path::Path, usize, vec}; use burn::{data::{dataloader::batcher::Batcher, dataset::Dataset}, tensor::{backend::Backend, Data, Tensor}}; @@ -39,16 +39,18 @@ impl VocObjBox { } } - fn get_label_class_idx(&self) -> usize { + fn get_label_class_idx(&self) -> (usize, bool) { let mut class_idx = 0; + let mut has_obj = false; for obj_class in OBJ_CLASSES { if self.box_class.as_str().eq(obj_class) { + has_obj = true; break; } class_idx += 1; } - class_idx + (class_idx, has_obj) } } @@ -72,6 +74,8 @@ impl VocLabel { let mut element_name = String::new(); + let mut begin_resolve = false; + loop { match reader.next() { Ok(e) => { @@ -93,15 +97,22 @@ impl VocLabel { } if element_name.as_str().eq("object") { obj_boxes.push(VocObjBox::new()); + begin_resolve = true; } } XmlEvent::EndElement { .. } => { // println!("EndElement({name})"); + if element_name.as_str().eq("object") { + begin_resolve = false; + } element_name = String::new(); }, XmlEvent::Characters(data) => { // println!(r#"Characters("{}")"#, data.escape_debug()); let element_name_str = element_name.as_str(); + if element_name_str.eq("name") && !begin_resolve { + continue; + } if element_name_str.eq("width") { width = data.parse().unwrap_or(0); } else if element_name_str.eq("height") { @@ -209,17 +220,21 @@ impl ItemLoader for VocItemLoader { let mut label = vec![0f32; 30 * SEGMENT * SEGMENT]; - let cell_size_x = WIDTH / SEGMENT; - let cell_size_y = HEIGHT / SEGMENT; + let cell_size_x = (WIDTH / SEGMENT) as f32; + let cell_size_y = (HEIGHT / SEGMENT) as f32; for obj_box in voc_label.obj_boxes { - let class_idx = obj_box.get_label_class_idx(); + let (class_idx, has_obj) = obj_box.get_label_class_idx(); + + if !has_obj { + continue; + } let (box_w, box_h) = (obj_box.xmax.abs_diff(obj_box.xmin), obj_box.ymax.abs_diff(obj_box.ymin)); - let (cx, cy) = (obj_box.xmin + box_w / 2, obj_box.ymin + box_h / 2); - let (i, j) = (min(cx / cell_size_x, 6), min(cy / cell_size_y, 6)); - let (delta_x, delta_y) = (((cx - cell_size_x * i) as f32) / cell_size_x as f32, - ((cy - cell_size_y * j) as f32) / cell_size_y as f32); + let (cx, cy) = ((obj_box.xmin + obj_box.xmax) as f32 / 2f32, (obj_box.ymin + obj_box.ymax) as f32 / 2f32); + let (i, j) = (libm::floor((cx / cell_size_x) as f64) as usize, libm::floor((cy / cell_size_y) as f64) as usize); + let (delta_x, delta_y) = ((cx - cell_size_x * i as f32) / cell_size_x as f32, + (cy - cell_size_y * j as f32) / cell_size_y as f32); label[i * SEGMENT + j] = delta_x; label[1 * SEGMENT * SEGMENT + i * SEGMENT + j] = delta_y; label[2 * SEGMENT * SEGMENT + i * SEGMENT + j] = box_w as f32/ WIDTH as f32; diff --git a/yolo_v1/src/main.rs b/yolo_v1/src/main.rs index 036af01..13a2fa7 100644 --- a/yolo_v1/src/main.rs +++ b/yolo_v1/src/main.rs @@ -14,7 +14,7 @@ type MyAutodiffBackend = Autodiff; fn main() { unsafe { backtrace_on_stack_overflow::enable() }; - let device = burn::backend::libtorch::LibTorchDevice::Cuda(1); + let device = burn::backend::libtorch::LibTorchDevice::Cuda(0); training::train::( "./yolo_v1/model", TrainingConfig::new(YoloV1Config::new(7, 2), AdamWConfig::new()), diff --git a/yolo_v1/src/model.rs b/yolo_v1/src/model.rs index f1ad7dc..ee40ba2 100644 --- a/yolo_v1/src/model.rs +++ b/yolo_v1/src/model.rs @@ -24,10 +24,10 @@ pub struct BBox { impl From<(&[f32], usize, usize)> for BBox { fn from(value: (&[f32], usize, usize)) -> Self { let (box_val, i, j) = (value.0, value.1, value.2); - let xmin = box_val[0] + SEGMENT as f32 * i as f32; - let ymin = box_val[1] + SEGMENT as f32 * j as f32; let box_w = box_val[2] * WIDTH as f32; let box_h = box_val[3] * HEIGHT as f32; + let xmin = box_val[0] * SEGMENT as f32 + SEGMENT as f32 * i as f32 - box_w / 2f32; + let ymin = box_val[1] * SEGMENT as f32 + SEGMENT as f32 * j as f32 - box_h / 2f32; let mut box_origin = [ box_val[0],