Skip to content

Commit

Permalink
loss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangqianMa committed Oct 13, 2019
1 parent 1e06713 commit 05c3fc9
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self, size_average=True, weight=[1.0, 1.0]):
super(SoftBCEDiceLoss, self).__init__()
self.size_average = size_average
self.weight = weight
self.bce_loss = nn.BCEWithLogitsLoss(size_average=self.size_average)
self.bce_loss = nn.BCEWithLogitsLoss(size_average=self.size_average, pos_weight=torch.tensor(weight[0]))
# self.bce_loss = SoftBceLoss(weight=weight)
self.softdiceloss = SoftDiceLoss(size_average=self.size_average, weight=weight)

Expand All @@ -103,12 +103,23 @@ def forward(self, input, target):


class MultiClassesSoftBCEDiceLoss(nn.Module):
def __init__(self, classes_num=4, size_average=True, class_weight=[0.3, 0.2, 0.3, 0.2]):
def __init__(self, classes_num=4, size_average=True, weight=[[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], class_weight=[0.3, 0.2, 0.3, 0.2]):
"""
Args:
weight: 正负样本权重
class_weight: 类别间权重
"""
super(MultiClassesSoftBCEDiceLoss, self).__init__()
self.classes_num = classes_num
self.size_average = size_average
self.class_weight = class_weight
self.soft_bce_dice_loss = SoftBCEDiceLoss(size_average=self.size_average)
self.soft_bce_dice_loss = [
SoftBCEDiceLoss(size_average=self.size_average, weight=weight[0]),
SoftBCEDiceLoss(size_average=self.size_average, weight=weight[1]),
SoftBCEDiceLoss(size_average=self.size_average, weight=weight[2]),
SoftBCEDiceLoss(size_average=self.size_average, weight=weight[3]),
]

def forward(self, input, target):
"""
Expand All @@ -120,7 +131,7 @@ def forward(self, input, target):
for class_index in range(self.classes_num):
input_single_class = input[:, class_index, :, :]
target_singlt_class = target[:, class_index, :, :]
single_class_loss = self.soft_bce_dice_loss(input_single_class, target_singlt_class)
single_class_loss = self.soft_bce_dice_loss[class_index](input_single_class, target_singlt_class)
loss += self.class_weight[class_index] * single_class_loss

return loss
Expand Down

0 comments on commit 05c3fc9

Please sign in to comment.