From ce7cf1fd3d4293200745925e4a3beb28a7c9884f Mon Sep 17 00:00:00 2001 From: mieskolainen Date: Mon, 22 Jul 2024 15:53:42 +0000 Subject: [PATCH] deploy: 90bb5adb042058d6970681d8101ebed86c629797 --- _modules/icenet/deep/losstools.html | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/_modules/icenet/deep/losstools.html b/_modules/icenet/deep/losstools.html index 8fabae48..3ea8714e 100644 --- a/_modules/icenet/deep/losstools.html +++ b/_modules/icenet/deep/losstools.html @@ -639,6 +639,21 @@

Source code for icenet.deep.losstools

         weights      = None # TBD. Could re-compute a new set of edge weights 
     # --------------------------------------------
     
+    def SWD_helper(logits):
+        """
+        Sliced Wasserstein reweight regularization
+        """
+        if 'SWD_beta' in param and param['SWD_beta'] > 0:
+            
+            beta  = param['SWD_beta']
+            value = beta * SWD_reweight_loss(logits=logits, x=x, y=y, weights=weights,
+                                    p=param['SWD_p'], num_slices=param['SWD_num_slices'],
+                                    mode=param['SWD_mode'])
+
+            return {f'SWD x $\\beta = {beta}$': value}
+        else:
+            return {}
+    
     def MI_helper(output):
         """ 
         Mutual Information regularization
@@ -695,21 +710,21 @@ 

Source code for icenet.deep.losstools

         logits = model.forward(x)
         loss   = BCE_loss(logits=logits, y=y, weights=weights)
         
-        loss = {'BCE': loss, **LZ_helper(), **LM_helper(logits), **MI_helper(torch.sigmoid(logits))}
+        loss = {'BCE': loss, **SWD_helper(logits), **LZ_helper(), **LM_helper(logits), **MI_helper(torch.sigmoid(logits))}
 
     elif param['lossfunc'] == 'binary_focal_entropy':
         
         logits = model.forward(x)
         loss   = binary_focal_loss(logits=logits, y=y, gamma=param['gamma'], weights=weights)
         
-        loss = {f"FE ($\\gamma = {param['gamma']}$)": loss, **LZ_helper(), **LM_helper(logits), **MI_helper(torch.sigmoid(logits))}
+        loss = {f"FE ($\\gamma = {param['gamma']}$)": loss, **SWD_helper(logits), **LZ_helper(), **LM_helper(logits), **MI_helper(torch.sigmoid(logits))}
 
     elif param['lossfunc'] == 'binary_Lq_entropy':
         
         logits = model.forward(x)
         loss   = Lq_binary_loss(logits=logits, y=y, q=param['q'], weights=weights)
         
-        loss = {f"LQ ($\\gamma = {param['q']}$)": loss, **LZ_helper(), **LM_helper(logits), **MI_helper(torch.sigmoid(logits))}
+        loss = {f"LQ ($\\gamma = {param['q']}$)": loss, **SWD_helper(logits), **LZ_helper(), **LM_helper(logits), **MI_helper(torch.sigmoid(logits))}
 
     elif param['lossfunc'] == 'SWD':
         
@@ -728,7 +743,7 @@ 

Source code for icenet.deep.losstools

         y_hat = model.forward(x)
         loss  = MSE_loss(y_hat=y_hat, y=y, weights=weights)
         
-        loss  = {'MSE': loss, **LZ_helper(), **LM_helper(y_hat), **MI_helper(y_hat)}
+        loss  = {'MSE': loss, **SWD_helper(logits), **LZ_helper(), **LM_helper(y_hat), **MI_helper(y_hat)}
 
     elif param['lossfunc'] == 'MSE_prob':
         
@@ -736,14 +751,14 @@ 

Source code for icenet.deep.losstools

         y_hat  = torch.sigmoid(logits)
         loss   = MSE_loss(y_hat=y_hat, y=y, weights=weights)
         
-        loss  = {'MSE': loss, **LZ_helper(), **LM_helper(logits), **MI_helper(y_hat)}
+        loss  = {'MSE': loss, **SWD_helper(logits), **LZ_helper(), **LM_helper(logits), **MI_helper(y_hat)}
     
     elif param['lossfunc'] == 'MAE':
         
         y_hat = model.forward(x)
         loss  = MSE_loss(y_hat=y_hat, y=y, weights=weights)
         
-        loss  = {'MAE': loss, **LZ_helper(), **LM_helper(y_hat), **MI_helper(y_hat)}
+        loss  = {'MAE': loss, **SWD_helper(logits), **LZ_helper(), **LM_helper(y_hat), **MI_helper(y_hat)}
     
     elif param['lossfunc'] == 'cross_entropy':
         """