Skip to content

Commit

Permalink
add leaky relu function to network
Browse files Browse the repository at this point in the history
  • Loading branch information
richardyang92 committed Mar 10, 2024
1 parent 755112f commit e2a3cda
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
2 changes: 1 addition & 1 deletion yolo_v1/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type MyAutodiffBackend = Autodiff<MyBackend>;

fn main() {
unsafe { backtrace_on_stack_overflow::enable() };
let device = burn::backend::libtorch::LibTorchDevice::Cuda(0);
let device = burn::backend::libtorch::LibTorchDevice::Mps;
training::train::<MyAutodiffBackend>(
"./yolo_v1/model",
TrainingConfig::new(YoloV1Config::new(7, 2), AdamWConfig::new()),
Expand Down
52 changes: 51 additions & 1 deletion yolo_v1/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{cmp::{max_by, min_by}, fmt::Display};

use burn::{config::Config, module::Module, nn::{conv::{Conv2d, Conv2dConfig}, pool::{MaxPool2d, MaxPool2dConfig}, Linear, LinearConfig, PaddingConfig2d}, tensor::{backend::{AutodiffBackend, Backend}, Float, Tensor}, train::{metric::{Adaptor, LossInput}, TrainOutput, TrainStep, ValidStep}};
use burn::{config::Config, module::Module, nn::{conv::{Conv2d, Conv2dConfig}, pool::{MaxPool2d, MaxPool2dConfig}, Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, ReLU}, tensor::{backend::{AutodiffBackend, Backend}, ElementConversion, Float, Tensor}, train::{metric::{Adaptor, LossInput}, TrainOutput, TrainStep, ValidStep}};

use crate::data::YoloV1Batch;

Expand Down Expand Up @@ -194,6 +194,18 @@ impl<B: Backend> Adaptor<LossInput<B>> for YoloV1RegressionOutput<B> {
}
}

#[derive(Module, Clone, Debug, Default, new)]
pub struct LeakyReLU { }

impl LeakyReLU {
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let input_primitive = input.clone().into_primitive();
let negative_slope = input.mul_scalar(0.01).into_primitive();
let mask = B::float_lower_equal_elem(input_primitive.clone(), 0.elem());
Tensor::new(B::float_mask_where(input_primitive, mask, negative_slope))
}
}

#[derive(Module, Debug)]
pub struct YoloV1<B: Backend> {
conv1: Conv2d<B>,
Expand Down Expand Up @@ -226,41 +238,67 @@ pub struct YoloV1<B: Backend> {
conv6_2: Conv2d<B>,
fc1: Linear<B>,
fc2: Linear<B>,
leaky_relu: LeakyReLU,
dropout: Dropout,
}

impl<B: Backend> YoloV1<B> {
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv1.forward(input);
let x = self.leaky_relu.forward(x);
let x = self.pool1.forward(x);
let x = self.conv2.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.pool2.forward(x);
let x = self.conv3_1.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv3_2.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv3_3.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv3_4.forward(x);
let x = self.leaky_relu.forward(x);
let x= self.pool3.forward(x);
let x = self.conv4_1.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv4_2.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv4_3.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv4_4.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv4_5.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv4_6.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv4_7.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv4_8.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv4_9.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv4_10.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.pool4.forward(x);
let x = self.conv5_1.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv5_2.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv5_3.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv5_4.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.pool5.forward(x);
let x = self.conv6_1.forward(x);
let x = self.leaky_relu.forward(x);
let x = self.conv6_2.forward(x);
let x = self.leaky_relu.forward(x);

let [batch_size, channels, height, width] = x.dims();
let x = x.reshape([batch_size, channels * height * width]);

let x = self.fc1.forward(x);
let x = self.dropout.forward(x);
let x = self.fc2.forward(x);
x.reshape([batch_size, 30, 7, 7])
}
Expand All @@ -280,6 +318,8 @@ impl<B: Backend> YoloV1<B> {
pub struct YoloV1Config {
segments_number: usize,
boxes_number: usize,
#[config(default = 0.3)]
dropout_prob: f64,
}

impl YoloV1Config {
Expand Down Expand Up @@ -329,6 +369,9 @@ impl YoloV1Config {
.with_padding(PaddingConfig2d::Explicit(1, 1)).init(device);
let fc1 = LinearConfig::new(1024 * self.segments_number * self.segments_number, 4096).init(device);
let fc2 = LinearConfig::new(4096, 30 * self.segments_number * self.segments_number).init(device);
let relu = ReLU::new();
let leaky_relu = LeakyReLU::new();
let dropout = DropoutConfig::new(self.dropout_prob).init();

YoloV1 {
conv1,
Expand Down Expand Up @@ -361,6 +404,8 @@ impl YoloV1Config {
conv6_2,
fc1,
fc2,
leaky_relu,
dropout,
}
}

Expand Down Expand Up @@ -410,6 +455,9 @@ impl YoloV1Config {
.with_padding(PaddingConfig2d::Explicit(1, 1)).init_with(record.conv6_2);
let fc1 = LinearConfig::new(1024 * self.segments_number * self.segments_number, 4096).init_with(record.fc1);
let fc2 = LinearConfig::new(4096, 30 * self.segments_number * self.segments_number).init_with(record.fc2);
let relu = ReLU::new();
let leaky_relu = LeakyReLU::new();
let dropout = DropoutConfig::new(self.dropout_prob).init();

YoloV1 {
conv1,
Expand Down Expand Up @@ -442,6 +490,8 @@ impl YoloV1Config {
conv6_2,
fc1,
fc2,
leaky_relu,
dropout,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions yolo_v1/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use burn::{config::Config, data::{dataloader::DataLoaderBuilder, dataset::Datase

use crate::{data::{YoloV1Batcher, VocDataset}, model::YoloV1Config};

// const VOC2007_ROOT: &'static str = "/Users/yangyang/Downloads/VOCdevkit 2/VOC2007";
// const VOC2007_ROOT: &'static str = "/Users/yangyang/Projects/burn-examples/yolo_v1/data";
const VOC2007_ROOT: &'static str = "/media/yang/MyFiles/VOC2007";

Expand Down

0 comments on commit e2a3cda

Please sign in to comment.