Skip to content

Commit

Permalink
Merge pull request autoreject#120 from jasmainak/autoreject_io
Browse files Browse the repository at this point in the history
[MRG] start adding IO to autoreject
  • Loading branch information
jasmainak authored Nov 15, 2018
2 parents 32fdbd4 + 96fcf6a commit 873e4d4
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ install:
- source activate testenv
- conda install --yes --quiet numpy scipy scikit-learn matplotlib
- conda install --yes --quiet nose coverage
- pip install -q flake8 mne check-manifest
- pip install -q flake8 mne check-manifest h5py
- pip install coverage coveralls
- python setup.py install
script:
Expand Down
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ These are the dependencies to use autoreject:
* mne-python (>=0.14)
* scikit-learn (>=0.18)

Two optional dependencies are `tqdm` (for nice progressbars) and `h5py` (for IO).

Cite
----

Expand Down
2 changes: 1 addition & 1 deletion autoreject/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__version__ = '0.2.dev0'

from .autoreject import _GlobalAutoReject, _AutoReject, AutoReject
from .autoreject import RejectLog
from .autoreject import RejectLog, read_auto_reject
from .autoreject import compute_thresholds, validation_curve, get_rejection_threshold
from .ransac import Ransac
from .utils import set_matplotlib_defaults
84 changes: 84 additions & 0 deletions autoreject/autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Alexandre Gramfort <[email protected]>
# Denis A. Engemann <[email protected]>

import os.path as op
import warnings
from functools import partial

Expand All @@ -12,6 +13,7 @@

import mne
from mne import pick_types
from mne.externals.h5io import read_hdf5, write_hdf5

from sklearn.base import BaseEstimator
from sklearn.model_selection import RandomizedSearchCV
Expand All @@ -28,6 +30,13 @@
mem = Memory(cachedir='cachedir')
mem.clear(warn=False)

_INIT_PARAMS = ('consensus', 'n_interpolate', 'picks',
'verbose', 'n_jobs', 'cv', 'random_state',
'thresh_method')

_FIT_PARAMS = ('threshes_', 'n_interpolate_', 'consensus_', 'picks_',
'loss_')


def _slicemean(obj, this_slice, axis):
mean = np.nan
Expand Down Expand Up @@ -81,6 +90,25 @@ def validation_curve(epochs, y, param_name, param_range, cv=None):
return train_scores, test_scores


def read_auto_reject(fname):
"""Read AutoReject object.
Parameters
----------
fname : str
The filename where the AutoReject object is saved.
Returns
-------
ar : instance of autoreject.AutoReject
"""
state = read_hdf5(fname, title='autoreject')
init_kwargs = {param: state[param] for param in _INIT_PARAMS}
ar = AutoReject(**init_kwargs)
ar.__setstate__(state)
return ar


class BaseAutoReject(BaseEstimator):
"""Base class for rejection."""

Expand Down Expand Up @@ -830,6 +858,43 @@ def __repr__(self):
return '%s(%s)' % (class_name, _pprint(params,
offset=len(class_name),),)

def __getstate__(self):
"""Get the state of autoreject as a dictionary."""
state = dict()

for param in _INIT_PARAMS:
state[param] = getattr(self, param)
for param in _FIT_PARAMS:
if hasattr(self, param):
state[param] = getattr(self, param)

if hasattr(self, 'local_reject_'):
state['local_reject_'] = dict()
for ch_type in self.local_reject_:
state['local_reject_'][ch_type] = dict()
for param in _INIT_PARAMS[:4] + _FIT_PARAMS[:3]:
state['local_reject_'][ch_type][param] = \
getattr(self.local_reject_[ch_type], param)
return state

def __setstate__(self, state):
"""Set the state of autoreject."""
for param in state.keys():
if param == 'local_reject_':
local_reject_ = dict()
for ch_type in state['local_reject_']:
init_kwargs = {
key: state['local_reject_'][ch_type][key]
for key in _INIT_PARAMS[:4]
}
local_reject_[ch_type] = _AutoReject(**init_kwargs)
for key in _FIT_PARAMS[:3]:
setattr(local_reject_[ch_type], key,
state['local_reject_'][ch_type][key])
self.local_reject_ = local_reject_
elif param not in _INIT_PARAMS:
setattr(self, param, state[param])

