From da1b10d159630bc83bd25eb2f981adeadb0fafc4 Mon Sep 17 00:00:00 2001 From: Pratik Fandade Date: Fri, 25 Oct 2024 20:13:34 -0400 Subject: [PATCH] Fixing the test cases and minor adjustment to the algorithm --- src/machine_learning/logistic_regression.rs | 54 ++++++++++++++++----- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/src/machine_learning/logistic_regression.rs b/src/machine_learning/logistic_regression.rs index cae85814671..fc020a795ac 100644 --- a/src/machine_learning/logistic_regression.rs +++ b/src/machine_learning/logistic_regression.rs @@ -11,7 +11,7 @@ pub fn logistic_regression( return None; } - let num_features = data_points[0].0.len(); + let num_features = data_points[0].0.len() + 1; let mut params = vec![0.0; num_features]; let derivative_fn = |params: &[f64]| derivative(params, &data_points); @@ -26,11 +26,17 @@ fn derivative(params: &[f64], data_points: &[(Vec, f64)]) -> Vec { let mut gradients = vec![0.0; num_features]; for (features, y_i) in data_points { - let z = params.iter().zip(features).map(|(p, x)| p * x).sum::(); + let z = params[0] + + params[1..] + .iter() + .zip(features) + .map(|(p, x)| p * x) + .sum::(); let prediction = 1.0 / (1.0 + E.powf(-z)); + gradients[0] += prediction - y_i; for (i, x_i) in features.iter().enumerate() { - gradients[i] += (prediction - y_i) * x_i; + gradients[i + 1] += (prediction - y_i) * x_i; } } @@ -42,21 +48,45 @@ mod test { use super::*; #[test] - fn test_logistic_regression() { + fn test_logistic_regression_simple() { let data = vec![ - (vec![0.0, 0.0], 0.0), - (vec![1.0, 1.0], 1.0), - (vec![2.0, 2.0], 1.0), + (vec![0.0], 0.0), + (vec![1.0], 0.0), + (vec![2.0], 0.0), + (vec![3.0], 1.0), + (vec![4.0], 1.0), + (vec![5.0], 1.0), ]; - let result = logistic_regression(data, 10000, 0.1); + + let result = logistic_regression(data, 10000, 0.05); assert!(result.is_some()); + + let params = result.unwrap(); + assert!((params[0] + 17.65).abs() < 1.0); + assert!((params[1] - 7.13).abs() < 1.0); + } + + #[test] + fn test_logistic_regression_extreme_data() { + let data = vec![ + (vec![-100.0], 0.0), + (vec![-10.0], 0.0), + (vec![0.0], 0.0), + (vec![10.0], 1.0), + (vec![100.0], 1.0), + ]; + + let result = logistic_regression(data, 10000, 0.05); + assert!(result.is_some()); + let params = result.unwrap(); - assert!((params[0] - 6.902976808251308).abs() < 1e-6); - assert!((params[1] - 2000.4659358334482).abs() < 1e-6); + assert!((params[0] + 6.20).abs() < 1.0); + assert!((params[1] - 5.5).abs() < 1.0); } #[test] - fn test_empty_list_logistic_regression() { - assert_eq!(logistic_regression(vec![], 10000, 0.1), None); + fn test_logistic_regression_no_data() { + let result = logistic_regression(vec![], 5000, 0.1); + assert_eq!(result, None); } }