Skip to content

Commit

Permalink
Merge pull request autoreject#103 from jasmainak/no_partial
Browse files Browse the repository at this point in the history
[MRG] MAINT simplify partial + compute_thresholds
  • Loading branch information
dengemann authored Jun 11, 2018
2 parents 0f1ec23 + c3bd24f commit bd9df0c
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 32 deletions.
30 changes: 21 additions & 9 deletions autoreject/autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# Denis A. Engemann <[email protected]>

import warnings
from functools import partial

import numpy as np
from scipy.stats.distributions import uniform

Expand Down Expand Up @@ -334,8 +336,6 @@ def compute_thresholds(epochs, method='bayesian_optimization',
If `'tqdm'`, use `tqdm.tqdm`.
If `'tqdm_notebook'`, use `tqdm.tqdm_notebook`.
If False, suppress all output messages.
n_jobs : int
The number of jobs.
Examples
--------
Expand Down Expand Up @@ -764,16 +764,19 @@ class AutoReject(object):
The values to try for the number of channels for which to interpolate.
This is :math:`\\rho`.If None, defaults to
np.array([1, 4, 32])
thresh_func : callable | None
Function which returns the channel-level thresholds. If None,
defaults to :func:`autoreject.compute_thresholds`.
cv : a scikit-learn cross-validation object
Defaults to cv=10
picks : ndarray, shape(n_channels) | None
The channels to be considered for autoreject. If None, defaults
to data channels {'meg', 'eeg'}, which will lead fitting and combining
autoreject solutions across these channel types. Note that, if picks is
None, autoreject ignores channels marked bad in epochs.info['bads'].
thresh_method : str
'bayesian_optimization' or 'random_search'
n_jobs : int
The number of jobs.
random_state : int seed, RandomState instance, or None (default)
The seed of the pseudo random number generator to use.
verbose : 'tqdm', 'tqdm_notebook', 'progressbar' or False
The verbosity of progress messages.
If `'progressbar'`, use `mne.utils.ProgressBar`.
Expand Down Expand Up @@ -801,14 +804,17 @@ class AutoReject(object):

def __init__(self, n_interpolate=None, consensus=None,
thresh_func=None, cv=10, picks=None,
verbose='progressbar'):
thresh_method='bayesian_optimization',
n_jobs=1, random_state=None, verbose='progressbar'):
"""Init it."""
self.n_interpolate = n_interpolate
self.consensus = consensus
self.thresh_func = thresh_func
self.thresh_method = thresh_method
self.cv = cv
self.verbose = verbose
self.picks = picks # XXX : should maybe be ch_types?
self.n_jobs = n_jobs
self.random_state = random_state

if self.consensus is None:
self.consensus = np.linspace(0, 1.0, 11)
Expand All @@ -818,7 +824,9 @@ def __repr__(self):
class_name = self.__class__.__name__
params = dict(n_interpolate=self.n_interpolate,
consensus=self.consensus,
cv=self.cv, verbose=self.verbose, picks=self.picks)
cv=self.cv, verbose=self.verbose, picks=self.picks,
thresh_method=self.thresh_method,
random_state=self.random_state, n_jobs=self.n_jobs)
return '%s(%s)' % (class_name, _pprint(params,
offset=len(class_name),),)

Expand All @@ -841,6 +849,10 @@ def fit(self, epochs):
if isinstance(self.cv_, int):
self.cv_ = KFold(n_splits=self.cv_)

thresh_func = partial(compute_thresholds, n_jobs=self.n_jobs,
method=self.thresh_method,
random_state=self.random_state)

