forked from autoreject/autoreject
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request autoreject#103 from jasmainak/no_partial
[MRG] MAINT simplify partial + compute_thresholds
- Loading branch information
Showing
4 changed files
with
27 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,8 @@ | |
# Denis A. Engemann <[email protected]> | ||
|
||
import warnings | ||
from functools import partial | ||
|
||
import numpy as np | ||
from scipy.stats.distributions import uniform | ||
|
||
|
@@ -334,8 +336,6 @@ def compute_thresholds(epochs, method='bayesian_optimization', | |
If `'tqdm'`, use `tqdm.tqdm`. | ||
If `'tqdm_notebook'`, use `tqdm.tqdm_notebook`. | ||
If False, suppress all output messages. | ||
n_jobs : int | ||
The number of jobs. | ||
Examples | ||
-------- | ||
|
@@ -764,16 +764,19 @@ class AutoReject(object): | |
The values to try for the number of channels for which to interpolate. | ||
This is :math:`\\rho`.If None, defaults to | ||
np.array([1, 4, 32]) | ||
thresh_func : callable | None | ||
Function which returns the channel-level thresholds. If None, | ||
defaults to :func:`autoreject.compute_thresholds`. | ||
cv : a scikit-learn cross-validation object | ||
Defaults to cv=10 | ||
picks : ndarray, shape(n_channels) | None | ||
The channels to be considered for autoreject. If None, defaults | ||
to data channels {'meg', 'eeg'}, which will lead fitting and combining | ||
autoreject solutions across these channel types. Note that, if picks is | ||
None, autoreject ignores channels marked bad in epochs.info['bads']. | ||
thresh_method : str | ||
'bayesian_optimization' or 'random_search' | ||
n_jobs : int | ||
The number of jobs. | ||
random_state : int seed, RandomState instance, or None (default) | ||
The seed of the pseudo random number generator to use. | ||
verbose : 'tqdm', 'tqdm_notebook', 'progressbar' or False | ||
The verbosity of progress messages. | ||
If `'progressbar'`, use `mne.utils.ProgressBar`. | ||
|
@@ -801,14 +804,17 @@ class AutoReject(object): | |
|
||
def __init__(self, n_interpolate=None, consensus=None, | ||
thresh_func=None, cv=10, picks=None, | ||
verbose='progressbar'): | ||
thresh_method='bayesian_optimization', | ||
n_jobs=1, random_state=None, verbose='progressbar'): | ||
"""Init it.""" | ||
self.n_interpolate = n_interpolate | ||
self.consensus = consensus | ||
self.thresh_func = thresh_func | ||
self.thresh_method = thresh_method | ||
self.cv = cv | ||
self.verbose = verbose | ||
self.picks = picks # XXX : should maybe be ch_types? | ||
self.n_jobs = n_jobs | ||
self.random_state = random_state | ||
|
||
if self.consensus is None: | ||
self.consensus = np.linspace(0, 1.0, 11) | ||
|
@@ -818,7 +824,9 @@ def __repr__(self): | |
class_name = self.__class__.__name__ | ||
params = dict(n_interpolate=self.n_interpolate, | ||
consensus=self.consensus, | ||
cv=self.cv, verbose=self.verbose, picks=self.picks) | ||
cv=self.cv, verbose=self.verbose, picks=self.picks, | ||
thresh_method=self.thresh_method, | ||
random_state=self.random_state, n_jobs=self.n_jobs) | ||
return '%s(%s)' % (class_name, _pprint(params, | ||
offset=len(class_name),),) | ||
|
||
|
@@ -841,6 +849,10 @@ def fit(self, epochs): | |
if isinstance(self.cv_, int): | ||
self.cv_ = KFold(n_splits=self.cv_) | ||
|
||
thresh_func = partial(compute_thresholds, n_jobs=self.n_jobs, | ||
method=self.thresh_method, | ||
random_state=self.random_state) | ||
|
||
if self.n_interpolate is None: | ||
if len(self.picks_) < 4: | ||
raise ValueError('Too few channels. autoreject is unlikely' | ||
|
@@ -861,7 +873,7 @@ def fit(self, epochs): | |
if self.verbose is not False: | ||
print('Running autoreject on ch_type=%s' % ch_type) | ||
this_local_reject, this_loss = \ | ||
_run_local_reject_cv(epochs, self.thresh_func, this_picks, | ||
_run_local_reject_cv(epochs, thresh_func, this_picks, | ||
self.n_interpolate, self.cv_, | ||
self.consensus, self.verbose) | ||
self.threshes_.update(this_local_reject.threshes_) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,6 @@ | |
# Denis A. Engemann <[email protected]> | ||
# License: BSD (3-clause) | ||
|
||
from functools import partial | ||
|
||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
|
||
|
@@ -130,10 +128,7 @@ def test_autoreject(): | |
|
||
ar = _AutoReject(picks=picks) # XXX : why do we need this?? | ||
|
||
thresh_func = partial(compute_thresholds, | ||
method='bayesian_optimization', | ||
random_state=42) | ||
ar = AutoReject(cv=3, picks=picks, thresh_func=thresh_func, | ||
ar = AutoReject(cv=3, picks=picks, random_state=42, | ||
n_interpolate=[1, 2], consensus=[0.5, 1]) | ||
assert_raises(AttributeError, ar.fit, X) | ||
assert_raises(ValueError, ar.transform, X) | ||
|
@@ -203,7 +198,7 @@ def test_autoreject(): | |
|
||
# test that transform ignores bad channels | ||
epochs_with_bads_fit.pick_types(meg='mag', eeg=True, eog=True, exclude=[]) | ||
ar_bads = AutoReject(cv=3, thresh_func=thresh_func, | ||
ar_bads = AutoReject(cv=3, random_state=42, | ||
n_interpolate=[1, 2], consensus=[0.5, 1]) | ||
ar_bads.fit(epochs_with_bads_fit) | ||
epochs_with_bads_clean = ar_bads.transform(epochs_with_bads_fit) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters