Skip to content

Commit

Permalink
deploy: dd1e63a
Browse files Browse the repository at this point in the history
  • Loading branch information
mieskolainen committed Oct 21, 2024
1 parent c68a3e0 commit e7e918c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 24 deletions.
37 changes: 27 additions & 10 deletions _modules/icenet/deep/autogradxgb.html
Original file line number Diff line number Diff line change
Expand Up @@ -529,17 +529,20 @@ <h1>Source code for icenet.deep.autogradxgb</h1><div class="highlight"><pre>
<a class="viewcode-back" href="../../../modules/icenet.html#icenet.deep.autogradxgb.XgboostObjective">[docs]</a>
<span class="k">class</span> <span class="nc">XgboostObjective</span><span class="p">():</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss_func</span><span class="p">:</span> <span class="n">Callable</span><span class="p">[[</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">Tensor</span><span class="p">],</span> <span class="n">mode</span><span class="o">=</span><span class="s1">&#39;train&#39;</span><span class="p">,</span> <span class="n">loss_sign</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">flatten_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">skip_hessian</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">hessian_const</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">):</span>
<span class="n">flatten_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">hessian_mode</span><span class="o">=</span><span class="s1">&#39;constant&#39;</span><span class="p">,</span> <span class="n">hessian_const</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="s1">&#39;cpu&#39;</span><span class="p">):</span>

<span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">=</span> <span class="n">mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">loss_func</span> <span class="o">=</span> <span class="n">loss_func</span>
<span class="bp">self</span><span class="o">.</span><span class="n">loss_sign</span> <span class="o">=</span> <span class="n">loss_sign</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="bp">self</span><span class="o">.</span><span class="n">skip_hessian</span> <span class="o">=</span> <span class="n">skip_hessian</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span> <span class="o">=</span> <span class="n">hessian_mode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hessian_const</span> <span class="o">=</span> <span class="n">hessian_const</span>
<span class="bp">self</span><span class="o">.</span><span class="n">flatten_grad</span> <span class="o">=</span> <span class="n">flatten_grad</span>

<span class="nb">print</span><span class="p">(</span><span class="vm">__name__</span> <span class="o">+</span> <span class="sa">f</span><span class="s1">&#39;.__init__: Using device: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="si">}</span><span class="s1"> | skip_hessian = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">skip_hessian</span><span class="si">}</span><span class="s1"> | hessian_const = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">hessian_const</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span> <span class="o">==</span> <span class="s1">&#39;constant&#39;</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="vm">__name__</span> <span class="o">+</span> <span class="sa">f</span><span class="s1">&#39;: Using device: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="si">}</span><span class="s1"> | hessian_mode = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span><span class="si">}</span><span class="s1"> | hessian_const = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">hessian_const</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="vm">__name__</span> <span class="o">+</span> <span class="sa">f</span><span class="s1">&#39;: Using device: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="si">}</span><span class="s1"> | hessian_mode = </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>

<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">,</span> <span class="n">targets</span><span class="p">:</span> <span class="n">xgboost</span><span class="o">.</span><span class="n">DMatrix</span><span class="p">):</span>

Expand Down Expand Up @@ -574,19 +577,34 @@ <h1>Source code for icenet.deep.autogradxgb</h1><div class="highlight"><pre>
<div class="viewcode-block" id="XgboostObjective.derivatives">
<a class="viewcode-back" href="../../../modules/icenet.html#icenet.deep.autogradxgb.XgboostObjective.derivatives">[docs]</a>
<span class="k">def</span> <span class="nf">derivatives</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">preds</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">):</span>

<span class="c1"># Gradient</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot; Gradient and Hessian diagonal</span>
<span class="sd"> &quot;&quot;&quot;</span>

<span class="c1">## Gradient</span>
<span class="n">grad1</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">preds</span><span class="p">,</span> <span class="n">create_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>

<span class="c1"># Diagonal elements of the Hessian matrix</span>
<span class="n">grad2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hessian_const</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">grad1</span><span class="p">)</span>
<span class="c1">## Diagonal elements of the Hessian matrix</span>

<span class="c1"># Constant</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span> <span class="o">==</span> <span class="s1">&#39;constant&#39;</span><span class="p">:</span>
<span class="n">grad2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hessian_const</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">grad1</span><span class="p">)</span>

<span class="c1"># Squared derivative based approximation</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span> <span class="o">==</span> <span class="s1">&#39;squared_approx&#39;</span><span class="p">:</span>
<span class="n">grad2</span> <span class="o">=</span> <span class="n">grad1</span> <span class="o">*</span> <span class="n">grad1</span>

<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip_hessian</span><span class="p">:</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">&#39;Computing Hessian ...&#39;</span><span class="p">)</span>
<span class="c1"># Exact autograd</span>
<span class="k">elif</span> <span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span> <span class="o">==</span> <span class="s1">&#39;exact&#39;</span><span class="p">:</span>

<span class="nb">print</span><span class="p">(</span><span class="vm">__name__</span> <span class="o">+</span> <span class="sa">f</span><span class="s1">&#39;.derivatives: Computing Hessian diagonal with exact autograd ...&#39;</span><span class="p">)</span>

<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">preds</span><span class="p">))):</span> <span class="c1"># Can be very slow</span>
<span class="n">grad2_i</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">grad1</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">preds</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">grad2</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">grad2_i</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>

<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">Exception</span><span class="p">(</span><span class="vm">__name__</span> <span class="o">+</span> <span class="sa">f</span><span class="s1">&#39;.derivatives: Unknown &quot;hessian_mode&quot; </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">hessian_mode</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>

<span class="n">grad1</span><span class="p">,</span> <span class="n">grad2</span> <span class="o">=</span> <span class="n">grad1</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">grad2</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>

<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">flatten_grad</span><span class="p">:</span>
Expand All @@ -595,7 +613,6 @@ <h1>Source code for icenet.deep.autogradxgb</h1><div class="highlight"><pre>
<span class="k">return</span> <span class="n">grad1</span><span class="p">,</span> <span class="n">grad2</span></div>
</div>


</pre></div>

</div>
Expand Down
Loading

0 comments on commit e7e918c

Please sign in to comment.