forked from seungjunlee96/U-Net_Lung-Segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dice_loss.py
40 lines (31 loc) · 1.23 KB
/
dice_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
class DiceCoefficient(torch.autograd.Function):
"""
Dice coefficient for individual examples
Dice coefficient = 2 * |X n Y| / (|X| + |Y|)
= 1 / ( 1/Precision + 1/Recall)
"""
def forward(self, input , target):
self.save_for_backward(input, target)
eps = 1e-10
self.inter = torch.dot(input.view(-1), target.view(-1)) # inter = |X n Y|
self.union = torch.sum(input) + torch.sum(target) # union =|X| + |Y|
dice = 2 * self.inter.float() / (self.union.float() + eps)
return dice
def backward(self,grad_output):
input,target = self.saved_variables
if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * self.union - self.inter) / ( self.union * self.union )
if self.needs_input_grad[1]:
grad_target = None
return grad_input, grad_target
def dice_coefficient(input, target):
"""Dice coeff for batches"""
if input.is_cuda:
s = torch.FloatTensor(1).cuda().zero_()
else:
s = torch.FloatTensor(1).zero_()
for i , c in enumerate(zip(input, target)):
s += DiceCoefficient().forward(c[0],c[1])
n_data = len(input)
return s / n_data