Skip to content

Commit

Permalink
modify method of creating label
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardYang committed Mar 9, 2024
1 parent 7d01a3a commit 6670ca1
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
1 change: 1 addition & 0 deletions yolo_v1/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ image = "0.23"
itertools = "0.12.1"
rand = "0.8"
xml = "0.8.10"
libm = "0.2.8"
35 changes: 25 additions & 10 deletions yolo_v1/src/data.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{cmp::min, fs::{read_dir, File}, io::BufReader, path::Path, usize, vec};
use std::{fs::{read_dir, File}, io::BufReader, path::Path, usize, vec};


use burn::{data::{dataloader::batcher::Batcher, dataset::Dataset}, tensor::{backend::Backend, Data, Tensor}};
Expand Down Expand Up @@ -39,16 +39,18 @@ impl VocObjBox {
}
}

fn get_label_class_idx(&self) -> usize {
fn get_label_class_idx(&self) -> (usize, bool) {
let mut class_idx = 0;
let mut has_obj = false;

for obj_class in OBJ_CLASSES {
if self.box_class.as_str().eq(obj_class) {
has_obj = true;
break;
}
class_idx += 1;
}
class_idx
(class_idx, has_obj)
}
}

Expand All @@ -72,6 +74,8 @@ impl VocLabel {

let mut element_name = String::new();

let mut begin_resolve = false;

loop {
match reader.next() {
Ok(e) => {
Expand All @@ -93,15 +97,22 @@ impl VocLabel {
}
if element_name.as_str().eq("object") {
obj_boxes.push(VocObjBox::new());
begin_resolve = true;
}
}
XmlEvent::EndElement { .. } => {
// println!("EndElement({name})");
if element_name.as_str().eq("object") {
begin_resolve = false;
}
element_name = String::new();
},
XmlEvent::Characters(data) => {
// println!(r#"Characters("{}")"#, data.escape_debug());
let element_name_str = element_name.as_str();
if element_name_str.eq("name") && !begin_resolve {
continue;
}
if element_name_str.eq("width") {
width = data.parse().unwrap_or(0);
} else if element_name_str.eq("height") {
Expand Down Expand Up @@ -209,17 +220,21 @@ impl ItemLoader for VocItemLoader {

let mut label = vec![0f32; 30 * SEGMENT * SEGMENT];

let cell_size_x = WIDTH / SEGMENT;
let cell_size_y = HEIGHT / SEGMENT;
let cell_size_x = (WIDTH / SEGMENT) as f32;
let cell_size_y = (HEIGHT / SEGMENT) as f32;

for obj_box in voc_label.obj_boxes {
let class_idx = obj_box.get_label_class_idx();
let (class_idx, has_obj) = obj_box.get_label_class_idx();

if !has_obj {
continue;
}

let (box_w, box_h) = (obj_box.xmax.abs_diff(obj_box.xmin), obj_box.ymax.abs_diff(obj_box.ymin));
let (cx, cy) = (obj_box.xmin + box_w / 2, obj_box.ymin + box_h / 2);
let (i, j) = (min(cx / cell_size_x, 6), min(cy / cell_size_y, 6));
let (delta_x, delta_y) = (((cx - cell_size_x * i) as f32) / cell_size_x as f32,
((cy - cell_size_y * j) as f32) / cell_size_y as f32);
let (cx, cy) = ((obj_box.xmin + obj_box.xmax) as f32 / 2f32, (obj_box.ymin + obj_box.ymax) as f32 / 2f32);
let (i, j) = (libm::floor((cx / cell_size_x) as f64) as usize, libm::floor((cy / cell_size_y) as f64) as usize);
let (delta_x, delta_y) = ((cx - cell_size_x * i as f32) / cell_size_x as f32,
(cy - cell_size_y * j as f32) / cell_size_y as f32);
label[i * SEGMENT + j] = delta_x;
label[1 * SEGMENT * SEGMENT + i * SEGMENT + j] = delta_y;
label[2 * SEGMENT * SEGMENT + i * SEGMENT + j] = box_w as f32/ WIDTH as f32;
Expand Down
2 changes: 1 addition & 1 deletion yolo_v1/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type MyAutodiffBackend = Autodiff<MyBackend>;

fn main() {
unsafe { backtrace_on_stack_overflow::enable() };
let device = burn::backend::libtorch::LibTorchDevice::Cuda(1);
let device = burn::backend::libtorch::LibTorchDevice::Cuda(0);
training::train::<MyAutodiffBackend>(
"./yolo_v1/model",
TrainingConfig::new(YoloV1Config::new(7, 2), AdamWConfig::new()),
Expand Down
4 changes: 2 additions & 2 deletions yolo_v1/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ pub struct BBox {
impl From<(&[f32], usize, usize)> for BBox {
fn from(value: (&[f32], usize, usize)) -> Self {
let (box_val, i, j) = (value.0, value.1, value.2);
let xmin = box_val[0] + SEGMENT as f32 * i as f32;
let ymin = box_val[1] + SEGMENT as f32 * j as f32;
let box_w = box_val[2] * WIDTH as f32;
let box_h = box_val[3] * HEIGHT as f32;
let xmin = box_val[0] * SEGMENT as f32 + SEGMENT as f32 * i as f32 - box_w / 2f32;
let ymin = box_val[1] * SEGMENT as f32 + SEGMENT as f32 * j as f32 - box_h / 2f32;

let mut box_origin = [
box_val[0],
Expand Down

0 comments on commit 6670ca1

Please sign in to comment.