Skip to content

Commit

Permalink
Fixing the test cases and minor adjustment to the algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
prkbuilds committed Oct 26, 2024
1 parent 21850fe commit e467474
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions src/machine_learning/logistic_regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub fn logistic_regression(
return None;
}

let num_features = data_points[0].0.len();
let num_features = data_points[0].0.len() + 1;
let mut params = vec![0.0; num_features];

let derivative_fn = |params: &[f64]| derivative(params, &data_points);
Expand All @@ -26,11 +26,12 @@ fn derivative(params: &[f64], data_points: &[(Vec<f64>, f64)]) -> Vec<f64> {
let mut gradients = vec![0.0; num_features];

for (features, y_i) in data_points {
let z = params.iter().zip(features).map(|(p, x)| p * x).sum::<f64>();
let z = params[0] + params[1..].iter().zip(features).map(|(p, x)| p * x).sum::<f64>();
let prediction = 1.0 / (1.0 + E.powf(-z));

gradients[0] += prediction - y_i;
for (i, x_i) in features.iter().enumerate() {
gradients[i] += (prediction - y_i) * x_i;
gradients[i + 1] += (prediction - y_i) * x_i;
}
}

Expand All @@ -42,21 +43,46 @@ mod test {
use super::*;

#[test]
fn test_logistic_regression() {
fn test_logistic_regression_simple() {
let data = vec![
(vec![0.0, 0.0], 0.0),
(vec![1.0, 1.0], 1.0),
(vec![2.0, 2.0], 1.0),
(vec![0.0], 0.0),
(vec![1.0], 0.0),
(vec![2.0], 0.0),
(vec![3.0], 1.0),
(vec![4.0], 1.0),
(vec![5.0], 1.0),
];
let result = logistic_regression(data, 10000, 0.1);

let result = logistic_regression(data, 10000, 0.05);
assert!(result.is_some());

let params = result.unwrap();
assert!((params[0] + 17.65).abs() < 1.0);
assert!((params[1] - 7.13).abs() < 1.0);
}

#[test]
fn test_logistic_regression_extreme_data() {
let data = vec![
(vec![-100.0], 0.0),
(vec![-10.0], 0.0),
(vec![0.0], 0.0),
(vec![10.0], 1.0),
(vec![100.0], 1.0),
];

let result = logistic_regression(data, 10000, 0.05);
assert!(result.is_some());

let params = result.unwrap();
assert!((params[0] - 6.902976808251308).abs() < 1e-6);
assert!((params[1] - 2000.4659358334482).abs() < 1e-6);
assert!((params[0] + 6.20).abs() < 1.0);
assert!((params[1] - 5.5).abs() < 1.0);
}

#[test]
fn test_empty_list_logistic_regression() {
assert_eq!(logistic_regression(vec![], 10000, 0.1), None);
fn test_logistic_regression_no_data() {
let result = logistic_regression(vec![], 5000, 0.1);
assert_eq!(result, None);
}
}

0 comments on commit e467474

Please sign in to comment.