Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend functionality of Sequential Feature Selector to allow repeating cross-validation. #272

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions mlxtend/feature_selection/sequential_feature_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
from sklearn.base import BaseEstimator
from sklearn.base import MetaEstimatorMixin
from ..externals.name_estimators import _name_estimators
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_score, RepeatedKFold, RepeatedStratifiedKFold
from sklearn.externals.joblib import Parallel, delayed


def _calc_score(selector, X, y, indices):
if selector.cv:
if selector.n_cv_repeats > 0:
if selector._estimator_type == 'classifier':
cv_folds = RepeatedStratifiedKFold(n_splits=selector.cv, n_repeats=selector.n_cv_repeats)
else:
cv_folds = RepeatedKFold(n_splits=selector.cv, n_repeats=selector.n_cv_repeats)
scores = cross_val_score(selector.est_,
X[:, indices], y,
cv=selector.cv,
cv=selector.cv if selector.n_cv_repeats == 0 else cv_folds,
scoring=selector.scorer,
n_jobs=1,
pre_dispatch=selector.pre_dispatch)
Expand Down Expand Up @@ -103,6 +107,11 @@ class SequentialFeatureSelector(BaseEstimator, MetaEstimatorMixin):
if False. Set to False if the estimator doesn't
implement scikit-learn's set_params and get_params methods.
In addition, it is required to set cv=0, and n_jobs=1.
n_cv_repeats : int (default = 0)
The number of times cross-validation will be repeated. If 0 then it's
not repeated. Negative numbers raise an exception. Uses Scikit-learn
RepeatedStratifiedKFold for a classifier or RepeatedKFold otherwise.


Attributes
----------
Expand All @@ -125,7 +134,8 @@ def __init__(self, estimator, k_features=1,
verbose=0, scoring=None,
cv=5, n_jobs=1,
pre_dispatch='2*n_jobs',
clone_estimator=True):
clone_estimator=True,
n_cv_repeats = 0):

self.estimator = estimator
self.k_features = k_features
Expand All @@ -149,6 +159,13 @@ def __init__(self, estimator, k_features=1,
self.est_ = self.estimator
self.scoring = scoring

self.n_cv_repeats = n_cv_repeats
if self.n_cv_repeats < 0:
raise AttributeError('Number of cross-validation repeats should be >= 0.')
if not self.cv and self.n_cv_repeats > 0:
raise AttributeError('Cannot repeat cross-validation when it\'s set to 0.')


if scoring is None:
if self.est_._estimator_type == 'classifier':
scoring = 'accuracy'
Expand Down