Skip to content

Commit

Permalink
add data augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Sep 28, 2019
1 parent b2f4bcf commit bf42381
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 51 deletions.
107 changes: 56 additions & 51 deletions datasets/steel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@
import cv2
import warnings
import pandas as pd
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split, StratifiedKFold
import torch
from torch.utils.data import DataLoader, Dataset, sampler
from albumentations import (HorizontalFlip, ShiftScaleRotate, Normalize, Resize, Compose, GaussNoise)
from torchvision import transforms
from albumentations import (HorizontalFlip, VerticalFlip, ShiftScaleRotate, Normalize, Resize, Compose, GaussNoise)
from albumentations.torch import ToTensor
import sys

sys.path.append('.')
from utils.data_augmentation import data_augmentation
from utils.rle_parse import mask2rle, make_mask
from utils.visualize import image_with_mask_torch
import pickle
warnings.filterwarnings("ignore")


# Dataset
# Dataset Segmentation
class SteelDataset(Dataset):
def __init__(self, df, data_folder, mean, std, phase):
super(SteelDataset, self).__init__()
Expand All @@ -24,23 +31,22 @@ def __init__(self, df, data_folder, mean, std, phase):
self.mean = mean
self.std = std
self.phase = phase
self.transforms = get_transforms(phase, mean, std)
self.transforms = get_transforms
self.fnames = self.df.index.tolist()

