Skip to content

Commit

Permalink
add crop
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Oct 7, 2019
1 parent ae5d87d commit a34e54c
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 26 deletions.
14 changes: 10 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ def get_seg_config():
hwp: 6
'''
# parser.add_argument('--image_size', type=int, default=768, help='image size')
parser.add_argument('--batch_size', type=int, default=24, help='batch size')
parser.add_argument('--epoch', type=int, default=60, help='epoch')
parser.add_argument('--batch_size', type=int, default=48, help='batch size')
parser.add_argument('--epoch', type=int, default=65, help='epoch')

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold')
parser.add_argument('--mask_only_flag', type=bool, default=False, help='if true, use masked data only.')
parser.add_argument('--crop', type=bool, default=True, help='if true, crop image to [height, width].')
parser.add_argument('--height', type=int, default=256, help='the height of cropped image')
parser.add_argument('--width', type=int, default=512, help='the width of cropped image')

# model set
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
Expand All @@ -33,7 +36,7 @@ def get_seg_config():
# model hyper-parameters
parser.add_argument('--class_num', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--lr', type=float, default=4e-4, help='init lr')
parser.add_argument('--lr', type=float, default=1e-4, help='init lr')
parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay in optimizer')

# dataset
Expand Down Expand Up @@ -67,6 +70,9 @@ def get_classify_config():

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold')
parser.add_argument('--crop', type=bool, default=False, help='if true, crop image to [height, width].')
parser.add_argument('--height', type=int, default=256, help='the height of cropped image')
parser.add_argument('--width', type=int, default=512, help='the width of cropped image')

# model set
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
Expand All @@ -89,4 +95,4 @@ def get_classify_config():


if __name__ == '__main__':
config = get_seg_config()
config = get_seg_config()
46 changes: 31 additions & 15 deletions datasets/steel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,24 @@

# Dataset Segmentation
class SteelDataset(Dataset):
def __init__(self, df, data_folder, mean, std, phase):
def __init__(self, df, data_folder, mean, std, phase, crop=False, height=None, width=None):
super(SteelDataset, self).__init__()
self.df = df
self.root = data_folder
self.mean = mean
self.std = std
self.phase = phase
self.transforms = get_transforms
self.crop = crop
self.height = height
self.width = width
self.fnames = self.df.index.tolist()

def __getitem__(self, idx):
image_id, mask = make_mask(idx, self.df)
image_path = os.path.join(self.root, "train_images", image_id)
img = cv2.imread(image_path)
img, mask = self.transforms(self.phase, img, mask, self.mean, self.std)
img, mask = self.transforms(self.phase, img, mask, self.mean, self.std, crop=self.crop, height=self.height, width=self.width)
mask = mask.permute(2, 0, 1)
return img, mask

Expand All @@ -47,21 +50,24 @@ def __len__(self):

# Dataset Classification
class SteelClassDataset(Dataset):
def __init__(self, df, data_folder, mean, std, phase):
def __init__(self, df, data_folder, mean, std, phase, crop=False, height=None, width=None):
super(SteelClassDataset, self).__init__()
self.df = df
self.root = data_folder
self.mean = mean
self.std = std
self.phase = phase
self.transforms = get_transforms
self.crop = crop
self.height = height
self.width = width
self.fnames = self.df.index.tolist()

def __getitem__(self, idx):
image_id, mask = make_mask(idx, self.df)
image_path = os.path.join(self.root, "train_images", image_id)
img = cv2.imread(image_path)
img, mask = self.transforms(self.phase, img, mask, self.mean, self.std)
img, mask = self.transforms(self.phase, img, mask, self.mean, self.std, crop=self.crop, height=self.height, width=self.width)
mask = mask.permute(2, 0, 1) # 4x256x1600
mask = mask.view(mask.size(0), -1)
mask = torch.sum(mask, dim=1)
Expand Down Expand Up @@ -99,7 +105,7 @@ def __len__(self):
return self.num_samples


def augmentation(image, mask):
def augmentation(image, mask, crop=False, height=None, width=None):
"""进行数据增强
Args:
image: 原始图像
Expand All @@ -108,16 +114,16 @@ def augmentation(image, mask):
image_aug: 增强后的图像,Image图像
mask: 增强后的掩膜,Image图像
"""
image_aug, mask_aug = data_augmentation(image, mask)
image_aug, mask_aug = data_augmentation(image, mask, crop=crop, height=height, width=width)
image_aug = Image.fromarray(image_aug)

return image_aug, mask_aug


def get_transforms(phase, image, mask, mean, std):
def get_transforms(phase, image, mask, mean, std, crop=False, height=None, width=None):

if phase == 'train':
image, mask = augmentation(image, mask)
image, mask = augmentation(image, mask, crop=crop, height=height, width=width)

to_tensor = transforms.ToTensor()
normalize = transforms.Normalize(mean, std)
Expand Down Expand Up @@ -152,7 +158,10 @@ def provider(
batch_size=8,
num_workers=4,
n_splits=0,
mask_only=False
mask_only=False,
crop=False,
height=None,
width=None
):
"""返回数据加载器,用于分割模型
Expand Down Expand Up @@ -192,7 +201,7 @@ def provider(
# 生成dataloader
dataloaders = list()
for df_index, (train_df, val_df) in enumerate(zip(train_dfs, val_dfs)):
train_dataset = SteelDataset(train_df, data_folder, mean, std, 'train')
train_dataset = SteelDataset(train_df, data_folder, mean, std, 'train', crop=crop, height=height, width=width)
val_dataset = SteelDataset(val_df, data_folder, mean, std, 'val')
if mask_only:
# 只在有掩膜的样本上训练
Expand Down Expand Up @@ -242,6 +251,9 @@ def classify_provider(
batch_size=8,
num_workers=4,
n_splits=0,
crop=False,
height=None,
width=False
):
"""返回数据加载器,用于分类模型
Expand Down Expand Up @@ -281,7 +293,7 @@ def classify_provider(
# 生成dataloader
dataloaders = list()
for df_index, (train_df, val_df) in enumerate(zip(train_dfs, val_dfs)):
train_dataset = SteelClassDataset(train_df, data_folder, mean, std, 'train')
train_dataset = SteelClassDataset(train_df, data_folder, mean, std, 'train', crop=crop, height=height, width=width)
val_dataset = SteelClassDataset(val_df, data_folder, mean, std, 'val')
train_dataloader = DataLoader(
train_dataset,
Expand All @@ -303,15 +315,19 @@ def classify_provider(


if __name__ == "__main__":
data_folder = "Steel_data"
df_path = "Steel_data/train.csv"
data_folder = "datasets/Steel_data"
df_path = "datasets/Steel_data/train.csv"
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 8
batch_size = 12
num_workers = 4
n_splits = 1
mask_only = False
crop = True
height = 256
width = 512
# 测试分割数据集
dataloader = provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits)
dataloader = provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits, mask_only=mask_only, crop=crop, height=height, width=width)
for fold_index, [train_dataloader, val_dataloader] in enumerate(dataloader):
train_bar = tqdm(train_dataloader)
class_color = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [139, 0, 139]]
Expand Down
13 changes: 12 additions & 1 deletion train_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,18 @@ def validation(self, valid_loader):
config = get_classify_config()
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
dataloaders = classify_provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits)
dataloaders = classify_provider(
config.dataset_root,
os.path.join(config.dataset_root, 'train.csv'),
mean,
std,
config.batch_size,
config.num_workers,
config.n_splits,
crop=config.crop,
height=config.height,
width=config.width
)
for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):
if fold_index != 1:
continue
Expand Down
9 changes: 6 additions & 3 deletions train_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(self, config, fold):
self.solver = Solver(self.model)

# 加载损失函数
# self.criterion = torch.nn.BCEWithLogitsLoss()
self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[0.75, 0.25])
self.criterion = torch.nn.BCEWithLogitsLoss()
# self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[0.75, 0.25])

# 创建保存权重的路径
self.model_path = os.path.join(config.save_path, config.model_name)
Expand Down Expand Up @@ -158,7 +158,10 @@ def validation(self, valid_loader):
config.batch_size,
config.num_workers,
config.n_splits,
mask_only=config.mask_only_flag
mask_only=config.mask_only_flag,
crop=config.crop,
height=config.height,
width=config.width
)
for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):
if fold_index != 1:
Expand Down
14 changes: 11 additions & 3 deletions utils/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
RandomBrightness, RandomContrast, RandomGamma, OneOf,
ToFloat, ShiftScaleRotate, GridDistortion, ElasticTransform, JpegCompression, HueSaturationValue,
RGBShift, RandomBrightnessContrast, RandomContrast, Blur, MotionBlur, MedianBlur, GaussNoise,CenterCrop,
IAAAdditiveGaussianNoise,GaussNoise,Cutout,Rotate, Normalize
IAAAdditiveGaussianNoise,GaussNoise,Cutout,Rotate, Normalize, Crop, RandomCrop
)

sys.path.append('.')
Expand Down Expand Up @@ -44,7 +44,7 @@ def visualize(image, mask, original_image=None, original_mask=None):
plt.show()


def data_augmentation(original_image, original_mask):
def data_augmentation(original_image, original_mask, crop=False, height=None, width=None):
"""进行样本和掩膜的随机增强
Args:
Expand Down Expand Up @@ -78,6 +78,14 @@ def data_augmentation(original_image, original_mask):
], p=0.2)
])

if crop:
# 是否进行随机裁剪
assert height and width
crop_aug = RandomCrop(height=height, width=width, always_apply=True)
crop_sample = crop_aug(image=original_image, mask=original_mask)
original_image = crop_sample['image']
original_mask = crop_sample['mask']

augmented = augmentations(image=original_image, mask=original_mask)
image_aug = augmented['image']
mask_aug = augmented['mask']
Expand All @@ -102,7 +110,7 @@ def data_augmentation(original_image, original_mask):
image_id, mask = make_mask(index, df)
image_path = os.path.join(data_folder, 'train_images', image_id)
image = cv2.imread(image_path)
image_aug, mask_aug = data_augmentation(image, mask)
image_aug, mask_aug = data_augmentation(image, mask, crop=True, height=256, width=400)
normalize = Normalize(mean=mean, std=std)
image = normalize(image=image)['image']
image_aug = normalize(image=image_aug)['image']
Expand Down

0 comments on commit a34e54c

Please sign in to comment.