Skip to content

Commit

Permalink
MAINT: Fixes for latest sklearn (#12951)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored Nov 11, 2024
1 parent ecc620d commit cd6dd5f
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 14 deletions.
12 changes: 11 additions & 1 deletion mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sklearn import model_selection as models
from sklearn.base import ( # noqa: F401
BaseEstimator,
MetaEstimatorMixin,
TransformerMixin,
clone,
is_classifier,
Expand All @@ -24,7 +25,7 @@
from ..utils import _pl, logger, verbose, warn


class LinearModel(BaseEstimator):
class LinearModel(MetaEstimatorMixin, BaseEstimator):
"""Compute and store patterns from linear models.
The linear model coefficients (filters) are used to extract discriminant
Expand Down Expand Up @@ -61,11 +62,14 @@ class LinearModel(BaseEstimator):
.. footbibliography::
"""

# TODO: Properly refactor this using
# https://github.com/scikit-learn/scikit-learn/issues/30237#issuecomment-2465572885
_model_attr_wrap = (
"transform",
"predict",
"predict_proba",
"_estimator_type",
"__tags__",
"decision_function",
"score",
"classes_",
Expand All @@ -77,6 +81,12 @@ def __init__(self, model=None):

self.model = model

def __sklearn_tags__(self):
"""Get sklearn tags."""
from sklearn.utils import get_tags # added in 1.6

return get_tags(self.model)

def __getattr__(self, attr):
"""Wrap to model for some attributes."""
if attr in LinearModel._model_attr_wrap:
Expand Down
2 changes: 1 addition & 1 deletion mne/decoding/ems.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .base import _set_cv


class EMS(BaseEstimator, TransformerMixin):
class EMS(TransformerMixin, BaseEstimator):
"""Transformer to compute event-matched spatial filters.
This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire
Expand Down
11 changes: 8 additions & 3 deletions mne/decoding/receptive_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@

import numpy as np
from scipy.stats import pearsonr
from sklearn.base import clone, is_regressor
from sklearn.base import (
BaseEstimator,
MetaEstimatorMixin,
clone,
is_regressor,
)
from sklearn.exceptions import NotFittedError
from sklearn.metrics import r2_score

from ..utils import _validate_type, fill_doc, pinv
from .base import BaseEstimator, _check_estimator, get_coef
from .base import _check_estimator, get_coef
from .time_delaying_ridge import TimeDelayingRidge


@fill_doc
class ReceptiveField(BaseEstimator):
class ReceptiveField(MetaEstimatorMixin, BaseEstimator):
"""Fit a receptive field model.
This allows you to fit an encoding model (stimulus to brain) or a decoding
Expand Down
18 changes: 16 additions & 2 deletions mne/decoding/search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.base import BaseEstimator, MetaEstimatorMixin, TransformerMixin, clone
from sklearn.metrics import check_scoring
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import check_array
Expand All @@ -16,7 +16,7 @@


@fill_doc
class SlidingEstimator(BaseEstimator, TransformerMixin):
class SlidingEstimator(MetaEstimatorMixin, TransformerMixin, BaseEstimator):
"""Search Light.
Fit, predict and score a series of models to each subset of the dataset
Expand Down Expand Up @@ -61,6 +61,20 @@ def __init__(
def _estimator_type(self):
return getattr(self.base_estimator, "_estimator_type", None)

def __sklearn_tags__(self):
"""Get sklearn tags."""
from sklearn.utils import get_tags

tags = super().__sklearn_tags__()
sub_tags = get_tags(self.base_estimator)
tags.estimator_type = sub_tags.estimator_type
for kind in ("classifier", "regressor", "transformer"):
if tags.estimator_type == kind:
attr = f"{kind}_tags"
setattr(tags, attr, getattr(sub_tags, attr))
break
return tags

def __repr__(self): # noqa: D105
repr_str = "<" + super().__repr__()
if hasattr(self, "estimators_"):
Expand Down
2 changes: 1 addition & 1 deletion mne/decoding/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@fill_doc
class SSD(BaseEstimator, TransformerMixin):
class SSD(TransformerMixin, BaseEstimator):
"""
Signal decomposition using the Spatio-Spectral Decomposition (SSD).
Expand Down
11 changes: 10 additions & 1 deletion mne/decoding/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,18 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3):
def test_get_coef():
"""Test getting linear coefficients (filters/patterns) from estimators."""
lm_classification = LinearModel()
assert hasattr(lm_classification, "__sklearn_tags__")
print(lm_classification.__sklearn_tags__)
assert is_classifier(lm_classification.model)
assert is_classifier(lm_classification)
assert not is_regressor(lm_classification.model)
assert not is_regressor(lm_classification)

lm_regression = LinearModel(Ridge())
assert is_regressor(lm_regression.model)
assert is_regressor(lm_regression)
assert not is_classifier(lm_regression.model)
assert not is_classifier(lm_regression)

parameters = {"kernel": ["linear"], "C": [1, 10]}
lm_gs_classification = LinearModel(
Expand Down Expand Up @@ -433,7 +441,8 @@ def test_cross_val_multiscore():
# raise an error if scoring is defined at cross-val-score level and
# search light, because search light does not return a 1-dimensional
# prediction.
pytest.raises(ValueError, cross_val_multiscore, clf, X, y, cv=cv, scoring="roc_auc")
with pytest.raises(ValueError, match="multi_class must be"):
cross_val_multiscore(clf, X, y, cv=cv, scoring="roc_auc", n_jobs=1)
clf = SlidingEstimator(logreg, scoring="roc_auc")
scores_auc = cross_val_multiscore(clf, X, y, cv=cv, n_jobs=None)
scores_auc_manual = list()
Expand Down
1 change: 1 addition & 0 deletions mne/decoding/tests/test_receptive_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_time_delay():

@pytest.mark.slowtest # slow on Azure
@pytest.mark.parametrize("n_jobs", n_jobs_test)
@pytest.mark.filterwarnings("ignore:Estimator .* has no __sklearn_tags__.*")
def test_receptive_field_basic(n_jobs):
"""Test model prep and fitting."""
# Make sure estimator pulling works
Expand Down
1 change: 1 addition & 0 deletions mne/decoding/tests/test_search_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_search_light():
sl = SlidingEstimator(Ridge())
assert not is_classifier(sl)
sl = SlidingEstimator(LogisticRegression(solver="liblinear"))
assert is_classifier(sl.base_estimator)
assert is_classifier(sl)
# fit
assert_equal(sl.__repr__()[:18], "<SlidingEstimator(")
Expand Down
4 changes: 2 additions & 2 deletions mne/decoding/time_delaying_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from scipy import linalg
from scipy.signal import fftconvolve
from scipy.sparse.csgraph import laplacian
from sklearn.base import BaseEstimator, RegressorMixin

from ..cuda import _setup_cuda_fft_multiply_repeated
from ..filter import next_fast_len
from ..fixes import jit
from ..utils import ProgressBar, _check_option, _validate_type, logger, warn
from .base import BaseEstimator


def _compute_corrs(
Expand Down Expand Up @@ -226,7 +226,7 @@ def _fit_corrs(x_xt, x_y, n_ch_x, reg_type, alpha, n_ch_in):
return w


class TimeDelayingRidge(BaseEstimator):
class TimeDelayingRidge(RegressorMixin, BaseEstimator):
"""Ridge regression of data with time delays.
Parameters
Expand Down
12 changes: 12 additions & 0 deletions mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,18 @@ def empirical_covariance(X, assume_centered=False):


class _EstimatorMixin:
def __sklearn_tags__(self):
# If we get here, we should have sklearn installed
from sklearn.utils import Tags, TargetTags

return Tags(
estimator_type=None,
target_tags=TargetTags(required=False),
transformer_tags=None,
regressor_tags=None,
classifier_tags=None,
)

def _param_names(self):
return inspect.getfullargspec(self.__init__).args[1:]

Expand Down
2 changes: 1 addition & 1 deletion mne/preprocessing/xdawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _least_square_evoked(epochs_data, events, tmin, sfreq):
sel = events[:, 2] == this_class

# build toeplitz matrix
trig = np.zeros((n_samples, 1))
trig = np.zeros((n_samples,))
ix_trig = (events[sel, 0]) + n_min
trig[ix_trig] = 1
toeplitz.append(linalg.toeplitz(trig[0:window], trig))
Expand Down
7 changes: 5 additions & 2 deletions tools/install_pre_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ python -m pip install $STD_ARGS git+https://github.com/nilearn/nilearn

echo "VTK"
# No pre until PyVista fixes a bug
# python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://wheels.vtk.org" vtk
python -m pip install $STD_ARGS vtk
if [[ "${PLATFORM}" == "Windows" ]]; then
python -m pip install $STD_ARGS "vtk<9.4" # 9.4 requires GLSL 1.5 and Azure win only has 1.3
else
python -m pip install $STD_ARGS --only-binary ":all:" --extra-index-url "https://wheels.vtk.org" vtk
fi
python -c "import vtk"

echo "PyVista"
Expand Down

0 comments on commit cd6dd5f

Please sign in to comment.