From ec9599864bbed1a9d6a249076cf8bc1319f0982e Mon Sep 17 00:00:00 2001 From: Fahimeh Mamashli Date: Fri, 4 Oct 2019 13:58:07 -0400 Subject: [PATCH] added mode='constant' to flaten TMS artifact --- mne/preprocessing/stim.py | 42 ++++++++++++++++++++++------ mne/preprocessing/tests/test_stim.py | 16 +++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/mne/preprocessing/stim.py b/mne/preprocessing/stim.py index c142885df2e..441887ddd7c 100644 --- a/mne/preprocessing/stim.py +++ b/mne/preprocessing/stim.py @@ -1,5 +1,6 @@ # Authors: Daniel Strohmeier -# +# Fahimeh Mamashli +# Padma Sundaram # License: BSD (3-clause) import numpy as np @@ -21,7 +22,8 @@ def _get_window(start, end): return window -def _fix_artifact(data, window, picks, first_samp, last_samp, mode): +def _fix_artifact(data, window, picks, first_samp, last_samp, base_tmin, + base_tmax, mode): """Modify original data by using parameter data.""" from scipy.interpolate import interp1d if mode == 'linear': @@ -33,10 +35,14 @@ def _fix_artifact(data, window, picks, first_samp, last_samp, mode): if mode == 'window': data[picks, first_samp:last_samp] = \ data[picks, first_samp:last_samp] * window[np.newaxis, :] + if mode == 'constant': + data[picks, first_samp:last_samp] = \ + data[picks, base_tmin: base_tmax].mean(axis=1)[:, None] def fix_stim_artifact(inst, events=None, event_id=None, tmin=0., - tmax=0.01, mode='linear', stim_channel=None): + tmax=0.01, baseline=None, + mode='linear', stim_channel=None): """Eliminate stimulation's artifacts from instance. .. note:: This function operates in-place, consider passing @@ -55,10 +61,13 @@ def fix_stim_artifact(inst, events=None, event_id=None, tmin=0., Start time of the interpolation window in seconds. tmax : float End time of the interpolation window in seconds. - mode : 'linear' | 'window' + baseline: None or tuple of length 2 + When mode = 'constant', baseline is required + mode : 'linear' | 'window' | 'constant' Way to fill the artifacted time interval. 'linear' does linear interpolation 'window' applies a (1 - hanning) window. + 'constant' use baseline avergae stim_channel : str | None Stim channel to use. @@ -67,9 +76,16 @@ def fix_stim_artifact(inst, events=None, event_id=None, tmin=0., inst : instance of Raw or Evoked or Epochs Instance with modified data """ - _check_option('mode', mode, ['linear', 'window']) + _check_option('mode', mode, ['linear', 'window', 'constant']) s_start = int(np.ceil(inst.info['sfreq'] * tmin)) s_end = int(np.ceil(inst.info['sfreq'] * tmax)) + if (mode == "constant") and (baseline is None): + raise ValueError('Please provide the baseline') + if mode == 'constant': + b_start = int(np.ceil(inst.info['sfreq'] * baseline[0])) + b_end = int(np.ceil(inst.info['sfreq'] * baseline[1])) + else: + b_start, b_end = np.nan, np.nan if (mode == "window") and (s_end - s_start) < 4: raise ValueError('Time range is too short. Use a larger interval ' 'or set mode to "linear".') @@ -93,7 +109,10 @@ def fix_stim_artifact(inst, events=None, event_id=None, tmin=0., for event_idx in event_start: first_samp = int(event_idx) - inst.first_samp + s_start last_samp = int(event_idx) - inst.first_samp + s_end - _fix_artifact(data, window, picks, first_samp, last_samp, mode) + base_t1 = int(event_idx) - inst.first_samp + b_start + base_t2 = int(event_idx) - inst.first_samp + b_end + _fix_artifact(data, window, picks, first_samp, last_samp, + base_t1, base_t2, mode) elif isinstance(inst, BaseEpochs): if inst.reject is not None: @@ -103,14 +122,21 @@ def fix_stim_artifact(inst, events=None, event_id=None, tmin=0., first_samp = s_start - e_start last_samp = s_end - e_start data = inst._data + base_t1 = e_start + b_start + base_t2 = e_start + b_end for epoch in data: - _fix_artifact(epoch, window, picks, first_samp, last_samp, mode) + _fix_artifact(epoch, window, picks, first_samp, last_samp, base_t1, + base_t2, mode) elif isinstance(inst, Evoked): first_samp = s_start - inst.first last_samp = s_end - inst.first data = inst.data - _fix_artifact(data, window, picks, first_samp, last_samp, mode) + base_t1 = b_start - inst.first + base_t2 = b_end - inst.first + + _fix_artifact(data, window, picks, first_samp, last_samp, base_t1, + base_t2, mode) else: raise TypeError('Not a Raw or Epochs or Evoked (got %s).' % type(inst)) diff --git a/mne/preprocessing/tests/test_stim.py b/mne/preprocessing/tests/test_stim.py index de15a1dabeb..3d05c7d657e 100644 --- a/mne/preprocessing/tests/test_stim.py +++ b/mne/preprocessing/tests/test_stim.py @@ -50,6 +50,11 @@ def test_fix_stim_artifact(): # XXX This is a very weird check... assert np.all(data_from_epochs_fix) == 0. + epochs = fix_stim_artifact(epochs, tmin=tmin, tmax=tmax, + baseline=(-0.1, -0.05), mode='constant') + data = epochs.get_data()[:, :, tmin_samp:tmax_samp] + assert np.all(np.diff(data[0][0])) == 0. + # use window before stimulus in raw event_idx = np.where(events[:, 2] == 1)[0][0] tmin, tmax = -0.045, -0.015 @@ -68,8 +73,14 @@ def test_fix_stim_artifact(): raw = fix_stim_artifact(raw, events, event_id=1, tmin=tmin, tmax=tmax, mode='window') data, times = raw[:, (tidx + tmin_samp):(tidx + tmax_samp)] + assert np.all(data) == 0. + raw = fix_stim_artifact(raw, events, event_id=1, tmin=tmin, tmax=tmax, + baseline=(-0.1, -0.05), mode='constant') + data, times = raw[:, (tidx + tmin_samp):(tidx + tmax_samp)] + assert np.all(np.diff(data[0])) == 0. + # get epochs from raw with fixed data tmin, tmax, event_id = -0.2, 0.5, 1 epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, @@ -95,3 +106,8 @@ def test_fix_stim_artifact(): evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax, mode='window') data = evoked.data[:, tmin_samp:tmax_samp] assert np.all(data) == 0. + + evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax, + baseline=(-0.1, -0.05), mode='constant') + data = evoked.data[:, tmin_samp:tmax_samp] + assert np.all(np.diff(data[0])) == 0