Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize model train time #116

Merged
merged 18 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions configs/fomo/fomo_mobnetv2_0.35_x8_abl_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

visualizer = dict(type='FomoLocalVisualizer', fomo=True)

num_classes = 2
num_classes = 1
data_preprocessor = dict(
type='mmdet.DetDataPreprocessor', mean=[0, 0, 0], std=[255.0, 255.0, 255.0], bgr_to_rgb=True, pad_size_divisor=32
)
Expand All @@ -22,6 +22,7 @@
loss_cls=dict(type='BCEWithLogitsLoss', reduction='none', pos_weight=40),
loss_bg=dict(type='BCEWithLogitsLoss', reduction='none'),
),
skip_preprocessor=True,
)

# dataset settings
Expand All @@ -33,7 +34,6 @@
workers = 1

albu_train_transforms = [
dict(type='RandomResizedCrop', height=height, width=width, scale=(0.80, 1.2), p=1),
dict(type='Rotate', limit=30),
dict(type='RandomBrightnessContrast', brightness_limit=0.3, contrast_limit=0.3, p=0.5),
dict(type='Blur', p=0.01),
Expand All @@ -46,24 +46,40 @@
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(type='mmdet.LoadAnnotations', with_bbox=True),
]

train_pipeline = [
*pre_transform,
dict(type='mmdet.Resize', scale=(height, width)),
dict(
type='mmdet.Albu',
transforms=albu_train_transforms,
bbox_params=dict(type='BboxParams', format='pascal_voc', label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
keymap={'img': 'image', 'gt_bboxes': 'bboxes'},
),
dict(type='Bbox2FomoMask', downsample_factor=(8,), num_classes=num_classes),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_path', 'img_id', 'instances', 'img_shape', 'ori_shape', 'gt_bboxes', 'gt_bboxes_labels'),
meta_keys=(
'fomo_mask',
'img_path',
'img_id',
'instances',
'img_shape',
'ori_shape',
'gt_bboxes',
'gt_bboxes_labels',
),
),
]

test_pipeline = [
*pre_transform,
dict(type='mmdet.Resize', scale=(height, width)),
dict(type='mmdet.PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')),
dict(type='Bbox2FomoMask', downsample_factor=(8,), num_classes=num_classes, ori_shape=(height, width)),
dict(
type='mmdet.PackDetInputs',
meta_keys=('fomo_mask', 'img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor'),
),
]

train_dataloader = dict(
Expand Down Expand Up @@ -103,13 +119,14 @@

find_unused_parameters = True

# optim_wrapper = dict(type="AmpOptimWrapper",optimizer=dict(type='Adam', lr=lr, weight_decay=5e-4, eps=1e-7))
optim_wrapper = dict(optimizer=dict(type='Adam', lr=lr, weight_decay=5e-4, eps=1e-7))

# evaluator
val_evaluator = dict(type='FomoMetric')
test_evaluator = val_evaluator

train_cfg = dict(by_epoch=True, max_epochs=epochs)
train_cfg = dict(by_epoch=True, max_epochs=epochs, val_interval=5)

# learning policy
param_scheduler = [
Expand All @@ -123,3 +140,4 @@
by_epoch=True,
),
]
# cfg=dict(compile=True)
149 changes: 11 additions & 138 deletions edgelab/datasets/cocodataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import json
import os.path as osp
from collections import OrderedDict
from typing import Callable, List, Optional, Sequence, Union
from typing import Optional, Sequence

import cv2
import numpy as np
import torch
from mmdet.datasets.coco import CocoDataset
from sklearn.metrics import confusion_matrix

from edgelab.registry import DATASETS

Expand Down Expand Up @@ -105,26 +100,19 @@ class CustomCocoDataset(CocoDataset):