def fit(self, epochs):
"""Fit the epochs on the AutoReject object.
Expand Down Expand Up @@ -999,6 +1064,25 @@ def fit_transform(self, epochs, return_log=False):
"""
return self.fit(epochs).transform(epochs, return_log=return_log)

def save(self, fname, overwrite=False):
"""Save autoreject object.
Parameters
----------
fname : str
The filename to save to. The filename must end
in '.h5' or '.hdf5'.
overwrite : bool
If True, overwrite file if it already exists. Defaults to False.
"""
fname = op.realpath(fname)
if not overwrite and op.isfile(fname):
raise ValueError('%s already exists. Please make overwrite=True'
'if you want to overwrite this file' % fname)

write_hdf5(fname, self.__getstate__(), overwrite=overwrite,
title='autoreject')


def _check_fit(epochs, threshes_, picks_):
if not all(epochs.ch_names[pp] in threshes_
Expand Down
42 changes: 41 additions & 1 deletion autoreject/tests/test_autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
# Denis A. Engemann <[email protected]>
# License: BSD (3-clause)

import os.path as op

import numpy as np
from numpy.testing import assert_array_equal

import mne
from mne.datasets import sample
from mne import io
from mne.utils import _TempDir

from autoreject import (_GlobalAutoReject, _AutoReject, AutoReject,
compute_thresholds, validation_curve,
get_rejection_threshold)
get_rejection_threshold, read_auto_reject)
from autoreject.utils import _get_picks_by_type
from autoreject.autoreject import _get_interp_chs

Expand Down Expand Up @@ -241,3 +244,40 @@ def test_autoreject():
threshes_b = compute_thresholds(
epochs_fit, picks=picks, method='bayesian_optimization')
assert_equal(set(threshes_b.keys()), set(ch_names))


def test_io():
"""Test IO functionality."""
event_id = None
tmin, tmax = -0.2, 0.5
events = mne.find_events(raw)
savedir = _TempDir()
fname = op.join(savedir, 'autoreject.hdf5')

include = [u'EEG %03d' % i for i in range(1, 45, 3)]
picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False,
eog=True, include=include, exclude=[])

# raise error if preload is false
epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
picks=picks, baseline=(None, 0), decim=4,
reject=None, preload=True)[:10]
ar = AutoReject(cv=2, random_state=42, n_interpolate=[1],
consensus=[0.5])
ar.save(fname) # save without fitting

# check that fit after saving is the same as fit
# without saving
ar2 = read_auto_reject(fname)
ar.fit(epochs)
ar2.fit(epochs)
assert_equal(np.sum([ar.threshes_[k] - ar2.threshes_[k]
for k in ar.threshes_.keys()]), 0.)

assert_raises(ValueError, ar.save, fname)
ar.save(fname, overwrite=True)
ar3 = read_auto_reject(fname)
epochs_clean1, reject_log1 = ar.transform(epochs, return_log=True)
epochs_clean2, reject_log2 = ar3.transform(epochs, return_log=True)
assert_array_equal(epochs_clean1.get_data(), epochs_clean2.get_data())
assert_array_equal(reject_log1.labels, reject_log2.labels)
3 changes: 2 additions & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ We recommend the `Anaconda Python distribution <https://www.continuum.io/downloa
$ pip install -U mne

An optional dependency is `tqdm <https://tqdm.github.io/>`_ if you want to use the verbosity flags `'tqdm'` or `'tqdm_notebook'`
for nice progressbars.
for nice progressbars. In case you want to be able to read and write `autoreject` objects using the HDF5 format,
you may also want to install `h5py <https://pypi.org/project/h5py/>`_.

Then install the latest release of autoreject use::

Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Current
Changelog
~~~~~~~~~

- Introduced a new method :meth:`autoreject.AutoReject.save` and function :func:`autoreject.read_auto_reject`
for IO of autoreject objects, by `Mainak Jas`_ in `#120 <https://github.com/autoreject/autoreject/pull/120>`_

Bug
~~~

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ cover-package = autoreject

[flake8]
exclude = __init__.py
ignore = E241
ignore = E241, W504

0 comments on commit 873e4d4

Please sign in to comment.