From d3a9ab9c4651d1affdca33c86ad9d190a6893ede Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Thu, 7 Dec 2023 12:57:05 +0000 Subject: [PATCH 1/9] add pseudo code for adaptive methods --- katex-header.html | 15 ++++++++++++ src/adadelta.rs | 41 ++++++++++++++++++++++++++------ src/adagrad.rs | 43 +++++++++++++++++++++++++++------- src/adam.rs | 59 ++++++++++++++++++++++++++++++++++++++--------- src/adamax.rs | 40 ++++++++++++++++++++++++++------ src/lib.rs | 6 +++-- src/nadam.rs | 47 ++++++++++++++++++++++++++++++------- src/radam.rs | 53 ++++++++++++++++++++++++++++++++++++------ 8 files changed, 254 insertions(+), 50 deletions(-) create mode 100644 katex-header.html diff --git a/katex-header.html b/katex-header.html new file mode 100644 index 0000000..98e8590 --- /dev/null +++ b/katex-header.html @@ -0,0 +1,15 @@ + + + + \ No newline at end of file diff --git a/src/adadelta.rs b/src/adadelta.rs index c9b46df..7c214a6 100644 --- a/src/adadelta.rs +++ b/src/adadelta.rs @@ -1,8 +1,37 @@ -//! Adadelta optimiser -//! -//! Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701) -//! -//! For pseudocde see +/*! +Adadelta optimiser + +Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701) + +Pseudocode (including decoupling of weight decay): +$$ +\\begin{aligned} + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)}, + \\: f(\\theta) \\text{ (objective)}, \\: \\rho \\text{ (decay)}, + \\: \\lambda \\text{ (weight decay)} \\\\ + &\\textbf{initialize} : v_0 \\leftarrow 0 \\: \\text{ (square avg)}, + \\: u_0 \\leftarrow 0 \\: \\text{ (accumulate variables)} \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ + &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ + &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ + &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ + &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ + &\\hspace{10mm}\\textbf{else} \\\\ + &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ + &\\hspace{5mm} v_t \\leftarrow v_{t-1} \\rho + g^2_t (1 - \\rho) \\\\ + &\\hspace{5mm}\\Delta x_t \\leftarrow \\frac{\\sqrt{u_{t-1} + + \\epsilon }}{ \\sqrt{v_t + \\epsilon} }g_t \\hspace{21mm} \\\\ + &\\hspace{5mm} u_t \\leftarrow u_{t-1} \\rho + + \\Delta x^2_t (1 - \\rho) \\\\ + &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\Delta x_t \\\\ + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + &\\bf{return} \\: \\theta_t \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + \\end{aligned} +$$ +*/ use candle_core::{Result, Var}; use candle_nn::optim::Optimizer; @@ -12,8 +41,6 @@ use crate::Decay; /// Adadelta optimiser /// /// Described in [ADADELTA: An Adaptive Learning Rate Method](https://arxiv.org/abs/1212.5701) -/// -/// For pseudocde see #[derive(Debug)] pub struct Adadelta { vars: Vec, diff --git a/src/adagrad.rs b/src/adagrad.rs index ce560b1..73ad51c 100644 --- a/src/adagrad.rs +++ b/src/adagrad.rs @@ -1,8 +1,38 @@ -//! Adagrad optimiser -//! -//! Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html) -//! -//! For pseudocde see +/*! +Adagrad optimiser + +Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html) + +Pseudocode (including decoupling of weight decay): + +$$ +\\begin{aligned} + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)}, \\: f(\\theta) + \\text{ (objective)}, \\: \\lambda \\text{ (weight decay)}, \\\\ + &\\hspace{12mm} \\tau \\text{ (initial accumulator value)}, \\: \\eta\\text{ (lr decay)}\\\\ + &\\textbf{initialize} : statesum_0 \\leftarrow 0 \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ + &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ + &\\hspace{5mm} \\tilde{\\gamma} \\leftarrow \\gamma / (1 +(t-1) \\eta) \\\\ + &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ + &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ + &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ + &\\hspace{10mm}\\textbf{else} \\\\ + &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ + &\\hspace{5mm}statesum_t \\leftarrow statesum_{t-1} + g^2_t \\\\ + &\\hspace{5mm}\\theta_t \\leftarrow + \\theta_{t-1}- \\tilde{\\gamma} \\frac{g_t}{\\sqrt{statesum_t}+\\epsilon} \\\\ + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + &\\bf{return} \\: \\theta_t \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + \\end{aligned} +$$ + + + +*/ use candle_core::{Result, Var}; use candle_nn::optim::Optimizer; @@ -12,9 +42,6 @@ use crate::Decay; /// Adagrad optimiser /// /// Described in [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](https://jmlr.org/papers/v12/duchi11a.html) -/// -/// For pseudocde see - #[derive(Debug)] pub struct Adagrad { vars: Vec, diff --git a/src/adam.rs b/src/adam.rs index 4846e8c..7c9e32b 100644 --- a/src/adam.rs +++ b/src/adam.rs @@ -1,14 +1,51 @@ -//! Adam optimiser -//! -//! This includes AdamW via use of decoupled weight decay -//! -//! Described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) -//! and [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) -//! -//! The AMSGrad variant is also implemented, described in [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ) -//! -//! For pseudocode see and -//! +/*! +Adam optimiser + +This includes AdamW via use of decoupled weight decay + +Described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) +and [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) + +The AMSGrad variant is also implemented, described in [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ) + +Pseudocode (including decoupling of weight decay AdamW): + +Note the AMSGrad branch is different to the PyTorch pseudocode: this is however equivalent to the torch implementation as far as I can tell. + +$$ +\\begin{aligned} + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{input} : \\gamma \\text{ (lr)}, \\beta_1, \\beta_2 + \\text{ (betas)},\\theta_0 \\text{ (params)},f(\\theta) \\text{ (objective)} \\\\ + &\\hspace{13mm} \\lambda \\text{ (weight decay)}, \\: \\textit{amsgrad} \\\\ + &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)}, + v_0\\leftarrow 0 \\text{ (second moment)},\\: v_0^{max}\\leftarrow 0 \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ + &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ + &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ + &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ + &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ + &\\hspace{10mm}\\textbf{else} \\\\ + &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ + &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\ + &\\hspace{5mm}v_t \\leftarrow \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t \\\\ + &\\hspace{5mm}\\widehat{m_t} \\leftarrow m_t/\\big(1-\\beta_1^t \\big) \\\\ + &\\hspace{5mm}\\textbf{if} \\: amsgrad \\\\ + &\\hspace{10mm}v_t^{max} \\leftarrow \\mathrm{max}(v_{t-1}^{max}, v_t) \\\\ + &\\hspace{10mm}\\widehat{v_t}^{max} \\leftarrow v_t^{max} /\\big(1-\\beta_2^t \\big) \\\\ + &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t}/ + \\big(\\sqrt{\\widehat{v_t}^{max}} + \\epsilon \\big) \\\\ + &\\hspace{5mm}\\textbf{else} \\\\ + &\\hspace{10mm}\\widehat{v_t} \\leftarrow v_t/\\big(1-\\beta_2^t \\big) \\\\ + &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t}/ + \\big(\\sqrt{\\widehat{v_t}} + \\epsilon \\big) \\\\ + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + &\\bf{return} \\: \\theta_t \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\[-1.ex] +\\end{aligned} +$$ +*/ use candle_core::{Result, Var}; use candle_nn::optim::Optimizer; diff --git a/src/adamax.rs b/src/adamax.rs index f3c3e1a..7e5b4e6 100644 --- a/src/adamax.rs +++ b/src/adamax.rs @@ -1,8 +1,36 @@ -//! The Adamax optimiser -//! -//! An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) -//! -//! For pseudocde see +/*! +Adamax optimiser + +An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) + +Pseudocode (including decoupling of weight decay): + +$$ +\\begin{aligned} + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{input} : \\gamma \\text{ (lr)}, \\beta_1, \\beta_2 + \\text{ (betas)},\\theta_0 \\text{ (params)},f(\\theta) \\text{ (objective)}, + \\: \\lambda \\text{ (weight decay)}, \\\\ + &\\hspace{13mm} \\epsilon \\text{ (epsilon)} \\\\ + &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)}, + u_0 \\leftarrow 0 \\text{ ( infinity norm)} \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ + &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ + &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ + &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ + &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ + &\\hspace{10mm}\\textbf{else} \\\\ + &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ + &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\ + &\\hspace{5mm}u_t \\leftarrow \\mathrm{max}(\\beta_2 u_{t-1}, |g_{t}|+\\epsilon) \\\\ + &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\frac{\\gamma m_t}{(1-\\beta^t_1) u_t} \\\\ + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + &\\bf{return} \\: \\theta_t \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\[-1.ex] +\\end{aligned} +$$ +*/ use candle_core::{Result, Var}; use candle_nn::optim::Optimizer; @@ -12,8 +40,6 @@ use crate::Decay; /// Adamax optimiser /// /// An Adam optimiser based on infinity norm, described in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) -/// -/// For pseudocde see #[derive(Debug)] pub struct Adamax { diff --git a/src/lib.rs b/src/lib.rs index 6bb2b88..7469944 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ -//! Optimisers for use with the candle framework for lightweight machine learning. -//! These currently all implement the [`candle_nn::optim::Optimizer`] trait from candle-nn +/*! +Optimisers for use with the candle framework for lightweight machine learning. +Apart from LBFGS, these all implement the [`candle_nn::optim::Optimizer`] trait from candle-nn +*/ use std::fmt::Debug; diff --git a/src/nadam.rs b/src/nadam.rs index dcdf33b..168a9c2 100644 --- a/src/nadam.rs +++ b/src/nadam.rs @@ -1,8 +1,42 @@ -//! The NAdam optimiser: Adam with Nesterov momentum -//! -//! Described in [Incorporating Nesterov Momentum into Adam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) -//! -//! For pseudocode see +/*! +NAdam optimiser: Adam with Nesterov momentum + +Described in [Incorporating Nesterov Momentum into Adam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) + +Pseudocode (including decoupling of weight decay): + +$$ +\\begin{aligned} + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{input} : \\gamma_t \\text{ (lr)}, \\: \\beta_1,\\beta_2 \\text{ (betas)}, + \\: \\theta_0 \\text{ (params)}, \\: f(\\theta) \\text{ (objective)} \\\\ + &\\hspace{12mm} \\: \\lambda \\text{ (weight decay)}, \\:\\psi \\text{ (momentum decay)} \\\\ + &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)}, + v_0 \\leftarrow 0 \\text{ ( second moment)} \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ + &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ + &\\hspace{5mm} \\theta_t \\leftarrow \\theta_{t-1} \\\\ + &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ + &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ + &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ + &\\hspace{10mm}\\textbf{else} \\\\ + &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ + &\\hspace{5mm} \\mu_t \\leftarrow \\beta_1 \\big(1 - \\frac{1}{2} 0.96^{t \\psi} \\big) \\\\ + &\\hspace{5mm} \\mu_{t+1} \\leftarrow \\beta_1 \\big(1 - \\frac{1}{2} 0.96^{(t+1)\\psi}\\big)\\\\ + &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\ + &\\hspace{5mm}v_t \\leftarrow \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t \\\\ + &\\hspace{5mm}\\widehat{m_t} \\leftarrow \\mu_{t+1} m_t/(1-\\prod_{i=1}^{t+1}\\mu_i)\\\\[-1.ex] + & \\hspace{11mm} + (1-\\mu_t) g_t /(1-\\prod_{i=1}^{t} \\mu_{i}) \\\\ + &\\hspace{5mm}\\widehat{v_t} \\leftarrow v_t/\\big(1-\\beta_2^t \\big) \\\\ + &\\hspace{5mm}\\theta_t \\leftarrow \\theta_t - \\gamma \\widehat{m_t}/ + \\big(\\sqrt{\\widehat{v_t}} + \\epsilon \\big) \\\\ + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + &\\bf{return} \\: \\theta_t \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\[-1.ex] +\\end{aligned} +$$ +*/ use candle_core::{Result, Var}; use candle_nn::optim::Optimizer; @@ -12,9 +46,6 @@ use crate::Decay; /// Adam optimiser with Nesterov momentum /// /// Described in -/// -/// For pseudocde see - #[derive(Debug)] pub struct NAdam { vars: Vec, diff --git a/src/radam.rs b/src/radam.rs index 59eae79..694c5de 100644 --- a/src/radam.rs +++ b/src/radam.rs @@ -1,8 +1,49 @@ -//! The RAdam optimiser -//! -//! Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265) -//! -//! For pseudocde see +/*! +RAdam optimiser: Adam with Nesterov momentum + +Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265) + +As decoupled weight decay is implemented, this can be used equivalent to the paper (which uses decoupled weight decay), +or the PyTorch implementation (which does not) + +Pseudocode (including decoupling of weight decay): + +$$ +\\begin{aligned} + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\beta_1, \\beta_2 + \\text{ (betas)}, \\: \\theta_0 \\text{ (params)}, \\:f(\\theta) \\text{ (objective)}, \\: + \\lambda \\text{ (weightdecay)}, \\\\ + &\\hspace{13mm} \\epsilon \\text{ (epsilon)} \\\\ + &\\textbf{initialize} : m_0 \\leftarrow 0 \\text{ ( first moment)}, + v_0 \\leftarrow 0 \\text{ ( second moment)}, \\\\ + &\\hspace{18mm} \\rho_{\\infty} \\leftarrow 2/(1-\\beta_2) -1 \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ + &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ + &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ + &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ + &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ + &\\hspace{10mm}\\textbf{else} \\\\ + &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ + &\\hspace{5mm}m_t \\leftarrow \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\ + &\\hspace{5mm}v_t \\leftarrow \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t \\\\ + &\\hspace{5mm}\\widehat{m_t} \\leftarrow m_t/\\big(1-\\beta_1^t \\big) \\\\ + &\\hspace{5mm}\\rho_t \\leftarrow \\rho_{\\infty} - + 2 t \\beta^t_2 /\\big(1-\\beta_2^t \\big) \\\\[0.1.ex] + &\\hspace{5mm}\\textbf{if} \\: \\rho_t > 5 \\\\ + &\\hspace{10mm} l_t \\leftarrow \\frac{\\sqrt{ (1-\\beta^t_2) }}{ \\sqrt{v_t} +\\epsilon } \\\\ + &\\hspace{10mm} r_t \\leftarrow + \\sqrt{\\frac{(\\rho_t-4)(\\rho_t-2)\\rho_{\\infty}}{(\\rho_{\\infty}-4)(\\rho_{\\infty}-2) \\rho_t}} \\\\ + &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t} r_t l_t \\\\ + &\\hspace{5mm}\\textbf{else} \\\\ + &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\widehat{m_t} \\\\ + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + &\\bf{return} \\: \\theta_t \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\[-1.ex] +\\end{aligned} +$$ +*/ use candle_core::{Result, Var}; use candle_nn::optim::Optimizer; @@ -12,8 +53,6 @@ use crate::Decay; /// R Adam optimiser /// /// Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265) -/// -/// For pseudocde see #[derive(Debug)] pub struct RAdam { From d18d370b49da691bb533a35c279b9ae99d3605e5 Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Thu, 7 Dec 2023 13:13:51 +0000 Subject: [PATCH 2/9] add SGDW --- src/esgd.rs | 211 +++++++++++++++++++++++++++++++------------- tests/esgd_tests.rs | 138 ++++++++++++++++++++++++++++- 2 files changed, 287 insertions(+), 62 deletions(-) diff --git a/src/esgd.rs b/src/esgd.rs index a5b2598..25f9b31 100644 --- a/src/esgd.rs +++ b/src/esgd.rs @@ -3,7 +3,7 @@ use candle_core::{Result, Var}; use candle_nn::optim::Optimizer; -use crate::Momentum; +use crate::{Decay, Momentum}; /// Optimizer for Stochastic Gradient Descent with momentum. /// @@ -29,7 +29,7 @@ pub struct ParamsSGD { /// Learning rate pub lr: f64, /// Weight decay - pub weight_decay: Option, + pub weight_decay: Option, /// Momentum pub momentum: Option, /// Dampening @@ -72,33 +72,71 @@ impl Optimizer for SGD { if let Some(momentum) = self.params.momentum { match momentum { Momentum::Classical(momentum) => { - if let Some(wd) = self.params.weight_decay { - for var in &mut self.vars { - let theta = &var.theta; - // let prev_step = var.b; - if let Some(grad) = grads.get(theta) { - let grad = &(grad + (wd * theta.as_tensor())?)?; - if let Some(prev_step) = &(var.b) { - // println!("Exists"); - // bt​←μbt−1​+(1−τ)gt - let bt = ((prev_step.as_tensor() * momentum)? - + (1. - self.params.dampening) * (grad))?; + if let Some(decay) = self.params.weight_decay { + match decay { + Decay::WeightDecay(decay) => { + for var in &mut self.vars { + let theta = &var.theta; + // let prev_step = var.b; + if let Some(grad) = grads.get(theta) { + let grad = &(grad + (decay * theta.as_tensor())?)?; + if let Some(prev_step) = &(var.b) { + // println!("Exists"); + // bt​←μbt−1​+(1−τ)gt + let bt = ((prev_step.as_tensor() * momentum)? + + (1. - self.params.dampening) * (grad))?; - // if not nesterov gt = bt - theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?; - // println!("Momentum {}", bt); - prev_step.set(&bt)?; - } else { - // println!("Doesn't Exist"); - // bt​←μbt−1​+(1−τ)gt - // if there is no history bt = gt = grad with no weight_decay - let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap + // if not nesterov gt = bt + theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?; + // println!("Momentum {}", bt); + prev_step.set(&bt)?; + } else { + // println!("Doesn't Exist"); + // bt​←μbt−1​+(1−τ)gt + // if there is no history bt = gt = grad with no weight_decay + let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap - // if not nesterov gt = bt - theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?; - // println!("Momentum {}", bt); - var.b = Some(Var::from_tensor(&bt)?); - }; + // if not nesterov gt = bt + theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?; + // println!("Momentum {}", bt); + var.b = Some(Var::from_tensor(&bt)?); + }; + } + } + } + Decay::DecoupledWeightDecay(decay) => { + for var in &mut self.vars { + let theta = &var.theta; + // let prev_step = var.b; + if let Some(grad) = grads.get(theta) { + // decoupled weight decay step + theta.set( + &(theta.as_tensor() + * self.params.lr.mul_add(-decay, 1.))?, + )?; + if let Some(prev_step) = &(var.b) { + // println!("Exists"); + // bt​←μbt−1​+(1−τ)gt + let bt = ((prev_step.as_tensor() * momentum)? + + (1. - self.params.dampening) * (grad))?; + + // if not nesterov gt = bt + theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?; + // println!("Momentum {}", bt); + prev_step.set(&bt)?; + } else { + // println!("Doesn't Exist"); + // bt​←μbt−1​+(1−τ)gt + // if there is no history bt = gt = grad with no weight_decay + let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap + + // if not nesterov gt = bt + theta.set(&theta.sub(&(&bt * self.params.lr)?)?)?; + // println!("Momentum {}", bt); + var.b = Some(Var::from_tensor(&bt)?); + }; + } + } } } } else { @@ -132,33 +170,71 @@ impl Optimizer for SGD { } } Momentum::Nesterov(momentum) => { - if let Some(wd) = self.params.weight_decay { - for var in &mut self.vars { - let theta = &var.theta; - // let prev_step = var.b; - if let Some(grad) = grads.get(theta) { - let grad = &(grad + (wd * theta.as_tensor())?)?; - if let Some(prev_step) = &(var.b) { - // println!("Exists"); - // bt​←μbt−1​+(1−τ)gt - let bt = ((prev_step.as_tensor() * momentum)? - + (1. - self.params.dampening) * (grad))?; + if let Some(decay) = self.params.weight_decay { + match decay { + Decay::WeightDecay(decay) => { + for var in &mut self.vars { + let theta = &var.theta; + // let prev_step = var.b; + if let Some(grad) = grads.get(theta) { + let grad = &(grad + (decay * theta.as_tensor())?)?; + if let Some(prev_step) = &(var.b) { + // println!("Exists"); + // bt​←μbt−1​+(1−τ)gt + let bt = ((prev_step.as_tensor() * momentum)? + + (1. - self.params.dampening) * (grad))?; - let gt = (grad + (momentum * &bt)?)?; - // println!("Momentum {}", bt); - prev_step.set(&bt)?; - theta.set(&theta.sub(&(gt * self.params.lr)?)?)?; - } else { - // println!("Doesn't Exist"); - // bt​←μbt−1​+(1−τ)gt - // if there is no history bt = gt = grad with no weight_decay - let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap + let gt = (grad + (momentum * &bt)?)?; + // println!("Momentum {}", bt); + prev_step.set(&bt)?; + theta.set(&theta.sub(&(gt * self.params.lr)?)?)?; + } else { + // println!("Doesn't Exist"); + // bt​←μbt−1​+(1−τ)gt + // if there is no history bt = gt = grad with no weight_decay + let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap - let gt = (grad + (momentum * &bt)?)?; - // println!("Momentum {}", bt); - var.b = Some(Var::from_tensor(&bt)?); - theta.set(&theta.sub(&(gt * self.params.lr)?)?)?; - }; + let gt = (grad + (momentum * &bt)?)?; + // println!("Momentum {}", bt); + var.b = Some(Var::from_tensor(&bt)?); + theta.set(&theta.sub(&(gt * self.params.lr)?)?)?; + }; + } + } + } + Decay::DecoupledWeightDecay(decay) => { + for var in &mut self.vars { + let theta = &var.theta; + // let prev_step = var.b; + if let Some(grad) = grads.get(theta) { + // decoupled weight decay step + theta.set( + &(theta.as_tensor() + * self.params.lr.mul_add(-decay, 1.))?, + )?; + if let Some(prev_step) = &(var.b) { + // println!("Exists"); + // bt​←μbt−1​+(1−τ)gt + let bt = ((prev_step.as_tensor() * momentum)? + + (1. - self.params.dampening) * (grad))?; + + let gt = (grad + (momentum * &bt)?)?; + // println!("Momentum {}", bt); + prev_step.set(&bt)?; + theta.set(&theta.sub(&(gt * self.params.lr)?)?)?; + } else { + // println!("Doesn't Exist"); + // bt​←μbt−1​+(1−τ)gt + // if there is no history bt = gt = grad with no weight_decay + let bt = grad.clone(); // clone must occur invariably due to need to store in hashmap + + let gt = (grad + (momentum * &bt)?)?; + // println!("Momentum {}", bt); + var.b = Some(Var::from_tensor(&bt)?); + theta.set(&theta.sub(&(gt * self.params.lr)?)?)?; + }; + } + } } } } else { @@ -192,13 +268,30 @@ impl Optimizer for SGD { } } } - } else if let Some(wd) = self.params.weight_decay { - for var in &mut self.vars { - let theta = &var.theta; - // let prev_step = var.b; - if let Some(grad) = grads.get(theta) { - let grad = &(grad + (wd * theta.as_tensor())?)?; // weight decay grad - theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; // update theta + } else if let Some(decay) = self.params.weight_decay { + // These should be the same up to numeric precision + // For SGD with no momentum decoupled weight decay and L2 reg are equivalent + match decay { + Decay::WeightDecay(decay) => { + for var in &mut self.vars { + let theta = &var.theta; + // let prev_step = var.b; + if let Some(grad) = grads.get(theta) { + let grad = &(grad + (decay * theta.as_tensor())?)?; // weight decay grad + theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; // update theta + } + } + } + Decay::DecoupledWeightDecay(decay) => { + for var in &mut self.vars { + let theta = &var.theta; + // let prev_step = var.b; + if let Some(grad) = grads.get(theta) { + theta + .set(&(theta.as_tensor() * self.params.lr.mul_add(-decay, 1.))?)?; + theta.set(&theta.sub(&(grad * self.params.lr)?)?)?; // update theta based on grad + } + } } } } else { diff --git a/tests/esgd_tests.rs b/tests/esgd_tests.rs index 86fc072..1ff4872 100644 --- a/tests/esgd_tests.rs +++ b/tests/esgd_tests.rs @@ -114,7 +114,7 @@ fn nesterov_decay_sgd_test() -> Result<()> { let params = ParamsSGD { lr: 0.004, - weight_decay: Some(0.1), + weight_decay: Some(optimisers::Decay::WeightDecay(0.1)), momentum: Some(Momentum::Nesterov(0.1)), dampening: 0.0, // nesterov: true, @@ -234,7 +234,7 @@ fn momentum_sgd_decay_test() -> Result<()> { let params = ParamsSGD { lr: 0.004, - weight_decay: Some(0.4), + weight_decay: Some(optimisers::Decay::WeightDecay(0.4)), momentum: Some(Momentum::Classical(0.1)), dampening: 0.0, // nesterov: false, @@ -326,7 +326,75 @@ fn sgd_test() -> Result<()> { let params = ParamsSGD { lr: 0.004, - weight_decay: Some(0.4), + weight_decay: None, + momentum: None, + dampening: 0.0, + // nesterov: false, + }; + // Now use backprop to run a linear regression between samples and get the coefficients back. + let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; + let b = Var::new(0f32, &Device::Cpu)?; + let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; + let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + n_sgd.backward_step(&loss)?; + } + + assert_eq!(to_vec2_round(&w, 4)?, &[[2.8809, 0.8513]]); + assert_eq!(to_vec0_round(&b, 4)?, -0.5606); + Ok(()) +} + +#[test] +fn sgd_decay_test() -> Result<()> { + // Generate some linear data, y = 3.x1 + x2 - 2. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + let params = ParamsSGD { + lr: 0.004, + weight_decay: Some(optimisers::Decay::WeightDecay(0.4)), + momentum: None, + dampening: 0.0, + // nesterov: false, + }; + // Now use backprop to run a linear regression between samples and get the coefficients back. + let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; + let b = Var::new(0f32, &Device::Cpu)?; + let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; + let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + n_sgd.backward_step(&loss)?; + } + + assert_eq!(to_vec2_round(&w, 4)?, &[[2.8700, 0.8450]]); + assert_eq!(to_vec0_round(&b, 4)?, -0.5003); + Ok(()) +} + +// The following are not tested against torch +// As torch has no implementation of SGDW + +// This should be the same (as without momentum, decoupling is equivalent) +#[test] +fn sgdw_decay_test() -> Result<()> { + // Generate some linear data, y = 3.x1 + x2 - 2. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + let params = ParamsSGD { + lr: 0.004, + weight_decay: Some(optimisers::Decay::DecoupledWeightDecay(0.4)), momentum: None, dampening: 0.0, // nesterov: false, @@ -346,3 +414,67 @@ fn sgd_test() -> Result<()> { assert_eq!(to_vec0_round(&b, 4)?, -0.5003); Ok(()) } + +#[test] +fn momentum_sgdw_decay_test() -> Result<()> { + // Generate some linear data, y = 3.x1 + x2 - 2. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + let params = ParamsSGD { + lr: 0.004, + weight_decay: Some(optimisers::Decay::DecoupledWeightDecay(0.4)), + momentum: Some(Momentum::Classical(0.1)), + dampening: 0.0, + // nesterov: false, + }; + // Now use backprop to run a linear regression between samples and get the coefficients back. + let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; + let b = Var::new(0f32, &Device::Cpu)?; + let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; + let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + n_sgd.backward_step(&loss)?; + } + + assert_eq!(to_vec2_round(&w, 4)?, &[[2.8763, 0.8521]]); + assert_eq!(to_vec0_round(&b, 4)?, -0.5693); + Ok(()) +} + +#[test] +fn nesterov_decay_sgdw_test() -> Result<()> { + // Generate some linear data, y = 3.x1 + x2 - 2. + let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; + let b_gen = Tensor::new(-2f32, &Device::Cpu)?; + let gen = Linear::new(w_gen, Some(b_gen)); + let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; + let sample_ys = gen.forward(&sample_xs)?; + + let params = ParamsSGD { + lr: 0.004, + weight_decay: Some(optimisers::Decay::DecoupledWeightDecay(0.1)), + momentum: Some(Momentum::Nesterov(0.1)), + dampening: 0.0, + // nesterov: true, + }; + // Now use backprop to run a linear regression between samples and get the coefficients back. + let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; + let b = Var::new(0f32, &Device::Cpu)?; + let mut n_sgd = SGD::new(vec![w.clone(), b.clone()], params)?; + let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); + for _step in 0..100 { + let ys = lin.forward(&sample_xs)?; + let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; + n_sgd.backward_step(&loss)?; + } + + assert_eq!(to_vec2_round(&w, 4)?, &[[0.9992, -10.3397]]); + assert_eq!(to_vec0_round(&b, 4)?, -1.9302); + Ok(()) +} From 3e5227fac13428c06f9359fab4987f8aad35f74f Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Thu, 7 Dec 2023 13:50:21 +0000 Subject: [PATCH 3/9] add pseudocode for rmsprop and sgd --- README.md | 2 +- src/adam.rs | 2 +- src/esgd.rs | 43 +++++++++++++++++++++++++++++++++++++------ src/rmsprop.rs | 42 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 39f114f..459c2de 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Adaptive methods: These are all checked against their pytorch implementation (see pytorch_test.ipynb) and should implement the same functionality (though without some input checking). -Additionally all of the adaptive mehods listed implement decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/pdf/1711.05101.pdf), in addition to the standard weight decay as implemented in pytorch. +Additionally all of the adaptive mehods listed and SGD implement decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/pdf/1711.05101.pdf), in addition to the standard weight decay as implemented in pytorch. Pseudosecond order methods: diff --git a/src/adam.rs b/src/adam.rs index 7c9e32b..8e97f28 100644 --- a/src/adam.rs +++ b/src/adam.rs @@ -1,5 +1,5 @@ /*! -Adam optimiser +Adam optimiser (inlcuding AdamW) This includes AdamW via use of decoupled weight decay diff --git a/src/esgd.rs b/src/esgd.rs index 25f9b31..90bce0f 100644 --- a/src/esgd.rs +++ b/src/esgd.rs @@ -1,4 +1,40 @@ -//! Stochastic Gradient Descent with momentum, weight decay and Nestervov momentum +/*! + Stochastic Gradient Descent + + This incoporates Nesterov and classical momentum as well as weight decay and decoupled weight decay + (as described as SGDW in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)) + +$$ +\\begin{aligned} + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{input} : \\gamma \\text{ (lr)}, \\: \\theta_0 \\text{ (params)}, \\: f(\\theta) + \\text{ (objective)}, \\: \\lambda \\text{ (weight decay)}, \\\\ + &\\hspace{13mm} \\:\\mu \\text{ (momentum)}, \\:\\tau \\text{ (dampening)} \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ + &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ + &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ + &\\hspace{10mm}\\textbf{if} \\: \\textit{decoupled} \\\\ + &\\hspace{15mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma \\lambda \\theta_{t-1} \\\\ + &\\hspace{10mm}\\textbf{else} \\\\ + &\\hspace{15mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ + &\\hspace{5mm}\\textbf{if} \\: \\mu \\textbf{ is } \\text{Some} \\\\ + &\\hspace{10mm}\\textbf{if} \\: t>1 \\\\ + &\\hspace{15mm} b_t \\leftarrow \\mu b_{t-1} + (1-\\tau)g_{t} \\\\ + &\\hspace{10mm}\\textbf{else} \\\\ + &\\hspace{15mm} b_t \\leftarrow g_{t} \\\\ + &\\hspace{10mm}\\textbf{if} \\: \\textit{nesterov} \\\\ + &\\hspace{15mm} g_t \\leftarrow gt + \\mu b_t \\\\ + &\\hspace{10mm}\\textbf{else} \\\\ + &\\hspace{15mm} g_t \\leftarrow b_t \\\\ + &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma g_t \\\\ + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + &\\bf{return} \\: \\theta_t \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\[-1.ex] +\\end{aligned} +$$ + +*/ use candle_core::{Result, Var}; use candle_nn::optim::Optimizer; @@ -6,11 +42,6 @@ use candle_nn::optim::Optimizer; use crate::{Decay, Momentum}; /// Optimizer for Stochastic Gradient Descent with momentum. -/// -/// Utilised same interface as pytorch but allows negative momenta and dampening with Nesterov -/// -/// For pseudocde see - #[derive(Debug)] pub struct SGD { vars: Vec, diff --git a/src/rmsprop.rs b/src/rmsprop.rs index a3fc297..110eb9e 100644 --- a/src/rmsprop.rs +++ b/src/rmsprop.rs @@ -4,6 +4,48 @@ //! //! For pseudocde see +/*! +RMS prop algorithm + +Described in + +Pseudocode: + +$$ +\\begin{aligned} + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{input} : \\alpha \\text{ (alpha)},\\: \\gamma \\text{ (lr)}, + \\: \\theta_0 \\text{ (params)}, \\: f(\\theta) \\text{ (objective)} \\\\ + &\\hspace{13mm} \\lambda \\text{ (weight decay)},\\: \\mu \\text{ (momentum)} \\\\ + &\\textbf{initialize} : v_0 \\leftarrow 0 \\text{ (square average)}, \\: + b_0 \\leftarrow 0 \\text{ (buffer)}, \\: g_0^{ave} \\leftarrow 0 \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\ + &\\textbf{for} \\: t=1 \\: \\textbf{to} \\: \\ldots \\: \\textbf{do} \\\\ + &\\hspace{5mm}g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\ + &\\hspace{5mm}\\textbf{if} \\: \\lambda \\textbf{ is } \\text{Some} \\\\ + &\\hspace{10mm} g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\ + &\\hspace{5mm}v_t \\leftarrow \\alpha v_{t-1} + (1 - \\alpha) g^2_t + \\hspace{8mm} \\\\ + &\\hspace{5mm} \\tilde{v_t} \\leftarrow v_t \\\\ + &\\hspace{5mm}\\textbf{if} \\: centered \\\\ + &\\hspace{10mm} g_t^{ave} \\leftarrow g_{t-1}^{ave} \\alpha + (1-\\alpha) g_t \\\\ + &\\hspace{10mm} \\tilde{v_t} \\leftarrow \\tilde{v_t} - \\big(g_{t}^{ave} \\big)^2 \\\\ + &\\hspace{5mm}\\textbf{if} \\: \\mu \\textbf{ is } \\text{Some} \\\\ + &\\hspace{10mm} b_t\\leftarrow \\mu b_{t-1} + + g_t/ \\big(\\sqrt{\\tilde{v_t}} + \\epsilon \\big) \\\\ + &\\hspace{10mm} \\theta_t \\leftarrow \\theta_{t-1} - \\gamma b_t \\\\ + &\\hspace{5mm} \\textbf{else} \\\\ + &\\hspace{10mm}\\theta_t \\leftarrow \\theta_{t-1} - + \\gamma g_t/ \\big(\\sqrt{\\tilde{v_t}} + \\epsilon \\big) \\hspace{3mm} \\\\ + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + &\\bf{return} \\: \\theta_t \\\\[-1.ex] + &\\rule{110mm}{0.4pt} \\\\[-1.ex] + \\end{aligned} +$$ + + +*/ + use candle_core::{Result, Var}; use candle_nn::optim::Optimizer; From 27ccc6f6934baefaf60357760e760f189c577ebc Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Thu, 7 Dec 2023 13:51:36 +0000 Subject: [PATCH 4/9] correct typo --- src/lbfgs.rs | 2 +- src/radam.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lbfgs.rs b/src/lbfgs.rs index a79bac6..c5fc8b5 100644 --- a/src/lbfgs.rs +++ b/src/lbfgs.rs @@ -1,4 +1,4 @@ -//! LBFGS optimiser +//! Limited memory Broyden–Fletcher–Goldfarb–Shanno algorithm //! //! A pseudo second order optimiser based on the BFGS method. //! diff --git a/src/radam.rs b/src/radam.rs index 694c5de..2f5ca6d 100644 --- a/src/radam.rs +++ b/src/radam.rs @@ -1,5 +1,5 @@ /*! -RAdam optimiser: Adam with Nesterov momentum +RAdam optimiser Described in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1908.03265) From 90b06312257114be2a025965966cecda481f4b73 Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Thu, 7 Dec 2023 14:26:36 +0000 Subject: [PATCH 5/9] config and doc typos --- .cargo/config.toml | 2 ++ Cargo.toml | 3 +++ src/esgd.rs | 2 +- src/lib.rs | 16 ++++++++++++++-- 4 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 .cargo/config.toml diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..beb80be --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +rustdocflags = [ "--html-in-header", "./katex-header.html" ] \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 7f1f93b..6070358 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,3 +51,6 @@ uninlined_format_args = {level = "allow", priority = 1} similar_names = {level = "allow", priority = 1} float_cmp = {level = "allow", priority = 1} # as internaly rounded before the comparison doc_markdown= {level = "allow", priority = 1} # otherwise names get flagged + +[package.metadata.docs.rs] +rustdoc-args = [ "--html-in-header", "./katex-header.html" ] \ No newline at end of file diff --git a/src/esgd.rs b/src/esgd.rs index 90bce0f..fe5edd1 100644 --- a/src/esgd.rs +++ b/src/esgd.rs @@ -24,7 +24,7 @@ $$ &\\hspace{10mm}\\textbf{else} \\\\ &\\hspace{15mm} b_t \\leftarrow g_{t} \\\\ &\\hspace{10mm}\\textbf{if} \\: \\textit{nesterov} \\\\ - &\\hspace{15mm} g_t \\leftarrow gt + \\mu b_t \\\\ + &\\hspace{15mm} g_t \\leftarrow g_t + \\mu b_t \\\\ &\\hspace{10mm}\\textbf{else} \\\\ &\\hspace{15mm} g_t \\leftarrow b_t \\\\ &\\hspace{5mm}\\theta_t \\leftarrow \\theta_{t-1} - \\gamma g_t \\\\ diff --git a/src/lib.rs b/src/lib.rs index 7469944..5d2a0ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,9 +60,21 @@ pub enum ModelOutcome { /// Method of weight decay to use #[derive(Clone, Copy, Debug)] pub enum Decay { - /// weight decay regularisation to penalise large weights + /// Weight decay regularisation to penalise large weights + /// + /// The gradient is transformed as + /// $$ g_{t} \\gets g_{t} + + \\lambda \\theta_{t-1}$$ + /// + /// This is equivalent to an L2 regularisation term in the loss adding $\\frac{\\lambda}{2}||\theta||_{2}^{2}$ but avoids autodifferentiation + /// of the L2 term WeightDecay(f64), - /// Decoupled weight decay as described in + /// Decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) + /// + /// This directly decays the weights as + /// + /// $$ \\theta_{t} \\gets (1 - \\eta \\lambda) \\theta_{t-1}$$ + /// + /// This is equivalent to regularisation, only for SGD without momentum, but is different for adaptive gradient methods DecoupledWeightDecay(f64), } From 86c798aa167f8806a5e0440e1e30d65ea02b1815 Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Thu, 7 Dec 2023 14:35:25 +0000 Subject: [PATCH 6/9] update Readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 459c2de..8c71a95 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Optimisers +# Candle Optimisers [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![codecov](https://codecov.io/gh/KGrewal1/optimisers/graph/badge.svg?token=6AFTLS6DFO)](https://codecov.io/gh/KGrewal1/optimisers) From 93bc2c0a4c54637b4f062d1499c4fbe578f83414 Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Thu, 7 Dec 2023 14:42:31 +0000 Subject: [PATCH 7/9] rename crate to candle-optimisers --- Cargo.toml | 2 +- benches/mnist_bench.rs | 4 ++-- benches/training.rs | 26 ++++++++++++++------------ examples/mnist/main.rs | 14 +++++++------- examples/mnist/optim.rs | 18 ++++++++++-------- tests/adadelta_tests.rs | 2 +- tests/adagrad_tests.rs | 6 +++--- tests/adam_tests.rs | 2 +- tests/adamax_tests.rs | 6 +++--- tests/esgd_tests.rs | 16 ++++++++-------- tests/lbfgs_tests.rs | 24 ++++++++++++------------ tests/nadam_tests.rs | 2 +- tests/radam-tests.rs | 9 ++++++--- tests/rmsprop-tests.rs | 2 +- 14 files changed, 70 insertions(+), 63 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6070358..0140589 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "optimisers" +name = "candle-optimisers" version = "0.2.1" edition = "2021" readme = "README.md" diff --git a/benches/mnist_bench.rs b/benches/mnist_bench.rs index 46c91fa..0cbd4c8 100644 --- a/benches/mnist_bench.rs +++ b/benches/mnist_bench.rs @@ -1,10 +1,10 @@ use candle_core::Result as CResult; use candle_datasets::vision::Dataset; -use criterion::{criterion_group, criterion_main, Criterion}; -use optimisers::{ +use candle_optimisers::{ adadelta::Adadelta, adagrad::Adagrad, adam::Adam, adamax::Adamax, esgd::SGD, nadam::NAdam, radam::RAdam, rmsprop::RMSprop, }; +use criterion::{criterion_group, criterion_main, Criterion}; use training::Mlp; // mod models; diff --git a/benches/training.rs b/benches/training.rs index 26e75d7..75194a0 100644 --- a/benches/training.rs +++ b/benches/training.rs @@ -1,16 +1,18 @@ use candle_core::{DType, Result, Tensor, Var, D}; use candle_nn::{loss, ops, Linear, Module, Optimizer, VarBuilder, VarMap}; -use optimisers::adadelta::{Adadelta, ParamsAdaDelta}; -use optimisers::adagrad::{Adagrad, ParamsAdaGrad}; -use optimisers::adam::{Adam, ParamsAdam}; -use optimisers::adamax::{Adamax, ParamsAdaMax}; -use optimisers::esgd::{ParamsSGD, SGD}; -use optimisers::lbfgs::{Lbfgs, LineSearch, ParamsLBFGS}; -use optimisers::nadam::{NAdam, ParamsNAdam}; -use optimisers::radam::{ParamsRAdam, RAdam}; -use optimisers::rmsprop::{ParamsRMSprop, RMSprop}; -use optimisers::{LossOptimizer, Model}; +use candle_optimisers::{ + adadelta::{Adadelta, ParamsAdaDelta}, + adagrad::{Adagrad, ParamsAdaGrad}, + adam::{Adam, ParamsAdam}, + adamax::{Adamax, ParamsAdaMax}, + esgd::{ParamsSGD, SGD}, + lbfgs::{Lbfgs, LineSearch, ParamsLBFGS}, + nadam::{NAdam, ParamsNAdam}, + radam::{ParamsRAdam, RAdam}, + rmsprop::{ParamsRMSprop, RMSprop}, + LossOptimizer, Model, +}; pub trait Optim: Sized { fn new(vars: Vec) -> Result; @@ -217,8 +219,8 @@ pub fn run_lbfgs_training( // step the tensors by backpropagating the loss let res = optimiser.backward_step(&loss)?; match res { - optimisers::ModelOutcome::Converged(_, _) => break, - optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss, + candle_optimisers::ModelOutcome::Converged(_, _) => break, + candle_optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss, // _ => panic!("unexpected outcome"), } } diff --git a/examples/mnist/main.rs b/examples/mnist/main.rs index 3796c78..5a365a4 100644 --- a/examples/mnist/main.rs +++ b/examples/mnist/main.rs @@ -7,13 +7,13 @@ mod training; use models::{LinearModel, Mlp}; -use optimisers::adagrad::Adagrad; -use optimisers::adamax::Adamax; -use optimisers::esgd::SGD; -use optimisers::nadam::NAdam; -use optimisers::radam::RAdam; -use optimisers::rmsprop::RMSprop; -use optimisers::{adadelta::Adadelta, adam::Adam}; +use candle_optimisers::adagrad::Adagrad; +use candle_optimisers::adamax::Adamax; +use candle_optimisers::esgd::SGD; +use candle_optimisers::nadam::NAdam; +use candle_optimisers::radam::RAdam; +use candle_optimisers::rmsprop::RMSprop; +use candle_optimisers::{adadelta::Adadelta, adam::Adam}; use parse_cli::{Args, TrainingArgs, WhichModel, WhichOptim}; use training::training_loop; diff --git a/examples/mnist/optim.rs b/examples/mnist/optim.rs index 0fcd7ec..ddcaa1a 100644 --- a/examples/mnist/optim.rs +++ b/examples/mnist/optim.rs @@ -1,13 +1,15 @@ use candle_core::{Result, Tensor, Var}; use candle_nn::Optimizer; -use optimisers::adadelta::{Adadelta, ParamsAdaDelta}; -use optimisers::adagrad::{Adagrad, ParamsAdaGrad}; -use optimisers::adam::{Adam, ParamsAdam}; -use optimisers::adamax::{Adamax, ParamsAdaMax}; -use optimisers::esgd::{ParamsSGD, SGD}; -use optimisers::nadam::{NAdam, ParamsNAdam}; -use optimisers::radam::{ParamsRAdam, RAdam}; -use optimisers::rmsprop::{ParamsRMSprop, RMSprop}; +use candle_optimisers::{ + adadelta::{Adadelta, ParamsAdaDelta}, + adagrad::{Adagrad, ParamsAdaGrad}, + adam::{Adam, ParamsAdam}, + adamax::{Adamax, ParamsAdaMax}, + esgd::{ParamsSGD, SGD}, + nadam::{NAdam, ParamsNAdam}, + radam::{ParamsRAdam, RAdam}, + rmsprop::{ParamsRMSprop, RMSprop}, +}; pub trait Optim: Sized { fn new(vars: Vec, lr: f64) -> Result; diff --git a/tests/adadelta_tests.rs b/tests/adadelta_tests.rs index 5d1500a..39e535d 100644 --- a/tests/adadelta_tests.rs +++ b/tests/adadelta_tests.rs @@ -9,7 +9,7 @@ use candle_core::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle_core::{Device, Tensor, Var}; use candle_nn::{Linear, Module, Optimizer}; -use optimisers::{ +use candle_optimisers::{ adadelta::{Adadelta, ParamsAdaDelta}, Decay, }; diff --git a/tests/adagrad_tests.rs b/tests/adagrad_tests.rs index b204c5c..2fcaa9c 100644 --- a/tests/adagrad_tests.rs +++ b/tests/adagrad_tests.rs @@ -9,7 +9,7 @@ use candle_core::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle_core::{Device, Tensor, Var}; use candle_nn::{Linear, Module, Optimizer}; -use optimisers::adagrad::{Adagrad, ParamsAdaGrad}; +use candle_optimisers::adagrad::{Adagrad, ParamsAdaGrad}; /* The results of this test have been checked against the following PyTorch code. import torch @@ -169,7 +169,7 @@ fn adagrad_weight_decay_test() -> Result<()> { let params = ParamsAdaGrad { lr: 0.004, lr_decay: 0.0, - weight_decay: Some(optimisers::Decay::WeightDecay(0.2)), + weight_decay: Some(candle_optimisers::Decay::WeightDecay(0.2)), initial_acc: 0.0, eps: 1e-10, }; @@ -204,7 +204,7 @@ fn adagrad_decoupled_weight_decay_test() -> Result<()> { let params = ParamsAdaGrad { lr: 0.004, lr_decay: 0.0, - weight_decay: Some(optimisers::Decay::DecoupledWeightDecay(0.2)), + weight_decay: Some(candle_optimisers::Decay::DecoupledWeightDecay(0.2)), initial_acc: 0.0, eps: 1e-10, }; diff --git a/tests/adam_tests.rs b/tests/adam_tests.rs index 338bcac..fa9a28f 100644 --- a/tests/adam_tests.rs +++ b/tests/adam_tests.rs @@ -9,7 +9,7 @@ use candle_core::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle_core::{Device, Tensor, Var}; use candle_nn::{Linear, Module, Optimizer}; -use optimisers::{ +use candle_optimisers::{ adam::{Adam, ParamsAdam}, Decay, }; diff --git a/tests/adamax_tests.rs b/tests/adamax_tests.rs index 0fe0d96..b8f01f9 100644 --- a/tests/adamax_tests.rs +++ b/tests/adamax_tests.rs @@ -9,7 +9,7 @@ use candle_core::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle_core::{Device, Tensor, Var}; use candle_nn::{Linear, Module, Optimizer}; -use optimisers::adamax::{Adamax, ParamsAdaMax}; +use candle_optimisers::adamax::{Adamax, ParamsAdaMax}; /* The results of this test have been checked against the following PyTorch code. import torch @@ -106,7 +106,7 @@ fn adamax_weight_decay_test() -> Result<()> { let params = ParamsAdaMax { lr: 0.004, - weight_decay: Some(optimisers::Decay::WeightDecay(0.6)), + weight_decay: Some(candle_optimisers::Decay::WeightDecay(0.6)), ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. @@ -139,7 +139,7 @@ fn adamax_decoupled_weight_decay_test() -> Result<()> { let params = ParamsAdaMax { lr: 0.004, - weight_decay: Some(optimisers::Decay::DecoupledWeightDecay(0.6)), + weight_decay: Some(candle_optimisers::Decay::DecoupledWeightDecay(0.6)), ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. diff --git a/tests/esgd_tests.rs b/tests/esgd_tests.rs index 1ff4872..7d160f5 100644 --- a/tests/esgd_tests.rs +++ b/tests/esgd_tests.rs @@ -9,9 +9,9 @@ use candle_core::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle_core::{Device, Tensor, Var}; use candle_nn::{Linear, Module, Optimizer}; -use optimisers::{ +use candle_optimisers::{ esgd::{ParamsSGD, SGD}, - Momentum, + Decay, Momentum, }; /* The results of this test have been checked against the following PyTorch code. @@ -114,7 +114,7 @@ fn nesterov_decay_sgd_test() -> Result<()> { let params = ParamsSGD { lr: 0.004, - weight_decay: Some(optimisers::Decay::WeightDecay(0.1)), + weight_decay: Some(Decay::WeightDecay(0.1)), momentum: Some(Momentum::Nesterov(0.1)), dampening: 0.0, // nesterov: true, @@ -234,7 +234,7 @@ fn momentum_sgd_decay_test() -> Result<()> { let params = ParamsSGD { lr: 0.004, - weight_decay: Some(optimisers::Decay::WeightDecay(0.4)), + weight_decay: Some(Decay::WeightDecay(0.4)), momentum: Some(Momentum::Classical(0.1)), dampening: 0.0, // nesterov: false, @@ -358,7 +358,7 @@ fn sgd_decay_test() -> Result<()> { let params = ParamsSGD { lr: 0.004, - weight_decay: Some(optimisers::Decay::WeightDecay(0.4)), + weight_decay: Some(Decay::WeightDecay(0.4)), momentum: None, dampening: 0.0, // nesterov: false, @@ -394,7 +394,7 @@ fn sgdw_decay_test() -> Result<()> { let params = ParamsSGD { lr: 0.004, - weight_decay: Some(optimisers::Decay::DecoupledWeightDecay(0.4)), + weight_decay: Some(Decay::DecoupledWeightDecay(0.4)), momentum: None, dampening: 0.0, // nesterov: false, @@ -426,7 +426,7 @@ fn momentum_sgdw_decay_test() -> Result<()> { let params = ParamsSGD { lr: 0.004, - weight_decay: Some(optimisers::Decay::DecoupledWeightDecay(0.4)), + weight_decay: Some(Decay::DecoupledWeightDecay(0.4)), momentum: Some(Momentum::Classical(0.1)), dampening: 0.0, // nesterov: false, @@ -458,7 +458,7 @@ fn nesterov_decay_sgdw_test() -> Result<()> { let params = ParamsSGD { lr: 0.004, - weight_decay: Some(optimisers::Decay::DecoupledWeightDecay(0.1)), + weight_decay: Some(candle_optimisers::Decay::DecoupledWeightDecay(0.1)), momentum: Some(Momentum::Nesterov(0.1)), dampening: 0.0, // nesterov: true, diff --git a/tests/lbfgs_tests.rs b/tests/lbfgs_tests.rs index f4f59a5..cd2927e 100644 --- a/tests/lbfgs_tests.rs +++ b/tests/lbfgs_tests.rs @@ -9,8 +9,8 @@ extern crate accelerate_src; use anyhow::Result; use candle_core::test_utils::to_vec2_round; use candle_core::{DType, Device, Result as CResult, Tensor}; -use optimisers::lbfgs::{GradConv, Lbfgs, LineSearch, ParamsLBFGS, StepConv}; -use optimisers::{LossOptimizer, Model}; +use candle_optimisers::lbfgs::{GradConv, Lbfgs, LineSearch, ParamsLBFGS, StepConv}; +use candle_optimisers::{LossOptimizer, Model, ModelOutcome}; /* These tests all use the 2D Rosenbrock function as a test function for the optimisers. This has minimum 0 at (1, 1) @@ -70,8 +70,8 @@ fn lbfgs_test() -> Result<()> { let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys // println!("end step {}", _step); match res { - optimisers::ModelOutcome::Converged(_, _) => break, - optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss, + ModelOutcome::Converged(_, _) => break, + ModelOutcome::Stepped(new_loss, _) => loss = new_loss, // _ => panic!("unexpected outcome"), } } @@ -108,8 +108,8 @@ fn lbfgs_test_strong_wolfe() -> Result<()> { let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys // println!("end step {}", _step); match res { - optimisers::ModelOutcome::Converged(_, _) => break, - optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss, + ModelOutcome::Converged(_, _) => break, + ModelOutcome::Stepped(new_loss, _) => loss = new_loss, // _ => panic!("unexpected outcome"), } } @@ -146,8 +146,8 @@ fn lbfgs_rms_grad_test() -> Result<()> { let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys // println!("end step {}", _step); match res { - optimisers::ModelOutcome::Converged(_, _) => break, - optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss, + ModelOutcome::Converged(_, _) => break, + ModelOutcome::Stepped(new_loss, _) => loss = new_loss, // _ => panic!("unexpected outcome"), } } @@ -185,8 +185,8 @@ fn lbfgs_rms_step_test() -> Result<()> { let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys // println!("end step {}", _step); match res { - optimisers::ModelOutcome::Converged(_, _) => break, - optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss, + ModelOutcome::Converged(_, _) => break, + ModelOutcome::Stepped(new_loss, _) => loss = new_loss, // _ => panic!("unexpected outcome"), } } @@ -224,8 +224,8 @@ fn lbfgs_test_strong_wolfe_weight_decay() -> Result<()> { let res = lbfgs.backward_step(&loss)?; //&sample_xs, &sample_ys // println!("end step {}", _step); match res { - optimisers::ModelOutcome::Converged(_, _) => break, - optimisers::ModelOutcome::Stepped(new_loss, _) => loss = new_loss, + ModelOutcome::Converged(_, _) => break, + ModelOutcome::Stepped(new_loss, _) => loss = new_loss, // _ => panic!("unexpected outcome"), } } diff --git a/tests/nadam_tests.rs b/tests/nadam_tests.rs index 4f20558..2224502 100644 --- a/tests/nadam_tests.rs +++ b/tests/nadam_tests.rs @@ -9,7 +9,7 @@ use candle_core::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle_core::{Device, Tensor, Var}; use candle_nn::{Linear, Module, Optimizer}; -use optimisers::{ +use candle_optimisers::{ nadam::{NAdam, ParamsNAdam}, Decay, }; diff --git a/tests/radam-tests.rs b/tests/radam-tests.rs index 5f87f53..25396c9 100644 --- a/tests/radam-tests.rs +++ b/tests/radam-tests.rs @@ -9,7 +9,10 @@ use candle_core::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle_core::{Device, Tensor, Var}; use candle_nn::{Linear, Module, Optimizer}; -use optimisers::radam::{ParamsRAdam, RAdam}; +use candle_optimisers::{ + radam::{ParamsRAdam, RAdam}, + Decay, +}; /* The results of this test have been checked against the following PyTorch code. import torch @@ -102,7 +105,7 @@ fn radam_weight_decay_test() -> Result<()> { let sample_ys = gen.forward(&sample_xs)?; let params = ParamsRAdam { - weight_decay: Some(optimisers::Decay::WeightDecay(0.4)), + weight_decay: Some(Decay::WeightDecay(0.4)), ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. @@ -134,7 +137,7 @@ fn radam_decoupled_weight_decay_test() -> Result<()> { let sample_ys = gen.forward(&sample_xs)?; let params = ParamsRAdam { - weight_decay: Some(optimisers::Decay::DecoupledWeightDecay(0.4)), + weight_decay: Some(Decay::DecoupledWeightDecay(0.4)), ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. diff --git a/tests/rmsprop-tests.rs b/tests/rmsprop-tests.rs index f8bd6cf..9307c04 100644 --- a/tests/rmsprop-tests.rs +++ b/tests/rmsprop-tests.rs @@ -9,7 +9,7 @@ use candle_core::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle_core::{Device, Tensor, Var}; use candle_nn::{Linear, Module, Optimizer}; -use optimisers::rmsprop::{ParamsRMSprop, RMSprop}; +use candle_optimisers::rmsprop::{ParamsRMSprop, RMSprop}; /* The results of this test have been checked against the following PyTorch code. import torch From eb42b5a681b441bace8b68c41af52fa011a8a795 Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Thu, 7 Dec 2023 14:48:26 +0000 Subject: [PATCH 8/9] update readme and version --- Cargo.toml | 2 +- README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0140589..31f6dc3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-optimisers" -version = "0.2.1" +version = "0.3.0" edition = "2021" readme = "README.md" license = "MIT" diff --git a/README.md b/README.md index 8c71a95..8fc9c95 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ to use the cuda backend. ## Usage ```cli -cargo add --git https://github.com/KGrewal1/optimisers.git optimisers +cargo add --git https://github.com/KGrewal1/optimisers.git candle-optimisers ``` ## To do From 62ae9ab6d1259c27ddcdbff7f08872688faa78ac Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Thu, 7 Dec 2023 14:49:54 +0000 Subject: [PATCH 9/9] lint fixing --- src/esgd.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/esgd.rs b/src/esgd.rs index fe5edd1..d5e9152 100644 --- a/src/esgd.rs +++ b/src/esgd.rs @@ -99,6 +99,7 @@ impl Optimizer for SGD { self.params.lr } + #[allow(clippy::too_many_lines)] fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> { if let Some(momentum) = self.params.momentum { match momentum {