diff --git a/bnpm/linear_regression.py b/bnpm/linear_regression.py index ca6591e..13bf8c6 100644 --- a/bnpm/linear_regression.py +++ b/bnpm/linear_regression.py @@ -367,6 +367,11 @@ def fit(self, X, y): """ self.n_features_in_ = X.shape[1] + ## Give a UserWarning if n_features >= n_samples + if X.shape[1] >= X.shape[0]: + import warnings + warnings.warn('OLS solution is expected to diverge from sklearn solution when n_features >= n_samples') + ns = self.get_backend_namespace(X=X, y=y) zeros = ns['zeros']