Skip to content

Commit

Permalink
Merge pull request #9 from KGrewal1/add_docs
Browse files Browse the repository at this point in the history
Prepare for publishing
  • Loading branch information
KGrewal1 authored Dec 7, 2023
2 parents 26bf0dd + 62ae9ab commit adc7092
Show file tree
Hide file tree
Showing 27 changed files with 709 additions and 182 deletions.
2 changes: 2 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[build]
rustdocflags = [ "--html-in-header", "./katex-header.html" ]
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "optimisers"
version = "0.2.1"
name = "candle-optimisers"
version = "0.3.0"
edition = "2021"
readme = "README.md"
license = "MIT"
Expand Down Expand Up @@ -51,3 +51,6 @@ uninlined_format_args = {level = "allow", priority = 1}
similar_names = {level = "allow", priority = 1}
float_cmp = {level = "allow", priority = 1} # as internaly rounded before the comparison
doc_markdown= {level = "allow", priority = 1} # otherwise names get flagged

[package.metadata.docs.rs]
rustdoc-args = [ "--html-in-header", "./katex-header.html" ]
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Optimisers
# Candle Optimisers

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![codecov](https://codecov.io/gh/KGrewal1/optimisers/graph/badge.svg?token=6AFTLS6DFO)](https://codecov.io/gh/KGrewal1/optimisers)
Expand Down Expand Up @@ -31,7 +31,7 @@ Adaptive methods:

These are all checked against their pytorch implementation (see pytorch_test.ipynb) and should implement the same functionality (though without some input checking).

Additionally all of the adaptive mehods listed implement decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/pdf/1711.05101.pdf), in addition to the standard weight decay as implemented in pytorch.
Additionally all of the adaptive mehods listed and SGD implement decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/pdf/1711.05101.pdf), in addition to the standard weight decay as implemented in pytorch.

Pseudosecond order methods:

Expand All @@ -58,7 +58,7 @@ to use the cuda backend.
## Usage

```cli
cargo add --git https://github.com/KGrewal1/optimisers.git optimisers
cargo add --git https://github.com/KGrewal1/optimisers.git candle-optimisers
```

## To do
Expand Down
4 changes: 2 additions & 2 deletions benches/mnist_bench.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use candle_core::Result as CResult;
use candle_datasets::vision::Dataset;
use criterion::{criterion_group, criterion_main, Criterion};
use optimisers::{
use candle_optimisers::{
adadelta::Adadelta, adagrad::Adagrad, adam::Adam, adamax::Adamax, esgd::SGD, nadam::NAdam,
radam::RAdam, rmsprop::RMSprop,
};
use criterion::{criterion_group, criterion_main, Criterion};
use training::Mlp;

// mod models;
Expand Down
26 changes: 14 additions & 12 deletions benches/training.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use candle_core::{DType, Result, Tensor, Var, D};
use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap};

use optimisers::adadelta::{Adadelta, ParamsAdaDelta};
use optimisers::adagrad::{Adagrad, ParamsAdaGrad};
use optimisers::adam::{Adam, ParamsAdam};
use optimisers::adamax::{Adamax, ParamsAdaMax};
use optimisers::esgd::{ParamsSGD, SGD};
use optimisers::lbfgs::{Lbfgs, LineSearch, ParamsLBFGS};
use optimisers::nadam::{NAdam, ParamsNAdam};
use optimisers::radam::{ParamsRAdam, RAdam};
use optimisers::rmsprop::{ParamsRMSprop, RMSprop};
use optimisers::{LossOptimizer, Model};
use candle_optimisers::{
adadelta::{Adadelta, ParamsAdaDelta},
adagrad::{Adagrad, ParamsAdaGrad},
adam::{Adam, ParamsAdam},
adamax::{Adamax, ParamsAdaMax},
esgd::{ParamsSGD, SGD},
lbfgs::{Lbfgs, LineSearch, ParamsLBFGS},
nadam::{NAdam, ParamsNAdam},
radam::{ParamsRAdam, RAdam},
rmsprop::{ParamsRMSprop, RMSprop},
LossOptimizer, Model,
};

