From 8ab45defa869476168461ecd3c5f5cdd656f99a0 Mon Sep 17 00:00:00 2001 From: Mainak Jas Date: Thu, 8 Nov 2018 01:52:00 -0500 Subject: [PATCH] Make it work --- autoreject/autoreject.py | 27 ++++++++++++++++----------- autoreject/tests/test_autoreject.py | 4 +++- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/autoreject/autoreject.py b/autoreject/autoreject.py index 6d2c05e0..cb3863f8 100644 --- a/autoreject/autoreject.py +++ b/autoreject/autoreject.py @@ -30,8 +30,8 @@ mem = Memory(cachedir='cachedir') mem.clear(warn=False) -INIT_PARAMS = ('n_interpolate', 'consensus', 'cv', - 'picks', 'n_jobs', 'verbose', 'random_state', +INIT_PARAMS = ('consensus', 'n_interpolate', 'picks', + 'verbose', 'n_jobs', 'cv', 'random_state', 'thresh_method') @@ -859,24 +859,23 @@ def __getstate__(self): """Get the state of autoreject as a dictionary.""" state = dict() - fit_params = [ + fit_params_cv = ( 'n_interpolate_', 'consensus_', 'picks_', 'threshes_', 'loss_' - ] - - local_reject_params = [ - 'consensus', 'n_interpolate', 'picks', 'verbose'] + ) for param in INIT_PARAMS: state[param] = getattr(self, param) - for param in fit_params: + for param in fit_params_cv: if hasattr(self, param): state[param] = getattr(self, param) + if hasattr(self, 'local_reject_'): state['local_reject_'] = dict() for ch_type in self.local_reject_: state['local_reject_'][ch_type] = dict() - for param in local_reject_params: + for param in INIT_PARAMS[:4] + \ + ('threshes_', 'n_interpolate_', 'consensus_'): state['local_reject_'][ch_type][param] = \ getattr(self.local_reject_[ch_type], param) return state @@ -887,8 +886,14 @@ def __setstate__(self, state): if param == 'local_reject_': local_reject_ = dict() for ch_type in state['local_reject_']: - local_reject_[ch_type] = \ - _AutoReject(**state['local_reject_'][ch_type]) + init_kwargs = { + key: state['local_reject_'][ch_type][key] + for key in INIT_PARAMS[:4] + } + local_reject_[ch_type] = _AutoReject(**init_kwargs) + for key in ('threshes_', 'n_interpolate_', 'consensus_'): + setattr(local_reject_[ch_type], key, + state['local_reject_'][ch_type][key]) self.local_reject_ = local_reject_ elif param not in INIT_PARAMS: setattr(self, param, state[param]) diff --git a/autoreject/tests/test_autoreject.py b/autoreject/tests/test_autoreject.py index 0ecd2669..93aa7230 100644 --- a/autoreject/tests/test_autoreject.py +++ b/autoreject/tests/test_autoreject.py @@ -277,4 +277,6 @@ def test_io(): assert_raises(ValueError, ar.save, fname) ar.save(fname, overwrite=True) ar3 = read_autoreject(fname) - epochs_clean = ar3.transform(epochs) + epochs_clean1 = ar.transform(epochs) + epochs_clean2 = ar3.transform(epochs) + assert_array_equal(epochs_clean1.get_data(), epochs_clean2.get_data())