Skip to content

Commit

Permalink
Add: Add YOLOv5 CustomCocoDataset and Rename Fomo CocoDataset (#214)
Browse files Browse the repository at this point in the history
* Refractor: Rename the pfld dataset to CustomFomoCocoDataset

* Add: Add CustomYOLOv5CocoDataset

* Add: Add YOLOv5 parameter scheduler and optimizer
  • Loading branch information
MILK-BIOS authored Apr 24, 2024
1 parent 68eb5b0 commit 7f9c4e0
Show file tree
Hide file tree
Showing 13 changed files with 423 additions and 11 deletions.
2 changes: 1 addition & 1 deletion configs/fomo/fomo_mobnetv2_0.35_x8_abl_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
widen_factor = 0.35

# DATA
dataset_type = 'CustomCocoDataset'
dataset_type = 'CustomFomoCocoDataset'
# datasets link: https://public.roboflow.com/object-detection/mask-wearing
data_root = 'https://public.roboflow.com/ds/o8GgfOIazi?key=hES8s8Gy7u'

Expand Down
2 changes: 1 addition & 1 deletion configs/fomo/fomo_mobnetv2_0.35_x8_abl_coco_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
widen_factor = 0.35

# DATA
dataset_type = 'CustomCocoDataset'
dataset_type = 'CustomFomoCocoDataset'
# datasets link: https://public.roboflow.com/object-detection/mask-wearing
data_root = 'https://public.roboflow.com/ds/o8GgfOIazi?key=hES8s8Gy7u'

Expand Down
2 changes: 1 addition & 1 deletion configs/fomo/fomo_mobnetv2_x8_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
widen_factor = 1

# DATA
dataset_type = 'CustomCocoDataset'
dataset_type = 'CustomFomoCocoDataset'
# datasets link: https://public.roboflow.com/object-detection/mask-wearing
data_root = 'https://public.roboflow.com/ds/o8GgfOIazi?key=hES8s8Gy7u'

Expand Down
4 changes: 2 additions & 2 deletions sscma/datasets/cocodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os.path as osp
from typing import Optional, Sequence, List

from mmdet.datasets.coco import CocoDataset
from mmdet.datasets import CocoDataset
from mmengine.fileio import get_local_path

from sscma.registry import DATASETS
Expand All @@ -12,7 +12,7 @@


@DATASETS.register_module()
class CustomCocoDataset(CocoDataset):
class CustomFomoCocoDataset(CocoDataset):
METAINFO = {
'classes': (),
# palette is a list of color tuples, which is used for visualization.
Expand Down
6 changes: 5 additions & 1 deletion sscma/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from .loading import LoadSensorFromFile
from .wrappers import MutiBranchPipe
from .transforms import YOLOv5KeepRatioResize, LetterResize, YOLOv5HSVRandomAug, YOLOv5RandomAffine, LoadAnnotations, Mosaic
from .utils import BatchShapePolicy, yolov5_collate


__all__ = ['PackSensorInputs',
'LoadSensorFromFile',
Expand All @@ -12,4 +14,6 @@
'YOLOv5HSVRandomAug',
'YOLOv5RandomAffine',
'LoadAnnotations',
'Mosaic']
'Mosaic',
'BatchShapePolicy',
'yolov5_collate']
114 changes: 114 additions & 0 deletions sscma/datasets/transforms/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence

import numpy as np
import torch
from mmengine.dataset import COLLATE_FUNCTIONS

from sscma.registry import TASK_UTILS


# @COLLATE_FUNCTIONS.register_module()
def yolov5_collate(data_batch: Sequence,
use_ms_training: bool = False) -> dict:
"""Rewrite collate_fn to get faster training speed.
Args:
data_batch (Sequence): Batch of data.
use_ms_training (bool): Whether to use multi-scale training.
"""
batch_imgs = []
batch_bboxes_labels = []
batch_masks = []
for i in range(len(data_batch)):
datasamples = data_batch[i]['data_samples']
inputs = data_batch[i]['inputs']
batch_imgs.append(inputs)

gt_bboxes = datasamples.gt_instances.bboxes.tensor
gt_labels = datasamples.gt_instances.labels
if 'masks' in datasamples.gt_instances:
masks = datasamples.gt_instances.masks.to_tensor(
dtype=torch.bool, device=gt_bboxes.device)
batch_masks.append(masks)
batch_idx = gt_labels.new_full((len(gt_labels), 1), i)
bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes),
dim=1)
batch_bboxes_labels.append(bboxes_labels)

collated_results = {
'data_samples': {
'bboxes_labels': torch.cat(batch_bboxes_labels, 0)
}
}
if len(batch_masks) > 0:
collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0)