pub trait Optim: Sized {
fn new(vars: Vec<Var>) -> Result<Self>;
Expand Down Expand Up @@ -217,8 +219,8 @@ pub fn run_lbfgs_training<M: SimpleModel + Model>(
// step the tensors by backpropagating the loss
let res = optimiser.backward_step(&loss)?;
match res {
optimisers::ModelOutcome::Converged(_, _) => break,
optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss,
candle_optimisers::ModelOutcome::Converged(_, _) => break,
candle_optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss,
// _ => panic!("unexpected outcome"),
}
}
Expand Down
14 changes: 7 additions & 7 deletions examples/mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ mod training;

use models::{LinearModel, Mlp};

use optimisers::adagrad::Adagrad;
use optimisers::adamax::Adamax;
use optimisers::esgd::SGD;
use optimisers::nadam::NAdam;
use optimisers::radam::RAdam;
use optimisers::rmsprop::RMSprop;
use optimisers::{adadelta::Adadelta, adam::Adam};
use candle_optimisers::adagrad::Adagrad;
use candle_optimisers::adamax::Adamax;
use candle_optimisers::esgd::SGD;
use candle_optimisers::nadam::NAdam;
use candle_optimisers::radam::RAdam;
use candle_optimisers::rmsprop::RMSprop;
use candle_optimisers::{adadelta::Adadelta, adam::Adam};

use parse_cli::{Args, TrainingArgs, WhichModel, WhichOptim};
use training::training_loop;
Expand Down
18 changes: 10 additions & 8 deletions examples/mnist/optim.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use candle_core::{Result, Tensor, Var};
use candle_nn::Optimizer;
use optimisers::adadelta::{Adadelta, ParamsAdaDelta};
use optimisers::adagrad::{Adagrad, ParamsAdaGrad};
use optimisers::adam::{Adam, ParamsAdam};
use optimisers::adamax::{Adamax, ParamsAdaMax};
use optimisers::esgd::{ParamsSGD, SGD};
use optimisers::nadam::{NAdam, ParamsNAdam};
use optimisers::radam::{ParamsRAdam, RAdam};
use optimisers::rmsprop::{ParamsRMSprop, RMSprop};
use candle_optimisers::{
adadelta::{Adadelta, ParamsAdaDelta},
adagrad::{Adagrad, ParamsAdaGrad},
adam::{Adam, ParamsAdam},
adamax::{Adamax, ParamsAdaMax},
esgd::{ParamsSGD, SGD},
nadam::{NAdam, ParamsNAdam},
radam::{ParamsRAdam, RAdam},
rmsprop::{ParamsRMSprop, RMSprop},
};

pub trait Optim: Sized {
fn new(vars: Vec<Var>, lr: f64) -> Result<Self>;
Expand Down
15 changes: 15 additions & 0 deletions katex-header.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" integrity="sha384-9eLZqc9ds8eNjO3TmqPeYcDj8n+Qfa4nuSiGYa6DjLNcv9BtN69ZIulL9+8CqC9Y" crossorigin="anonymous">
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.js" integrity="sha384-K3vbOmF2BtaVai+Qk37uypf7VrgBubhQreNQe9aGsz9lB63dIFiQVlJbr92dw2Lx" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/contrib/auto-render.min.js" integrity="sha384-kmZOZB5ObwgQnS/DuDg6TScgOiWWBiVt0plIRkZCmE6rDZGrEOQeHM5PcHi+nyqe" crossorigin="anonymous"></script>
<script>
document.addEventListener("DOMContentLoaded", function() {
renderMathInElement(document.body, {
delimiters: [
{left: "$$", right: "$$", display: true},
{left: "\\(", right: "\\)", display: false},
{left: "$", right: "$", display: false},
{left: "\\[", right: "\\]", display: true}
]
});
});
</script>
41 changes: 34 additions & 7 deletions src/adadelta.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
//! Adadelta optimiser
//!
//! Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701)
//!
//! For pseudocde see <https://pytorch.org/docs/stable/generated/torch.optim.Adadelta.html>
/*!
Adadelta optimiser
Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701)
Pseudocode (including decoupling of weight decay):
$$
\\begin{aligned}
&\\rule{110mm}{0.4pt} \\\\
&\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)},
\\: f(\\theta) \\text{ (objective)}, \\: \\rho \\text{ (decay)},
\\: \\lambda \\text{ (weight decay)} \\\\
&\\textbf{initialize} : v_0 \\leftarrow 0 \\: \\text{ (square avg)},
\\: u_0 \\leftarrow 0 \\: \\text{ (accumulate variables)} \\\\[-1.ex]
&\\rule{110mm}{0.4pt} \\\\
&\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
&\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
&\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
&\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
&\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
&\\hspace{10mm}\\textbf{else} \\\\
&\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
&\\hspace{5mm} v_t \\leftarrow v_{t-1} \\rho + g^2_t (1 - \\rho) \\\\
&\\hspace{5mm}\\Delta x_t \\leftarrow \\frac{\\sqrt{u_{t-1} +
\\epsilon }}{ \\sqrt{v_t + \\epsilon} }g_t \\hspace{21mm} \\\\
&\\hspace{5mm} u_t \\leftarrow u_{t-1} \\rho +
\\Delta x^2_t (1 - \\rho) \\\\
&\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\Delta x_t \\\\
&\\rule{110mm}{0.4pt} \\\\[-1.ex]
&\\bf{return} \\: \\theta_t \\\\[-1.ex]
&\\rule{110mm}{0.4pt} \\\\[-1.ex]
\\end{aligned}
$$
*/

use candle_core::{Result, Var};
use candle_nn::optim::Optimizer;
Expand All @@ -12,8 +41,6 @@ use crate::Decay;
/// Adadelta optimiser
///
/// Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701)
///
/// For pseudocde see <https://pytorch.org/docs/stable/generated/torch.optim.Adadelta.html>
#[derive(Debug)]
pub struct Adadelta {
vars: Vec<VarAdaDelta>,
Expand Down
43 changes: 35 additions & 8 deletions src/adagrad.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,38 @@
//! Adagrad optimiser
//!
//! Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html)
//!
//! For pseudocde see <https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html>
/*!
Adagrad optimiser
Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html)
Pseudocode (including decoupling of weight decay):
$$
\\begin{aligned}
&\\rule{110mm}{0.4pt} \\\\
&\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)}, \\: f(\\theta)
\\text{ (objective)}, \\: \\lambda \\text{ (weight decay)}, \\\\
&\\hspace{12mm} \\tau \\text{ (initial accumulator value)}, \\: \\eta\\text{ (lr decay)}\\\\
&\\textbf{initialize} : statesum_0 \\leftarrow 0 \\\\[-1.ex]
&\\rule{110mm}{0.4pt} \\\\
&\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
&\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
&\\hspace{5mm} \\tilde{\\gamma} \\leftarrow \\gamma / (1 +(t-1) \\eta) \\\\
&\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
&\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
&\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
&\\hspace{10mm}\\textbf{else} \\\\
&\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
&\\hspace{5mm}statesum_t \\leftarrow statesum_{t-1} + g^2_t \\\\
&\\hspace{5mm}\\theta_t \\leftarrow
\\theta_{t-1}- \\tilde{\\gamma} \\frac{g_t}{\\sqrt{statesum_t}+\\epsilon} \\\\
&\\rule{110mm}{0.4pt} \\\\[-1.ex]
&\\bf{return} \\: \\theta_t \\\\[-1.ex]
&\\rule{110mm}{0.4pt} \\\\[-1.ex]
\\end{aligned}
$$
*/

use candle_core::{Result, Var};
use candle_nn::optim::Optimizer;
Expand All @@ -12,9 +42,6 @@ use crate::Decay;
/// Adagrad optimiser
///
/// Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html)
///
/// For pseudocde see <https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html>
#[derive(Debug)]
pub struct Adagrad {
vars: Vec<VarAdaGrad>,
Expand Down
59 changes: 48 additions & 11 deletions src/adam.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
//! Adam optimiser
//!
//! This includes AdamW via use of decoupled weight decay
//!
//! Described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
//! and [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
//!
//! The AMSGrad variant is also implemented, described in [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
//!
//! For pseudocode see <https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam> and
//! <https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW>
/*!
Adam optimiser (inlcuding AdamW)
This includes AdamW via use of decoupled weight decay
Described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
and [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
The AMSGrad variant is also implemented, described in [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
Pseudocode (including decoupling of weight decay AdamW):
Note the AMSGrad branch is different to the PyTorch pseudocode: this is however equivalent to the torch implementation as far as I can tell.
$$
\\begin{aligned}
&\\rule{110mm}{0.4pt} \\\\
&\\textbf{input} : \\gamma \\text{ (lr)}, \\beta_1, \\beta_2
\\text{ (betas)},\\theta_0 \\text{ (params)},f(\\theta) \\text{ (objective)} \\\\
&\\hspace{13mm} \\lambda \\text{ (weight decay)}, \\: \\textit{amsgrad} \\\\
&\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)},
v_0\\leftarrow 0 \\text{ (second moment)},\\: v_0^{max}\\leftarrow 0 \\\\[-1.ex]
&\\rule{110mm}{0.4pt} \\\\
&\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
&\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
&\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
&\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
&\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
&\\hspace{10mm}\\textbf{else} \\\\
&\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
&\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\
&\\hspace{5mm}v_t \\leftarrow \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t \\\\
&\\hspace{5mm}\\widehat{m_t} \\leftarrow m_t/\\big(1-\\beta_1^t \\big) \\\\
&\\hspace{5mm}\\textbf{if} \\: amsgrad \\\\
&\\hspace{10mm}v_t^{max} \\leftarrow \\mathrm{max}(v_{t-1}^{max}, v_t) \\\\
&\\hspace{10mm}\\widehat{v_t}^{max} \\leftarrow v_t^{max} /\\big(1-\\beta_2^t \\big) \\\\
&\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t}/
\\big(\\sqrt{\\widehat{v_t}^{max}} + \\epsilon \\big) \\\\
&\\hspace{5mm}\\textbf{else} \\\\
&\\hspace{10mm}\\widehat{v_t} \\leftarrow v_t/\\big(1-\\beta_2^t \\big) \\\\
&\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t}/
\\big(\\sqrt{\\widehat{v_t}} + \\epsilon \\big) \\\\
&\\rule{110mm}{0.4pt} \\\\[-1.ex]
&\\bf{return} \\: \\theta_t \\\\[-1.ex]
&\\rule{110mm}{0.4pt} \\\\[-1.ex]
\\end{aligned}
$$
*/

use candle_core::{Result, Var};
use candle_nn::optim::Optimizer;
Expand Down
40 changes: 33 additions & 7 deletions src/adamax.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,36 @@
//! The Adamax optimiser
//!
//! An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
//!
//! For pseudocde see <https://pytorch.org/docs/stable/generated/torch.optim.Adamax.html#torch.optim.Adamax>
/*!
Adamax optimiser
An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
Pseudocode (including decoupling of weight decay):
$$
\\begin{aligned}
&\\rule{110mm}{0.4pt} \\\\
&\\textbf{input} : \\gamma \\text{ (lr)}, \\beta_1, \\beta_2
\\text{ (betas)},\\theta_0 \\text{ (params)},f(\\theta) \\text{ (objective)},
\\: \\lambda \\text{ (weight decay)}, \\\\
&\\hspace{13mm} \\epsilon \\text{ (epsilon)} \\\\
&\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)},
u_0 \\leftarrow 0 \\text{ ( infinity norm)} \\\\[-1.ex]
&\\rule{110mm}{0.4pt} \\\\
&\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\
&\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\
&\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\
&\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\
&\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\
&\\hspace{10mm}\\textbf{else} \\\\
&\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\
&\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\
&\\hspace{5mm}u_t \\leftarrow \\mathrm{max}(\\beta_2 u_{t-1}, |g_{t}|+\\epsilon) \\\\
&\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\frac{\\gamma m_t}{(1-\\beta^t_1) u_t} \\\\
&\\rule{110mm}{0.4pt} \\\\[-1.ex]
&\\bf{return} \\: \\theta_t \\\\[-1.ex]
&\\rule{110mm}{0.4pt} \\\\[-1.ex]
\\end{aligned}
$$
*/

use candle_core::{Result, Var};
use candle_nn::optim::Optimizer;
Expand All @@ -12,8 +40,6 @@ use crate::Decay;
/// Adamax optimiser
///
/// An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
///
/// For pseudocde see <https://pytorch.org/docs/stable/generated/torch.optim.Adamax.html#torch.optim.Adamax>
#[derive(Debug)]
pub struct Adamax {
Expand Down
Loading

0 comments on commit adc7092

Please sign in to comment.