Skip to content

Commit

Permalink
add test for l2 reg
Browse files Browse the repository at this point in the history
  • Loading branch information
KGrewal1 committed Jan 7, 2024
1 parent 054735f commit b874db1
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions src/lbfgs/strong_wolfe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,75 @@ impl<M: Model> Lbfgs<M> {
}
}
}

#[cfg(test)]
mod tests {
// use candle_core::test_utils::{to_vec0_round, to_vec2_round};

use crate::lbfgs::ParamsLBFGS;
use crate::{LossOptimizer, Model};
use anyhow::Result;
use assert_approx_eq::assert_approx_eq;
use candle_core::Device;
use candle_core::{Module, Result as CResult};
pub struct LinearModel {
linear: candle_nn::Linear,
xs: Tensor,
ys: Tensor,
}

impl Model for LinearModel {
fn loss(&self) -> CResult<Tensor> {
let preds = self.forward(&self.xs)?;
let loss = candle_nn::loss::mse(&preds, &self.ys)?;
Ok(loss)
}
}

impl LinearModel {
fn new() -> CResult<(Self, Vec<Var>)> {
let weight = Var::from_tensor(&Tensor::new(&[3f64, 1.], &Device::Cpu)?)?;
let bias = Var::from_tensor(&Tensor::new(-2f64, &Device::Cpu)?)?;

let linear =
candle_nn::Linear::new(weight.as_tensor().clone(), Some(bias.as_tensor().clone()));

Ok((
Self {
linear,
xs: Tensor::new(&[[2f64, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?,
ys: Tensor::new(&[[7f64], [26.], [0.], [27.]], &Device::Cpu)?,
},
vec![weight, bias],
))
}

fn forward(&self, xs: &Tensor) -> CResult<Tensor> {
self.linear.forward(xs)
}
}

use super::*;
#[test]
fn l2_test() -> Result<()> {
let params = ParamsLBFGS {
lr: 0.004,
..Default::default()
};
let (model, vars) = LinearModel::new()?;
let lbfgs = Lbfgs::new(vars, params, model)?;
let l2 = lbfgs.l2_reg()?;
assert_approx_eq!(0.0, l2);

let params = ParamsLBFGS {
lr: 0.004,
weight_decay: Some(1.0),
..Default::default()
};
let (model, vars) = LinearModel::new()?;
let lbfgs = Lbfgs::new(vars, params, model)?;
let l2 = lbfgs.l2_reg()?;
assert_approx_eq!(7.0, l2); // 0.5 *(3^2 +1^2 + (-2)^2)
Ok(())
}
}

0 comments on commit b874db1

Please sign in to comment.