diff --git a/config.py b/config.py index d6817bd..7fb9fc4 100755 --- a/config.py +++ b/config.py @@ -24,6 +24,7 @@ def get_seg_config(): parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set') parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold') + parser.add_argument('--mask_only_flag', type=bool, default=False, help='if true, use masked data only.') # model set parser.add_argument('--model_name', type=str, default='unet_resnet34', \ diff --git a/datasets/steel_dataset.py b/datasets/steel_dataset.py index ebcb2ea..5b453cf 100644 --- a/datasets/steel_dataset.py +++ b/datasets/steel_dataset.py @@ -9,6 +9,7 @@ from sklearn.model_selection import train_test_split, StratifiedKFold import torch from torch.utils.data import DataLoader, Dataset, sampler +from torch.utils.data.dataloader import default_collate from torchvision import transforms from albumentations.pytorch import ToTensor import sys @@ -127,6 +128,22 @@ def get_transforms(phase, image, mask, mean, std): return image, mask +def mask_only_collate_fun(batch): + """自定义collate_fn函数,用于从一个batch中去除没有掩膜的样本 + """ + batch_scale = list() + for image, mask in batch: + pair = list() + mask_pixel_num = torch.sum(mask) + if mask_pixel_num > 0: + pair.append(image) + pair.append(mask) + batch_scale.append(pair) + batch_scale = default_collate(batch_scale) + + return batch_scale + + def provider( data_folder, df_path, @@ -135,6 +152,7 @@ def provider( batch_size=8, num_workers=4, n_splits=0, + mask_only=False ): """返回数据加载器,用于分割模型 @@ -146,6 +164,7 @@ def provider( batch_size num_workers n_split: 交叉验证折数,为1时不使用交叉验证 + mask_only: 是否只在有掩膜的样本上训练分割模型 Return: dataloadrs: list,该list中的每一个元素为list,元素list中保存训练集和验证集 @@ -175,20 +194,41 @@ def provider( for df_index, (train_df, val_df) in enumerate(zip(train_dfs, val_dfs)): train_dataset = SteelDataset(train_df, data_folder, mean, std, 'train') val_dataset = SteelDataset(val_df, data_folder, mean, std, 'val') - train_dataloader = DataLoader( - train_dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=True, - shuffle=True + if mask_only: + # 只在有掩膜的样本上训练 + print('Segmentation modle: only masked data.') + train_dataloader = DataLoader( + train_dataset, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=mask_only_collate_fun, + pin_memory=True, + shuffle=True + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=mask_only_collate_fun, + pin_memory=True, + shuffle=True + ) + else: + print('Segmentation model: all data.') + train_dataloader = DataLoader( + train_dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + shuffle=True + ) + val_dataloader = DataLoader( + val_dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + shuffle=True ) - val_dataloader = DataLoader( - val_dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=True, - shuffle=True - ) dataloaders.append([train_dataloader, val_dataloader]) return dataloaders diff --git a/train_segment.py b/train_segment.py index 34baa0d..ec76da2 100644 --- a/train_segment.py +++ b/train_segment.py @@ -14,6 +14,8 @@ from utils.set_seed import seed_torch from config import get_seg_config from solver import Solver +from utils.loss import MultiClassesSoftBCEDiceLoss + class TrainVal(): def __init__(self, config, fold): @@ -36,7 +38,8 @@ def __init__(self, config, fold): self.solver = Solver(self.model) # 加载损失函数 - self.criterion = torch.nn.BCEWithLogitsLoss() + # self.criterion = torch.nn.BCEWithLogitsLoss() + self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[0.75, 0.25]) # 创建保存权重的路径 self.model_path = os.path.join(config.save_path, config.model_name) @@ -145,9 +148,18 @@ def validation(self, valid_loader): if __name__ == "__main__": config = get_seg_config() - mean=(0.485, 0.456, 0.406) - std=(0.229, 0.224, 0.225) - dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits) + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + dataloaders = provider( + config.dataset_root, + os.path.join(config.dataset_root, 'train.csv'), + mean, + std, + config.batch_size, + config.num_workers, + config.n_splits, + mask_only=config.mask_only_flag + ) for fold_index, [train_loader, valid_loader] in enumerate(dataloaders): if fold_index != 1: continue diff --git a/utils/loss.py b/utils/loss.py index 8df7817..209bba7 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -32,3 +32,131 @@ def forward(self, logit, truth): # raise NotImplementedError return loss + + +# reference https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/101429#latest-588288 +class SoftDiceLoss(nn.Module): + """二分类加权dice损失 + """ + def __init__(self, size_average=True, weight=[0.2, 0.8]): + """ + weight: 各类别权重 + """ + super(SoftDiceLoss, self).__init__() + self.size_average = size_average + self.weight = torch.FloatTensor(weight) + + def forward(self, logit_pixel, truth_pixel): + batch_size = len(logit_pixel) + logit = logit_pixel.view(batch_size, -1) + truth = truth_pixel.view(batch_size, -1) + assert(logit.shape == truth.shape) + + loss = self.soft_dice_criterion(logit, truth) + + if self.size_average: + loss = loss.mean() + return loss + + def soft_dice_criterion(self, logit, truth): + batch_size = len(logit) + probability = torch.sigmoid(logit) + + p = probability.view(batch_size, -1) + t = truth.view(batch_size, -1) + # 向各样本分配所属类别的权重 + w = truth.detach() + self.weight = self.weight.type_as(logit) + w = w * (self.weight[1] - self.weight[0]) + self.weight[0] + + p = w * (p*2 - 1) #convert to [0,1] --> [-1, 1] + t = w * (t*2 - 1) + + intersection = (p * t).sum(-1) + union = (p * p).sum(-1) + (t * t).sum(-1) + dice = 1 - 2 * intersection/union + + loss = dice + return loss + + +# reference https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/101429#latest-588288 +class SoftBceLoss(nn.Module): + """二分类交叉熵加权损失 + """ + def __init__(self, weight=[0.25, 0.75]): + super(SoftBceLoss, self).__init__() + self.weight = weight + + def forward(self, logit_pixel, truth_pixel): + logit = logit_pixel.view(-1) + truth = truth_pixel.view(-1) + assert(logit.shape==truth.shape) + + loss = F.binary_cross_entropy_with_logits(logit, truth, reduction='none') + if self.weight: + pos = (truth>0.5).float() + neg = (truth<0.5).float() + # pos_weight = pos.sum().item() + 1e-12 + # neg_weight = neg.sum().item() + 1e-12 + # loss = (self.weight[0]*pos*loss/pos_weight + self.weight[1]*neg*loss/neg_weight).sum() + loss = (self.weight[1]*pos*loss + self.weight[0]*neg*loss).mean() + else: + loss = loss.mean() + return loss + + +class SoftBCEDiceLoss(nn.Module): + """加权BCE+DiceLoss + """ + def __init__(self, size_average=True, weight=[0.2, 0.8]): + """ + weight: weight[0]为负类的权重,weight[1]为正类的权重 + """ + super(SoftBCEDiceLoss, self).__init__() + self.size_average = size_average + self.weight = weight + self.bce_loss = nn.BCEWithLogitsLoss(size_average=self.size_average, pos_weight=torch.tensor(self.weight[1])) + # self.bce_loss = SoftBceLoss(weight=weight) + self.softdiceloss = SoftDiceLoss(size_average=self.size_average, weight=weight) + + def forward(self, input, target): + soft_bce_loss = self.bce_loss(input, target) + soft_dice_loss = self.softdiceloss(input, target) + loss = soft_bce_loss + soft_dice_loss + + return loss + + +class MultiClassesSoftBCEDiceLoss(nn.Module): + def __init__(self, classes_num=4, size_average=True, weight=[0.2, 0.8]): + super(MultiClassesSoftBCEDiceLoss, self).__init__() + self.classes_num = classes_num + self.size_average = size_average + self.weight = weight + self.soft_bce_dice_loss = SoftBCEDiceLoss(size_average=self.size_average, weight=self.weight) + + def forward(self, input, target): + """ + Args: + input: tensor, [batch_size, classes_num, height, width] + target: tensor, [batch_size, classes_num, height, width] + """ + loss = 0 + for class_index in range(self.classes_num): + input_single_class = input[:, class_index, :, :] + target_singlt_class = target[:, class_index, :, :] + single_class_loss = self.soft_bce_dice_loss(input_single_class, target_singlt_class) + loss += single_class_loss + + loss /= self.classes_num + + return loss + + +if __name__ == "__main__": + input = torch.Tensor(4, 4, 256, 1600) + target = torch.Tensor(4, 4, 256, 1600) + criterion = MultiClassesSoftBCEDiceLoss(4, True) + loss = criterion(input, target) +