From a34e54c505901cfa654f9b6e1e5058ecbcba1440 Mon Sep 17 00:00:00 2001 From: XiangqianMa Date: Mon, 7 Oct 2019 19:22:54 +0800 Subject: [PATCH] add crop --- config.py | 14 ++++++++---- datasets/steel_dataset.py | 46 +++++++++++++++++++++++++------------- train_classify.py | 13 ++++++++++- train_segment.py | 9 +++++--- utils/data_augmentation.py | 14 +++++++++--- 5 files changed, 70 insertions(+), 26 deletions(-) diff --git a/config.py b/config.py index 7fb9fc4..1930255 100755 --- a/config.py +++ b/config.py @@ -19,12 +19,15 @@ def get_seg_config(): hwp: 6 ''' # parser.add_argument('--image_size', type=int, default=768, help='image size') - parser.add_argument('--batch_size', type=int, default=24, help='batch size') - parser.add_argument('--epoch', type=int, default=60, help='epoch') + parser.add_argument('--batch_size', type=int, default=48, help='batch size') + parser.add_argument('--epoch', type=int, default=65, help='epoch') parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set') parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold') parser.add_argument('--mask_only_flag', type=bool, default=False, help='if true, use masked data only.') + parser.add_argument('--crop', type=bool, default=True, help='if true, crop image to [height, width].') + parser.add_argument('--height', type=int, default=256, help='the height of cropped image') + parser.add_argument('--width', type=int, default=512, help='the width of cropped image') # model set parser.add_argument('--model_name', type=str, default='unet_resnet34', \ @@ -33,7 +36,7 @@ def get_seg_config(): # model hyper-parameters parser.add_argument('--class_num', type=int, default=4) parser.add_argument('--num_workers', type=int, default=8) - parser.add_argument('--lr', type=float, default=4e-4, help='init lr') + parser.add_argument('--lr', type=float, default=1e-4, help='init lr') parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay in optimizer') # dataset @@ -67,6 +70,9 @@ def get_classify_config(): parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set') parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold') + parser.add_argument('--crop', type=bool, default=False, help='if true, crop image to [height, width].') + parser.add_argument('--height', type=int, default=256, help='the height of cropped image') + parser.add_argument('--width', type=int, default=512, help='the width of cropped image') # model set parser.add_argument('--model_name', type=str, default='unet_resnet34', \ @@ -89,4 +95,4 @@ def get_classify_config(): if __name__ == '__main__': - config = get_seg_config() \ No newline at end of file + config = get_seg_config() diff --git a/datasets/steel_dataset.py b/datasets/steel_dataset.py index 5b453cf..6747c51 100644 --- a/datasets/steel_dataset.py +++ b/datasets/steel_dataset.py @@ -23,7 +23,7 @@ # Dataset Segmentation class SteelDataset(Dataset): - def __init__(self, df, data_folder, mean, std, phase): + def __init__(self, df, data_folder, mean, std, phase, crop=False, height=None, width=None): super(SteelDataset, self).__init__() self.df = df self.root = data_folder @@ -31,13 +31,16 @@ def __init__(self, df, data_folder, mean, std, phase): self.std = std self.phase = phase self.transforms = get_transforms + self.crop = crop + self.height = height + self.width = width self.fnames = self.df.index.tolist() def __getitem__(self, idx): image_id, mask = make_mask(idx, self.df) image_path = os.path.join(self.root, "train_images", image_id) img = cv2.imread(image_path) - img, mask = self.transforms(self.phase, img, mask, self.mean, self.std) + img, mask = self.transforms(self.phase, img, mask, self.mean, self.std, crop=self.crop, height=self.height, width=self.width) mask = mask.permute(2, 0, 1) return img, mask @@ -47,7 +50,7 @@ def __len__(self): # Dataset Classification class SteelClassDataset(Dataset): - def __init__(self, df, data_folder, mean, std, phase): + def __init__(self, df, data_folder, mean, std, phase, crop=False, height=None, width=None): super(SteelClassDataset, self).__init__() self.df = df self.root = data_folder @@ -55,13 +58,16 @@ def __init__(self, df, data_folder, mean, std, phase): self.std = std self.phase = phase self.transforms = get_transforms + self.crop = crop + self.height = height + self.width = width self.fnames = self.df.index.tolist() def __getitem__(self, idx): image_id, mask = make_mask(idx, self.df) image_path = os.path.join(self.root, "train_images", image_id) img = cv2.imread(image_path) - img, mask = self.transforms(self.phase, img, mask, self.mean, self.std) + img, mask = self.transforms(self.phase, img, mask, self.mean, self.std, crop=self.crop, height=self.height, width=self.width) mask = mask.permute(2, 0, 1) # 4x256x1600 mask = mask.view(mask.size(0), -1) mask = torch.sum(mask, dim=1) @@ -99,7 +105,7 @@ def __len__(self): return self.num_samples -def augmentation(image, mask): +def augmentation(image, mask, crop=False, height=None, width=None): """进行数据增强 Args: image: 原始图像 @@ -108,16 +114,16 @@ def augmentation(image, mask): image_aug: 增强后的图像,Image图像 mask: 增强后的掩膜,Image图像 """ - image_aug, mask_aug = data_augmentation(image, mask) + image_aug, mask_aug = data_augmentation(image, mask, crop=crop, height=height, width=width) image_aug = Image.fromarray(image_aug) return image_aug, mask_aug -def get_transforms(phase, image, mask, mean, std): +def get_transforms(phase, image, mask, mean, std, crop=False, height=None, width=None): if phase == 'train': - image, mask = augmentation(image, mask) + image, mask = augmentation(image, mask, crop=crop, height=height, width=width) to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean, std) @@ -152,7 +158,10 @@ def provider( batch_size=8, num_workers=4, n_splits=0, - mask_only=False + mask_only=False, + crop=False, + height=None, + width=None ): """返回数据加载器,用于分割模型 @@ -192,7 +201,7 @@ def provider( # 生成dataloader dataloaders = list() for df_index, (train_df, val_df) in enumerate(zip(train_dfs, val_dfs)): - train_dataset = SteelDataset(train_df, data_folder, mean, std, 'train') + train_dataset = SteelDataset(train_df, data_folder, mean, std, 'train', crop=crop, height=height, width=width) val_dataset = SteelDataset(val_df, data_folder, mean, std, 'val') if mask_only: # 只在有掩膜的样本上训练 @@ -242,6 +251,9 @@ def classify_provider( batch_size=8, num_workers=4, n_splits=0, + crop=False, + height=None, + width=False ): """返回数据加载器,用于分类模型 @@ -281,7 +293,7 @@ def classify_provider( # 生成dataloader dataloaders = list() for df_index, (train_df, val_df) in enumerate(zip(train_dfs, val_dfs)): - train_dataset = SteelClassDataset(train_df, data_folder, mean, std, 'train') + train_dataset = SteelClassDataset(train_df, data_folder, mean, std, 'train', crop=crop, height=height, width=width) val_dataset = SteelClassDataset(val_df, data_folder, mean, std, 'val') train_dataloader = DataLoader( train_dataset, @@ -303,15 +315,19 @@ def classify_provider( if __name__ == "__main__": - data_folder = "Steel_data" - df_path = "Steel_data/train.csv" + data_folder = "datasets/Steel_data" + df_path = "datasets/Steel_data/train.csv" mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) - batch_size = 8 + batch_size = 12 num_workers = 4 n_splits = 1 + mask_only = False + crop = True + height = 256 + width = 512 # 测试分割数据集 - dataloader = provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits) + dataloader = provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits, mask_only=mask_only, crop=crop, height=height, width=width) for fold_index, [train_dataloader, val_dataloader] in enumerate(dataloader): train_bar = tqdm(train_dataloader) class_color = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [139, 0, 139]] diff --git a/train_classify.py b/train_classify.py index 64f60e8..e4f3b93 100755 --- a/train_classify.py +++ b/train_classify.py @@ -154,7 +154,18 @@ def validation(self, valid_loader): config = get_classify_config() mean=(0.485, 0.456, 0.406) std=(0.229, 0.224, 0.225) - dataloaders = classify_provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits) + dataloaders = classify_provider( + config.dataset_root, + os.path.join(config.dataset_root, 'train.csv'), + mean, + std, + config.batch_size, + config.num_workers, + config.n_splits, + crop=config.crop, + height=config.height, + width=config.width + ) for fold_index, [train_loader, valid_loader] in enumerate(dataloaders): if fold_index != 1: continue diff --git a/train_segment.py b/train_segment.py index ec76da2..dfd7426 100644 --- a/train_segment.py +++ b/train_segment.py @@ -38,8 +38,8 @@ def __init__(self, config, fold): self.solver = Solver(self.model) # 加载损失函数 - # self.criterion = torch.nn.BCEWithLogitsLoss() - self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[0.75, 0.25]) + self.criterion = torch.nn.BCEWithLogitsLoss() + # self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[0.75, 0.25]) # 创建保存权重的路径 self.model_path = os.path.join(config.save_path, config.model_name) @@ -158,7 +158,10 @@ def validation(self, valid_loader): config.batch_size, config.num_workers, config.n_splits, - mask_only=config.mask_only_flag + mask_only=config.mask_only_flag, + crop=config.crop, + height=config.height, + width=config.width ) for fold_index, [train_loader, valid_loader] in enumerate(dataloaders): if fold_index != 1: diff --git a/utils/data_augmentation.py b/utils/data_augmentation.py index 0c82dac..475d316 100644 --- a/utils/data_augmentation.py +++ b/utils/data_augmentation.py @@ -10,7 +10,7 @@ RandomBrightness, RandomContrast, RandomGamma, OneOf, ToFloat, ShiftScaleRotate, GridDistortion, ElasticTransform, JpegCompression, HueSaturationValue, RGBShift, RandomBrightnessContrast, RandomContrast, Blur, MotionBlur, MedianBlur, GaussNoise,CenterCrop, - IAAAdditiveGaussianNoise,GaussNoise,Cutout,Rotate, Normalize + IAAAdditiveGaussianNoise,GaussNoise,Cutout,Rotate, Normalize, Crop, RandomCrop ) sys.path.append('.') @@ -44,7 +44,7 @@ def visualize(image, mask, original_image=None, original_mask=None): plt.show() -def data_augmentation(original_image, original_mask): +def data_augmentation(original_image, original_mask, crop=False, height=None, width=None): """进行样本和掩膜的随机增强 Args: @@ -78,6 +78,14 @@ def data_augmentation(original_image, original_mask): ], p=0.2) ]) + if crop: + # 是否进行随机裁剪 + assert height and width + crop_aug = RandomCrop(height=height, width=width, always_apply=True) + crop_sample = crop_aug(image=original_image, mask=original_mask) + original_image = crop_sample['image'] + original_mask = crop_sample['mask'] + augmented = augmentations(image=original_image, mask=original_mask) image_aug = augmented['image'] mask_aug = augmented['mask'] @@ -102,7 +110,7 @@ def data_augmentation(original_image, original_mask): image_id, mask = make_mask(index, df) image_path = os.path.join(data_folder, 'train_images', image_id) image = cv2.imread(image_path) - image_aug, mask_aug = data_augmentation(image, mask) + image_aug, mask_aug = data_augmentation(image, mask, crop=True, height=256, width=400) normalize = Normalize(mean=mean, std=std) image = normalize(image=image)['image'] image_aug = normalize(image=image_aug)['image']