-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add: Add YOLOv5 CustomCocoDataset and Rename Fomo CocoDataset (#214)
* Refractor: Rename the pfld dataset to CustomFomoCocoDataset * Add: Add CustomYOLOv5CocoDataset * Add: Add YOLOv5 parameter scheduler and optimizer
- Loading branch information
Showing
13 changed files
with
423 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import List, Sequence | ||
|
||
import numpy as np | ||
import torch | ||
from mmengine.dataset import COLLATE_FUNCTIONS | ||
|
||
from sscma.registry import TASK_UTILS | ||
|
||
|
||
# @COLLATE_FUNCTIONS.register_module() | ||
def yolov5_collate(data_batch: Sequence, | ||
use_ms_training: bool = False) -> dict: | ||
"""Rewrite collate_fn to get faster training speed. | ||
Args: | ||
data_batch (Sequence): Batch of data. | ||
use_ms_training (bool): Whether to use multi-scale training. | ||
""" | ||
batch_imgs = [] | ||
batch_bboxes_labels = [] | ||
batch_masks = [] | ||
for i in range(len(data_batch)): | ||
datasamples = data_batch[i]['data_samples'] | ||
inputs = data_batch[i]['inputs'] | ||
batch_imgs.append(inputs) | ||
|
||
gt_bboxes = datasamples.gt_instances.bboxes.tensor | ||
gt_labels = datasamples.gt_instances.labels | ||
if 'masks' in datasamples.gt_instances: | ||
masks = datasamples.gt_instances.masks.to_tensor( | ||
dtype=torch.bool, device=gt_bboxes.device) | ||
batch_masks.append(masks) | ||
batch_idx = gt_labels.new_full((len(gt_labels), 1), i) | ||
bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes), | ||
dim=1) | ||
batch_bboxes_labels.append(bboxes_labels) | ||
|
||
collated_results = { | ||
'data_samples': { | ||
'bboxes_labels': torch.cat(batch_bboxes_labels, 0) | ||
} | ||
} | ||
if len(batch_masks) > 0: | ||
collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0) | ||
|
||
if use_ms_training: | ||
collated_results['inputs'] = batch_imgs | ||
else: | ||
collated_results['inputs'] = torch.stack(batch_imgs, 0) | ||
return collated_results | ||
|
||
|
||
@TASK_UTILS.register_module() | ||
class BatchShapePolicy: | ||
"""BatchShapePolicy is only used in the testing phase, which can reduce the | ||
number of pad pixels during batch inference. | ||
Args: | ||
batch_size (int): Single GPU batch size during batch inference. | ||
Defaults to 32. | ||
img_size (int): Expected output image size. Defaults to 640. | ||
size_divisor (int): The minimum size that is divisible | ||
by size_divisor. Defaults to 32. | ||
extra_pad_ratio (float): Extra pad ratio. Defaults to 0.5. | ||
""" | ||
|
||
def __init__(self, | ||
batch_size: int = 32, | ||
img_size: int = 640, | ||
size_divisor: int = 32, | ||
extra_pad_ratio: float = 0.5): | ||
self.batch_size = batch_size | ||
self.img_size = img_size | ||
self.size_divisor = size_divisor | ||
self.extra_pad_ratio = extra_pad_ratio | ||
|
||
def __call__(self, data_list: List[dict]) -> List[dict]: | ||
image_shapes = [] | ||
for data_info in data_list: | ||
image_shapes.append((data_info['width'], data_info['height'])) | ||
|
||
image_shapes = np.array(image_shapes, dtype=np.float64) | ||
|
||
n = len(image_shapes) # number of images | ||
batch_index = np.floor(np.arange(n) / self.batch_size).astype( | ||
np.int64) # batch index | ||
number_of_batches = batch_index[-1] + 1 # number of batches | ||
|
||
aspect_ratio = image_shapes[:, 1] / image_shapes[:, 0] # aspect ratio | ||
irect = aspect_ratio.argsort() | ||
|
||
data_list = [data_list[i] for i in irect] | ||
|
||
aspect_ratio = aspect_ratio[irect] | ||
# Set training image shapes | ||
shapes = [[1, 1]] * number_of_batches | ||
for i in range(number_of_batches): | ||
aspect_ratio_index = aspect_ratio[batch_index == i] | ||
min_index, max_index = aspect_ratio_index.min( | ||
), aspect_ratio_index.max() | ||
if max_index < 1: | ||
shapes[i] = [max_index, 1] | ||
elif min_index > 1: | ||
shapes[i] = [1, 1 / min_index] | ||
|
||
batch_shapes = np.ceil( | ||
np.array(shapes) * self.img_size / self.size_divisor + | ||
self.extra_pad_ratio).astype(np.int64) * self.size_divisor | ||
|
||
for i, data_info in enumerate(data_list): | ||
data_info['batch_shape'] = batch_shapes[batch_index[i]] | ||
|
||
return data_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) Seeed Technology Co.,Ltd. All rights reserved. | ||
import math | ||
from typing import Optional | ||
|
||
import numpy as np | ||
from mmengine.hooks import ParamSchedulerHook | ||
from mmengine.runner import Runner | ||
|
||
from sscma.registry import HOOKS | ||
|
||
|
||
def linear_fn(lr_factor: float, max_epochs: int): | ||
"""Generate linear function.""" | ||
return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor | ||
|
||
|
||
def cosine_fn(lr_factor: float, max_epochs: int): | ||
"""Generate cosine function.""" | ||
return lambda x: ((1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1 | ||
|
||
|
||
@HOOKS.register_module() | ||
class YOLOv5ParamSchedulerHook(ParamSchedulerHook): | ||
"""A hook to update learning rate and momentum in optimizer of YOLOv5.""" | ||
|
||
priority = 9 | ||
|
||
scheduler_maps = {'linear': linear_fn, 'cosine': cosine_fn} | ||
|
||
def __init__( | ||
self, | ||
scheduler_type: str = 'linear', | ||
lr_factor: float = 0.01, | ||
max_epochs: int = 300, | ||
warmup_epochs: int = 3, | ||
warmup_bias_lr: float = 0.1, | ||
warmup_momentum: float = 0.8, | ||
warmup_mim_iter: int = 1000, | ||
**kwargs, | ||
): | ||
assert scheduler_type in self.scheduler_maps | ||
|
||
self.warmup_epochs = warmup_epochs | ||
self.warmup_bias_lr = warmup_bias_lr | ||
self.warmup_momentum = warmup_momentum | ||
self.warmup_mim_iter = warmup_mim_iter | ||
|
||
kwargs.update({'lr_factor': lr_factor, 'max_epochs': max_epochs}) | ||
self.scheduler_fn = self.scheduler_maps[scheduler_type](**kwargs) | ||
|
||
self._warmup_end = False | ||
self._base_lr = None | ||
self._base_momentum = None | ||
|
||
def before_train(self, runner: Runner): | ||
"""Operations before train. | ||
Args: | ||
runner (Runner): The runner of the training process. | ||
""" | ||
optimizer = runner.optim_wrapper.optimizer | ||
for group in optimizer.param_groups: | ||
# If the param is never be scheduled, record the current value | ||
# as the initial value. | ||
group.setdefault('initial_lr', group['lr']) | ||
group.setdefault('initial_momentum', group.get('momentum', -1)) | ||
|
||
self._base_lr = [group['initial_lr'] for group in optimizer.param_groups] | ||
self._base_momentum = [group['initial_momentum'] for group in optimizer.param_groups] | ||
|
||
def before_train_iter(self, runner: Runner, batch_idx: int, data_batch: Optional[dict] = None): | ||
"""Operations before each training iteration. | ||
Args: | ||
runner (Runner): The runner of the training process. | ||
batch_idx (int): The index of the current batch in the train loop. | ||
data_batch (dict or tuple or list, optional): Data from dataloader. | ||
""" | ||
cur_iters = runner.iter | ||
cur_epoch = runner.epoch | ||
optimizer = runner.optim_wrapper.optimizer | ||
|
||
# The minimum warmup is self.warmup_mim_iter | ||
warmup_total_iters = max(round(self.warmup_epochs * len(runner.train_dataloader)), self.warmup_mim_iter) | ||
|
||
if cur_iters <= warmup_total_iters: | ||
xp = [0, warmup_total_iters] | ||
for group_idx, param in enumerate(optimizer.param_groups): | ||
if group_idx == 2: | ||
# bias learning rate will be handled specially | ||
yp = [self.warmup_bias_lr, self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)] | ||
else: | ||
yp = [0.0, self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)] | ||
param['lr'] = np.interp(cur_iters, xp, yp) | ||
|
||
if 'momentum' in param: | ||
param['momentum'] = np.interp(cur_iters, xp, [self.warmup_momentum, self._base_momentum[group_idx]]) | ||
else: | ||
self._warmup_end = True | ||
|
||
def after_train_epoch(self, runner: Runner): | ||
"""Operations after each training epoch. | ||
Args: | ||
runner (Runner): The runner of the training process. | ||
""" | ||
if not self._warmup_end: | ||
return | ||
|
||
cur_epoch = runner.epoch | ||
optimizer = runner.optim_wrapper.optimizer | ||
for group_idx, param in enumerate(optimizer.param_groups): | ||
param['lr'] = self._base_lr[group_idx] * self.scheduler_fn(cur_epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (c) Seeed Technology Co.,Ltd. All rights reserved. | ||
from .yolov5_optimizer import * # noqa |
Oops, something went wrong.