Skip to content

Commit

Permalink
Make it work
Browse files Browse the repository at this point in the history
  • Loading branch information
jasmainak committed Nov 8, 2018
1 parent 913b3d4 commit 8ab45de
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
27 changes: 16 additions & 11 deletions autoreject/autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
mem = Memory(cachedir='cachedir')
mem.clear(warn=False)

INIT_PARAMS = ('n_interpolate', 'consensus', 'cv',
'picks', 'n_jobs', 'verbose', 'random_state',
INIT_PARAMS = ('consensus', 'n_interpolate', 'picks',
'verbose', 'n_jobs', 'cv', 'random_state',
'thresh_method')


Expand Down Expand Up @@ -859,24 +859,23 @@ def __getstate__(self):
"""Get the state of autoreject as a dictionary."""
state = dict()

fit_params = [
fit_params_cv = (
'n_interpolate_', 'consensus_', 'picks_',
'threshes_', 'loss_'
]

local_reject_params = [
'consensus', 'n_interpolate', 'picks', 'verbose']
)

for param in INIT_PARAMS:
state[param] = getattr(self, param)
for param in fit_params:
for param in fit_params_cv:
if hasattr(self, param):
state[param] = getattr(self, param)

if hasattr(self, 'local_reject_'):
state['local_reject_'] = dict()
for ch_type in self.local_reject_:
state['local_reject_'][ch_type] = dict()
for param in local_reject_params:
for param in INIT_PARAMS[:4] + \
('threshes_', 'n_interpolate_', 'consensus_'):
state['local_reject_'][ch_type][param] = \
getattr(self.local_reject_[ch_type], param)
return state
Expand All @@ -887,8 +886,14 @@ def __setstate__(self, state):
if param == 'local_reject_':
local_reject_ = dict()
for ch_type in state['local_reject_']:
local_reject_[ch_type] = \
_AutoReject(**state['local_reject_'][ch_type])
init_kwargs = {
key: state['local_reject_'][ch_type][key]
for key in INIT_PARAMS[:4]
}
local_reject_[ch_type] = _AutoReject(**init_kwargs)
for key in ('threshes_', 'n_interpolate_', 'consensus_'):
setattr(local_reject_[ch_type], key,
state['local_reject_'][ch_type][key])
self.local_reject_ = local_reject_
elif param not in INIT_PARAMS:
setattr(self, param, state[param])
Expand Down
4 changes: 3 additions & 1 deletion autoreject/tests/test_autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,6 @@ def test_io():
assert_raises(ValueError, ar.save, fname)
ar.save(fname, overwrite=True)
ar3 = read_autoreject(fname)
epochs_clean = ar3.transform(epochs)
epochs_clean1 = ar.transform(epochs)
epochs_clean2 = ar3.transform(epochs)
assert_array_equal(epochs_clean1.get_data(), epochs_clean2.get_data())

0 comments on commit 8ab45de

Please sign in to comment.