Skip to content

Commit

Permalink
Add UserWarning for OLS solution divergence when n_features >= n_samples
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Apr 14, 2024
1 parent fec8edc commit bb27bc7
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions bnpm/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,11 @@ def fit(self, X, y):
"""
self.n_features_in_ = X.shape[1]

## Give a UserWarning if n_features >= n_samples
if X.shape[1] >= X.shape[0]:
import warnings
warnings.warn('OLS solution is expected to diverge from sklearn solution when n_features >= n_samples')

ns = self.get_backend_namespace(X=X, y=y)
zeros = ns['zeros']

Expand Down

0 comments on commit bb27bc7

Please sign in to comment.