def __init__(
self,
*args,
data_prefix: dict = dict(img_path=''),
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root=None,
data_prefix: dict = dict(img_path=''),
filter_cfg: Optional[dict] = None,
indices: Optional[Union[int, Sequence[int]]] = None,
serialize_data: bool = True,
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
lazy_init: bool = False,
max_refetch: int = 1000,
data_root: str = '',
filter_supercat: bool = True,
file_client_args: Optional[dict] = dict(backend='disk'),
classes=None,
classes: Optional[Sequence[str]] = None,
**kwargs,
):
if data_root:
if not (osp.isabs(ann_file) and (osp.isabs(data_prefix['img']))):
data_root = check_file(data_root, data_name='coco') if data_root else data_root
if metainfo is None and not self.METAINFO['classes'] and not classes:
if metainfo is None and not self.METAINFO['classes']:
if not osp.isabs(ann_file) and ann_file:
self.ann_file = osp.join(data_root, ann_file)
with open(self.ann_file, 'r') as f:
Expand All @@ -138,125 +126,10 @@ def __init__(
self.METAINFO['classes'] = classes

super().__init__(
ann_file,
metainfo,
data_root,
data_prefix,
filter_cfg,
indices,
serialize_data,
pipeline,
test_mode,
lazy_init,
max_refetch,
*args,
data_prefix=data_prefix,
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
**kwargs,
)

def bboxe2cell(self, bboxe, img_h, img_w, H, W):
w = (bboxe[0] + bboxe[2]) / 2
h = (bboxe[1] + bboxe[3]) / 2
w = w / img_w
h = h / img_h
x = int(w * (W - 1))
y = int(h * (H - 1))
return (x, y)

def build_target(self, preds, targets, img_h, img_w):
B, H, W = preds.shape
target_data = torch.zeros(size=(B, H, W), device=preds.device)
target_data[..., 0] = 0
bboxes = targets['bboxes']
labels = targets['labels']

bboxes = [self.bboxe2cell(bboxe, img_h, img_w, H, W) for bboxe in bboxes]

for bboxe, label in zip(bboxes, labels):
target_data[0, bboxe[1], bboxe[0]] = label + 1 # label

return target_data

def compute_FTP(self, pred, target):
confusion = confusion_matrix(
target.flatten().cpu().numpy(), pred.flatten().cpu().numpy(), labels=range(len(self.CLASSES) + 1)
)
tn = confusion[0, 0]
tp = np.diagonal(confusion).sum() - tn
fn = np.tril(confusion, k=-1).sum()
fp = np.triu(confusion, k=1).sum()

return tp, fp, fn

def computer_prf(self, tp, fp, fn):
if tp == 0 and fn == 0 and fp == 0:
return 1.0, 1.0, 1.0

p = 0.0 if (tp + fp == 0) else tp / (tp + fp)
r = 0.0 if (tp + fn == 0) else tp / (tp + fn)
f1 = 0.0 if (p + r == 0) else 2 * (p * r) / (p + r)
return p, r, f1

def evaluate(
self,
results,
metric='bbox',
logger=None,
jsonfile_prefix=None,
classwise=False,
proposal_nums=...,
iou_thrs=None,
fomo=False,
metric_items=None,
):
if fomo: # just with here evaluate for fomo data
annotations = [self.get_ann_info(i) for i in range(len(self))]
eval_results = OrderedDict()
tmp = []

TP, FP, FN = [], [], []
for idx, (pred, ann) in enumerate(zip(results, annotations)):
data = self.__getitem__(idx)
B, H, W = pred.shape
img_h, img_w = data['img_metas'][0].data['ori_shape'][:2]
target = self.build_target(pred, ann, img_h, img_w)
tp, fp, fn = self.compute_FTP(pred, target)
mask = torch.eq(pred, target)
acc = torch.sum(mask) / (H * W)
tmp.append(acc)
TP.append(tp)
FP.append(fp)
FN.append(fn)
# show_result(pred,data['img_metas'][0].data['filename'],self.CLASSES)
P, R, F1 = self.computer_prf(sum(TP), sum(FP), sum(FN))
eval_results['Acc'] = torch.mean(torch.Tensor(tmp)).cpu().item()
eval_results['Acc'] = torch.mean(torch.Tensor(tmp)).cpu().item()
eval_results['P'] = P
eval_results['R'] = R
eval_results['F1'] = F1
return eval_results

return super().evaluate(
results, metric, logger, jsonfile_prefix, classwise, proposal_nums, iou_thrs, metric_items
)


def show_result(result, img_path, classes):
img = cv2.imread(img_path)
H, W = img.shape[:-1]
pred = result.cpu().numpy()
mask = np.argwhere(pred > 0)
for i in mask:
b, h, w = i
label = classes[pred[0, h, w] - 1]
cv2.circle(
img, (int(W / result[0].shape[1] * (w + 0.5)), int(H / result[0].shape[0] * (h + 0.5))), 5, (0, 0, 255), 1
)
cv2.putText(
img,
str(label),
org=(int(W / result[0].shape[1] * w), int(H / result[0].shape[0] * h)),
color=(255, 0, 0),
fontScale=1,
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
)
cv2.imshow('img', img)
cv2.waitKey(0)
33 changes: 18 additions & 15 deletions edgelab/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
from typing import Dict, List, Optional, Tuple, Union

import torch
import numpy as np
from mmcv.transforms.base import BaseTransform

from edgelab.registry import TRANSFORMS
Expand All @@ -11,33 +12,35 @@ class Bbox2FomoMask(BaseTransform):
def __init__(
self,
downsample_factor: Tuple[int, ...] = (8,),
classes_num: int = 80,
num_classes: int = 80,
) -> None:
super().__init__()
self.downsample_factor = downsample_factor
self.classes_num = classes_num
self.num_classes = num_classes

def transform(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:
results['img']
H, W = results['img_shape']
bbox = results['gt_bboxes']
print(bbox)
labels = results['gt_bboxes_labels']

res = []
for factor in self.downsample_factor:
Dh, Dw = H / factor, W / factor
target = self.build_target(bbox, shape=(Dh, Dw))
Dh, Dw = int(H / factor), int(W / factor)
target = self.build_target(bbox, feature_shape=(Dh, Dw), ori_shape=(W, H), labels=labels)
res.append(target)

results['fomo_mask'] = res
results['fomo_mask'] = copy.deepcopy(res)
return results

def build_target(self, targets, shape):
(H, W) = shape
target_data = torch.zeros(size=(H, W, self.classes_num + 1))
def build_target(self, bboxs, feature_shape, ori_shape, labels):
(H, W) = feature_shape
# target_data = torch.zeros(size=(1,H, W, self.num_classes + 1))
target_data = np.zeros((1, H, W, self.num_classes + 1))
target_data[..., 0] = 1
for i in targets:
h, w = int(i[3].item() * H), int(i[2].item() * W)
target_data[int(i[0]), h, w, 0] = 0 # background
target_data[int(i[0]), h, w, int(i[1])] = 1 # label

for idx, i in enumerate(bboxs):
w = int(i.centers[0][0] / ori_shape[0] * H)
h = int(i.centers[0][1] / ori_shape[1] * W)
target_data[0, h, w, 0] = 0 # background
target_data[0, h, w, int(labels[idx] + 1)] = 1 # label
return target_data
3 changes: 3 additions & 0 deletions edgelab/datasets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .audio_augs import * # noqa
from .download import * # noqa
from .functions import * # noqa
3 changes: 3 additions & 0 deletions edgelab/engine/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .batch_augs import * # noqa
from .helper_funcs import * # noqa
from .resample import * # noqa
2 changes: 2 additions & 0 deletions edgelab/models/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .data_preprocessor import * # noqa
from .general import * # noqa
Loading
Loading