Skip to content

Commit

Permalink
modify yolo loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardYang committed Mar 9, 2024
1 parent 6bb5cef commit 755112f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 27 deletions.
2 changes: 1 addition & 1 deletion fashion_mnist/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl FashionMNISTDataset {

fn new(split: &str) -> Self {
// Download dataset
let root = Path::new("/Users/yangyang/Projects/burn-examples/fashion_mnist/data");
let root = Path::new("./fashion_mnist/data");

// MNIST is tiny so we can load it in-memory
// Train images (u8): 28 * 28 * 60000 = 47.04Mb
Expand Down
2 changes: 1 addition & 1 deletion yolo_v1/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

use burn::{backend::{Autodiff, LibTorch}, optim::AdamWConfig};
use burn::{backend::{Autodiff, LibTorch}, optim::AdamWConfig, tensor::Tensor};
// use data::{ItemLoader, VocItem, VocItemLoader};
use training::TrainingConfig;

Expand Down
49 changes: 26 additions & 23 deletions yolo_v1/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct BBox {
ymin: f32,
xmax: f32,
ymax: f32,
prob: f32,
confident: f32,
box_origin: [f32; 4],
}

Expand Down Expand Up @@ -51,7 +51,7 @@ impl From<(&[f32], usize, usize)> for BBox {
ymin,
xmax: xmin + box_w,
ymax: ymin + box_h,
prob: box_val[4],
confident: box_val[4],
box_origin
}
}
Expand Down Expand Up @@ -117,30 +117,33 @@ impl YoloV1Loss {
for (i, j) in iproduct!(0..segment_w, 0..segment_h) {
let (target_has_obj, (target_bbox, _), target_probs) = self.create_bboxes(&target, i, j);
let (_, (predict_bbox_1, predict_bbox_2), predict_probs) = self.create_bboxes(&predict, i, j);

let iou_1 = self.compute_iou::<B>(predict_bbox_1, target_bbox);
let iou_2 = self.compute_iou::<B>(predict_bbox_2, target_bbox);

let (resp_bbox, no_resp_bbox) = if iou_1 > iou_2 {
(predict_bbox_1, predict_bbox_2)
} else {
(predict_bbox_2, predict_bbox_1)
};

if target_has_obj {
let iou_1 = self.compute_iou::<B>(predict_bbox_1, target_bbox);
let iou_2 = self.compute_iou::<B>(predict_bbox_2, target_bbox);
let choosen_bbox = if iou_1 > iou_2 {
loss += (predict_bbox_1.prob - target_bbox.prob).powi(2);
loss += self.l_noobj * (predict_bbox_2.prob - target_bbox.prob).powi(2);
predict_bbox_1.box_origin
} else {
loss += self.l_noobj * (predict_bbox_1.prob - target_bbox.prob).powi(2);
loss += (predict_bbox_2.prob - target_bbox.prob).powi(2);
predict_bbox_2.box_origin
};
loss += self.l_coord * ((choosen_bbox[0] - target_bbox.box_origin[0])).powi(2);
loss += self.l_coord * ((choosen_bbox[1] - target_bbox.box_origin[1])).powi(2);
loss += self.l_coord * (choosen_bbox[2].sqrt() - target_bbox.box_origin[2].sqrt()).powi(2);
loss += self.l_coord * (choosen_bbox[3].sqrt() - target_bbox.box_origin[3].sqrt()).powi(2);
for i in 0..20 {
if target_probs[i] > 0f32 {
loss += (predict_probs[i] - target_probs[i]).powi(2);
}
if resp_bbox.confident > 0f32 {
loss += self.l_coord * ((resp_bbox.box_origin[0] - target_bbox.box_origin[0])).powi(2);
loss += self.l_coord * ((resp_bbox.box_origin[1] - target_bbox.box_origin[1])).powi(2);
loss += self.l_coord * (resp_bbox.box_origin[2].sqrt() - target_bbox.box_origin[2].sqrt()).powi(2);
loss += self.l_coord * (resp_bbox.box_origin[3].sqrt() - target_bbox.box_origin[3].sqrt()).powi(2);
loss += (resp_bbox.confident - target_bbox.confident).powi(2);
}
} else {
loss += self.l_noobj * (predict_bbox_1.prob - target_bbox.prob).powi(2);
loss += self.l_noobj * (predict_bbox_2.prob - target_bbox.prob).powi(2);
loss += self.l_noobj * (resp_bbox.confident - target_bbox.confident).powi(2);
loss += self.l_noobj * (no_resp_bbox.confident - target_bbox.confident).powi(2);
}

for i in 0..20 {
if target_probs[i] > 0f32 {
loss += (predict_probs[i] - target_probs[i]).powi(2);
}
}
}
loss_vec.push(loss);
Expand Down
4 changes: 2 additions & 2 deletions yolo_v1/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ pub struct TrainingConfig {
pub optimizer: AdamWConfig,
#[config(default = 10)]
pub num_epochs: usize,
#[config(default = 64)]
#[config(default = 32)]
pub batch_size: usize,
#[config(default = 8)]
pub num_workers: usize,
#[config(default = 35)]
pub seed: u64,
#[config(default = 0.0001)]
#[config(default = 0.001)]
pub learing_rate: f64,
}

Expand Down

0 comments on commit 755112f

Please sign in to comment.