From bb27bc776a3c7a977d745bc981a1d152eaac964f Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Sun, 14 Apr 2024 16:22:07 -0400 Subject: [PATCH] Add UserWarning for OLS solution divergence when n_features >= n_samples --- bnpm/linear_regression.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bnpm/linear_regression.py b/bnpm/linear_regression.py index ca6591e..13bf8c6 100644 --- a/bnpm/linear_regression.py +++ b/bnpm/linear_regression.py @@ -367,6 +367,11 @@ def fit(self, X, y): """ self.n_features_in_ = X.shape[1] + ## Give a UserWarning if n_features >= n_samples + if X.shape[1] >= X.shape[0]: + import warnings + warnings.warn('OLS solution is expected to diverge from sklearn solution when n_features >= n_samples') + ns = self.get_backend_namespace(X=X, y=y) zeros = ns['zeros']