forked from autoreject/autoreject
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request autoreject#120 from jasmainak/autoreject_io
[MRG] start adding IO to autoreject
- Loading branch information
Showing
8 changed files
with
135 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
__version__ = '0.2.dev0' | ||
|
||
from .autoreject import _GlobalAutoReject, _AutoReject, AutoReject | ||
from .autoreject import RejectLog | ||
from .autoreject import RejectLog, read_auto_reject | ||
from .autoreject import compute_thresholds, validation_curve, get_rejection_threshold | ||
from .ransac import Ransac | ||
from .utils import set_matplotlib_defaults |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# Alexandre Gramfort <[email protected]> | ||
# Denis A. Engemann <[email protected]> | ||
|
||
import os.path as op | ||
import warnings | ||
from functools import partial | ||
|
||
|
@@ -12,6 +13,7 @@ | |
|
||
import mne | ||
from mne import pick_types | ||
from mne.externals.h5io import read_hdf5, write_hdf5 | ||
|
||
from sklearn.base import BaseEstimator | ||
from sklearn.model_selection import RandomizedSearchCV | ||
|
@@ -28,6 +30,13 @@ | |
mem = Memory(cachedir='cachedir') | ||
mem.clear(warn=False) | ||
|
||
_INIT_PARAMS = ('consensus', 'n_interpolate', 'picks', | ||
'verbose', 'n_jobs', 'cv', 'random_state', | ||
'thresh_method') | ||
|
||
_FIT_PARAMS = ('threshes_', 'n_interpolate_', 'consensus_', 'picks_', | ||
'loss_') | ||
|
||
|
||
def _slicemean(obj, this_slice, axis): | ||
mean = np.nan | ||
|
@@ -81,6 +90,25 @@ def validation_curve(epochs, y, param_name, param_range, cv=None): | |
return train_scores, test_scores | ||
|
||
|
||
def read_auto_reject(fname): | ||
"""Read AutoReject object. | ||
Parameters | ||
---------- | ||
fname : str | ||
The filename where the AutoReject object is saved. | ||
Returns | ||
------- | ||
ar : instance of autoreject.AutoReject | ||
""" | ||
state = read_hdf5(fname, title='autoreject') | ||
init_kwargs = {param: state[param] for param in _INIT_PARAMS} | ||
ar = AutoReject(**init_kwargs) | ||
ar.__setstate__(state) | ||
return ar | ||
|
||
|
||
class BaseAutoReject(BaseEstimator): | ||
"""Base class for rejection.""" | ||
|
||
|
@@ -830,6 +858,43 @@ def __repr__(self): | |
return '%s(%s)' % (class_name, _pprint(params, | ||
offset=len(class_name),),) | ||
|
||
def __getstate__(self): | ||
"""Get the state of autoreject as a dictionary.""" | ||
state = dict() | ||
|
||
for param in _INIT_PARAMS: | ||
state[param] = getattr(self, param) | ||
for param in _FIT_PARAMS: | ||
if hasattr(self, param): | ||
state[param] = getattr(self, param) | ||
|
||
if hasattr(self, 'local_reject_'): | ||
state['local_reject_'] = dict() | ||
for ch_type in self.local_reject_: | ||
state['local_reject_'][ch_type] = dict() | ||
for param in _INIT_PARAMS[:4] + _FIT_PARAMS[:3]: | ||
state['local_reject_'][ch_type][param] = \ | ||
getattr(self.local_reject_[ch_type], param) | ||
return state | ||
|
||
def __setstate__(self, state): | ||
"""Set the state of autoreject.""" | ||
for param in state.keys(): | ||
if param == 'local_reject_': | ||
local_reject_ = dict() | ||
for ch_type in state['local_reject_']: | ||
init_kwargs = { | ||
key: state['local_reject_'][ch_type][key] | ||
for key in _INIT_PARAMS[:4] | ||
} | ||
local_reject_[ch_type] = _AutoReject(**init_kwargs) | ||
for key in _FIT_PARAMS[:3]: | ||
setattr(local_reject_[ch_type], key, | ||
state['local_reject_'][ch_type][key]) | ||
self.local_reject_ = local_reject_ | ||
elif param not in _INIT_PARAMS: | ||
setattr(self, param, state[param]) | ||
|
||
def fit(self, epochs): | ||
"""Fit the epochs on the AutoReject object. | ||
|
@@ -999,6 +1064,25 @@ def fit_transform(self, epochs, return_log=False): | |
""" | ||
return self.fit(epochs).transform(epochs, return_log=return_log) | ||
|
||
def save(self, fname, overwrite=False): | ||
"""Save autoreject object. | ||
Parameters | ||
---------- | ||
fname : str | ||
The filename to save to. The filename must end | ||
in '.h5' or '.hdf5'. | ||
overwrite : bool | ||
If True, overwrite file if it already exists. Defaults to False. | ||
""" | ||
fname = op.realpath(fname) | ||
if not overwrite and op.isfile(fname): | ||
raise ValueError('%s already exists. Please make overwrite=True' | ||
'if you want to overwrite this file' % fname) | ||
|
||
write_hdf5(fname, self.__getstate__(), overwrite=overwrite, | ||
title='autoreject') | ||
|
||
|
||
def _check_fit(epochs, threshes_, picks_): | ||
if not all(epochs.ch_names[pp] in threshes_ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,16 +2,19 @@ | |
# Denis A. Engemann <[email protected]> | ||
# License: BSD (3-clause) | ||
|
||
import os.path as op | ||
|
||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
|
||
import mne | ||
from mne.datasets import sample | ||
from mne import io | ||
from mne.utils import _TempDir | ||
|
||
from autoreject import (_GlobalAutoReject, _AutoReject, AutoReject, | ||
compute_thresholds, validation_curve, | ||
get_rejection_threshold) | ||
get_rejection_threshold, read_auto_reject) | ||
from autoreject.utils import _get_picks_by_type | ||
from autoreject.autoreject import _get_interp_chs | ||
|
||
|
@@ -241,3 +244,40 @@ def test_autoreject(): | |
threshes_b = compute_thresholds( | ||
epochs_fit, picks=picks, method='bayesian_optimization') | ||
assert_equal(set(threshes_b.keys()), set(ch_names)) | ||
|
||
|
||
def test_io(): | ||
"""Test IO functionality.""" | ||
event_id = None | ||
tmin, tmax = -0.2, 0.5 | ||
events = mne.find_events(raw) | ||
savedir = _TempDir() | ||
fname = op.join(savedir, 'autoreject.hdf5') | ||
|
||
include = [u'EEG %03d' % i for i in range(1, 45, 3)] | ||
picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False, | ||
eog=True, include=include, exclude=[]) | ||
|
||
# raise error if preload is false | ||
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, | ||
picks=picks, baseline=(None, 0), decim=4, | ||
reject=None, preload=True)[:10] | ||
ar = AutoReject(cv=2, random_state=42, n_interpolate=[1], | ||
consensus=[0.5]) | ||
ar.save(fname) # save without fitting | ||
|
||
# check that fit after saving is the same as fit | ||
# without saving | ||
ar2 = read_auto_reject(fname) | ||
ar.fit(epochs) | ||
ar2.fit(epochs) | ||
assert_equal(np.sum([ar.threshes_[k] - ar2.threshes_[k] | ||
for k in ar.threshes_.keys()]), 0.) | ||
|
||
assert_raises(ValueError, ar.save, fname) | ||
ar.save(fname, overwrite=True) | ||
ar3 = read_auto_reject(fname) | ||
epochs_clean1, reject_log1 = ar.transform(epochs, return_log=True) | ||
epochs_clean2, reject_log2 = ar3.transform(epochs, return_log=True) | ||
assert_array_equal(epochs_clean1.get_data(), epochs_clean2.get_data()) | ||
assert_array_equal(reject_log1.labels, reject_log2.labels) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,4 @@ cover-package = autoreject | |
|
||
[flake8] | ||
exclude = __init__.py | ||
ignore = E241 | ||
ignore = E241, W504 |