Skip to content

Commit

Permalink
deploy: 6e39e4b
Browse files Browse the repository at this point in the history
  • Loading branch information
mieskolainen committed Dec 11, 2024
1 parent ec19e85 commit a3bfe39
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 33 deletions.
36 changes: 25 additions & 11 deletions _modules/icenet/deep/autogradxgb.html
Original file line number Diff line number Diff line change
Expand Up @@ -532,12 +532,11 @@ <h1>Source code for icenet.deep.autogradxgb</h1><div class="highlight"><pre>
<div class="viewcode-block" id="XgboostObjective">
<a class="viewcode-back" href="../../../modules/icenet.html#icenet.deep.autogradxgb.XgboostObjective">[docs]</a>
<span class="k">class</span> <span class="nc">XgboostObjective</span><span class="p">():</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss_func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;train&#39;</span><span class="p">,</span> <span class="n">loss_sign</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss_func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;train&#39;</span><span class="p">,</span>
<span class="n">flatten_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">hessian_mode</span><span class="o">=</span><span class="s1">&#39;constant&#39;</span><span class="p">,</span> <span class="n">hessian_const</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">):</span>

<span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span> <span class="o">=</span> <span class="n">loss_func</span>
<span class="bp">self</span><span class="o">.</span><span class="n">loss_sign</span> <span class="o">=</span> <span class="n">loss_sign</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span> <span class="o">=</span> <span class="n">hessian_mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hessian_const</span> <span class="o">=</span> <span class="n">hessian_const</span>
Expand All @@ -553,10 +552,10 @@ <h1>Source code for icenet.deep.autogradxgb</h1><div class="highlight"><pre>
<span class="n">preds_</span><span class="p">,</span> <span class="n">targets_</span><span class="p">,</span> <span class="n">weights_</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">torch_conversion</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds</span><span class="p">,</span> <span class="n">targets</span><span class="o">=</span><span class="n">targets</span><span class="p">)</span>

<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;train&#39;</span><span class="p">:</span>
<span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_sign</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds_</span><span class="p">,</span> <span class="n">targets</span><span class="o">=</span><span class="n">targets_</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights_</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds_</span><span class="p">,</span> <span class="n">targets</span><span class="o">=</span><span class="n">targets_</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights_</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">derivatives</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="n">loss</span><span class="p">,</span> <span class="n">preds</span><span class="o">=</span><span class="n">preds_</span><span class="p">)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;eval&#39;</span><span class="p">:</span>
<span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_sign</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds_</span><span class="p">,</span> <span class="n">targets</span><span class="o">=</span><span class="n">targets_</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights_</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span><span class="p">(</span><span class="n">preds</span><span class="o">=</span><span class="n">preds_</span><span class="p">,</span> <span class="n">targets</span><span class="o">=</span><span class="n">targets_</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="n">weights_</span><span class="p">)</span>
<span class="k">return</span> <span class="s1">&#39;custom&#39;</span><span class="p">,</span> <span class="n">loss</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;Unknown mode (set either &quot;train&quot; or &quot;eval&quot;)&#39;</span><span class="p">)</span>
Expand Down Expand Up @@ -589,25 +588,40 @@ <h1>Source code for icenet.deep.autogradxgb</h1><div class="highlight"><pre>

<span class="c1">## Diagonal elements of the Hessian matrix</span>

<span class="c1"># Constant</span>
<span class="c1"># Constant curvature</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span> <span class="o">==</span> <span class="s1">&#39;constant&#39;</span><span class="p">:</span>
<span class="n">grad2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hessian_const</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">grad1</span><span class="p">)</span>

<span class="c1"># Squared derivative based approximation</span>
<span class="c1"># Squared derivative based [uncontrolled] approximation (always positive curvature)</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span> <span class="o">==</span> <span class="s1">&#39;squared_approx&#39;</span><span class="p">:</span>
<span class="n">grad2</span> <span class="o">=</span> <span class="n">grad1</span> <span class="o">*</span> <span class="n">grad1</span>

<span class="c1"># Exact autograd</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span> <span class="o">==</span> <span class="s1">&#39;exact&#39;</span><span class="p">:</span>

<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Computing Hessian diagonal with exact autograd ...&#39;</span><span class="p">)</span>
<span class="w"> </span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> for i in tqdm(range(len(preds))):</span>
<span class="sd"> grad2_i = torch.autograd.grad(grad1[i], preds, retain_graph=True)[0]</span>
<span class="sd"> grad2[i] = grad2_i[i]</span>
<span class="sd"> &quot;&quot;&quot;</span>

<span class="n">hess_diag</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">))):</span>

<span class="c1"># A basis vector</span>
<span class="n">e_i</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">preds</span><span class="p">)</span>
<span class="n">e_i</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="mf">1.0</span>

<span class="c1"># Compute the Hessian-vector product H e_i</span>
<span class="n">Hv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">grad1</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">grad_outputs</span><span class="o">=</span><span class="n">e_i</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">hess_diag</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Hv</span><span class="p">[</span><span class="n">i</span><span class="p">])</span>

<span class="n">grad2</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">hess_diag</span><span class="p">)</span>

<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">))):</span> <span class="c1"># Can be very slow</span>
<span class="n">grad2_i</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">grad1</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">preds</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">grad2</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">grad2_i</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>

<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="s1">&#39;Unknown &quot;hessian_mode&quot; </span><span class="si">{self.hessian_mode}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;Unknown &quot;hessian_mode&quot; </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>

<span class="n">grad1</span><span class="p">,</span> <span class="n">grad2</span> <span class="o">=</span> <span class="n">grad1</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">grad2</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>

Expand Down
Loading

0 comments on commit a3bfe39

Please sign in to comment.