diff --git a/mne/decoding/base.py b/mne/decoding/base.py index a8e457137da..85ed102b514 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -11,6 +11,7 @@ from sklearn import model_selection as models from sklearn.base import ( # noqa: F401 BaseEstimator, + MetaEstimatorMixin, TransformerMixin, clone, is_classifier, @@ -24,7 +25,7 @@ from ..utils import _pl, logger, verbose, warn -class LinearModel(BaseEstimator): +class LinearModel(MetaEstimatorMixin, BaseEstimator): """Compute and store patterns from linear models. The linear model coefficients (filters) are used to extract discriminant @@ -61,11 +62,14 @@ class LinearModel(BaseEstimator): .. footbibliography:: """ + # TODO: Properly refactor this using + # https://github.com/scikit-learn/scikit-learn/issues/30237#issuecomment-2465572885 _model_attr_wrap = ( "transform", "predict", "predict_proba", "_estimator_type", + "__tags__", "decision_function", "score", "classes_", @@ -77,6 +81,12 @@ def __init__(self, model=None): self.model = model + def __sklearn_tags__(self): + """Get sklearn tags.""" + from sklearn.utils import get_tags # added in 1.6 + + return get_tags(self.model) + def __getattr__(self, attr): """Wrap to model for some attributes.""" if attr in LinearModel._model_attr_wrap: diff --git a/mne/decoding/ems.py b/mne/decoding/ems.py index f6811de460d..911b25e6692 100644 --- a/mne/decoding/ems.py +++ b/mne/decoding/ems.py @@ -13,7 +13,7 @@ from .base import _set_cv -class EMS(BaseEstimator, TransformerMixin): +class EMS(TransformerMixin, BaseEstimator): """Transformer to compute event-matched spatial filters. This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire diff --git a/mne/decoding/receptive_field.py b/mne/decoding/receptive_field.py index 7cdbec64ff0..99412cf56b7 100644 --- a/mne/decoding/receptive_field.py +++ b/mne/decoding/receptive_field.py @@ -6,17 +6,22 @@ import numpy as np from scipy.stats import pearsonr -from sklearn.base import clone, is_regressor +from sklearn.base import ( + BaseEstimator, + MetaEstimatorMixin, + clone, + is_regressor, +) from sklearn.exceptions import NotFittedError from sklearn.metrics import r2_score from ..utils import _validate_type, fill_doc, pinv -from .base import BaseEstimator, _check_estimator, get_coef +from .base import _check_estimator, get_coef from .time_delaying_ridge import TimeDelayingRidge @fill_doc -class ReceptiveField(BaseEstimator): +class ReceptiveField(MetaEstimatorMixin, BaseEstimator): """Fit a receptive field model. This allows you to fit an encoding model (stimulus to brain) or a decoding diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index 64f38a60634..e3059a3e959 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -5,7 +5,7 @@ import logging import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin, clone +from sklearn.base import BaseEstimator, MetaEstimatorMixin, TransformerMixin, clone from sklearn.metrics import check_scoring from sklearn.preprocessing import LabelEncoder from sklearn.utils import check_array @@ -16,7 +16,7 @@ @fill_doc -class SlidingEstimator(BaseEstimator, TransformerMixin): +class SlidingEstimator(MetaEstimatorMixin, TransformerMixin, BaseEstimator): """Search Light. Fit, predict and score a series of models to each subset of the dataset @@ -61,6 +61,20 @@ def __init__( def _estimator_type(self): return getattr(self.base_estimator, "_estimator_type", None) + def __sklearn_tags__(self): + """Get sklearn tags.""" + from sklearn.utils import get_tags + + tags = super().__sklearn_tags__() + sub_tags = get_tags(self.base_estimator) + tags.estimator_type = sub_tags.estimator_type + for kind in ("classifier", "regressor", "transformer"): + if tags.estimator_type == kind: + attr = f"{kind}_tags" + setattr(tags, attr, getattr(sub_tags, attr)) + break + return tags + def __repr__(self): # noqa: D105 repr_str = "<" + super().__repr__() if hasattr(self, "estimators_"): diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 4043aa99835..573f21862bf 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -23,7 +23,7 @@ @fill_doc -class SSD(BaseEstimator, TransformerMixin): +class SSD(TransformerMixin, BaseEstimator): """ Signal decomposition using the Spatio-Spectral Decomposition (SSD). diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 0930d007d28..4ec6ed4d281 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -94,10 +94,18 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3): def test_get_coef(): """Test getting linear coefficients (filters/patterns) from estimators.""" lm_classification = LinearModel() + assert hasattr(lm_classification, "__sklearn_tags__") + print(lm_classification.__sklearn_tags__) + assert is_classifier(lm_classification.model) assert is_classifier(lm_classification) + assert not is_regressor(lm_classification.model) + assert not is_regressor(lm_classification) lm_regression = LinearModel(Ridge()) + assert is_regressor(lm_regression.model) assert is_regressor(lm_regression) + assert not is_classifier(lm_regression.model) + assert not is_classifier(lm_regression) parameters = {"kernel": ["linear"], "C": [1, 10]} lm_gs_classification = LinearModel( @@ -433,7 +441,8 @@ def test_cross_val_multiscore(): # raise an error if scoring is defined at cross-val-score level and # search light, because search light does not return a 1-dimensional # prediction. - pytest.raises(ValueError, cross_val_multiscore, clf, X, y, cv=cv, scoring="roc_auc") + with pytest.raises(ValueError, match="multi_class must be"): + cross_val_multiscore(clf, X, y, cv=cv, scoring="roc_auc", n_jobs=1) clf = SlidingEstimator(logreg, scoring="roc_auc") scores_auc = cross_val_multiscore(clf, X, y, cv=cv, n_jobs=None) scores_auc_manual = list() diff --git a/mne/decoding/tests/test_receptive_field.py b/mne/decoding/tests/test_receptive_field.py index 5afe7bcdc25..35e65c051f7 100644 --- a/mne/decoding/tests/test_receptive_field.py +++ b/mne/decoding/tests/test_receptive_field.py @@ -174,6 +174,7 @@ def test_time_delay(): @pytest.mark.slowtest # slow on Azure @pytest.mark.parametrize("n_jobs", n_jobs_test) +@pytest.mark.filterwarnings("ignore:Estimator .* has no __sklearn_tags__.*") def test_receptive_field_basic(n_jobs): """Test model prep and fitting.""" # Make sure estimator pulling works diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 9e15a1df59b..7cb3a66dd81 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -56,6 +56,7 @@ def test_search_light(): sl = SlidingEstimator(Ridge()) assert not is_classifier(sl) sl = SlidingEstimator(LogisticRegression(solver="liblinear")) + assert is_classifier(sl.base_estimator) assert is_classifier(sl) # fit assert_equal(sl.__repr__()[:18], "