Skip to content

Commit

Permalink
WIP start adding IO to autoreject
Browse files Browse the repository at this point in the history
  • Loading branch information
jasmainak committed Nov 2, 2018
1 parent 32fdbd4 commit 76fa2c7
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 2 deletions.
2 changes: 1 addition & 1 deletion autoreject/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__version__ = '0.2.dev0'

from .autoreject import _GlobalAutoReject, _AutoReject, AutoReject
from .autoreject import RejectLog
from .autoreject import RejectLog, read_autoreject
from .autoreject import compute_thresholds, validation_curve, get_rejection_threshold
from .ransac import Ransac
from .utils import set_matplotlib_defaults
54 changes: 54 additions & 0 deletions autoreject/autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Alexandre Gramfort <[email protected]>
# Denis A. Engemann <[email protected]>

import os.path as op
import warnings
from functools import partial

Expand All @@ -12,6 +13,7 @@

import mne
from mne import pick_types
from mne.externals.h5io import read_hdf5, write_hdf5

from sklearn.base import BaseEstimator
from sklearn.model_selection import RandomizedSearchCV
Expand Down Expand Up @@ -81,6 +83,23 @@ def validation_curve(epochs, y, param_name, param_range, cv=None):
return train_scores, test_scores


def read_autoreject(fname):
"""Read autoreject object.
Parameters
----------
fname : str
The filename where the autoreject object is saved.
Returns
-------
ar : instance of autoreject.AutoReject
"""
state = read_hdf5(fname, title='autoreject')
ar = AutoReject(state)
return ar


class BaseAutoReject(BaseEstimator):
"""Base class for rejection."""

Expand Down Expand Up @@ -830,6 +849,22 @@ def __repr__(self):
return '%s(%s)' % (class_name, _pprint(params,
offset=len(class_name),),)

def __getstate__(self):
"""Get the state of autoreject as a dictionary."""
state = dict()
opt_params = []

non_opt_params = [
'n_interpolate', 'n_interpolate_', 'consensus', 'consensus_',
'cv', 'picks', 'picks_', 'n_jobs', 'verbose', 'random_state',
'threshes_', 'loss_'] # local_reject_
for param in non_opt_params:
state[param] = getattr(self, param)
for param in opt_params:
if hasattr(self, param):
state[param] = getattr(self, param)
return state

def fit(self, epochs):
"""Fit the epochs on the AutoReject object.
Expand Down Expand Up @@ -999,6 +1034,25 @@ def fit_transform(self, epochs, return_log=False):
"""
return self.fit(epochs).transform(epochs, return_log=return_log)

def save(self, fname, overwrite=False):
"""Save autoreject object.
Parameters
----------
fname : str
The filename to save to. The filename must end
in '.h5' or '.hdf5'.
overwrite : bool
If True, overwrite file if it already exists. Defaults to False.
"""
fname = op.realpath(fname)
if not overwrite and op.isfile(fname):
raise ValueError('%s already exists. Please make overwrite=True'
'if you want to overwrite this file' % fname)

write_hdf5(fname, self.__getstate__(), overwrite=overwrite,
title='autoreject')


def _check_fit(epochs, threshes_, picks_):
if not all(epochs.ch_names[pp] in threshes_
Expand Down
29 changes: 28 additions & 1 deletion autoreject/tests/test_autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
# Denis A. Engemann <[email protected]>
# License: BSD (3-clause)

import os.path as op

import numpy as np
from numpy.testing import assert_array_equal

import mne
from mne.datasets import sample
from mne import io
from mne.utils import _TempDir

from autoreject import (_GlobalAutoReject, _AutoReject, AutoReject,
compute_thresholds, validation_curve,
get_rejection_threshold)
get_rejection_threshold, read_autoreject)
from autoreject.utils import _get_picks_by_type
from autoreject.autoreject import _get_interp_chs

Expand Down Expand Up @@ -241,3 +244,27 @@ def test_autoreject():
threshes_b = compute_thresholds(
epochs_fit, picks=picks, method='bayesian_optimization')
assert_equal(set(threshes_b.keys()), set(ch_names))


def test_io():
"""Test IO functionality."""
event_id = None
tmin, tmax = -0.2, 0.5
events = mne.find_events(raw)
savedir = _TempDir()
fname = op.join(savedir, 'autoreject.hdf5')

picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False,
eog=True, exclude=[])
# raise error if preload is false
epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
picks=picks, baseline=(None, 0),
reject=None, preload=False)

ar = AutoReject(cv=3, random_state=42, n_interpolate=[1, 2],
consensus=[0.5, 1])
ar.save(fname)
read_autoreject(fname)
ar.fit(epochs)
ar.save(fname)
read_autoreject(fname)

0 comments on commit 76fa2c7

Please sign in to comment.