Skip to content

Commit

Permalink
feat: swift yolo mbnv4
Browse files Browse the repository at this point in the history
* Squashed commit of the following:

commit 87250f5
Author: mjq2020 <[email protected]>
Date:   Fri Apr 26 10:33:07 2024 +0000

    add: mobilenetv4 backbone

commit 0771f18
Merge: 8e0b2f7 7f9c4e0
Author: mjq2020 <[email protected]>
Date:   Fri Apr 26 10:31:57 2024 +0000

    Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main

commit 8e0b2f7
Merge: ac0f39d 9b00e64
Author: mjq2020 <[email protected]>
Date:   Fri Apr 19 06:22:49 2024 +0000

    Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main

commit ac0f39d
Merge: c4ea712 1f67493
Author: mjq2020 <[email protected]>
Date:   Mon Apr 1 10:02:11 2024 +0000

    Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main

commit c4ea712
Author: mjq2020 <[email protected]>
Date:   Mon Apr 1 10:00:39 2024 +0000

    Fix: cls loss weight too high

commit b87fc04
Merge: f146c73 ee72f81
Author: mjq2020 <[email protected]>
Date:   Mon Apr 1 09:54:45 2024 +0000

    Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main

commit f146c73
Merge: c068454 289360c
Author: mjq2020 <[email protected]>
Date:   Tue Mar 19 02:16:47 2024 +0000

    Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main

commit c068454
Author: mjq2020 <[email protected]>
Date:   Mon Mar 18 07:05:34 2024 +0000

    Optim: model inference display

commit fc8874f
Author: mjq2020 <[email protected]>
Date:   Mon Mar 18 07:03:33 2024 +0000

    Fix: data type bug

commit f1d76fc
Merge: 1378e1b 3c61e3e
Author: mjq2020 <[email protected]>
Date:   Thu Mar 14 18:42:18 2024 +0800

    Merge branch 'Seeed-Studio:main' into main

commit 1378e1b
Merge: 8c8ffd7 31c5291
Author: mjq2020 <[email protected]>
Date:   Tue Jan 30 11:41:16 2024 +0800

    Merge branch 'Seeed-Studio:main' into main

commit 8c8ffd7
Merge: c67ed2d ebb1ec2
Author: mjq2020 <[email protected]>
Date:   Fri Oct 13 11:09:55 2023 +0800

    Merge branch 'Seeed-Studio:main' into main

commit c67ed2d
Merge: d70e424 9be0612
Author: mjq2020 <[email protected]>
Date:   Sat Sep 23 16:18:57 2023 +0800

    Merge branch 'Seeed-Studio:main' into main

* feat: mobilenetv4 swift yolo backbone

* chore: modify mbnv4 medium/large outputs
  • Loading branch information
iChizer0 authored May 22, 2024
1 parent 343a2cb commit 8e6ad23
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 18 deletions.
164 changes: 164 additions & 0 deletions configs/swift_yolo/swift_yolo_1xb16_300e_coco_mbnv4s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Seeed Technology Co.,Ltd. All rights reserved.
_base_ = ['./base_arch.py']

# ========================Suggested optional parameters========================
# MODEL
num_classes = 71
deepen_factor = 0.33
widen_factor = 1

# DATA
dataset_type = 'sscma.CustomYOLOv5CocoDataset'
train_ann = 'train/_annotations.coco.json'
train_data = 'train/' # Prefix of train image path
val_ann = 'valid/_annotations.coco.json'
val_data = 'valid/' # Prefix of val image path

# dataset link: https://universe.roboflow.com/team-roboflow/coco-128
data_root = 'https://universe.roboflow.com/ds/z5UOcgxZzD?key=bwx9LQUT0t'
height = 192
width = 192
batch = 16
workers = 2
val_batch = batch
val_workers = workers
imgsz = (width, height)

# TRAIN
persistent_workers = True

# ================================END=================================

# DATA
affine_scale = 0.5
# MODEL
strides = [8, 16, 32]

anchors = [
[(10, 13), (16, 30), (33, 23)], # P3/8
[(30, 61), (62, 45), (59, 119)], # P4/16
[(116, 90), (156, 198), (373, 326)], # P5/32
]

# default_scope = 'sscma'

model = dict(
type='mmyolo.YOLODetector',
backbone=dict(
_delete_=True,
type='sscma.MobileNetv4',
arch='small'
),
neck=dict(
type='mmyolo.YOLOv5PAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[64, 96, 128],
out_channels=[64, 96, 128]
),
bbox_head=dict(
head_module=dict(
num_classes=num_classes,
in_channels=[64, 96, 128],
widen_factor=widen_factor,
),
),
)

# ======================datasets==================


batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=1,
img_size=imgsz[0],
# The image scale of padding should be divided by pad_size_divisor
size_divisor=32,
# Additional paddings for pixel scale
extra_pad_ratio=0.5,
)

