Skip to content

Commit

Permalink
add color conver pipline
Browse files Browse the repository at this point in the history
  • Loading branch information
mjq2020 committed Aug 19, 2023
1 parent f8e2b23 commit bd5d5ce
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
30 changes: 29 additions & 1 deletion edgelab/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,41 @@
import copy
from typing import Dict, List, Optional, Tuple, Union

import cv2
import numpy as np
from mmcv.transforms.base import BaseTransform

from edgelab.registry import TRANSFORMS


@TRANSFORMS.register_module()
class Color2Gray(BaseTransform):
def __init__(self, one_channel: bool = False, conver_order: Optional[str] = None) -> None:
super().__init__()
if one_channel and conver_order is not None:
raise ValueError("one_channel and conver_order can only set one of them, not all of them ")

if conver_order is not None:
if not hasattr(cv2, "COLOR_" + conver_order):
opt = ','.join(
(map(lambda x: x.replace("COLOR_", ""), filter(lambda x: x.startswith("COLOR_"), dir(cv2))))
)
raise ValueError(
f"The value of convert_order can only be one of the following[{opt}], but {conver_order} is obtained"
)
self.conver_opt = getattr(cv2, "COLOR_" + conver_order)
self.one_channel = one_channel

def transform(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:
img = results['img']
if self.one_channel:
img = img[..., 0:1]
else:
img = np.expand_dims(cv2.cvtColor(img, self.conver_opt, dstCn=1), -1)
results['img'] = img
return results


@TRANSFORMS.register_module()
class Bbox2FomoMask(BaseTransform):
def __init__(
Expand All @@ -19,7 +48,6 @@ def __init__(
self.num_classes = num_classes

def transform(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:
results['img']
H, W = results['img_shape']
bbox = results['gt_bboxes']
labels = results['gt_bboxes_labels']
Expand Down
2 changes: 1 addition & 1 deletion edgelab/models/backbones/EfficientNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def forward(self, x):
res.append(x)
if i == max(self.out_indices):
break
return res
return tuple(res)

def _freeze_stages(self):
if self.frozen_stages >= 0:
Expand Down

0 comments on commit bd5d5ce

Please sign in to comment.