forked from autoreject/autoreject
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
83 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
__version__ = '0.2.dev0' | ||
|
||
from .autoreject import _GlobalAutoReject, _AutoReject, AutoReject | ||
from .autoreject import RejectLog | ||
from .autoreject import RejectLog, read_autoreject | ||
from .autoreject import compute_thresholds, validation_curve, get_rejection_threshold | ||
from .ransac import Ransac | ||
from .utils import set_matplotlib_defaults |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# Alexandre Gramfort <[email protected]> | ||
# Denis A. Engemann <[email protected]> | ||
|
||
import os.path as op | ||
import warnings | ||
from functools import partial | ||
|
||
|
@@ -12,6 +13,7 @@ | |
|
||
import mne | ||
from mne import pick_types | ||
from mne.externals.h5io import read_hdf5, write_hdf5 | ||
|
||
from sklearn.base import BaseEstimator | ||
from sklearn.model_selection import RandomizedSearchCV | ||
|
@@ -81,6 +83,23 @@ def validation_curve(epochs, y, param_name, param_range, cv=None): | |
return train_scores, test_scores | ||
|
||
|
||
def read_autoreject(fname): | ||
"""Read autoreject object. | ||
Parameters | ||
---------- | ||
fname : str | ||
The filename where the autoreject object is saved. | ||
Returns | ||
------- | ||
ar : instance of autoreject.AutoReject | ||
""" | ||
state = read_hdf5(fname, title='autoreject') | ||
ar = AutoReject(state) | ||
return ar | ||
|
||
|
||
class BaseAutoReject(BaseEstimator): | ||
"""Base class for rejection.""" | ||
|
||
|
@@ -830,6 +849,22 @@ def __repr__(self): | |
return '%s(%s)' % (class_name, _pprint(params, | ||
offset=len(class_name),),) | ||
|
||
def __getstate__(self): | ||
"""Get the state of autoreject as a dictionary.""" | ||
state = dict() | ||
opt_params = [] | ||
|
||
non_opt_params = [ | ||
'n_interpolate', 'n_interpolate_', 'consensus', 'consensus_', | ||
'cv', 'picks', 'picks_', 'n_jobs', 'verbose', 'random_state', | ||
'threshes_', 'loss_'] # local_reject_ | ||
for param in non_opt_params: | ||
state[param] = getattr(self, param) | ||
for param in opt_params: | ||
if hasattr(self, param): | ||
state[param] = getattr(self, param) | ||
return state | ||
|
||
def fit(self, epochs): | ||
"""Fit the epochs on the AutoReject object. | ||
|
@@ -999,6 +1034,25 @@ def fit_transform(self, epochs, return_log=False): | |
""" | ||
return self.fit(epochs).transform(epochs, return_log=return_log) | ||
|
||
def save(self, fname, overwrite=False): | ||
"""Save autoreject object. | ||
Parameters | ||
---------- | ||
fname : str | ||
The filename to save to. The filename must end | ||
in '.h5' or '.hdf5'. | ||
overwrite : bool | ||
If True, overwrite file if it already exists. Defaults to False. | ||
""" | ||
fname = op.realpath(fname) | ||
if not overwrite and op.isfile(fname): | ||
raise ValueError('%s already exists. Please make overwrite=True' | ||
'if you want to overwrite this file' % fname) | ||
|
||
write_hdf5(fname, self.__getstate__(), overwrite=overwrite, | ||
title='autoreject') | ||
|
||
|
||
def _check_fit(epochs, threshes_, picks_): | ||
if not all(epochs.ch_names[pp] in threshes_ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,16 +2,19 @@ | |
# Denis A. Engemann <[email protected]> | ||
# License: BSD (3-clause) | ||
|
||
import os.path as op | ||
|
||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
|
||
import mne | ||
from mne.datasets import sample | ||
from mne import io | ||
from mne.utils import _TempDir | ||
|
||
from autoreject import (_GlobalAutoReject, _AutoReject, AutoReject, | ||
compute_thresholds, validation_curve, | ||
get_rejection_threshold) | ||
get_rejection_threshold, read_autoreject) | ||
from autoreject.utils import _get_picks_by_type | ||
from autoreject.autoreject import _get_interp_chs | ||
|
||
|
@@ -241,3 +244,27 @@ def test_autoreject(): | |
threshes_b = compute_thresholds( | ||
epochs_fit, picks=picks, method='bayesian_optimization') | ||
assert_equal(set(threshes_b.keys()), set(ch_names)) | ||
|
||
|
||
def test_io(): | ||
"""Test IO functionality.""" | ||
event_id = None | ||
tmin, tmax = -0.2, 0.5 | ||
events = mne.find_events(raw) | ||
savedir = _TempDir() | ||
fname = op.join(savedir, 'autoreject.hdf5') | ||
|
||
picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False, | ||
eog=True, exclude=[]) | ||
# raise error if preload is false | ||
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, | ||
picks=picks, baseline=(None, 0), | ||
reject=None, preload=False) | ||
|
||
ar = AutoReject(cv=3, random_state=42, n_interpolate=[1, 2], | ||
consensus=[0.5, 1]) | ||
ar.save(fname) | ||
read_autoreject(fname) | ||
ar.fit(epochs) | ||
ar.save(fname) | ||
read_autoreject(fname) |