Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
xqh5201314 authored Jun 21, 2023
1 parent fcc0e81 commit 93ccd57
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 0 deletions.
35 changes: 35 additions & 0 deletions SPPF-imporve/SPPF-imporve.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
这是关于SPPF的系列改进包括SimSPPF,SPPCSPC,SimCSPSPPF

SimSPPF为YOLOv6提出的,论文地址:[YOLOv6: A Single-Stage Object Detection Framework for Industrial Applications | PDF (arxiv.org)](https://arxiv.org/pdf/2209.02976v1.pdf)

SPPCSPC为YOLOv7提出的,论文地址:[YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors | PDF (arxiv.org)](https://arxiv.org/pdf/2207.02696v1.pdf)

SimCSPSPPF为YOLOv6 3.0提出的,论文地址:https://arxiv.org/pdf/2301.05586v1.pdf。

具体改进详情见论文。

使用方法:

1.将SPPF_imporve.py复制到models文件夹下

2.在models/yolo.py导入所需模块

from SPPF_imporve import SimSPPF, SPPCSPC, SimCSPSPPF

```
elif m is [SimSPPF,SPPCSPC,SimCSPSPPF]:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
args = [c1, c2, *args[1:]]
if m in [SimSPPF,SPPCSPC,SimCSPSPPF]:
args.insert(2, n) # number of repeats
n = 1
```

3.进行训练

```python
python train.py --cfg YOLOv5_sppf_imporve.yaml
```

111 changes: 111 additions & 0 deletions SPPF-imporve/SPPF_imporve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import time
import torch
import torch.nn as nn

def autopad(k, p=None): # kernel, padding
# Pad to 'same'
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p

class Conv(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

def forward(self, x):
return self.act(self.bn(self.conv(x)))

def forward_fuse(self, x):
return self.act(self.conv(x))

class SimConv(nn.Module):
'''Normal Conv with ReLU activation'''

def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
super().__init__()
padding = kernel_size // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=bias,
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.ReLU()

def forward(self, x):
return self.act(self.bn(self.conv(x)))

def forward_fuse(self, x):
return self.act(self.conv(x))

class SimSPPF(nn.Module):
'''Simplified SPPF with ReLU activation'''

def __init__(self, in_channels, out_channels, kernel_size=5):
super().__init__()
c_ = in_channels // 2 # hidden channels
self.cv1 = SimConv(in_channels, c_, 1, 1)
self.cv2 = SimConv(c_ * 4, out_channels, 1, 1)
self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)

def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))

class SPPCSPC(nn.Module):
# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
super(SPPCSPC, self).__init__()
c_ = int(2 * c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(c_, c_, 3, 1)
self.cv4 = Conv(c_, c_, 1, 1)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
self.cv5 = Conv(4 * c_, c_, 1, 1)
self.cv6 = Conv(c_, c_, 3, 1)
self.cv7 = Conv(2 * c_, c2, 1, 1)

def forward(self, x):
x1 = self.cv4(self.cv3(self.cv1(x)))
y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
y2 = self.cv2(x)
return self.cv7(torch.cat((y1, y2), dim=1))


class SimCSPSPPF(nn.Module):
# CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, in_channels, out_channels, kernel_size=5, e=0.5):
super(SimCSPSPPF, self).__init__()
c_ = int(out_channels * e) # hidden channels
self.cv1 = SimConv(in_channels, c_, 1, 1)
self.cv2 = SimConv(in_channels, c_, 1, 1)
self.cv3 = SimConv(c_, c_, 3, 1)
self.cv4 = SimConv(c_, c_, 1, 1)

self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
self.cv5 = SimConv(4 * c_, c_, 1, 1)
self.cv6 = SimConv(c_, c_, 3, 1)
self.cv7 = SimConv(2 * c_, out_channels, 1, 1)

def forward(self, x):
x1 = self.cv4(self.cv3(self.cv1(x)))
y0 = self.cv2(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
y1 = self.m(x1)
y2 = self.m(y1)
y3 = self.cv6(self.cv5(torch.cat([x1, y1, y2, self.m(y2)], 1)))
return self.cv7(torch.cat((y0, y3), dim=1))
52 changes: 52 additions & 0 deletions SPPF-imporve/YOLOv5_sppf_imporve.yaml.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license

# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32

# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
# [-1, 1, SimSPPF, [1024, 5]],
# [-1, 1, SPPCSPC, [1024]],
# [-1, 1, SimCSPSPPF, [1024, 5]],
]


# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13

[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)

[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)

[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)

[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]

0 comments on commit 93ccd57

Please sign in to comment.