Skip to content

Commit

Permalink
add soft_bce_dice loss and segmentation mask only
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Oct 3, 2019
1 parent 2213fb0 commit 3f31fe6
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 17 deletions.
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_seg_config():

parser.add_argument('--augmentation_flag', type=bool, default=True, help='if true, use augmentation method in train set')
parser.add_argument('--n_splits', type=int, default=5, help='n_splits_fold')
parser.add_argument('--mask_only_flag', type=bool, default=False, help='if true, use masked data only.')

# model set
parser.add_argument('--model_name', type=str, default='unet_resnet34', \
Expand Down
66 changes: 53 additions & 13 deletions datasets/steel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sklearn.model_selection import train_test_split, StratifiedKFold
import torch
from torch.utils.data import DataLoader, Dataset, sampler
from torch.utils.data.dataloader import default_collate
from torchvision import transforms
from albumentations.pytorch import ToTensor
import sys
Expand Down Expand Up @@ -127,6 +128,22 @@ def get_transforms(phase, image, mask, mean, std):
return image, mask


def mask_only_collate_fun(batch):
"""自定义collate_fn函数,用于从一个batch中去除没有掩膜的样本
"""
batch_scale = list()
for image, mask in batch:
pair = list()
mask_pixel_num = torch.sum(mask)
if mask_pixel_num > 0:
pair.append(image)
pair.append(mask)
batch_scale.append(pair)
batch_scale = default_collate(batch_scale)

return batch_scale


def provider(
data_folder,
df_path,
Expand All @@ -135,6 +152,7 @@ def provider(
batch_size=8,
num_workers=4,
n_splits=0,
mask_only=False
):
"""返回数据加载器,用于分割模型
Expand All @@ -146,6 +164,7 @@ def provider(
batch_size
num_workers
n_split: 交叉验证折数,为1时不使用交叉验证
mask_only: 是否只在有掩膜的样本上训练分割模型
Return:
dataloadrs: list,该list中的每一个元素为list,元素list中保存训练集和验证集
Expand Down Expand Up @@ -175,20 +194,41 @@ def provider(
for df_index, (train_df, val_df) in enumerate(zip(train_dfs, val_dfs)):
train_dataset = SteelDataset(train_df, data_folder, mean, std, 'train')
val_dataset = SteelDataset(val_df, data_folder, mean, std, 'val')
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=True
if mask_only:
# 只在有掩膜的样本上训练
print('Segmentation modle: only masked data.')
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=mask_only_collate_fun,
pin_memory=True,
shuffle=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=mask_only_collate_fun,
pin_memory=True,
shuffle=True
)
else:
print('Segmentation model: all data.')
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=True
)
dataloaders.append([train_dataloader, val_dataloader])

return dataloaders
Expand Down
20 changes: 16 additions & 4 deletions train_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from utils.set_seed import seed_torch
from config import get_seg_config
from solver import Solver
from utils.loss import MultiClassesSoftBCEDiceLoss


class TrainVal():
def __init__(self, config, fold):
Expand All @@ -36,7 +38,8 @@ def __init__(self, config, fold):
self.solver = Solver(self.model)

# 加载损失函数
self.criterion = torch.nn.BCEWithLogitsLoss()
# self.criterion = torch.nn.BCEWithLogitsLoss()
self.criterion = MultiClassesSoftBCEDiceLoss(classes_num=4, size_average=True, weight=[0.75, 0.25])

# 创建保存权重的路径
self.model_path = os.path.join(config.save_path, config.model_name)
Expand Down Expand Up @@ -145,9 +148,18 @@ def validation(self, valid_loader):

if __name__ == "__main__":
config = get_seg_config()
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
dataloaders = provider(config.dataset_root, os.path.join(config.dataset_root, 'train.csv'), mean, std, config.batch_size, config.num_workers, config.n_splits)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
dataloaders = provider(
config.dataset_root,
os.path.join(config.dataset_root, 'train.csv'),
mean,
std,
config.batch_size,
config.num_workers,
config.n_splits,
mask_only=config.mask_only_flag
)
for fold_index, [train_loader, valid_loader] in enumerate(dataloaders):
if fold_index != 1:
continue
Expand Down
128 changes: 128 additions & 0 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,131 @@ def forward(self, logit, truth):
# raise NotImplementedError

return loss


# reference https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/101429#latest-588288
class SoftDiceLoss(nn.Module):
"""二分类加权dice损失
"""
def __init__(self, size_average=True, weight=[0.2, 0.8]):
"""
weight: 各类别权重
"""
super(SoftDiceLoss, self).__init__()
self.size_average = size_average
self.weight = torch.FloatTensor(weight)

def forward(self, logit_pixel, truth_pixel):
batch_size = len(logit_pixel)
logit = logit_pixel.view(batch_size, -1)
truth = truth_pixel.view(batch_size, -1)
assert(logit.shape == truth.shape)

loss = self.soft_dice_criterion(logit, truth)

if self.size_average:
loss = loss.mean()
return loss

def soft_dice_criterion(self, logit, truth):
batch_size = len(logit)
probability = torch.sigmoid(logit)

p = probability.view(batch_size, -1)
t = truth.view(batch_size, -1)
# 向各样本分配所属类别的权重
w = truth.detach()
self.weight = self.weight.type_as(logit)
w = w * (self.weight[1] - self.weight[0]) + self.weight[0]

p = w * (p*2 - 1) #convert to [0,1] --> [-1, 1]
t = w * (t*2 - 1)

intersection = (p * t).sum(-1)
union = (p * p).sum(-1) + (t * t).sum(-1)
dice = 1 - 2 * intersection/union

loss = dice
return loss


# reference https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/discussion/101429#latest-588288
class SoftBceLoss(nn.Module):
"""二分类交叉熵加权损失
"""
def __init__(self, weight=[0.25, 0.75]):
super(SoftBceLoss, self).__init__()
self.weight = weight

def forward(self, logit_pixel, truth_pixel):
logit = logit_pixel.view(-1)
truth = truth_pixel.view(-1)
assert(logit.shape==truth.shape)

loss = F.binary_cross_entropy_with_logits(logit, truth, reduction='none')
if self.weight:
pos = (truth>0.5).float()
neg = (truth<0.5).float()
# pos_weight = pos.sum().item() + 1e-12
# neg_weight = neg.sum().item() + 1e-12
# loss = (self.weight[0]*pos*loss/pos_weight + self.weight[1]*neg*loss/neg_weight).sum()
loss = (self.weight[1]*pos*loss + self.weight[0]*neg*loss).mean()
else:
loss = loss.mean()
return loss


class SoftBCEDiceLoss(nn.Module):
"""加权BCE+DiceLoss
"""
def __init__(self, size_average=True, weight=[0.2, 0.8]):
"""
weight: weight[0]为负类的权重,weight[1]为正类的权重
"""
super(SoftBCEDiceLoss, self).__init__()
self.size_average = size_average
self.weight = weight
self.bce_loss = nn.BCEWithLogitsLoss(size_average=self.size_average, pos_weight=torch.tensor(self.weight[1]))
# self.bce_loss = SoftBceLoss(weight=weight)
self.softdiceloss = SoftDiceLoss(size_average=self.size_average, weight=weight)

def forward(self, input, target):
soft_bce_loss = self.bce_loss(input, target)
soft_dice_loss = self.softdiceloss(input, target)
loss = soft_bce_loss + soft_dice_loss

return loss


class MultiClassesSoftBCEDiceLoss(nn.Module):
def __init__(self, classes_num=4, size_average=True, weight=[0.2, 0.8]):
super(MultiClassesSoftBCEDiceLoss, self).__init__()
self.classes_num = classes_num
self.size_average = size_average
self.weight = weight
self.soft_bce_dice_loss = SoftBCEDiceLoss(size_average=self.size_average, weight=self.weight)

def forward(self, input, target):
"""
Args:
input: tensor, [batch_size, classes_num, height, width]
target: tensor, [batch_size, classes_num, height, width]
"""
loss = 0
for class_index in range(self.classes_num):
input_single_class = input[:, class_index, :, :]
target_singlt_class = target[:, class_index, :, :]
single_class_loss = self.soft_bce_dice_loss(input_single_class, target_singlt_class)
loss += single_class_loss

loss /= self.classes_num

return loss


if __name__ == "__main__":
input = torch.Tensor(4, 4, 256, 1600)
target = torch.Tensor(4, 4, 256, 1600)
criterion = MultiClassesSoftBCEDiceLoss(4, True)
loss = criterion(input, target)

0 comments on commit 3f31fe6

Please sign in to comment.