From 8e6ad23aa25435e25d3b22b36b63287d268b5eb4 Mon Sep 17 00:00:00 2001 From: iChizer0 <62390647+iChizer0@users.noreply.github.com> Date: Wed, 22 May 2024 15:40:39 +0800 Subject: [PATCH] feat: swift yolo mbnv4 * Squashed commit of the following: commit 87250f5ced163fc9dd6ad011dc86a1806278cb95 Author: mjq2020 Date: Fri Apr 26 10:33:07 2024 +0000 add: mobilenetv4 backbone commit 0771f1866ea81db5629f08ac6f9e97c156664d23 Merge: 8e0b2f7 7f9c4e0 Author: mjq2020 Date: Fri Apr 26 10:31:57 2024 +0000 Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main commit 8e0b2f72d7bc60eade1ae9d6b2d67f4975b70abc Merge: ac0f39d 9b00e64 Author: mjq2020 Date: Fri Apr 19 06:22:49 2024 +0000 Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main commit ac0f39d286914df6cf6801fde1557eff0b9b7261 Merge: c4ea712 1f67493 Author: mjq2020 Date: Mon Apr 1 10:02:11 2024 +0000 Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main commit c4ea712a5e9bf899403913b0800c732076af6938 Author: mjq2020 Date: Mon Apr 1 10:00:39 2024 +0000 Fix: cls loss weight too high commit b87fc0402e5530f0b3235c011dd2475a612c8053 Merge: f146c73 ee72f81 Author: mjq2020 Date: Mon Apr 1 09:54:45 2024 +0000 Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main commit f146c733893e17ae57ca534192c2d0e5231e79a3 Merge: c068454 289360c Author: mjq2020 Date: Tue Mar 19 02:16:47 2024 +0000 Merge branch 'main' of https://github.com/mjq2020/EdgeLab into main commit c0684545a69e4ed39a23e78157d9d8fd16ea35d4 Author: mjq2020 Date: Mon Mar 18 07:05:34 2024 +0000 Optim: model inference display commit fc8874f7a6dfc2d3d975dbff96b6dec6844aa6a0 Author: mjq2020 Date: Mon Mar 18 07:03:33 2024 +0000 Fix: data type bug commit f1d76fc2d16d3362883add1ee6f733b9dabd80b4 Merge: 1378e1b 3c61e3e Author: mjq2020 <74635395+mjq2020@users.noreply.github.com> Date: Thu Mar 14 18:42:18 2024 +0800 Merge branch 'Seeed-Studio:main' into main commit 1378e1b9b8ed4799bc2d90c1f506c431b6487ebd Merge: 8c8ffd7 31c5291 Author: mjq2020 <74635395+mjq2020@users.noreply.github.com> Date: Tue Jan 30 11:41:16 2024 +0800 Merge branch 'Seeed-Studio:main' into main commit 8c8ffd75ca355ee6fe31311e379ae36cc21ace0c Merge: c67ed2d ebb1ec2 Author: mjq2020 <74635395+mjq2020@users.noreply.github.com> Date: Fri Oct 13 11:09:55 2023 +0800 Merge branch 'Seeed-Studio:main' into main commit c67ed2dc1a0af414a75e91a26e09fec06c09d4b6 Merge: d70e424 9be0612 Author: mjq2020 <74635395+mjq2020@users.noreply.github.com> Date: Sat Sep 23 16:18:57 2023 +0800 Merge branch 'Seeed-Studio:main' into main * feat: mobilenetv4 swift yolo backbone * chore: modify mbnv4 medium/large outputs --- .../swift_yolo_1xb16_300e_coco_mbnv4s.py | 164 ++++++++++++++++++ sscma/models/backbones/MobileNetv4.py | 52 ++++-- sscma/models/backbones/__init__.py | 2 + 3 files changed, 200 insertions(+), 18 deletions(-) create mode 100644 configs/swift_yolo/swift_yolo_1xb16_300e_coco_mbnv4s.py diff --git a/configs/swift_yolo/swift_yolo_1xb16_300e_coco_mbnv4s.py b/configs/swift_yolo/swift_yolo_1xb16_300e_coco_mbnv4s.py new file mode 100644 index 00000000..f08b03d7 --- /dev/null +++ b/configs/swift_yolo/swift_yolo_1xb16_300e_coco_mbnv4s.py @@ -0,0 +1,164 @@ +# Copyright (c) Seeed Technology Co.,Ltd. All rights reserved. +_base_ = ['./base_arch.py'] + +# ========================Suggested optional parameters======================== +# MODEL +num_classes = 71 +deepen_factor = 0.33 +widen_factor = 1 + +# DATA +dataset_type = 'sscma.CustomYOLOv5CocoDataset' +train_ann = 'train/_annotations.coco.json' +train_data = 'train/' # Prefix of train image path +val_ann = 'valid/_annotations.coco.json' +val_data = 'valid/' # Prefix of val image path + +# dataset link: https://universe.roboflow.com/team-roboflow/coco-128 +data_root = 'https://universe.roboflow.com/ds/z5UOcgxZzD?key=bwx9LQUT0t' +height = 192 +width = 192 +batch = 16 +workers = 2 +val_batch = batch +val_workers = workers +imgsz = (width, height) + +# TRAIN +persistent_workers = True + +# ================================END================================= + +# DATA +affine_scale = 0.5 +# MODEL +strides = [8, 16, 32] + +anchors = [ + [(10, 13), (16, 30), (33, 23)], # P3/8 + [(30, 61), (62, 45), (59, 119)], # P4/16 + [(116, 90), (156, 198), (373, 326)], # P5/32 +] + +# default_scope = 'sscma' + +model = dict( + type='mmyolo.YOLODetector', + backbone=dict( + _delete_=True, + type='sscma.MobileNetv4', + arch='small' + ), + neck=dict( + type='mmyolo.YOLOv5PAFPN', + deepen_factor=deepen_factor, + widen_factor=widen_factor, + in_channels=[64, 96, 128], + out_channels=[64, 96, 128] + ), + bbox_head=dict( + head_module=dict( + num_classes=num_classes, + in_channels=[64, 96, 128], + widen_factor=widen_factor, + ), + ), +) + +# ======================datasets================== + + +batch_shapes_cfg = dict( + type='BatchShapePolicy', + batch_size=1, + img_size=imgsz[0], + # The image scale of padding should be divided by pad_size_divisor + size_divisor=32, + # Additional paddings for pixel scale + extra_pad_ratio=0.5, +) + +albu_train_transforms = [ + dict(type='Blur', p=0.01), + dict(type='MedianBlur', p=0.01), + dict(type='ToGray', p=0.01), + dict(type='CLAHE', p=0.01), +] + +pre_transform = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict(type='LoadAnnotations', with_bbox=True, _scope_='sscma'), +] + +train_pipeline = [ + *pre_transform, + dict(type='Mosaic', img_scale=imgsz, pad_val=114.0, pre_transform=pre_transform, _scope_='sscma'), + dict( + type='YOLOv5RandomAffine', + max_rotate_degree=0.0, + max_shear_degree=0.0, + scaling_ratio_range=(1 - affine_scale, 1 + affine_scale), + # imgsz is (width, height) + border=(-imgsz[0] // 2, -imgsz[1] // 2), + border_val=(114, 114, 114), + _scope_='sscma' + ), + dict( + type='mmdet.Albu', + transforms=albu_train_transforms, + bbox_params=dict(type='BboxParams', format='pascal_voc', label_fields=['gt_bboxes_labels', 'gt_ignore_flags']), + keymap={'img': 'image', 'gt_bboxes': 'bboxes'}, + ), + dict(type='YOLOv5HSVRandomAug', _scope_='sscma'), + dict(type='mmdet.RandomFlip', prob=0.5), + dict( + type='mmdet.PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', 'flip_direction') + ), +] + +train_dataloader = dict( + batch_size=batch, + num_workers=workers, + persistent_workers=persistent_workers, + pin_memory=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=train_ann, + data_prefix=dict(img=train_data), + filter_cfg=dict(filter_empty_gt=False, min_size=32), + pipeline=train_pipeline, + ), +) + +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')), + dict(type='YOLOv5KeepRatioResize', scale=imgsz, _scope_='sscma'), + dict(type='sscma.LetterResize', scale=imgsz, allow_scale_up=False, pad_val=dict(img=114), _scope_='sscma'), + dict(type='LoadAnnotations', with_bbox=True, _scope_='sscma'), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'pad_param'), + ), +] + +val_dataloader = dict( + batch_size=val_batch, + num_workers=val_workers, + persistent_workers=persistent_workers, + pin_memory=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + test_mode=True, + data_prefix=dict(img=val_data), + ann_file=val_ann, + pipeline=test_pipeline, + batch_shapes_cfg=batch_shapes_cfg, + ), +) + +test_dataloader = val_dataloader diff --git a/sscma/models/backbones/MobileNetv4.py b/sscma/models/backbones/MobileNetv4.py index d2be6cca..3b42515e 100644 --- a/sscma/models/backbones/MobileNetv4.py +++ b/sscma/models/backbones/MobileNetv4.py @@ -5,7 +5,8 @@ from torch import Tensor from sscma.models.base.general import ConvNormActivation -from sscma.registry import BACKBONES +from sscma.registry import MODELS + from sscma.models.layers.nn_blocks import ( UniversalInvertedBottleneckBlock, InvertedBottleneckBlock, @@ -129,7 +130,8 @@ def mhsa_large_12px(): ) -@BACKBONES.register_module() + +@MODELS.register_module() class MobileNetv4(nn.Module): ''' Architecture: https://arxiv.org/abs/2404.10518 @@ -144,7 +146,7 @@ class MobileNetv4(nn.Module): 'small': [ ('convbn', 'ReLU', 3, None, None, False, 2, 32, None, False), # 1/2 ('fused_ib', 'ReLU', 3, None, None, False, 2, 32, 1, False), # 1/4 - ('fused_ib', 'ReLU', 3, None, None, False, 2, 64, 3, False), # 1/8 + ('fused_ib', 'ReLU', 3, None, None, False, 2, 64, 3, True), # 1/8 ('uib', 'ReLU', None, 5, 5, True, 2, 96, 3.0, False), # 1/16 ('uib', 'ReLU', None, 0, 3, True, 1, 96, 2.0, False), # IB ('uib', 'ReLU', None, 0, 3, True, 1, 96, 2.0, False), # IB @@ -193,7 +195,7 @@ class MobileNetv4(nn.Module): ], 'large': [ ('convbn', 'ReLU', 3, None, None, False, 2, 24, None, False), - ('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4.0, True), + ('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4.0, False), ('uib', 'ReLU', None, 3, 5, True, 2, 96, 4.0, False), ('uib', 'ReLU', None, 3, 3, True, 1, 96, 4.0, True), ('uib', 'ReLU', None, 3, 5, True, 2, 192, 4.0, False), @@ -227,9 +229,9 @@ class MobileNetv4(nn.Module): ], 'hybridmedium': [ ('convbn', 'ReLU', 3, None, None, False, 2, 32, None, False), # 1/2 - ('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4, True), # 1/4 + ('fused_ib', 'ReLU', 3, None, None, False, 2, 48, 4, False), # 1/4 ('uib', 'ReLU', None, 3, 5, True, 2, 80, 4.0, False), # IB - ('uib', 'ReLU', None, 3, 3, True, 1, 80, 2.0, False), # IB + ('uib', 'ReLU', None, 3, 3, True, 1, 80, 2.0, True), # IB ('uib', 'ReLU', None, 3, 5, True, 2, 160, 6.0, False), # IB ('uib', 'ReLU', None, 0, 0, True, 1, 160, 2.0, False), # IB ('uib', 'ReLU', None, 3, 3, True, 1, 160, 4.0, False), # IB @@ -242,7 +244,7 @@ class MobileNetv4(nn.Module): ('uib', 'ReLU', None, 3, 3, True, 1, 160, 4.0, False), mhsa_medium_24px(), ('uib', 'ReLU', None, 3, 0, True, 1, 160, 4.0, True), - ('uib', 'ReLU', None, 5, 5, True, 2, 256, 6.0, True), + ('uib', 'ReLU', None, 5, 5, True, 2, 256, 6.0, False), ('uib', 'ReLU', None, 5, 5, True, 1, 256, 4.0, False), ('uib', 'ReLU', None, 3, 5, True, 1, 256, 4.0, False), ('uib', 'ReLU', None, 3, 5, True, 1, 256, 4.0, False), @@ -265,7 +267,7 @@ class MobileNetv4(nn.Module): ], 'hybridlarge': [ ('convbn', 'GELU', 3, None, None, False, 2, 24, None, False), # 1/2 - ('fused_ib', 'GELU', 3, None, None, False, 2, 48, 4, True), # 1/4 + ('fused_ib', 'GELU', 3, None, None, False, 2, 48, 4, False), # 1/4 ('uib', 'GELU', None, 3, 5, True, 2, 96, 4.0, False), # IB ('uib', 'GELU', None, 3, 3, True, 1, 96, 4.0, True), # IB ('uib', 'GELU', None, 3, 5, True, 2, 192, 4.0, False), # IB @@ -283,7 +285,7 @@ class MobileNetv4(nn.Module): ('uib', 'GELU', None, 5, 3, True, 1, 192, 4.0, False), mhsa_large_24px(), ('uib', 'GELU', None, 3, 0, True, 1, 192, 4.0, True), # output - ('uib', 'GELU', None, 5, 5, True, 2, 512, 4.0, True), + ('uib', 'GELU', None, 5, 5, True, 2, 512, 4.0, False), ('uib', 'GELU', None, 5, 5, True, 1, 512, 4.0, False), ('uib', 'GELU', None, 5, 5, True, 1, 512, 4.0, False), ('uib', 'GELU', None, 5, 5, True, 1, 512, 4.0, False), @@ -326,23 +328,29 @@ def __init__( self._output_stride: int = (1,) - self.blocks_setting = [] + self.block_settings = [] for setting in arch_setting: if isinstance(setting, tuple): - self.blocks_setting.append(BlockConfig(*setting, input_channels=input_channels)) + self.block_settings.append(BlockConfig(*setting, input_channels=input_channels)) else: - self.blocks_setting.append(BlockConfig(**setting, input_channels=input_channels)) - if self.blocks_setting[-1].output_channels is not None: - input_channels = self.blocks_setting[-1].output_channels + self.block_settings.append(BlockConfig(**setting, input_channels=input_channels)) + if self.block_settings[-1].output_channels is not None: + input_channels = self.block_settings[-1].output_channels - self._forward_blocks = self.build_layers() + last_output_block = 0 + for i, block in enumerate(self.block_settings): + if block.isoutputblock: + last_output_block = i + + self._forward_blocks = self.build_layers()[: last_output_block + 1] def build_layers(self): layers = [] block: BlockConfig current_stride = 1 rate = 1 - for block in self.blocks_setting: + + for block in self.block_settings: if not block.stride: block.stride = 1 @@ -355,6 +363,7 @@ def build_layers(self): layer_stride = block.stride layer_rate = 1 current_stride *= block.stride + if block.block_name == 'convbn': layer = ConvNormActivation( block.input_channels, @@ -422,9 +431,16 @@ def build_layers(self): ) else: raise ValueError(f'block name "{block.block_name}" is not supported') + layers.append(layer) + return nn.Sequential(*layers) def forward(self, x): - x = self._forward_blocks(x) - return x + outs = [] + for cfg, blk in zip(self.block_settings, self._forward_blocks): + x = blk(x) + if cfg.isoutputblock: + outs.append(x) + + return tuple(outs) diff --git a/sscma/models/backbones/__init__.py b/sscma/models/backbones/__init__.py index c65e8783..bb35d84f 100644 --- a/sscma/models/backbones/__init__.py +++ b/sscma/models/backbones/__init__.py @@ -3,6 +3,7 @@ from .EfficientNet import EfficientNet from .MobileNetv2 import MobileNetv2 from .MobileNetv3 import MobileNetV3 +from .MobileNetv4 import MobileNetv4 from .pfld_mobilenet_v2 import PfldMobileNetV2 from .ShuffleNetV2 import ShuffleNetV2, CustomShuffleNetV2, FastShuffleNetV2 from .SoundNet import SoundNetRaw @@ -17,6 +18,7 @@ 'CustomShuffleNetV2', 'AxesNet', 'MobileNetV3', + 'MobileNetv4', 'ShuffleNetV2', 'SqueezeNet', 'EfficientNet',