def __getitem__(self, idx):
image_id, mask = make_mask(idx, self.df)
image_path = os.path.join(self.root, "train_images", image_id)
img = cv2.imread(image_path)
augmented = self.transforms(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask'] # 1x256x1600x4
mask = mask[0].permute(2, 0, 1) # 1x4x256x1600
img, mask = self.transforms(self.phase, img, mask, self.mean, self.std)
mask = mask.permute(2, 0, 1)
return img, mask

def __len__(self):
return len(self.fnames)


# Dataset Classification
class SteelClassDataset(Dataset):
def __init__(self, df, data_folder, mean, std, phase):
super(SteelClassDataset, self).__init__()
Expand All @@ -49,17 +55,15 @@ def __init__(self, df, data_folder, mean, std, phase):
self.mean = mean
self.std = std
self.phase = phase
self.transforms = get_transforms(phase, mean, std)
self.transforms = get_transforms
self.fnames = self.df.index.tolist()

def __getitem__(self, idx):
image_id, mask = make_mask(idx, self.df)
image_path = os.path.join(self.root, "train_images", image_id)
img = cv2.imread(image_path)
augmented = self.transforms(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask'] # 1x256x1600x4
mask = mask[0].permute(2, 0, 1) # 1x4x256x1600
img, mask = self.transforms(self.phase, img, mask, self.mean, self.std)
mask = mask.permute(2, 0, 1) # 4x256x1600
mask = mask.view(mask.size(0), -1)
mask = torch.sum(mask, dim=1)
mask = mask > 0
Expand Down Expand Up @@ -96,22 +100,33 @@ def __len__(self):
return self.num_samples


def get_transforms(phase, mean, std):
list_transforms = []
if phase == "train":
list_transforms.extend(
[
HorizontalFlip(p=0.5), # only horizontal flip as of now
]
)
list_transforms.extend(
[
Normalize(mean=mean, std=std, p=1),
ToTensor(),
]
)
list_trfms = Compose(list_transforms)
return list_trfms
def augmentation(image, mask):
"""进行数据增强
Args:
image: 原始图像
mask: 原始掩膜
Return:
image_aug: 增强后的图像,Image图像
mask: 增强后的掩膜,Image图像
"""
image_aug, mask_aug = data_augmentation(image, mask)
image_aug = Image.fromarray(image_aug)

return image_aug, mask_aug


def get_transforms(phase, image, mask, mean, std):

if phase == 'train':
image, mask = augmentation(image, mask)

to_tensor = transforms.ToTensor()
normalize = transforms.Normalize(mean, std)
transform_compose = transforms.Compose([to_tensor, normalize])
image = transform_compose(image)
mask = torch.from_numpy(mask)

return image, mask


def provider(
Expand Down Expand Up @@ -250,34 +265,24 @@ def classify_provider(


if __name__ == "__main__":
data_folder = "datasets/Steel_data"
df_path = "datasets/Steel_data/train.csv"
data_folder = "Steel_data"
df_path = "Steel_data/train.csv"
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 8
num_workers = 4
n_splits = 1
# 测试分割数据集
# dataloader = provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits)
# for fold_index, [train_dataloader, val_dataloader] in enumerate(dataloader):
# train_bar = tqdm(train_dataloader)
# class_color = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [139, 0, 139]]
# for images, targets in train_bar:
# image = images[0]
# for i in range(3):
# image[i] = image[i] * std[i]
# image[i] = image[i] + mean[i]

# target = targets[0]
# for i in range(target.size(0)):
# target_0 = target[i] * class_color[i][0]
# target_1 = target[i] * class_color[i][1]
# target_2 = target[i] * class_color[i][2]
# mask = torch.stack([target_0, target_1, target_2], dim=0)
# image += mask
# image = image.permute(1, 2, 0).numpy()
# cv2.imshow('win', image)
# cv2.waitKey(0)
dataloader = provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits)
for fold_index, [train_dataloader, val_dataloader] in enumerate(dataloader):
train_bar = tqdm(train_dataloader)
class_color = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [139, 0, 139]]
for images, targets in train_bar:
image = images[0]
target = targets[0]
image = image_with_mask_torch(image, target, mean, std)['image']
cv2.imshow('win', image)
cv2.waitKey(0)
class_dataloader = classify_provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits)
# 测试分类数据集
for fold_index, [train_dataloader, val_dataloader] in enumerate(class_dataloader):
Expand All @@ -299,6 +304,6 @@ def classify_provider(
font = cv2.FONT_HERSHEY_SIMPLEX
image = cv2.putText(image, str(i), position, font, 1.2, color, 2)
cv2.imshow('win', image)
cv2.waitKey(480)
cv2.waitKey(60)

pass
116 changes: 116 additions & 0 deletions utils/data_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import numpy as np
import cv2
import random
import glob
from matplotlib import pyplot as plt
from PIL import Image
import pandas as pd
from tqdm import tqdm
import sys
import os

from albumentations import (
Compose, HorizontalFlip, VerticalFlip, CLAHE, RandomRotate90, HueSaturationValue,
RandomBrightness, RandomContrast, RandomGamma,OneOf,
ToFloat, ShiftScaleRotate,GridDistortion, ElasticTransform, JpegCompression, HueSaturationValue,
RGBShift, RandomBrightnessContrast, RandomContrast, Blur, MotionBlur, MedianBlur, GaussNoise,CenterCrop,
IAAAdditiveGaussianNoise,GaussNoise,Cutout,Rotate
)

sys.path.append('.')
from utils.visualize import image_with_mask_torch, image_with_mask_numpy
from utils.rle_parse import make_mask


def visualize(image, mask, original_image=None, original_mask=None):
fontsize = 18

if original_image is None and original_mask is None:
f, ax = plt.subplots(2, 1, figsize=(8, 8))

ax[0].imshow(image)
ax[1].imshow(mask)
else:
f, ax = plt.subplots(2, 2, figsize=(8, 8))

ax[0, 0].imshow(original_image)
ax[0, 0].set_title('Original image', fontsize=fontsize)

ax[1, 0].imshow(original_mask)
ax[1, 0].set_title('Original mask', fontsize=fontsize)

ax[0, 1].imshow(image)
ax[0, 1].set_title('Transformed image', fontsize=fontsize)

ax[1, 1].imshow(mask)
ax[1, 1].set_title('Transformed mask', fontsize=fontsize)

plt.show()


def data_augmentation(original_image, original_mask):
"""进行样本和掩膜的随机增强
Args:
original_image: 原始图片
original_mask: 原始掩膜
Return:
image_aug: 增强后的图片
mask_aug: 增强后的掩膜
"""
original_height, original_width = original_image.shape[:2]
augmentations = Compose([
HorizontalFlip(p=0.4),
Rotate(limit=15, p=0.4),
CenterCrop(p=0.3, height=original_height, width=original_width),
# 直方图均衡化
CLAHE(p=0.4),

# 亮度、对比度
RandomGamma(gamma_limit=(80, 120), p=0.1),
RandomBrightnessContrast(p=0.1),

# 模糊
OneOf([
MotionBlur(p=0.1),
MedianBlur(blur_limit=3, p=0.1),
Blur(blur_limit=3, p=0.1),
], p=0.3),

OneOf([
IAAAdditiveGaussianNoise(),
GaussNoise(),
], p=0.2)
])

augmented = augmentations(image=original_image, mask=original_mask)
image_aug = augmented['image']
mask_aug = augmented['mask']

return image_aug, mask_aug


if __name__ == "__main__":
data_folder = "../datasets/Steel_data"
df_path = "../datasets/Steel_data/train.csv"

df = pd.read_csv(df_path)
# https://www.kaggle.com/amanooo/defect-detection-starter-u-net
df['ImageId'], df['ClassId'] = zip(*df['ImageId_ClassId'].str.split('_'))
df['ClassId'] = df['ClassId'].astype(int)
df = df.pivot(index='ImageId', columns='ClassId', values='EncodedPixels')
df['defects'] = df.count(axis=1)
file_names = df.index.tolist()

for index in range(len(file_names)):
image_id, mask = make_mask(index, df)
image_path = os.path.join(data_folder, 'train_images', image_id)
image = cv2.imread(image_path)
image_aug, mask_aug = data_augmentation(image, mask)

image_mask = image_with_mask_numpy(image, mask)['image']
image_aug_mask = image_with_mask_numpy(image_aug, mask_aug)['image']
cv2.imshow('image', image_mask)
cv2.imshow('image_aug', image_aug_mask)
cv2.waitKey(0)
pass
61 changes: 61 additions & 0 deletions utils/visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# 可视化操作
import cv2
import torch
from PIL import Image
import numpy as np


def image_with_mask_torch(image, target, mean=None, std=None, mask_only=False):
"""返回numpy形式的样本和掩膜
:param image: 样本,tensor
:param target: 掩膜,tensor
:param mean: 样本均值
:param std: 样本标准差
:param mask_only: bool,是否只返回掩膜
"""
class_color = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [139, 0, 139]]
if mean and std:
for i in range(3):
image[i] = image[i] * std[i]
image[i] = image[i] + mean[i]
mask = torch.zeros(3, target.size(1), target.size(2))
for i in range(target.size(0)):
target_0 = target[i] * class_color[i][0]
target_1 = target[i] * class_color[i][1]
target_2 = target[i] * class_color[i][2]
mask += torch.stack([target_0, target_1, target_2], dim=0)
image += mask

pair = {'mask': mask.permute(1, 2, 0).numpy()}
if not mask_only:
pair['image'] = image.permute(1, 2, 0).numpy()

return pair


def image_with_mask_numpy(image, target, mean=None, std=None, mask_only=False):
"""返回numpy形式的样本和掩膜
:param image: 样本,numpy
:param target: 掩膜,numpy
:param mean: 样本均值
:param std: 样本标准差
:param mask_only: bool,是否只返回掩膜
"""
class_color = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [139, 0, 139]]
if mean and std:
for i in range(3):
image[..., i] = image[..., i] * std[i]
image[..., i] = image[..., i] + mean[i]
mask = np.zeros([target.shape[0], target.shape[1], 3])
for i in range(target.shape[2]):
target_0 = target[..., i] * class_color[i][0]
target_1 = target[..., i] * class_color[i][1]
target_2 = target[..., i] * class_color[i][2]
mask += np.stack([target_0, target_1, target_2], axis=2)
image += np.uint8(mask)

pair = {'mask': np.uint8(mask)}
if not mask_only:
pair['image'] = image

return pair

0 comments on commit bf42381

Please sign in to comment.