Skip to content

Commit

Permalink
[Fix] Fix cd transform (#3598)
Browse files Browse the repository at this point in the history
## Motivation

Fix the bug that data augmentation only takes effect on one image in the
change detection task.

## Modification

configs/base/datasets/levir_256x256.py
configs/swin/swin-tiny-patch4-window7_upernet_1xb8-20k_levir-256x256.py
mmseg/datasets/transforms/transforms.py
  • Loading branch information
Zoulinx authored Mar 18, 2024
1 parent 5465118 commit b677081
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
11 changes: 10 additions & 1 deletion configs/_base_/datasets/levir_256x256.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@
train_pipeline = [
dict(type='LoadMultipleRSImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Albu', transforms=albu_train_transforms),
dict(
type='Albu',
keymap={
'img': 'image',
'img2': 'image2',
'gt_seg_map': 'mask'
},
transforms=albu_train_transforms,
additional_targets={'image2': 'image'},
bgr_to_rgb=False),
dict(type='ConcatCDInput'),
dict(type='PackSegInputs')
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
size=crop_size,
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53, 123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375, 58.395, 57.12, 57.375])
std=[58.395, 57.12, 57.375, 58.395, 57.12, 57.375],
bgr_to_rgb=False)

model = dict(
data_preprocessor=data_preprocessor,
Expand Down
33 changes: 28 additions & 5 deletions mmseg/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,14 +2329,19 @@ class Albu(BaseTransform):
Args:
transforms (list[dict]): A list of albu transformations
keymap (dict): Contains {'input key':'albumentation-style key'}
additional_targets(dict): Allows applying same augmentations to \
multiple objects of same type.
update_pad_shape (bool): Whether to update padding shape according to \
the output shape of the last transform
bgr_to_rgb (bool): Whether to convert the band order to RGB
"""

def __init__(self,
transforms: List[dict],
keymap: Optional[dict] = None,
update_pad_shape: bool = False):
additional_targets: Optional[dict] = None,
update_pad_shape: bool = False,
bgr_to_rgb: bool = True):
if not ALBU_INSTALLED:
raise ImportError(
'albumentations is not installed, '
Expand All @@ -2349,9 +2354,12 @@ def __init__(self,

self.transforms = transforms
self.keymap = keymap
self.additional_targets = additional_targets
self.update_pad_shape = update_pad_shape
self.bgr_to_rgb = bgr_to_rgb

self.aug = Compose([self.albu_builder(t) for t in self.transforms])
self.aug = Compose([self.albu_builder(t) for t in self.transforms],
additional_targets=self.additional_targets)

if not keymap:
self.keymap_to_albu = {'img': 'image', 'gt_seg_map': 'mask'}
Expand Down Expand Up @@ -2417,12 +2425,27 @@ def transform(self, results):
results = self.mapper(results, self.keymap_to_albu)

# Convert to RGB since Albumentations works with RGB images
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB)

if self.bgr_to_rgb:
results['image'] = cv2.cvtColor(results['image'],
cv2.COLOR_BGR2RGB)
if self.additional_targets:
for key, value in self.additional_targets.items():
if value == 'image':
results[key] = cv2.cvtColor(results[key],
cv2.COLOR_BGR2RGB)

# Apply Transform
results = self.aug(**results)

# Convert back to BGR
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR)
if self.bgr_to_rgb:
results['image'] = cv2.cvtColor(results['image'],
cv2.COLOR_RGB2BGR)
if self.additional_targets:
for key, value in self.additional_targets.items():
if value == 'image':
results[key] = cv2.cvtColor(results['image2'],
cv2.COLOR_RGB2BGR)

# back to the original format
results = self.mapper(results, self.keymap_back)
Expand Down

0 comments on commit b677081

Please sign in to comment.