Skip to content

Commit

Permalink
ENH: also save local_reject_
Browse files Browse the repository at this point in the history
  • Loading branch information
jasmainak committed Nov 7, 2018
1 parent cbb83aa commit 913b3d4
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions autoreject/autoreject.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
mem = Memory(cachedir='cachedir')
mem.clear(warn=False)

INIT_PARAMS = ['n_interpolate', 'consensus', 'cv',
INIT_PARAMS = ('n_interpolate', 'consensus', 'cv',
'picks', 'n_jobs', 'verbose', 'random_state',
'thresh_method']
'thresh_method')


def _slicemean(obj, this_slice, axis):
Expand Down Expand Up @@ -861,20 +861,36 @@ def __getstate__(self):

fit_params = [
'n_interpolate_', 'consensus_', 'picks_',
'threshes_', 'loss_',
'threshes_', 'loss_'
]

local_reject_params = [
'consensus', 'n_interpolate', 'picks', 'verbose']

for param in INIT_PARAMS:
state[param] = getattr(self, param)
for param in fit_params:
if hasattr(self, param):
state[param] = getattr(self, param)
if hasattr(self, 'local_reject_'):
state['local_reject_'] = dict()
for ch_type in self.local_reject_:
state['local_reject_'][ch_type] = dict()
for param in local_reject_params:
state['local_reject_'][ch_type][param] = \
getattr(self.local_reject_[ch_type], param)
return state

def __setstate__(self, state):
"""Set the state of autoreject."""
for param in state.keys():
if param not in INIT_PARAMS:
if param == 'local_reject_':
local_reject_ = dict()
for ch_type in state['local_reject_']:
local_reject_[ch_type] = \
_AutoReject(**state['local_reject_'][ch_type])
self.local_reject_ = local_reject_
elif param not in INIT_PARAMS:
setattr(self, param, state[param])

def fit(self, epochs):
Expand Down

0 comments on commit 913b3d4

Please sign in to comment.