albu_train_transforms = [
dict(type='Blur', p=0.01),
dict(type='MedianBlur', p=0.01),
dict(type='ToGray', p=0.01),
dict(type='CLAHE', p=0.01),
]

pre_transform = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(type='LoadAnnotations', with_bbox=True, _scope_='sscma'),
]

train_pipeline = [
*pre_transform,
dict(type='Mosaic', img_scale=imgsz, pad_val=114.0, pre_transform=pre_transform, _scope_='sscma'),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_shear_degree=0.0,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
# imgsz is (width, height)
border=(-imgsz[0] // 2, -imgsz[1] // 2),
border_val=(114, 114, 114),
_scope_='sscma'
),
dict(
type='mmdet.Albu',
transforms=albu_train_transforms,
bbox_params=dict(type='BboxParams', format='pascal_voc', label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
keymap={'img': 'image', 'gt_bboxes': 'bboxes'},
),
dict(type='YOLOv5HSVRandomAug', _scope_='sscma'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', 'flip_direction')
),
]

train_dataloader = dict(
batch_size=batch,
num_workers=workers,
persistent_workers=persistent_workers,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=train_ann,
data_prefix=dict(img=train_data),
filter_cfg=dict(filter_empty_gt=False, min_size=32),
pipeline=train_pipeline,
),
)

test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(type='YOLOv5KeepRatioResize', scale=imgsz, _scope_='sscma'),
dict(type='sscma.LetterResize', scale=imgsz, allow_scale_up=False, pad_val=dict(img=114), _scope_='sscma'),
dict(type='LoadAnnotations', with_bbox=True, _scope_='sscma'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'pad_param'),
),
]

val_dataloader = dict(
batch_size=val_batch,
num_workers=val_workers,
persistent_workers=persistent_workers,
pin_memory=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
test_mode=True,
data_prefix=dict(img=val_data),
ann_file=val_ann,
pipeline=test_pipeline,
batch_shapes_cfg=batch_shapes_cfg,
),
)

test_dataloader = val_dataloader
52 changes: 34 additions & 18 deletions sscma/models/backbones/MobileNetv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from torch import Tensor

from sscma.models.base.general import ConvNormActivation
from sscma.registry import BACKBONES
from sscma.registry import MODELS

from sscma.models.layers.nn_blocks import (
UniversalInvertedBottleneckBlock,
InvertedBottleneckBlock,
Expand Down Expand Up @@ -129,7 +130,8 @@ def mhsa_large_12px():
)


@BACKBONES.register_module()

