diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index 64e23230c6..f4052adacc 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -2424,31 +2424,38 @@ def transform(self, results): # dict to albumentations format results = self.mapper(results, self.keymap_to_albu) + albumentations_results = { + k: results.pop(k) for k in ['image', 'mask'] if k in results + } + # Convert to RGB since Albumentations works with RGB images if self.bgr_to_rgb: - results['image'] = cv2.cvtColor(results['image'], + albumentations_results['image'] = cv2.cvtColor(albumentations_results['image'], cv2.COLOR_BGR2RGB) if self.additional_targets: for key, value in self.additional_targets.items(): if value == 'image': - results[key] = cv2.cvtColor(results[key], + albumentations_results[key] = cv2.cvtColor(albumentations_results[key], cv2.COLOR_BGR2RGB) # Apply Transform - results = self.aug(**results) + albumentations_results = self.aug(**albumentations_results) # Convert back to BGR if self.bgr_to_rgb: - results['image'] = cv2.cvtColor(results['image'], + albumentations_results['image'] = cv2.cvtColor(albumentations_results['image'], cv2.COLOR_RGB2BGR) if self.additional_targets: for key, value in self.additional_targets.items(): if value == 'image': - results[key] = cv2.cvtColor(results['image2'], + albumentations_results[key] = cv2.cvtColor(albumentations_results['image2'], cv2.COLOR_RGB2BGR) # back to the original format - results = self.mapper(results, self.keymap_back) + albumentations_results = self.mapper(albumentations_results, self.keymap_back) + + # Update corresponding keys in the original `results` + results.update(albumentations_results) # update final shape if self.update_pad_shape: