Skip to content

Commit

Permalink
fixed data_augmentation bug
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Sep 28, 2019
1 parent b759d18 commit f8a6ef8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion datasets/steel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def classify_provider(
target = targets[0]
image = image_with_mask_torch(image, target, mean, std)['image']
cv2.imshow('win', image)
cv2.waitKey(0)
cv2.waitKey(480)
class_dataloader = classify_provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits)
# 测试分类数据集
for fold_index, [train_dataloader, val_dataloader] in enumerate(class_dataloader):
Expand Down
19 changes: 11 additions & 8 deletions utils/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from tqdm import tqdm
import sys
import os
from copy import deepcopy

from albumentations import (
Compose, HorizontalFlip, VerticalFlip, CLAHE, RandomRotate90, HueSaturationValue,
RandomBrightness, RandomContrast, RandomGamma,OneOf,
RandomBrightness, RandomContrast, RandomGamma, OneOf,
ToFloat, ShiftScaleRotate,GridDistortion, ElasticTransform, JpegCompression, HueSaturationValue,
RGBShift, RandomBrightnessContrast, RandomContrast, Blur, MotionBlur, MedianBlur, GaussNoise,CenterCrop,
IAAAdditiveGaussianNoise,GaussNoise,Cutout,Rotate
IAAAdditiveGaussianNoise,GaussNoise,Cutout,Rotate, Normalize
)

sys.path.append('.')
Expand Down Expand Up @@ -58,11 +59,9 @@ def data_augmentation(original_image, original_mask):
image_aug: 增强后的图片
mask_aug: 增强后的掩膜
"""
original_height, original_width = original_image.shape[:2]
augmentations = Compose([
HorizontalFlip(p=0.4),
Rotate(limit=15, p=0.4),
CenterCrop(p=0.3, height=original_height, width=original_width),
Rotate(limit=15, p=0.4),
# 直方图均衡化
CLAHE(p=0.4),

Expand Down Expand Up @@ -93,7 +92,8 @@ def data_augmentation(original_image, original_mask):
if __name__ == "__main__":
data_folder = "../datasets/Steel_data"
df_path = "../datasets/Steel_data/train.csv"

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
df = pd.read_csv(df_path)
# https://www.kaggle.com/amanooo/defect-detection-starter-u-net
df['ImageId'], df['ClassId'] = zip(*df['ImageId_ClassId'].str.split('_'))
Expand All @@ -107,9 +107,12 @@ def data_augmentation(original_image, original_mask):
image_path = os.path.join(data_folder, 'train_images', image_id)
image = cv2.imread(image_path)
image_aug, mask_aug = data_augmentation(image, mask)
normalize = Normalize(mean=mean, std=std)
image = normalize(image=image)['image']
image_aug = normalize(image=image_aug)['image']

image_mask = image_with_mask_numpy(image, mask)['image']
image_aug_mask = image_with_mask_numpy(image_aug, mask_aug)['image']
image_mask = image_with_mask_numpy(deepcopy(image), mask, mean, std)['image']
image_aug_mask = image_with_mask_numpy(deepcopy(image_aug), mask_aug, mean,std)['image']
cv2.imshow('image', image_mask)
cv2.imshow('image_aug', image_aug_mask)
cv2.waitKey(0)
Expand Down

0 comments on commit f8a6ef8

Please sign in to comment.