@MODELS.register_module()
class MobileNetv4(nn.Module):
'''
Architecture: https://arxiv.org/abs/2404.10518
Expand All @@ -144,7 +146,7 @@ class MobileNetv4(nn.Module):
'small': [
('convbn', 'ReLU', 3, None, None, False, 2, 32, None, False), # 1/2
('fused_ib', 'ReLU', 3, None, None, False, 2, 32, 1, False), # 1/4
('fused_ib', 'ReLU', 3, None, None, False, 2, 64, 3, False), # 1/8
('fused_ib', 'ReLU', 3, None, None, False, 2, 64, 3, True), # 1/8
('uib', 'ReLU', None, 5, 5, True, 2, 96, 3.0, False), # 1/16
('uib', 'ReLU', None, 0, 3, True, 1, 96, 2.0, False), # IB
('uib', 'ReLU', None, 0, 3, True, 1, 96, 2.0, False), # IB
Expand Down Expand Up @@ -193,7 +195,7 @@ class MobileNetv4(nn.Module):
],
'large': [
('convbn', 'ReLU', 3, None, None, False, 2, 24, None, False),
('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4.0, True),
('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4.0, False),
('uib', 'ReLU', None, 3, 5, True, 2, 96, 4.0, False),
('uib', 'ReLU', None, 3, 3, True, 1, 96, 4.0, True),
('uib', 'ReLU', None, 3, 5, True, 2, 192, 4.0, False),
Expand Down Expand Up @@ -227,9 +229,9 @@ class MobileNetv4(nn.Module):
],
'hybridmedium': [
('convbn', 'ReLU', 3, None, None, False, 2, 32, None, False), # 1/2
('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4, True), # 1/4
('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4, False), # 1/4
('uib', 'ReLU', None, 3, 5, True, 2, 80, 4.0, False), # IB
('uib', 'ReLU', None, 3, 3, True, 1, 80, 2.0, False), # IB
('uib', 'ReLU', None, 3, 3, True, 1, 80, 2.0, True), # IB
('uib', 'ReLU', None, 3, 5, True, 2, 160, 6.0, False), # IB
('uib', 'ReLU', None, 0, 0, True, 1, 160, 2.0, False), # IB
('uib', 'ReLU', None, 3, 3, True, 1, 160, 4.0, False), # IB
Expand All @@ -242,7 +244,7 @@ class MobileNetv4(nn.Module):
('uib', 'ReLU', None, 3, 3, True, 1, 160, 4.0, False),
mhsa_medium_24px(),
('uib', 'ReLU', None, 3, 0, True, 1, 160, 4.0, True),
('uib', 'ReLU', None, 5, 5, True, 2, 256, 6.0, True),
('uib', 'ReLU', None, 5, 5, True, 2, 256, 6.0, False),
('uib', 'ReLU', None, 5, 5, True, 1, 256, 4.0, False),
('uib', 'ReLU', None, 3, 5, True, 1, 256, 4.0, False),
('uib', 'ReLU', None, 3, 5, True, 1, 256, 4.0, False),
Expand All @@ -265,7 +267,7 @@ class MobileNetv4(nn.Module):
],
'hybridlarge': [
('convbn', 'GELU', 3, None, None, False, 2, 24, None, False), # 1/2
('fused_ib', 'GELU', 3, None, None, False, 2, 48, 4, True), # 1/4
('fused_ib', 'GELU', 3, None, None, False, 2, 48, 4, False), # 1/4
('uib', 'GELU', None, 3, 5, True, 2, 96, 4.0, False), # IB
('uib', 'GELU', None, 3, 3, True, 1, 96, 4.0, True), # IB
('uib', 'GELU', None, 3, 5, True, 2, 192, 4.0, False), # IB
Expand All @@ -283,7 +285,7 @@ class MobileNetv4(nn.Module):
('uib', 'GELU', None, 5, 3, True, 1, 192, 4.0, False),
mhsa_large_24px(),
('uib', 'GELU', None, 3, 0, True, 1, 192, 4.0, True), # output
('uib', 'GELU', None, 5, 5, True, 2, 512, 4.0, True),
('uib', 'GELU', None, 5, 5, True, 2, 512, 4.0, False),
('uib', 'GELU', None, 5, 5, True, 1, 512, 4.0, False),
('uib', 'GELU', None, 5, 5, True, 1, 512, 4.0, False),
('uib', 'GELU', None, 5, 5, True, 1, 512, 4.0, False),
Expand Down Expand Up @@ -326,23 +328,29 @@ def __init__(

self._output_stride: int = (1,)

self.blocks_setting = []
self.block_settings = []
for setting in arch_setting:
if isinstance(setting, tuple):
self.blocks_setting.append(BlockConfig(*setting, input_channels=input_channels))
self.block_settings.append(BlockConfig(*setting, input_channels=input_channels))
else:
self.blocks_setting.append(BlockConfig(**setting, input_channels=input_channels))
if self.blocks_setting[-1].output_channels is not None:
input_channels = self.blocks_setting[-1].output_channels
self.block_settings.append(BlockConfig(**setting, input_channels=input_channels))
if self.block_settings[-1].output_channels is not None:
input_channels = self.block_settings[-1].output_channels

self._forward_blocks = self.build_layers()
last_output_block = 0
for i, block in enumerate(self.block_settings):
if block.isoutputblock:
last_output_block = i

self._forward_blocks = self.build_layers()[: last_output_block + 1]

def build_layers(self):
layers = []
block: BlockConfig
current_stride = 1
rate = 1
for block in self.blocks_setting:

for block in self.block_settings:

if not block.stride:
block.stride = 1
Expand All @@ -355,6 +363,7 @@ def build_layers(self):
layer_stride = block.stride
layer_rate = 1
current_stride *= block.stride

if block.block_name == 'convbn':
layer = ConvNormActivation(
block.input_channels,
Expand Down Expand Up @@ -422,9 +431,16 @@ def build_layers(self):
)
else:
raise ValueError(f'block name "{block.block_name}" is not supported')

layers.append(layer)

return nn.Sequential(*layers)

def forward(self, x):
x = self._forward_blocks(x)
return x
outs = []
for cfg, blk in zip(self.block_settings, self._forward_blocks):
x = blk(x)
if cfg.isoutputblock:
outs.append(x)

return tuple(outs)
2 changes: 2 additions & 0 deletions sscma/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .EfficientNet import EfficientNet
from .MobileNetv2 import MobileNetv2
from .MobileNetv3 import MobileNetV3
from .MobileNetv4 import MobileNetv4
from .pfld_mobilenet_v2 import PfldMobileNetV2
from .ShuffleNetV2 import ShuffleNetV2, CustomShuffleNetV2, FastShuffleNetV2
from .SoundNet import SoundNetRaw
Expand All @@ -17,6 +18,7 @@
'CustomShuffleNetV2',
'AxesNet',
'MobileNetV3',
'MobileNetv4',
'ShuffleNetV2',
'SqueezeNet',
'EfficientNet',
Expand Down

0 comments on commit 8e6ad23

Please sign in to comment.