-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
39 lines (30 loc) · 1.26 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torch import nn
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
self.smooth = 1
def forward(self, input, target):
axes = tuple(range(1, input.dim()))
intersect = (input * target).sum(dim=axes)
union = torch.pow(input, 2).sum(dim=axes) + torch.pow(target, 2).sum(dim=axes)
loss = 1 - (2 * intersect + self.smooth) / (union + self.smooth)
return loss.mean()
class FocalLoss(nn.Module):
def __init__(self, gamma=2):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.eps = 1e-3
def forward(self, input, target):
input = input.clamp(self.eps, 1 - self.eps)
loss = - (target * torch.pow((1 - input), self.gamma) * torch.log(input) +
(1 - target) * torch.pow(input, self.gamma) * torch.log(1 - input))
return loss.mean()
class Dice_and_FocalLoss(nn.Module):
def __init__(self, gamma=2):
super(Dice_and_FocalLoss, self).__init__()
self.dice_loss = DiceLoss()
self.focal_loss = FocalLoss(gamma)
def forward(self, input, target):
loss = self.dice_loss(input, target) + self.focal_loss(input, target)
return loss