Skip to content

Commit

Permalink
ENH: progress with saving and reading
Browse files Browse the repository at this point in the history
  • Loading branch information
jasmainak committed Nov 2, 2018
1 parent 76fa2c7 commit cbb83aa
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
28 changes: 20 additions & 8 deletions autoreject/autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
mem = Memory(cachedir='cachedir')
mem.clear(warn=False)

INIT_PARAMS = ['n_interpolate', 'consensus', 'cv',
'picks', 'n_jobs', 'verbose', 'random_state',
'thresh_method']


def _slicemean(obj, this_slice, axis):
mean = np.nan
Expand Down Expand Up @@ -96,7 +100,9 @@ def read_autoreject(fname):
ar : instance of autoreject.AutoReject
"""
state = read_hdf5(fname, title='autoreject')
ar = AutoReject(state)
init_kwargs = {param: state[param] for param in INIT_PARAMS}
ar = AutoReject(**init_kwargs)
ar.__setstate__(state)
return ar


Expand Down Expand Up @@ -852,19 +858,25 @@ def __repr__(self):
def __getstate__(self):
"""Get the state of autoreject as a dictionary."""
state = dict()
opt_params = []

non_opt_params = [
'n_interpolate', 'n_interpolate_', 'consensus', 'consensus_',
'cv', 'picks', 'picks_', 'n_jobs', 'verbose', 'random_state',
'threshes_', 'loss_'] # local_reject_
for param in non_opt_params:
fit_params = [
'n_interpolate_', 'consensus_', 'picks_',
'threshes_', 'loss_',
]

for param in INIT_PARAMS:
state[param] = getattr(self, param)
for param in opt_params:
for param in fit_params:
if hasattr(self, param):
state[param] = getattr(self, param)
return state

def __setstate__(self, state):
"""Set the state of autoreject."""
for param in state.keys():
if param not in INIT_PARAMS:
setattr(self, param, state[param])

def fit(self, epochs):
"""Fit the epochs on the AutoReject object.
Expand Down
32 changes: 21 additions & 11 deletions autoreject/tests/test_autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,27 @@ def test_io():
savedir = _TempDir()
fname = op.join(savedir, 'autoreject.hdf5')

picks = mne.pick_types(raw.info, meg=True, eeg=True, stim=False,
eog=True, exclude=[])
include = [u'EEG %03d' % i for i in range(1, 45, 3)]
picks = mne.pick_types(raw.info, meg=False, eeg=False, stim=False,
eog=True, include=include, exclude=[])

# raise error if preload is false
epochs = mne.Epochs(raw, events, event_id, tmin, tmax,
picks=picks, baseline=(None, 0),
reject=None, preload=False)

ar = AutoReject(cv=3, random_state=42, n_interpolate=[1, 2],
consensus=[0.5, 1])
ar.save(fname)
read_autoreject(fname)
picks=picks, baseline=(None, 0), decim=4,
reject=None, preload=True)[:10]
ar = AutoReject(cv=2, random_state=42, n_interpolate=[1],
consensus=[0.5])
ar.save(fname) # save without fitting

# check that fit after saving is the same as fit
# without saving
ar2 = read_autoreject(fname)
ar.fit(epochs)
ar.save(fname)
read_autoreject(fname)
ar2.fit(epochs)
assert_equal(np.sum([ar.threshes_[k] - ar2.threshes_[k]
for k in ar.threshes_.keys()]), 0.)

assert_raises(ValueError, ar.save, fname)
ar.save(fname, overwrite=True)
ar3 = read_autoreject(fname)
epochs_clean = ar3.transform(epochs)

0 comments on commit cbb83aa

Please sign in to comment.