if self.n_interpolate is None:
if len(self.picks_) < 4:
raise ValueError('Too few channels. autoreject is unlikely'
Expand All @@ -861,7 +873,7 @@ def fit(self, epochs):
if self.verbose is not False:
print('Running autoreject on ch_type=%s' % ch_type)
this_local_reject, this_loss = \
_run_local_reject_cv(epochs, self.thresh_func, this_picks,
_run_local_reject_cv(epochs, thresh_func, this_picks,
self.n_interpolate, self.cv_,
self.consensus, self.verbose)
self.threshes_.update(this_local_reject.threshes_)
Expand Down
9 changes: 2 additions & 7 deletions autoreject/tests/test_autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# Denis A. Engemann <[email protected]>
# License: BSD (3-clause)

from functools import partial

import numpy as np
from numpy.testing import assert_array_equal

Expand Down Expand Up @@ -130,10 +128,7 @@ def test_autoreject():

ar = _AutoReject(picks=picks) # XXX : why do we need this??

thresh_func = partial(compute_thresholds,
method='bayesian_optimization',
random_state=42)
ar = AutoReject(cv=3, picks=picks, thresh_func=thresh_func,
ar = AutoReject(cv=3, picks=picks, random_state=42,
n_interpolate=[1, 2], consensus=[0.5, 1])
assert_raises(AttributeError, ar.fit, X)
assert_raises(ValueError, ar.transform, X)
Expand Down Expand Up @@ -203,7 +198,7 @@ def test_autoreject():

# test that transform ignores bad channels
epochs_with_bads_fit.pick_types(meg='mag', eeg=True, eog=True, exclude=[])
ar_bads = AutoReject(cv=3, thresh_func=thresh_func,
ar_bads = AutoReject(cv=3, random_state=42,
n_interpolate=[1, 2], consensus=[0.5, 1])
ar_bads.fit(epochs_with_bads_fit)
epochs_with_bads_clean = ar_bads.transform(epochs_with_bads_fit)
Expand Down
14 changes: 3 additions & 11 deletions examples/plot_auto_repair.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
# provided to the :class:`autoreject.AutoReject` class for computing
# the channel-level thresholds.

from autoreject import (AutoReject, compute_thresholds,
set_matplotlib_defaults) # noqa
from autoreject import (AutoReject, set_matplotlib_defaults) # noqa

###############################################################################
# Let us now read in the raw `fif` file for MNE sample dataset.
Expand Down Expand Up @@ -82,13 +81,6 @@
baseline=(None, 0), reject=None,
verbose=False, detrend=0, preload=True)

###############################################################################
# First, we set up the function to compute the sensor-level thresholds.

from functools import partial # noqa
thresh_func = partial(compute_thresholds, picks=picks, method='random_search',
random_state=42)

###############################################################################
# :class:`autoreject.AutoReject` internally does cross-validation to
# determine the optimal values :math:`\rho^{*}` and :math:`\kappa^{*}`
Expand All @@ -102,9 +94,9 @@
# Here we only use a subset of channels to save time.

ar = AutoReject(n_interpolates, consensus_percs, picks=picks,
thresh_func=thresh_func)
thresh_method='random_search', random_state=42)

# Not that fitting and transforming can be done on different compatible
# Note that fitting and transforming can be done on different compatible
# portions of data if needed.
ar.fit(epochs['Auditory/Left'])
epochs_clean = ar.transform(epochs['Auditory/Left'])
Expand Down
6 changes: 1 addition & 5 deletions examples/plot_visualize_bad_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,12 @@
# Now, we apply autoreject

from autoreject import AutoReject, compute_thresholds # noqa
from functools import partial # noqa

this_epoch = epochs['famous']
exclude = [] # XXX
picks = mne.pick_types(epochs.info, meg=False, eeg=True, stim=False,
eog=False, exclude=exclude)

thresh_func = partial(compute_thresholds, random_state=42, n_jobs=1)

###############################################################################
# Note that :class:`autoreject.AutoReject` by design supports multiple
# channels. If no picks are passed separate solutions will be computed for each
Expand All @@ -107,8 +104,7 @@
# may be saved by fitting :class:`autoreject.AutoReject` on a
# representative subsample of the data.


ar = AutoReject(thresh_func=thresh_func, picks=picks, verbose='tqdm')
ar = AutoReject(picks=picks, random_state=42, n_jobs=1, verbose='tqdm')

epochs_ar, reject_log = ar.fit_transform(this_epoch, return_log=True)

Expand Down

0 comments on commit bd9df0c

Please sign in to comment.