if use_ms_training:
collated_results['inputs'] = batch_imgs
else:
collated_results['inputs'] = torch.stack(batch_imgs, 0)
return collated_results


@TASK_UTILS.register_module()
class BatchShapePolicy:
"""BatchShapePolicy is only used in the testing phase, which can reduce the
number of pad pixels during batch inference.
Args:
batch_size (int): Single GPU batch size during batch inference.
Defaults to 32.
img_size (int): Expected output image size. Defaults to 640.
size_divisor (int): The minimum size that is divisible
by size_divisor. Defaults to 32.
extra_pad_ratio (float): Extra pad ratio. Defaults to 0.5.
"""

def __init__(self,
batch_size: int = 32,
img_size: int = 640,
size_divisor: int = 32,
extra_pad_ratio: float = 0.5):
self.batch_size = batch_size
self.img_size = img_size
self.size_divisor = size_divisor
self.extra_pad_ratio = extra_pad_ratio

def __call__(self, data_list: List[dict]) -> List[dict]:
image_shapes = []
for data_info in data_list:
image_shapes.append((data_info['width'], data_info['height']))

image_shapes = np.array(image_shapes, dtype=np.float64)

n = len(image_shapes) # number of images
batch_index = np.floor(np.arange(n) / self.batch_size).astype(
np.int64) # batch index
number_of_batches = batch_index[-1] + 1 # number of batches

aspect_ratio = image_shapes[:, 1] / image_shapes[:, 0] # aspect ratio
irect = aspect_ratio.argsort()

data_list = [data_list[i] for i in irect]

aspect_ratio = aspect_ratio[irect]
# Set training image shapes
shapes = [[1, 1]] * number_of_batches
for i in range(number_of_batches):
aspect_ratio_index = aspect_ratio[batch_index == i]
min_index, max_index = aspect_ratio_index.min(
), aspect_ratio_index.max()
if max_index < 1:
shapes[i] = [max_index, 1]
elif min_index > 1:
shapes[i] = [1, 1 / min_index]

batch_shapes = np.ceil(
np.array(shapes) * self.img_size / self.size_divisor +
self.extra_pad_ratio).astype(np.int64) * self.size_divisor

for i, data_info in enumerate(data_list):
data_info['batch_shape'] = batch_shapes[batch_index[i]]

return data_list
59 changes: 54 additions & 5 deletions sscma/datasets/yolodataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,64 @@
# Copyright (c) Seeed Technology Co.,Ltd. All rights reserved.
import json
import os.path as osp
from typing import Optional, Sequence
from typing import Optional, Sequence, Any

from mmyolo.datasets.yolov5_coco import YOLOv5CocoDataset
from mmdet.datasets import BaseDetDataset, CocoDataset

from sscma.registry import DATASETS
from sscma.registry import DATASETS, TASK_UTILS


class BatchShapePolicyDataset(BaseDetDataset):
"""Dataset with the batch shape policy that makes paddings with least
pixels during batch inference process, which does not require the image
scales of all batches to be the same throughout validation."""

def __init__(self,
*args,
batch_shapes_cfg: Optional[dict] = None,
**kwargs):
self.batch_shapes_cfg = batch_shapes_cfg
super().__init__(*args, **kwargs)

def full_init(self):
"""rewrite full_init() to be compatible with serialize_data in
BatchShapePolicy."""
if self._fully_initialized:
return
# load data information
self.data_list = self.load_data_list()

# batch_shapes_cfg
if self.batch_shapes_cfg:
batch_shapes_policy = TASK_UTILS.build(self.batch_shapes_cfg)
self.data_list = batch_shapes_policy(self.data_list)
del batch_shapes_policy

# filter illegal data, such as data that has no annotations.
self.data_list = self.filter_data()
# Get subset data according to indices.
if self._indices is not None:
self.data_list = self._get_unserialized_subset(self._indices)

# serialize data_list
if self.serialize_data:
self.data_bytes, self.data_address = self._serialize_data()

self._fully_initialized = True

def prepare_data(self, idx: int) -> Any:
"""Pass the dataset to the pipeline during training to support mixed
data augmentation, such as Mosaic and MixUp."""
if self.test_mode is False:
data_info = self.get_data_info(idx)
data_info['dataset'] = self
return self.pipeline(data_info)
else:
return super().prepare_data(idx)


@DATASETS.register_module()
class CustomYOLOv5CocoDataset(YOLOv5CocoDataset):
class CustomYOLOv5CocoDataset(BatchShapePolicyDataset, CocoDataset):
METAINFO = {
'classes': (),
# palette is a list of color tuples, which is used for visualization.
Expand Down Expand Up @@ -130,4 +179,4 @@ def __init__(
data_root=data_root,
batch_shapes_cfg=batch_shapes_cfg,
**kwargs,
)
)
2 changes: 2 additions & 0 deletions sscma/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
WandbLoggerHook,
)
from .runner import GetEpochBasedTrainLoop
from .optimizers import YOLOv5OptimizerConstructor

