From ad1640855945a88d427c1c5e573572c7cc63aa32 Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Fri, 3 May 2024 17:24:30 +0100 Subject: [PATCH 1/6] change f to scalar in line search --- src/lbfgs/strong_wolfe.rs | 138 +++++++++++++++----------------------- 1 file changed, 55 insertions(+), 83 deletions(-) diff --git a/src/lbfgs/strong_wolfe.rs b/src/lbfgs/strong_wolfe.rs index 66b137d..7401936 100644 --- a/src/lbfgs/strong_wolfe.rs +++ b/src/lbfgs/strong_wolfe.rs @@ -92,7 +92,9 @@ impl Lbfgs { // evaluate objective and gradient using initial step let (f_new, g_new, mut l2_new) = self.directional_evaluate(step_size, direction)?; let g_new = Var::from_tensor(&g_new)?; - let f_new = Var::from_tensor(&f_new)?; + let mut f_new = f_new + .to_dtype(candle_core::DType::F64)? + .to_scalar::()?; let mut ls_func_evals = 1; let mut gtd_new = g_new .unsqueeze(0)? @@ -104,7 +106,10 @@ impl Lbfgs { // bracket an interval containing a point satisfying the Wolfe criteria let g_prev = Var::from_tensor(grad)?; - let f_prev = Var::from_tensor(loss)?; + let dtype = loss.dtype(); + let shape = loss.shape(); + let dev = loss.device(); + let mut f_prev = loss.to_dtype(candle_core::DType::F64)?.to_scalar::()?; let l2_init = self.l2_reg()?; let mut l2_prev = l2_init; let (mut t_prev, mut gtd_prev) = (0., directional_grad); @@ -113,21 +118,13 @@ impl Lbfgs { let mut bracket_gtd; let mut bracket_l2; - let bracket_f; + let mut bracket_f; let (mut bracket, bracket_g) = loop { // check conditions - if f_new - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + l2_new - >= f_prev - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + l2_prev - { + if f_new + l2_new >= f_prev + l2_prev { bracket_gtd = [gtd_prev, gtd_new]; bracket_l2 = [l2_prev, l2_new]; - bracket_f = [f_prev, Var::from_tensor(f_new.as_tensor())?]; + bracket_f = [f_prev, f_new]; break ( [t_prev, step_size], [g_prev, Var::from_tensor(g_new.as_tensor())?], @@ -138,10 +135,7 @@ impl Lbfgs { done = true; bracket_gtd = [gtd_prev, gtd_new]; bracket_l2 = [l2_prev, l2_new]; - bracket_f = [ - Var::from_tensor(f_new.as_tensor())?, - Var::from_tensor(f_new.as_tensor())?, - ]; + bracket_f = [f_new, f_new]; break ( [step_size, step_size], [ @@ -154,7 +148,7 @@ impl Lbfgs { if gtd_new >= 0. { bracket_gtd = [gtd_prev, gtd_new]; bracket_l2 = [l2_prev, l2_new]; - bracket_f = [f_prev, Var::from_tensor(f_new.as_tensor())?]; + bracket_f = [f_prev, f_new]; break ( [t_prev, step_size], [g_prev, Var::from_tensor(g_new.as_tensor())?], @@ -167,23 +161,17 @@ impl Lbfgs { let tmp = step_size; step_size = cubic_interpolate( t_prev, - f_prev - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + l2_prev, + f_prev + l2_prev, gtd_prev, step_size, - f_new - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + l2_new, + f_new + l2_new, gtd_new, Some((min_step, max_step)), ); // next step t_prev = tmp; - f_prev.set(f_new.as_tensor())?; + f_prev = f_new; g_prev.set(g_new.as_tensor())?; l2_prev = l2_new; gtd_prev = gtd_new; @@ -191,7 +179,9 @@ impl Lbfgs { let (next_f, next_g, next_l2) = self.directional_evaluate(step_size, direction)?; // overwrite - f_new.set(&next_f)?; + f_new = next_f + .to_dtype(candle_core::DType::F64)? + .to_scalar::()?; g_new.set(&next_g)?; l2_new = next_l2; @@ -211,8 +201,8 @@ impl Lbfgs { bracket_gtd = [gtd_prev, gtd_new]; bracket_l2 = [l2_prev, l2_new]; bracket_f = [ - Var::from_tensor(loss)?, - Var::from_tensor(f_new.as_tensor())?, + loss.to_dtype(candle_core::DType::F64)?.to_scalar::()?, + f_new, ]; break ( [0., step_size], @@ -229,19 +219,12 @@ impl Lbfgs { // exact point satisfying the criteria let mut insuf_progress = false; // find high and low points in bracket - let (mut low_pos, mut high_pos) = if bracket_f[0] - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + bracket_l2[0] - <= bracket_f[1] - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + bracket_l2[1] - { - (0, 1) - } else { - (1, 0) - }; + let (mut low_pos, mut high_pos) = + if bracket_f[0] + bracket_l2[0] <= bracket_f[1] + bracket_l2[1] { + (0, 1) + } else { + (1, 0) + }; while !done && ls_iter < max_ls { // line-search bracket is so small if (bracket[1] - bracket[0]).abs() * d_norm < tolerance_change { @@ -251,16 +234,10 @@ impl Lbfgs { // compute new trial value step_size = cubic_interpolate( bracket[0], - bracket_f[0] - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + bracket_l2[0], + bracket_f[0] + bracket_l2[0], bracket_gtd[0], bracket[1], - bracket_f[1] - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + bracket_l2[1], + bracket_f[1] + bracket_l2[1], bracket_gtd[1], None, ); @@ -296,7 +273,9 @@ impl Lbfgs { // assign to temp vars: let (next_f, next_g, next_l2) = self.directional_evaluate(step_size, direction)?; // overwrite - f_new.set(&next_f)?; + f_new = next_f + .to_dtype(candle_core::DType::F64)? + .to_scalar::()?; g_new.set(&next_g)?; l2_new = next_l2; ls_func_evals += 1; @@ -310,42 +289,25 @@ impl Lbfgs { .to_scalar::()?; ls_iter += 1; - if f_new - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + l2_new + if f_new + l2_new > (loss.to_dtype(candle_core::DType::F64)?.to_scalar::()? + l2_init + c1 * step_size * directional_grad) - || f_new - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + l2_new - >= bracket_f[low_pos] - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + bracket_l2[low_pos] + || f_new + l2_new >= bracket_f[low_pos] + bracket_l2[low_pos] { // Armijo condition not satisfied or not lower than lowest point bracket[high_pos] = step_size; - bracket_f[high_pos].set(&f_new)?; - bracket_g[high_pos].set(g_new.as_tensor())?; + bracket_f[high_pos] = f_new; + let _ = bracket_g[high_pos].set(g_new.as_tensor()); bracket_l2[high_pos] = l2_new; bracket_gtd[high_pos] = gtd_new; - (low_pos, high_pos) = if bracket_f[0] - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + bracket_l2[0] - <= bracket_f[1] - .to_dtype(candle_core::DType::F64)? - .to_scalar::()? - + bracket_l2[1] - { - (0, 1) - } else { - (1, 0) - }; + (low_pos, high_pos) = + if bracket_f[0] + bracket_l2[0] <= bracket_f[1] + bracket_l2[1] { + (0, 1) + } else { + (1, 0) + }; } else { if gtd_new.abs() <= -c2 * directional_grad { // Wolfe conditions satisfied @@ -353,7 +315,7 @@ impl Lbfgs { } else if gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0. { // old low becomes new high bracket[high_pos] = bracket[low_pos]; - bracket_f[high_pos].set(bracket_f[low_pos].as_tensor())?; + bracket_f[high_pos] = bracket_f[low_pos]; bracket_g[high_pos].set(bracket_g[low_pos].as_tensor())?; bracket_gtd[high_pos] = bracket_gtd[low_pos]; bracket_l2[high_pos] = bracket_l2[low_pos]; @@ -361,7 +323,7 @@ impl Lbfgs { // new point becomes new low bracket[low_pos] = step_size; - bracket_f[low_pos].set(f_new.as_tensor())?; + bracket_f[low_pos] = f_new; bracket_g[low_pos].set(g_new.as_tensor())?; bracket_gtd[low_pos] = gtd_new; bracket_l2[low_pos] = l2_new; @@ -374,9 +336,19 @@ impl Lbfgs { let [f0, f1] = bracket_f; if low_pos == 1 { // if b is the lower value set a to b, else a should be returned - Ok((f1.into_inner(), g1.into_inner(), step_size, ls_func_evals)) + Ok(( + Tensor::from_slice(&[f1], shape, &dev)?.to_dtype(dtype)?, + g1.into_inner(), + step_size, + ls_func_evals, + )) } else { - Ok((f0.into_inner(), g0.into_inner(), step_size, ls_func_evals)) + Ok(( + Tensor::from_slice(&[f0], shape, &dev)?.to_dtype(dtype)?, + g0.into_inner(), + step_size, + ls_func_evals, + )) } } From 916f0ba407b081933de5c93c85b482dfc52bf4e5 Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Sat, 4 May 2024 21:23:26 +0100 Subject: [PATCH 2/6] fix lbfgs --- Cargo.toml | 6 +++--- src/lbfgs.rs | 8 +++----- src/lbfgs/strong_wolfe.rs | 33 +++++++++++++++------------------ 3 files changed, 21 insertions(+), 26 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 85dca5d..7b135a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,15 +17,15 @@ exclude = [ [dependencies] -candle-core = "0.4.0" -candle-nn = "0.4.0" +candle-core = "0.5.0" +candle-nn = "0.5.0" log = "0.4.20" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } assert_approx_eq = "1.1.0" -candle-datasets = "0.4.0" +candle-datasets = "0.5.0" clap = {version = "4.4.6", features = ["derive"] } criterion = { version = "0.5.1", features = ["html_reports"] } diff --git a/src/lbfgs.rs b/src/lbfgs.rs index e5ff132..b6a1b1f 100644 --- a/src/lbfgs.rs +++ b/src/lbfgs.rs @@ -153,7 +153,7 @@ impl LossOptimizer for Lbfgs { let mut evals = 1; let grad = if let Some(this_grad) = &self.next_grad { - this_grad.as_tensor().clone() + this_grad.as_tensor().copy()? } else { flat_grads(&self.vars, loss, self.params.weight_decay)? }; @@ -302,10 +302,8 @@ impl LossOptimizer for Lbfgs { if let Some(ls) = &self.params.line_search { match ls { LineSearch::StrongWolfe(c1, c2, tol) => { - let (loss, grad, t, steps) = self.strong_wolfe( - lr, &q, loss, //.to_dtype(candle_core::DType::F64)?.to_scalar()? - &grad, dd, *c1, *c2, *tol, 25, - )?; + let (loss, grad, t, steps) = + self.strong_wolfe(lr, &q, loss, &grad, dd, *c1, *c2, *tol, 25)?; if let Some(next_grad) = &self.next_grad { next_grad.set(&grad)?; } else { diff --git a/src/lbfgs/strong_wolfe.rs b/src/lbfgs/strong_wolfe.rs index 7401936..7a65b9d 100644 --- a/src/lbfgs/strong_wolfe.rs +++ b/src/lbfgs/strong_wolfe.rs @@ -83,6 +83,10 @@ impl Lbfgs { ) -> CResult<(Tensor, Tensor, f64, usize)> { // ported from https://github.com/torch/optim/blob/master/lswolfe.lua + let dtype = loss.dtype(); + let shape = loss.shape(); + let dev = loss.device(); + let d_norm = &direction .abs()? .max(0)? @@ -105,11 +109,10 @@ impl Lbfgs { .to_scalar::()?; // bracket an interval containing a point satisfying the Wolfe criteria - let g_prev = Var::from_tensor(grad)?; - let dtype = loss.dtype(); - let shape = loss.shape(); - let dev = loss.device(); - let mut f_prev = loss.to_dtype(candle_core::DType::F64)?.to_scalar::()?; + let grad_det = grad.copy()?; + let g_prev = Var::from_tensor(&grad_det)?; + let scalar_loss = loss.to_dtype(candle_core::DType::F64)?.to_scalar::()?; + let mut f_prev = scalar_loss; let l2_init = self.l2_reg()?; let mut l2_prev = l2_init; let (mut t_prev, mut gtd_prev) = (0., directional_grad); @@ -139,7 +142,7 @@ impl Lbfgs { break ( [step_size, step_size], [ - Var::from_tensor(g_new.as_tensor())?, + Var::from_tensor(&g_new.as_tensor().copy()?)?, Var::from_tensor(g_new.as_tensor())?, ], ); @@ -200,10 +203,7 @@ impl Lbfgs { if ls_iter == max_ls { bracket_gtd = [gtd_prev, gtd_new]; bracket_l2 = [l2_prev, l2_new]; - bracket_f = [ - loss.to_dtype(candle_core::DType::F64)?.to_scalar::()?, - f_new, - ]; + bracket_f = [scalar_loss, f_new]; break ( [0., step_size], [ @@ -276,11 +276,11 @@ impl Lbfgs { f_new = next_f .to_dtype(candle_core::DType::F64)? .to_scalar::()?; - g_new.set(&next_g)?; + l2_new = next_l2; ls_func_evals += 1; - gtd_new = g_new + gtd_new = next_g .unsqueeze(0)? .matmul(&(direction.unsqueeze(1)?))? .to_dtype(candle_core::DType::F64)? @@ -289,16 +289,13 @@ impl Lbfgs { .to_scalar::()?; ls_iter += 1; - if f_new + l2_new - > (loss.to_dtype(candle_core::DType::F64)?.to_scalar::()? - + l2_init - + c1 * step_size * directional_grad) + if f_new + l2_new > (scalar_loss + l2_init + c1 * step_size * directional_grad) || f_new + l2_new >= bracket_f[low_pos] + bracket_l2[low_pos] { // Armijo condition not satisfied or not lower than lowest point bracket[high_pos] = step_size; bracket_f[high_pos] = f_new; - let _ = bracket_g[high_pos].set(g_new.as_tensor()); + bracket_g[high_pos].set(&next_g)?; bracket_l2[high_pos] = l2_new; bracket_gtd[high_pos] = gtd_new; @@ -324,7 +321,7 @@ impl Lbfgs { // new point becomes new low bracket[low_pos] = step_size; bracket_f[low_pos] = f_new; - bracket_g[low_pos].set(g_new.as_tensor())?; + bracket_g[low_pos].set(&next_g)?; bracket_gtd[low_pos] = gtd_new; bracket_l2[low_pos] = l2_new; } From f4c208f1cc9bc3a383f7a765da915866380824f9 Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Sat, 4 May 2024 21:25:03 +0100 Subject: [PATCH 3/6] clippy lint setup --- Cargo.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7b135a0..b0df402 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,12 +41,12 @@ cuda = ["candle-core/cuda", "candle-nn/cuda"] lto = true # maximal LTO optimisaiton [lints.clippy] -pedantic = "warn" -suspicious = "warn" -perf = "warn" -complexity = "warn" -style = "warn" -cargo = "warn" +pedantic = {level = "warn", priority = -1} +suspicious = {level = "warn", priority = -1} +perf = {level = "warn", priority = -1} +complexity = {level = "warn", priority = -1} +style = {level = "warn", priority = -1} +cargo = {level = "warn", priority = -1} imprecise_flops = "warn" missing_errors_doc = {level = "allow", priority = 1} uninlined_format_args = {level = "allow", priority = 1} From 624d9746b65803a54390abaedc22b27b107b27b2 Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Sat, 4 May 2024 21:25:31 +0100 Subject: [PATCH 4/6] lint --- src/lbfgs/strong_wolfe.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lbfgs/strong_wolfe.rs b/src/lbfgs/strong_wolfe.rs index 7a65b9d..09dc24f 100644 --- a/src/lbfgs/strong_wolfe.rs +++ b/src/lbfgs/strong_wolfe.rs @@ -334,14 +334,14 @@ impl Lbfgs { if low_pos == 1 { // if b is the lower value set a to b, else a should be returned Ok(( - Tensor::from_slice(&[f1], shape, &dev)?.to_dtype(dtype)?, + Tensor::from_slice(&[f1], shape, dev)?.to_dtype(dtype)?, g1.into_inner(), step_size, ls_func_evals, )) } else { Ok(( - Tensor::from_slice(&[f0], shape, &dev)?.to_dtype(dtype)?, + Tensor::from_slice(&[f0], shape, dev)?.to_dtype(dtype)?, g0.into_inner(), step_size, ls_func_evals, From 4e7fa71650e3d9807b270f16e7509bded0cab47f Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Sat, 4 May 2024 21:26:48 +0100 Subject: [PATCH 5/6] changelog --- Changelog.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Changelog.md b/Changelog.md index ef58f5f..bf33e6f 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,10 @@ # Changelog +## v0.5.0 (2024-02-28) + +* Bump candle requirtement to 0.5.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn +* Internal changes for LBFGS line search + ## v0.4.0 (2024-02-28) * Bump candle requirtement to 0.4.0: this is considered a breaking change due to the reliance of this library on candle-core and candle-nn From 772f45df709cb6b7abac0940ec4ee36081aa7374 Mon Sep 17 00:00:00 2001 From: Kirpal Grewal Date: Sat, 4 May 2024 21:28:37 +0100 Subject: [PATCH 6/6] bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index b0df402..f57e134 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-optimisers" -version = "0.4.0" +version = "0.5.0" edition = "2021" readme = "README.md" license = "MIT"