diff --git a/choose_thre_area.py b/choose_thre_area.py index 246aec6..f3ec347 100644 --- a/choose_thre_area.py +++ b/choose_thre_area.py @@ -201,8 +201,9 @@ def get_model(model_name, load_path): config = get_seg_config() mean=(0.485, 0.456, 0.406) std=(0.229, 0.224, 0.225) - - dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits) + mask_only = True + + dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits, mask_only) results = {} # 存放权重的路径 model_path = os.path.join(config.save_path, config.model_name) diff --git a/classify_segment.py b/classify_segment.py index fe2ffd4..7195c48 100644 --- a/classify_segment.py +++ b/classify_segment.py @@ -204,6 +204,65 @@ def classify_segment_folds(self, images): return results +class Classify_Segment_Folds_Split(): + def __init__(self, model_name, classify_folds, segment_folds, model_path, class_num=4, tta_flag=False): + ''' 首先的到分类模型的集成结果,再得到分割模型的集成结果,最后将两个结果进行融合 + + :param model_name: 当前的模型名称 + :param classify_folds: 参与集成的分类模型的折序号,为list列表 + :param segment_folds: 参与集成的分割模型的折序号,为list列表 + :param model_path: 存放所有模型的路径 + :param class_num: 类别总数 + ''' + self.model_name = model_name + self.classify_folds = classify_folds + self.segment_folds = segment_folds + self.model_path = model_path + self.class_num = class_num + self.tta_flag = tta_flag + + self.classify_models, self.segment_models = list(), list() + self.get_classify_segment_models() + + def get_classify_segment_models(self): + ''' 加载所有折的分割模型和分类模型 + ''' + for fold in self.classify_folds: + self.classify_models.append(Get_Classify_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag)) + for fold in self.segment_folds: + self.segment_models.append(Get_Segment_Results(self.model_name, fold, self.model_path, self.class_num, tta_flag=self.tta_flag)) + + def classify_segment_folds(self, images): + ''' 使用投票法处理所有fold一个batch的分割结果和分类结果 + + :param images: 一个batch的数据,维度为[batch, channels, height, width] + :return: results,使用投票法处理所有fold一个batch的分割结果和分类结果,维度为[batch, class_num, height, width] + ''' + classify_results = torch.zeros(images.shape[0], self.class_num) + segment_results = torch.zeros(images.shape[0], self.class_num, images.shape[2], images.shape[3]) + # 得到分类结果 + for classify_index, classify_model in enumerate(self.classify_models): + classify_result_fold = classify_model.get_classify_results(images) + classify_results += classify_result_fold.detach().cpu().squeeze().float() + classify_vote_model_num = len(self.classify_folds) + classify_vote_ticket = round(classify_vote_model_num / 2.0) + classify_results = classify_results > classify_vote_ticket + + # 得到分割结果 + for segment_index, segment_model in enumerate(self.segment_models): + segment_result_fold = segment_model.get_segment_results(images) + segment_results += segment_result_fold.detach().cpu() + segment_vote_model_num = len(self.segment_folds) + segment_vote_ticket = round(segment_vote_model_num / 2.0) + segment_results = segment_results > segment_vote_ticket + + # 将分类结果和分割结果进行融合 + for batch_index, classify_result in enumerate(classify_results): + segment_results[batch_index, 1-classify_result, ...] = 0 + + return segment_results + + class Segment_Folds(): def __init__(self, model_name, n_splits, model_path, class_num=4, tta_flag=False): ''' 使用投票法处理所有fold一个batch的分割结果 diff --git a/config.py b/config.py index 7fb9fc4..547acc8 100755 --- a/config.py +++ b/config.py @@ -19,12 +19,15 @@ def get_seg_config(): hwp: 6 ''' # parser.add_argument('--image_size', type=int, default=768, help='image size') - parser.add_argument('--batch_size', type=int, default=24, help='batch size') - parser.add_argument('--epoch', type=int, default=60, help='epoch') + parser.add_argument('--batch_size', type=int, default=4, help='batch size') + parser.add_argument('--epoch', type=int, default=65, help='epoch') parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set') parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold') parser.add_argument('--mask_only_flag', type=bool, default=False, help='if true, use masked data only.') + parser.add_argument('--crop', type=bool, default=True, help='if true, crop image to [height, width].') + parser.add_argument('--height', type=int, default=256, help='the height of cropped image') + parser.add_argument('--width', type=int, default=512, help='the width of cropped image') # model set parser.add_argument('--model_name', type=str, default='unet_resnet34', \ @@ -33,7 +36,7 @@ def get_seg_config(): # model hyper-parameters parser.add_argument('--class_num', type=int, default=4) parser.add_argument('--num_workers', type=int, default=8) - parser.add_argument('--lr', type=float, default=4e-4, help='init lr') + parser.add_argument('--lr', type=float, default=1e-4, help='init lr') parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay in optimizer') # dataset @@ -67,6 +70,9 @@ def get_classify_config(): parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set') parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold') + parser.add_argument('--crop', type=bool, default=False, help='if true, crop image to [height, width].') + parser.add_argument('--height', type=int, default=256, help='the height of cropped image') + parser.add_argument('--width', type=int, default=512, help='the width of cropped image') # model set parser.add_argument('--model_name', type=str, default='unet_resnet34', \ @@ -89,4 +95,4 @@ def get_classify_config(): if __name__ == '__main__': - config = get_seg_config() \ No newline at end of file + config = get_seg_config() diff --git a/create_submission.py b/create_submission.py index 9458e1f..9e61cea 100644 --- a/create_submission.py +++ b/create_submission.py @@ -14,7 +14,7 @@ package_path = '/kaggle/input/sources' # add unet script dataset import sys sys.path.append(package_path) -from classify_segment import Classify_Segment_Folds, Classify_Segment_Fold +from classify_segment import Classify_Segment_Folds, Classify_Segment_Fold, Classify_Segment_Folds_Split class TestDataset(Dataset): @@ -55,10 +55,11 @@ def mask2rle(img): return ' '.join(str(x) for x in runs) -def create_submission(n_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path, tta_flag=False): +def create_submission(classify_splits, seg_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path, tta_flag=False): ''' - :param n_splits: 折数,类型为list + :param classify_splits: 分类模型的折数,类型为list + :param seg_splits: 分割模型的折数,类型为list :param model_name: 当前模型的名称 :param batch_size: batch的大小 :param num_workers: 加载数据的线程 @@ -79,10 +80,12 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std, num_workers=num_workers, pin_memory=True ) - if len(n_splits) == 1: - classify_segment = Classify_Segment_Fold(model_name, n_splits[0], model_path, tta_flag=tta_flag).classify_segment - else: - classify_segment = Classify_Segment_Folds(model_name, n_splits, model_path, tta_flag=tta_flag).classify_segment_folds + if len(classify_splits) == 1 and len(seg_splits) == 1: + classify_segment = Classify_Segment_Fold(model_name, classify_splits[0], model_path, tta_flag=tta_flag).classify_segment + elif len(classify_splits) == len(seg_splits): + classify_segment = Classify_Segment_Folds(model_name, classify_splits, model_path, tta_flag=tta_flag).classify_segment_folds + elif len(classify_splits) != len(seg_splits): + classify_segment = Classify_Segment_Folds_Split(model_name, classify_splits, seg_splits, model_path, tta_flag=tta_flag).classify_segment_folds # start prediction predictions = [] @@ -104,11 +107,12 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std, # 设置超参数 model_name = 'unet_resnet34' num_workers = 12 - batch_size = 6 + batch_size = 4 mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) - n_splits = [1] # [0, 1, 2, 3, 4] - tta_flag = False + classify_splits = [1] # [0, 1, 2, 3, 4] + segment_splits = [0, 1, 2, 3, 4] + tta_flag = True if kaggle: sample_submission_path = '/kaggle/input/severstal-steel-defect-detection/sample_submission.csv' @@ -119,5 +123,5 @@ def create_submission(n_splits, model_name, batch_size, num_workers, mean, std, test_data_folder = 'datasets/Steel_data/test_images' model_path = './checkpoints/' + model_name - create_submission(n_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, + create_submission(classify_splits, segment_splits, model_name, batch_size, num_workers, mean, std, test_data_folder, sample_submission_path, model_path, tta_flag=tta_flag) diff --git a/datasets/steel_dataset.py b/datasets/steel_dataset.py index 5b453cf..6747c51 100644 --- a/datasets/steel_dataset.py +++ b/datasets/steel_dataset.py @@ -23,7 +23,7 @@ # Dataset Segmentation class SteelDataset(Dataset): - def __init__(self, df, data_folder, mean, std, phase): + def __init__(self, df, data_folder, mean, std, phase, crop=False, height=None, width=None): super(SteelDataset, self).__init__() self.df = df self.root = data_folder @@ -31,13 +31,16 @@ def __init__(self, df, data_folder, mean, std, phase): self.std = std self.phase = phase self.transforms = get_transforms + self.crop = crop + self.height = height + self.width = width self.fnames = self.df.index.tolist() def __getitem__(self, idx): image_id, mask = make_mask(idx, self.df) image_path = os.path.join(self.root, "train_images", image_id) img = cv2.imread(image_path) - img, mask = self.transforms(self.phase, img, mask, self.mean, self.std) + img, mask = self.transforms(self.phase, img, mask, self.mean, self.std, crop=self.crop, height=self.height, width=self.width) mask = mask.permute(2, 0, 1) return img, mask @@ -47,7 +50,7 @@ def __len__(self): # Dataset Classification class SteelClassDataset(Dataset): - def __init__(self, df, data_folder, mean, std, phase): + def __init__(self, df, data_folder, mean, std, phase, crop=False, height=None, width=None): super(SteelClassDataset, self).__init__() self.df = df self.root = data_folder @@ -55,13 +58,16 @@ def __init__(self, df, data_folder, mean, std, phase): self.std = std self.phase = phase self.transforms = get_transforms + self.crop = crop + self.height = height + self.width = width self.fnames = self.df.index.tolist() def __getitem__(self, idx): image_id, mask = make_mask(idx, self.df) image_path = os.path.join(self.root, "train_images", image_id) img = cv2.imread(image_path) - img, mask = self.transforms(self.phase, img, mask, self.mean, self.std) + img, mask = self.transforms(self.phase, img, mask, self.mean, self.std, crop=self.crop, height=self.height, width=self.width) mask = mask.permute(2, 0, 1) # 4x256x1600 mask = mask.view(mask.size(0), -1) mask = torch.sum(mask, dim=1) @@ -99,7 +105,7 @@ def __len__(self): return self.num_samples -def augmentation(image, mask): +def augmentation(image, mask, crop=False, height=None, width=None): """进行数据增强 Args: image: 原始图像 @@ -108,16 +114,16 @@ def augmentation(image, mask): image_aug: 增强后的图像,Image图像 mask: 增强后的掩膜,Image图像 """ - image_aug, mask_aug = data_augmentation(image, mask) + image_aug, mask_aug = data_augmentation(image, mask, crop=crop, height=height, width=width) image_aug = Image.fromarray(image_aug) return image_aug, mask_aug -def get_transforms(phase, image, mask, mean, std): +def get_transforms(phase, image, mask, mean, std, crop=False, height=None, width=None): if phase == 'train': - image, mask = augmentation(image, mask) + image, mask = augmentation(image, mask, crop=crop, height=height, width=width) to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean, std) @@ -152,7 +158,10 @@ def provider( batch_size=8, num_workers=4, n_splits=0, - mask_only=False + mask_only=False, + crop=False, + height=None, + width=None ): """返回数据加载器,用于分割模型 @@ -192,7 +201,7 @@ def provider( # 生成dataloader dataloaders = list() for df_index, (train_df, val_df) in enumerate(zip(train_dfs, val_dfs)): - train_dataset = SteelDataset(train_df, data_folder, mean, std, 'train') + train_dataset = SteelDataset(train_df, data_folder, mean, std, 'train', crop=crop, height=height, width=width) val_dataset = SteelDataset(val_df, data_folder, mean, std, 'val') if mask_only: # 只在有掩膜的样本上训练 @@ -242,6 +251,9 @@ def classify_provider( batch_size=8, num_workers=4, n_splits=0, + crop=False, + height=None, + width=False ): """返回数据加载器,用于分类模型 @@ -281,7 +293,7 @@ def classify_provider( # 生成dataloader dataloaders = list() for df_index, (train_df, val_df) in enumerate(zip(train_dfs, val_dfs)): - train_dataset = SteelClassDataset(train_df, data_folder, mean, std, 'train') + train_dataset = SteelClassDataset(train_df, data_folder, mean, std, 'train', crop=crop, height=height, width=width) val_dataset = SteelClassDataset(val_df, data_folder, mean, std, 'val') train_dataloader = DataLoader( train_dataset, @@ -303,15 +315,19 @@ def classify_provider( if __name__ == "__main__": - data_folder = "Steel_data" - df_path = "Steel_data/train.csv" + data_folder = "datasets/Steel_data" + df_path = "datasets/Steel_data/train.csv" mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) - batch_size = 8 + batch_size = 12 num_workers = 4 n_splits = 1 + mask_only = False + crop = True + height = 256 + width = 512 # 测试分割数据集 - dataloader = provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits) + dataloader = provider(data_folder, df_path, mean, std, batch_size, num_workers, n_splits, mask_only=mask_only, crop=crop, height=height, width=width) for fold_index, [train_dataloader, val_dataloader] in enumerate(dataloader): train_bar = tqdm(train_dataloader) class_color = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [139, 0, 139]] diff --git a/demo.py b/demo.py index 7519ac7..f821267 100644 --- a/demo.py +++ b/demo.py @@ -104,7 +104,7 @@ def pred_show(images, preds, mean, std, targets=None, flag=False, auto_flag=Fals # 是否显示真实的mask show_truemask_flag = True # 加载哪几折的模型进行测试,若list中有多个值,则使用投票法 - n_splits = [1] # [0, 1, 2, 3, 4] + n_splits = [0, 1, 2, 3, 4] # [0, 1, 2, 3, 4] # 是否只使用分割模型 use_segment_only = True # 是否使用自动显示 diff --git a/experiments.md b/experiments.md new file mode 100644 index 0000000..aab5ba4 --- /dev/null +++ b/experiments.md @@ -0,0 +1,114 @@ +# 钢铁损伤检测试验记录 + +## 分割模型使用加权BCE+DICE损失 + +### 学习率:4e-4,Epoch: 70 + +#### 训练曲线 + +* 训练集损失曲线 + + ![训练集损失曲线](images/Screenshot from 2019-10-05 17-23-44.png) + + 如上图所示,研究曲线后发现,模型在训练前期过早收敛,后期收敛缓慢,因而认为学习率可能过大。 + +* 验证集损失曲线和dice曲线 + + ![验证集损失曲线](images/Screenshot from 2019-10-05 17-24-09.png) + + ![验证集DICE](images/Screenshot from 2019-10-05 17-24-01.png) + + 观察损失曲线和dice曲线,可以发现: + + * 学习率有点高 + * 验证集的指标仍在上升 + +#### LB + +0.89005 + +#### 改进策略 + +设备:LZD + +1. 降低学习率,1e-4 +2. 多跑几个epoch,75 + +## 只使用有掩膜的样本训练分割模型,不加载分类权重 + +### 学习率:4e-4,Epoch:60 + +#### 训练曲线 + +* 训练集损失曲线 + + ![](images/Screenshot from 2019-10-05 17-47-24.png) + + 训练前期,曲线下降过快,后期缓慢,掉入局部最小点。 + +* 验证集损失曲线和Dice曲线 + + ![](images/Screenshot from 2019-10-05 17-47-55.png) + + ![](images/Screenshot from 2019-10-05 17-47-37.png) + + 验证集曲线比较正常。 + +#### LB + +0.90092 + +#### 改进策略 + +设备:MXQ + +1. 降低学习率,1e-4 +2. 增加训练epoch,70 epoch + +## unet_se_renext50 + +### 学习率:4e-4,Epoch:60 + +#### 训练曲线 + +* 训练集损失曲线 + + ![](images/Screenshot from 2019-10-05 18-19-45.png) + + 由于batch_size比较小的原因,模型的训练损失曲线很震荡,同时也有点过早收敛。 + +* 验证集损失曲线和Dice曲线 + + ![](images/Screenshot from 2019-10-05 18-20-10.png) + + ![](images/Screenshot from 2019-10-05 18-19-54.png) + + 观察验证集的损失曲线,很容易发现发生了过拟合现象。 + +#### LB + +* TTA: 0.90268 +* w/ TTA:0.90222 + +### 改进策略 + +设备:HWP + +需要解决batch_size过小和过拟合问题。 + +* batch_size + + * 使用GN(分组归一化) + * 用MXQ的设备 + +* 过拟合 + + * 早停,45个epoch + * 权重衰减,5e-4 + +* 稍微降低学习率 + + 4e-4 -> 4e-5 + + + diff --git a/images/Screenshot from 2019-10-05 17-23-44.png b/images/Screenshot from 2019-10-05 17-23-44.png new file mode 100644 index 0000000..5edbabe Binary files /dev/null and b/images/Screenshot from 2019-10-05 17-23-44.png differ diff --git a/images/Screenshot from 2019-10-05 17-24-01.png b/images/Screenshot from 2019-10-05 17-24-01.png new file mode 100644 index 0000000..3c550a5 Binary files /dev/null and b/images/Screenshot from 2019-10-05 17-24-01.png differ diff --git a/images/Screenshot from 2019-10-05 17-24-09.png b/images/Screenshot from 2019-10-05 17-24-09.png new file mode 100644 index 0000000..44ea3ee Binary files /dev/null and b/images/Screenshot from 2019-10-05 17-24-09.png differ diff --git a/images/Screenshot from 2019-10-05 17-47-24.png b/images/Screenshot from 2019-10-05 17-47-24.png new file mode 100644 index 0000000..4bf5216 Binary files /dev/null and b/images/Screenshot from 2019-10-05 17-47-24.png differ diff --git a/images/Screenshot from 2019-10-05 17-47-37.png b/images/Screenshot from 2019-10-05 17-47-37.png new file mode 100644 index 0000000..6d74169 Binary files /dev/null and b/images/Screenshot from 2019-10-05 17-47-37.png differ diff --git a/images/Screenshot from 2019-10-05 17-47-55.png b/images/Screenshot from 2019-10-05 17-47-55.png new file mode 100644 index 0000000..adbeee8 Binary files /dev/null and b/images/Screenshot from 2019-10-05 17-47-55.png differ diff --git a/images/Screenshot from 2019-10-05 18-19-45.png b/images/Screenshot from 2019-10-05 18-19-45.png new file mode 100644 index 0000000..2c940a8 Binary files /dev/null and b/images/Screenshot from 2019-10-05 18-19-45.png differ diff --git a/images/Screenshot from 2019-10-05 18-19-54.png b/images/Screenshot from 2019-10-05 18-19-54.png new file mode 100644 index 0000000..95f841a Binary files /dev/null and b/images/Screenshot from 2019-10-05 18-19-54.png differ diff --git a/images/Screenshot from 2019-10-05 18-20-10.png b/images/Screenshot from 2019-10-05 18-20-10.png new file mode 100644 index 0000000..337a6fa Binary files /dev/null and b/images/Screenshot from 2019-10-05 18-20-10.png differ diff --git a/run.sh b/run.sh index cdab892..8351016 100755 --- a/run.sh +++ b/run.sh @@ -1,3 +1,4 @@ #!/bin/bash -python train_classify.py -python train_segment.py \ No newline at end of file +# python train_classify.py +python train_segment.py +python choose_thre_area.py \ No newline at end of file diff --git a/train_classify.py b/train_classify.py index 64f60e8..e4f3b93 100755 --- a/train_classify.py +++ b/train_classify.py @@ -154,7 +154,18 @@ def validation(self, valid_loader): config = get_classify_config() mean=(0.485, 0.456, 0.406) std=(0.229, 0.224, 0.225) - dataloaders = classify_provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits) + dataloaders = classify_provider( + config.dataset_root, + os.path.join(config.dataset_root, 'train.csv'), + mean, + std, + config.batch_size, + config.num_workers, + config.n_splits, + crop=config.crop, + height=config.height, + width=config.width + ) for fold_index, [train_loader, valid_loader] in enumerate(dataloaders): if fold_index != 1: continue diff --git a/train_segment.py b/train_segment.py index ec76da2..dfd7426 100644 --- a/train_segment.py +++ b/train_segment.py @@ -38,8 +38,8 @@ def __init__(self, config, fold): self.solver = Solver(self.model) # 加载损失函数 - # self.criterion = torch.nn.BCEWithLogitsLoss() - self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[0.75, 0.25]) + self.criterion = torch.nn.BCEWithLogitsLoss() + # self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[0.75, 0.25]) # 创建保存权重的路径 self.model_path = os.path.join(config.save_path, config.model_name) @@ -158,7 +158,10 @@ def validation(self, valid_loader): config.batch_size, config.num_workers, config.n_splits, - mask_only=config.mask_only_flag + mask_only=config.mask_only_flag, + crop=config.crop, + height=config.height, + width=config.width ) for fold_index, [train_loader, valid_loader] in enumerate(dataloaders): if fold_index != 1: diff --git a/utils/data_augmentation.py b/utils/data_augmentation.py index 0c82dac..475d316 100644 --- a/utils/data_augmentation.py +++ b/utils/data_augmentation.py @@ -10,7 +10,7 @@ RandomBrightness, RandomContrast, RandomGamma, OneOf, ToFloat, ShiftScaleRotate, GridDistortion, ElasticTransform, JpegCompression, HueSaturationValue, RGBShift, RandomBrightnessContrast, RandomContrast, Blur, MotionBlur, MedianBlur, GaussNoise,CenterCrop, - IAAAdditiveGaussianNoise,GaussNoise,Cutout,Rotate, Normalize + IAAAdditiveGaussianNoise,GaussNoise,Cutout,Rotate, Normalize, Crop, RandomCrop ) sys.path.append('.') @@ -44,7 +44,7 @@ def visualize(image, mask, original_image=None, original_mask=None): plt.show() -def data_augmentation(original_image, original_mask): +def data_augmentation(original_image, original_mask, crop=False, height=None, width=None): """进行样本和掩膜的随机增强 Args: @@ -78,6 +78,14 @@ def data_augmentation(original_image, original_mask): ], p=0.2) ]) + if crop: + # 是否进行随机裁剪 + assert height and width + crop_aug = RandomCrop(height=height, width=width, always_apply=True) + crop_sample = crop_aug(image=original_image, mask=original_mask) + original_image = crop_sample['image'] + original_mask = crop_sample['mask'] + augmented = augmentations(image=original_image, mask=original_mask) image_aug = augmented['image'] mask_aug = augmented['mask'] @@ -102,7 +110,7 @@ def data_augmentation(original_image, original_mask): image_id, mask = make_mask(index, df) image_path = os.path.join(data_folder, 'train_images', image_id) image = cv2.imread(image_path) - image_aug, mask_aug = data_augmentation(image, mask) + image_aug, mask_aug = data_augmentation(image, mask, crop=True, height=256, width=400) normalize = Normalize(mean=mean, std=std) image = normalize(image=image)['image'] image_aug = normalize(image=image_aug)['image']