Skip to content

Commit

Permalink
deploy: 90bb5ad
Browse files Browse the repository at this point in the history
  • Loading branch information
mieskolainen committed Jul 22, 2024
1 parent a6caacf commit ce7cf1f
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions _modules/icenet/deep/losstools.html
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,21 @@ <h1>Source code for icenet.deep.losstools</h1><div class="highlight"><pre>
<span class="n">weights</span> <span class="o">=</span> <span class="kc">None</span> <span class="c1"># TBD. Could re-compute a new set of edge weights </span>
<span class="c1"># --------------------------------------------</span>

<span class="k">def</span> <span class="nf">SWD_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sliced Wasserstein reweight regularization</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="s1">&#39;SWD_beta&#39;</span> <span class="ow">in</span> <span class="n">param</span> <span class="ow">and</span> <span class="n">param</span><span class="p">[</span><span class="s1">&#39;SWD_beta&#39;</span><span class="p">]</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>

<span class="n">beta</span> <span class="o">=</span> <span class="n">param</span><span class="p">[</span><span class="s1">&#39;SWD_beta&#39;</span><span class="p">]</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">beta</span> <span class="o">*</span> <span class="n">SWD_reweight_loss</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights</span><span class="p">,</span>
<span class="n">p</span><span class="o">=</span><span class="n">param</span><span class="p">[</span><span class="s1">&#39;SWD_p&#39;</span><span class="p">],</span> <span class="n">num_slices</span><span class="o">=</span><span class="n">param</span><span class="p">[</span><span class="s1">&#39;SWD_num_slices&#39;</span><span class="p">],</span>
<span class="n">mode</span><span class="o">=</span><span class="n">param</span><span class="p">[</span><span class="s1">&#39;SWD_mode&#39;</span><span class="p">])</span>

<span class="k">return</span> <span class="p">{</span><span class="sa">f</span><span class="s1">&#39;SWD x $</span><span class="se">\\</span><span class="s1">beta = </span><span class="si">{</span><span class="n">beta</span><span class="si">}</span><span class="s1">$&#39;</span><span class="p">:</span> <span class="n">value</span><span class="p">}</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="p">{}</span>

<span class="k">def</span> <span class="nf">MI_helper</span><span class="p">(</span><span class="n">output</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot; </span>
<span class="sd"> Mutual Information regularization</span>
Expand Down Expand Up @@ -695,21 +710,21 @@ <h1>Source code for icenet.deep.losstools</h1><div class="highlight"><pre>
<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">BCE_loss</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights</span><span class="p">)</span>

<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;BCE&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">logits</span><span class="p">))}</span>
<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;BCE&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">SWD_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">logits</span><span class="p">))}</span>

<span class="k">elif</span> <span class="n">param</span><span class="p">[</span><span class="s1">&#39;lossfunc&#39;</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;binary_focal_entropy&#39;</span><span class="p">:</span>

<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">binary_focal_loss</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">gamma</span><span class="o">=</span><span class="n">param</span><span class="p">[</span><span class="s1">&#39;gamma&#39;</span><span class="p">],</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights</span><span class="p">)</span>

<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="sa">f</span><span class="s2">&quot;FE ($</span><span class="se">\\</span><span class="s2">gamma = </span><span class="si">{</span><span class="n">param</span><span class="p">[</span><span class="s1">&#39;gamma&#39;</span><span class="p">]</span><span class="si">}</span><span class="s2">$)&quot;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">logits</span><span class="p">))}</span>
<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="sa">f</span><span class="s2">&quot;FE ($</span><span class="se">\\</span><span class="s2">gamma = </span><span class="si">{</span><span class="n">param</span><span class="p">[</span><span class="s1">&#39;gamma&#39;</span><span class="p">]</span><span class="si">}</span><span class="s2">$)&quot;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">SWD_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">logits</span><span class="p">))}</span>

<span class="k">elif</span> <span class="n">param</span><span class="p">[</span><span class="s1">&#39;lossfunc&#39;</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;binary_Lq_entropy&#39;</span><span class="p">:</span>

<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">Lq_binary_loss</span><span class="p">(</span><span class="n">logits</span><span class="o">=</span><span class="n">logits</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">q</span><span class="o">=</span><span class="n">param</span><span class="p">[</span><span class="s1">&#39;q&#39;</span><span class="p">],</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights</span><span class="p">)</span>

<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="sa">f</span><span class="s2">&quot;LQ ($</span><span class="se">\\</span><span class="s2">gamma = </span><span class="si">{</span><span class="n">param</span><span class="p">[</span><span class="s1">&#39;q&#39;</span><span class="p">]</span><span class="si">}</span><span class="s2">$)&quot;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">logits</span><span class="p">))}</span>
<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="sa">f</span><span class="s2">&quot;LQ ($</span><span class="se">\\</span><span class="s2">gamma = </span><span class="si">{</span><span class="n">param</span><span class="p">[</span><span class="s1">&#39;q&#39;</span><span class="p">]</span><span class="si">}</span><span class="s2">$)&quot;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">SWD_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">logits</span><span class="p">))}</span>

<span class="k">elif</span> <span class="n">param</span><span class="p">[</span><span class="s1">&#39;lossfunc&#39;</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;SWD&#39;</span><span class="p">:</span>

Expand All @@ -728,22 +743,22 @@ <h1>Source code for icenet.deep.losstools</h1><div class="highlight"><pre>
<span class="n">y_hat</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">MSE_loss</span><span class="p">(</span><span class="n">y_hat</span><span class="o">=</span><span class="n">y_hat</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights</span><span class="p">)</span>

<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;MSE&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">y_hat</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">y_hat</span><span class="p">)}</span>
<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;MSE&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">SWD_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">y_hat</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">y_hat</span><span class="p">)}</span>

<span class="k">elif</span> <span class="n">param</span><span class="p">[</span><span class="s1">&#39;lossfunc&#39;</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;MSE_prob&#39;</span><span class="p">:</span>

<span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">y_hat</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">MSE_loss</span><span class="p">(</span><span class="n">y_hat</span><span class="o">=</span><span class="n">y_hat</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights</span><span class="p">)</span>

<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;MSE&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">y_hat</span><span class="p">)}</span>
<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;MSE&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">SWD_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">y_hat</span><span class="p">)}</span>

<span class="k">elif</span> <span class="n">param</span><span class="p">[</span><span class="s1">&#39;lossfunc&#39;</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;MAE&#39;</span><span class="p">:</span>

<span class="n">y_hat</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">MSE_loss</span><span class="p">(</span><span class="n">y_hat</span><span class="o">=</span><span class="n">y_hat</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights</span><span class="p">)</span>

<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;MAE&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">y_hat</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">y_hat</span><span class="p">)}</span>
<span class="n">loss</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;MAE&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="p">,</span> <span class="o">**</span><span class="n">SWD_helper</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="o">**</span><span class="n">LZ_helper</span><span class="p">(),</span> <span class="o">**</span><span class="n">LM_helper</span><span class="p">(</span><span class="n">y_hat</span><span class="p">),</span> <span class="o">**</span><span class="n">MI_helper</span><span class="p">(</span><span class="n">y_hat</span><span class="p">)}</span>

<span class="k">elif</span> <span class="n">param</span><span class="p">[</span><span class="s1">&#39;lossfunc&#39;</span><span class="p">]</span> <span class="o">==</span> <span class="s1">&#39;cross_entropy&#39;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
Expand Down

0 comments on commit ce7cf1f

Please sign in to comment.