__all__ = [
'TextLoggerHook',
Expand All @@ -17,4 +18,5 @@
'GetEpochBasedTrainLoop',
'Posevisualization',
'DetFomoVisualizationHook',
'YOLOv5OptimizerConstructor',
]
2 changes: 2 additions & 0 deletions sscma/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Posevisualization,
SensorVisualizationHook,
)
from .yolov5_param_scheduler import YOLOv5ParamSchedulerHook

__all__ = [
'TextLoggerHook',
Expand All @@ -22,4 +23,5 @@
'DetFomoVisualizationHook',
'SensorVisualizationHook',
'SemiHook',
'YOLOv5ParamSchedulerHook',
]
113 changes: 113 additions & 0 deletions sscma/engine/hooks/yolov5_param_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Seeed Technology Co.,Ltd. All rights reserved.
import math
from typing import Optional

import numpy as np
from mmengine.hooks import ParamSchedulerHook
from mmengine.runner import Runner

from sscma.registry import HOOKS


def linear_fn(lr_factor: float, max_epochs: int):
"""Generate linear function."""
return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor


def cosine_fn(lr_factor: float, max_epochs: int):
"""Generate cosine function."""
return lambda x: ((1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1


@HOOKS.register_module()
class YOLOv5ParamSchedulerHook(ParamSchedulerHook):
"""A hook to update learning rate and momentum in optimizer of YOLOv5."""

priority = 9

scheduler_maps = {'linear': linear_fn, 'cosine': cosine_fn}

def __init__(
self,
scheduler_type: str = 'linear',
lr_factor: float = 0.01,
max_epochs: int = 300,
warmup_epochs: int = 3,
warmup_bias_lr: float = 0.1,
warmup_momentum: float = 0.8,
warmup_mim_iter: int = 1000,
**kwargs,
):
assert scheduler_type in self.scheduler_maps

self.warmup_epochs = warmup_epochs
self.warmup_bias_lr = warmup_bias_lr
self.warmup_momentum = warmup_momentum
self.warmup_mim_iter = warmup_mim_iter

kwargs.update({'lr_factor': lr_factor, 'max_epochs': max_epochs})
self.scheduler_fn = self.scheduler_maps[scheduler_type](**kwargs)

self._warmup_end = False
self._base_lr = None
self._base_momentum = None

def before_train(self, runner: Runner):
"""Operations before train.
Args:
runner (Runner): The runner of the training process.
"""
optimizer = runner.optim_wrapper.optimizer
for group in optimizer.param_groups:
# If the param is never be scheduled, record the current value
# as the initial value.
group.setdefault('initial_lr', group['lr'])
group.setdefault('initial_momentum', group.get('momentum', -1))

self._base_lr = [group['initial_lr'] for group in optimizer.param_groups]
self._base_momentum = [group['initial_momentum'] for group in optimizer.param_groups]

def before_train_iter(self, runner: Runner, batch_idx: int, data_batch: Optional[dict] = None):
"""Operations before each training iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict or tuple or list, optional): Data from dataloader.
"""
cur_iters = runner.iter
cur_epoch = runner.epoch
optimizer = runner.optim_wrapper.optimizer

# The minimum warmup is self.warmup_mim_iter
warmup_total_iters = max(round(self.warmup_epochs * len(runner.train_dataloader)), self.warmup_mim_iter)

if cur_iters <= warmup_total_iters:
xp = [0, warmup_total_iters]
for group_idx, param in enumerate(optimizer.param_groups):
if group_idx == 2:
# bias learning rate will be handled specially
yp = [self.warmup_bias_lr, self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)]
else:
yp = [0.0, self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)]
param['lr'] = np.interp(cur_iters, xp, yp)

if 'momentum' in param:
param['momentum'] = np.interp(cur_iters, xp, [self.warmup_momentum, self._base_momentum[group_idx]])
else:
self._warmup_end = True

def after_train_epoch(self, runner: Runner):
"""Operations after each training epoch.
Args:
runner (Runner): The runner of the training process.
"""
if not self._warmup_end:
return

cur_epoch = runner.epoch
optimizer = runner.optim_wrapper.optimizer
for group_idx, param in enumerate(optimizer.param_groups):
param['lr'] = self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
2 changes: 2 additions & 0 deletions sscma/engine/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Seeed Technology Co.,Ltd. All rights reserved.
from .yolov5_optimizer import * # noqa
Loading

0 comments on commit 7f9c4e0

Please sign in to comment.