-
Notifications
You must be signed in to change notification settings - Fork 1
/
loss_combination.py
114 lines (83 loc) · 3.73 KB
/
loss_combination.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
from torch.nn import functional as F
from deepnet.model.losses.dice_loss import DiceLoss
from deepnet.model.losses.loss import mse_loss, rmse_loss, bce_loss, bcewithlogits_loss
from deepnet.model.losses.ssim import SSIM_Loss
class BCE_RMSE_LOSS(nn.Module):
def __init__(self):
super(BCE_RMSE_LOSS, self).__init__()
self.bce_loss = bcewithlogits_loss()
self.rmse_loss = rmse_loss()
def forward(self,prediction, label):
loss = self.bce_loss(prediction['bg_fg_mask'], label['bg_fg_mask']) + (2 * self.rmse_loss(prediction['bg_fg_depth'], label['bg_fg_depth']))
return loss
class SSIM_RMSE_LOSS(nn.Module):
def __init__(self):
super(SSIM_RMSE_LOSS, self).__init__()
self.ssim_loss = SSIM_Loss()
self.rmse_loss = rmse_loss()
def forward(self,prediction, label):
loss = self.ssim_loss(prediction['bg_fg_mask'], label['bg_fg_mask']) + (2 * self.rmse_loss(prediction['bg_fg_depth'], label['bg_fg_depth']))
return loss
class BCE_SSIM_LOSS(nn.Module):
def __init__(self):
super(BCE_SSIM_LOSS, self).__init__()
self.bce_loss = bcewithlogits_loss()
self.ssim_loss = SSIM_Loss()
def forward(self,prediction, label):
loss = self.bce_loss(prediction['bg_fg_mask'], label['bg_fg_mask']) + (2 * self.ssim_loss(prediction['bg_fg_depth'], label['bg_fg_depth']))
return loss
class RMSE_SSIM_LOSS(nn.Module):
def __init__(self):
super(RMSE_SSIM_LOSS, self).__init__()
self.bce_loss = rmse_loss()
self.ssim_loss = SSIM_Loss()
def forward(self,prediction, label):
loss = self.rmse_loss(prediction['bg_fg_mask'], label['bg_fg_mask']) + (2 * self.ssim_loss(prediction['bg_fg_depth'], label['bg_fg_depth']))
return loss
class SSIM_DICE_LOSS(nn.Module):
def __init__(self):
super(SSIM_DICE_LOSS, self).__init__()
self.ssim_loss = SSIM_Loss()
self.dice_loss = DiceLoss()
def forward(self,prediction, label):
loss = self.ssim_loss(prediction['bg_fg_mask'], label['bg_fg_mask']) + (2 * self.dice_loss(prediction['bg_fg_depth'], label['bg_fg_depth']))
return loss
class RMSE_DICE_LOSS(nn.Module):
def __init__(self):
super(RMSE_DICE_LOSS, self).__init__()
self.rmse_loss = rmse_loss()
self.dice_loss = DiceLoss()
def forward(self,prediction, label):
loss = self.rmse_loss(prediction['bg_fg_mask'], label['bg_fg_mask']) + (2 * self.dice_loss(prediction['bg_fg_depth'], label['bg_fg_depth']))
return loss
class BCEDiceLoss(nn.Module):
def __init__(self, smooth=1e-6):
"""BCEDice Loss.
Args:
smooth (float, optional): Smoothing value.
"""
super(BCEDiceLoss, self).__init__()
self.dice = DiceLoss(smooth)
def forward(self, input, target):
"""Calculate BCEDice Loss.
Args:
input (torch.Tensor): Model predictions.
target (torch.Tensor): Target values.
Returns:
BCEDice loss
"""
bce_loss = F.binary_cross_entropy_with_logits(input, target)
dice_loss = self.dice(torch.sigmoid(input), target)
return bce_loss + 2 * dice_loss
class RmseBceDiceLoss(nn.Module):
def __init__(self):
super(RmseBceDiceLoss, self).__init__()
self.rmse = rmse_loss()
self.bce_dice = BCEDiceLoss()
def forward(self, prediction, label):
return (
2 * self.rmse(prediction['bg_fg_mask'], label['bg_fg_mask']) +
self.bce_dice(prediction['bg_fg_depth'], label['bg_fg_depth'])
)