-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fcc0e81
commit 93ccd57
Showing
3 changed files
with
198 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
这是关于SPPF的系列改进包括SimSPPF,SPPCSPC,SimCSPSPPF | ||
|
||
SimSPPF为YOLOv6提出的,论文地址:[YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications | PDF (arxiv.org)](https://arxiv.org/pdf/2209.02976v1.pdf); | ||
|
||
SPPCSPC为YOLOv7提出的,论文地址:[YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors | PDF (arxiv.org)](https://arxiv.org/pdf/2207.02696v1.pdf); | ||
|
||
SimCSPSPPF为YOLOv6 3.0提出的,论文地址:https://arxiv.org/pdf/2301.05586v1.pdf。 | ||
|
||
具体改进详情见论文。 | ||
|
||
使用方法: | ||
|
||
1.将SPPF_imporve.py复制到models文件夹下 | ||
|
||
2.在models/yolo.py导入所需模块 | ||
|
||
from SPPF_imporve import SimSPPF, SPPCSPC, SimCSPSPPF | ||
|
||
``` | ||
elif m is [SimSPPF,SPPCSPC,SimCSPSPPF]: | ||
c1, c2 = ch[f], args[0] | ||
if c2 != no: # if not output | ||
c2 = make_divisible(c2 * gw, 8) | ||
args = [c1, c2, *args[1:]] | ||
if m in [SimSPPF,SPPCSPC,SimCSPSPPF]: | ||
args.insert(2, n) # number of repeats | ||
n = 1 | ||
``` | ||
|
||
3.进行训练 | ||
|
||
```python | ||
python train.py --cfg YOLOv5_sppf_imporve.yaml | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import time | ||
import torch | ||
import torch.nn as nn | ||
|
||
def autopad(k, p=None): # kernel, padding | ||
# Pad to 'same' | ||
if p is None: | ||
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad | ||
return p | ||
|
||
class Conv(nn.Module): | ||
# Standard convolution | ||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups | ||
super(Conv, self).__init__() | ||
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) | ||
self.bn = nn.BatchNorm2d(c2) | ||
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) | ||
|
||
def forward(self, x): | ||
return self.act(self.bn(self.conv(x))) | ||
|
||
def forward_fuse(self, x): | ||
return self.act(self.conv(x)) | ||
|
||
class SimConv(nn.Module): | ||
'''Normal Conv with ReLU activation''' | ||
|
||
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False): | ||
super().__init__() | ||
padding = kernel_size // 2 | ||
self.conv = nn.Conv2d( | ||
in_channels, | ||
out_channels, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
padding=padding, | ||
groups=groups, | ||
bias=bias, | ||
) | ||
self.bn = nn.BatchNorm2d(out_channels) | ||
self.act = nn.ReLU() | ||
|
||
def forward(self, x): | ||
return self.act(self.bn(self.conv(x))) | ||
|
||
def forward_fuse(self, x): | ||
return self.act(self.conv(x)) | ||
|
||
class SimSPPF(nn.Module): | ||
'''Simplified SPPF with ReLU activation''' | ||
|
||
def __init__(self, in_channels, out_channels, kernel_size=5): | ||
super().__init__() | ||
c_ = in_channels // 2 # hidden channels | ||
self.cv1 = SimConv(in_channels, c_, 1, 1) | ||
self.cv2 = SimConv(c_ * 4, out_channels, 1, 1) | ||
self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2) | ||
|
||
def forward(self, x): | ||
x = self.cv1(x) | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter('ignore') | ||
y1 = self.m(x) | ||
y2 = self.m(y1) | ||
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1)) | ||
|
||
class SPPCSPC(nn.Module): | ||
# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks | ||
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)): | ||
super(SPPCSPC, self).__init__() | ||
c_ = int(2 * c2 * e) # hidden channels | ||
self.cv1 = Conv(c1, c_, 1, 1) | ||
self.cv2 = Conv(c1, c_, 1, 1) | ||
self.cv3 = Conv(c_, c_, 3, 1) | ||
self.cv4 = Conv(c_, c_, 1, 1) | ||
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) | ||
self.cv5 = Conv(4 * c_, c_, 1, 1) | ||
self.cv6 = Conv(c_, c_, 3, 1) | ||
self.cv7 = Conv(2 * c_, c2, 1, 1) | ||
|
||
def forward(self, x): | ||
x1 = self.cv4(self.cv3(self.cv1(x))) | ||
y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1))) | ||
y2 = self.cv2(x) | ||
return self.cv7(torch.cat((y1, y2), dim=1)) | ||
|
||
|
||
class SimCSPSPPF(nn.Module): | ||
# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks | ||
def __init__(self, in_channels, out_channels, kernel_size=5, e=0.5): | ||
super(SimCSPSPPF, self).__init__() | ||
c_ = int(out_channels * e) # hidden channels | ||
self.cv1 = SimConv(in_channels, c_, 1, 1) | ||
self.cv2 = SimConv(in_channels, c_, 1, 1) | ||
self.cv3 = SimConv(c_, c_, 3, 1) | ||
self.cv4 = SimConv(c_, c_, 1, 1) | ||
|
||
self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2) | ||
self.cv5 = SimConv(4 * c_, c_, 1, 1) | ||
self.cv6 = SimConv(c_, c_, 3, 1) | ||
self.cv7 = SimConv(2 * c_, out_channels, 1, 1) | ||
|
||
def forward(self, x): | ||
x1 = self.cv4(self.cv3(self.cv1(x))) | ||
y0 = self.cv2(x) | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter('ignore') | ||
y1 = self.m(x1) | ||
y2 = self.m(y1) | ||
y3 = self.cv6(self.cv5(torch.cat([x1, y1, y2, self.m(y2)], 1))) | ||
return self.cv7(torch.cat((y0, y3), dim=1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license | ||
|
||
# Parameters | ||
nc: 80 # number of classes | ||
depth_multiple: 0.33 # model depth multiple | ||
width_multiple: 0.50 # layer channel multiple | ||
anchors: | ||
- [10,13, 16,30, 33,23] # P3/8 | ||
- [30,61, 62,45, 59,119] # P4/16 | ||
- [116,90, 156,198, 373,326] # P5/32 | ||
|
||
# YOLOv5 v6.0 backbone | ||
backbone: | ||
# [from, number, module, args] | ||
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 | ||
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4 | ||
[-1, 3, C3, [128]], | ||
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8 | ||
[-1, 6, C3, [256]], | ||
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16 | ||
[-1, 9, C3, [512]], | ||
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 | ||
[-1, 3, C3, [1024]], | ||
[-1, 1, SPPF, [1024, 5]], # 9 | ||
# [-1, 1, SimSPPF, [1024, 5]], | ||
# [-1, 1, SPPCSPC, [1024]], | ||
# [-1, 1, SimCSPSPPF, [1024, 5]], | ||
] | ||
|
||
|
||
# YOLOv5 v6.0 head | ||
head: | ||
[[-1, 1, Conv, [512, 1, 1]], | ||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||
[[-1, 6], 1, Concat, [1]], # cat backbone P4 | ||
[-1, 3, C3, [512, False]], # 13 | ||
|
||
[-1, 1, Conv, [256, 1, 1]], | ||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||
[[-1, 4], 1, Concat, [1]], # cat backbone P3 | ||
[-1, 3, C3, [256, False]], # 17 (P3/8-small) | ||
|
||
[-1, 1, Conv, [256, 3, 2]], | ||
[[-1, 14], 1, Concat, [1]], # cat head P4 | ||
[-1, 3, C3, [512, False]], # 20 (P4/16-medium) | ||
|
||
[-1, 1, Conv, [512, 3, 2]], | ||
[[-1, 10], 1, Concat, [1]], # cat head P5 | ||
[-1, 3, C3, [1024, False]], # 23 (P5/32-large) | ||
|
||
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) | ||
] |