From 755112f8a5ca6d2d554ed4c854cc24f8ea5e2eb1 Mon Sep 17 00:00:00 2001 From: RichardYang Date: Sun, 10 Mar 2024 03:43:13 +0800 Subject: [PATCH] modify yolo loss function --- fashion_mnist/src/data.rs | 2 +- yolo_v1/src/main.rs | 2 +- yolo_v1/src/model.rs | 49 +++++++++++++++++++++------------------ yolo_v1/src/training.rs | 4 ++-- 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/fashion_mnist/src/data.rs b/fashion_mnist/src/data.rs index ca0552c..22cb731 100644 --- a/fashion_mnist/src/data.rs +++ b/fashion_mnist/src/data.rs @@ -105,7 +105,7 @@ impl FashionMNISTDataset { fn new(split: &str) -> Self { // Download dataset - let root = Path::new("/Users/yangyang/Projects/burn-examples/fashion_mnist/data"); + let root = Path::new("./fashion_mnist/data"); // MNIST is tiny so we can load it in-memory // Train images (u8): 28 * 28 * 60000 = 47.04Mb diff --git a/yolo_v1/src/main.rs b/yolo_v1/src/main.rs index c3342d8..eefefff 100644 --- a/yolo_v1/src/main.rs +++ b/yolo_v1/src/main.rs @@ -1,5 +1,5 @@ -use burn::{backend::{Autodiff, LibTorch}, optim::AdamWConfig}; +use burn::{backend::{Autodiff, LibTorch}, optim::AdamWConfig, tensor::Tensor}; // use data::{ItemLoader, VocItem, VocItemLoader}; use training::TrainingConfig; diff --git a/yolo_v1/src/model.rs b/yolo_v1/src/model.rs index 0eaa6a7..c34ed53 100644 --- a/yolo_v1/src/model.rs +++ b/yolo_v1/src/model.rs @@ -17,7 +17,7 @@ pub struct BBox { ymin: f32, xmax: f32, ymax: f32, - prob: f32, + confident: f32, box_origin: [f32; 4], } @@ -51,7 +51,7 @@ impl From<(&[f32], usize, usize)> for BBox { ymin, xmax: xmin + box_w, ymax: ymin + box_h, - prob: box_val[4], + confident: box_val[4], box_origin } } @@ -117,30 +117,33 @@ impl YoloV1Loss { for (i, j) in iproduct!(0..segment_w, 0..segment_h) { let (target_has_obj, (target_bbox, _), target_probs) = self.create_bboxes(&target, i, j); let (_, (predict_bbox_1, predict_bbox_2), predict_probs) = self.create_bboxes(&predict, i, j); + + let iou_1 = self.compute_iou::(predict_bbox_1, target_bbox); + let iou_2 = self.compute_iou::(predict_bbox_2, target_bbox); + + let (resp_bbox, no_resp_bbox) = if iou_1 > iou_2 { + (predict_bbox_1, predict_bbox_2) + } else { + (predict_bbox_2, predict_bbox_1) + }; + if target_has_obj { - let iou_1 = self.compute_iou::(predict_bbox_1, target_bbox); - let iou_2 = self.compute_iou::(predict_bbox_2, target_bbox); - let choosen_bbox = if iou_1 > iou_2 { - loss += (predict_bbox_1.prob - target_bbox.prob).powi(2); - loss += self.l_noobj * (predict_bbox_2.prob - target_bbox.prob).powi(2); - predict_bbox_1.box_origin - } else { - loss += self.l_noobj * (predict_bbox_1.prob - target_bbox.prob).powi(2); - loss += (predict_bbox_2.prob - target_bbox.prob).powi(2); - predict_bbox_2.box_origin - }; - loss += self.l_coord * ((choosen_bbox[0] - target_bbox.box_origin[0])).powi(2); - loss += self.l_coord * ((choosen_bbox[1] - target_bbox.box_origin[1])).powi(2); - loss += self.l_coord * (choosen_bbox[2].sqrt() - target_bbox.box_origin[2].sqrt()).powi(2); - loss += self.l_coord * (choosen_bbox[3].sqrt() - target_bbox.box_origin[3].sqrt()).powi(2); - for i in 0..20 { - if target_probs[i] > 0f32 { - loss += (predict_probs[i] - target_probs[i]).powi(2); - } + if resp_bbox.confident > 0f32 { + loss += self.l_coord * ((resp_bbox.box_origin[0] - target_bbox.box_origin[0])).powi(2); + loss += self.l_coord * ((resp_bbox.box_origin[1] - target_bbox.box_origin[1])).powi(2); + loss += self.l_coord * (resp_bbox.box_origin[2].sqrt() - target_bbox.box_origin[2].sqrt()).powi(2); + loss += self.l_coord * (resp_bbox.box_origin[3].sqrt() - target_bbox.box_origin[3].sqrt()).powi(2); + loss += (resp_bbox.confident - target_bbox.confident).powi(2); } } else { - loss += self.l_noobj * (predict_bbox_1.prob - target_bbox.prob).powi(2); - loss += self.l_noobj * (predict_bbox_2.prob - target_bbox.prob).powi(2); + loss += self.l_noobj * (resp_bbox.confident - target_bbox.confident).powi(2); + loss += self.l_noobj * (no_resp_bbox.confident - target_bbox.confident).powi(2); + } + + for i in 0..20 { + if target_probs[i] > 0f32 { + loss += (predict_probs[i] - target_probs[i]).powi(2); + } } } loss_vec.push(loss); diff --git a/yolo_v1/src/training.rs b/yolo_v1/src/training.rs index 88858f5..69deac4 100644 --- a/yolo_v1/src/training.rs +++ b/yolo_v1/src/training.rs @@ -12,13 +12,13 @@ pub struct TrainingConfig { pub optimizer: AdamWConfig, #[config(default = 10)] pub num_epochs: usize, - #[config(default = 64)] + #[config(default = 32)] pub batch_size: usize, #[config(default = 8)] pub num_workers: usize, #[config(default = 35)] pub seed: u64, - #[config(default = 0.0001)] + #[config(default = 0.001)] pub learing_rate: f64, }