Skip to content

Commit

Permalink
added mode='constant' to flaten TMS artifact
Browse files Browse the repository at this point in the history
  • Loading branch information
fmamashli committed Oct 4, 2019
1 parent 37f2f45 commit ec95998
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
42 changes: 34 additions & 8 deletions mne/preprocessing/stim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Authors: Daniel Strohmeier <[email protected]>
#
# Fahimeh Mamashli <[email protected]>
# Padma Sundaram <[email protected]>
# License: BSD (3-clause)

import numpy as np
Expand All @@ -21,7 +22,8 @@ def _get_window(start, end):
return window


def _fix_artifact(data, window, picks, first_samp, last_samp, mode):
def _fix_artifact(data, window, picks, first_samp, last_samp, base_tmin,
base_tmax, mode):
"""Modify original data by using parameter data."""
from scipy.interpolate import interp1d
if mode == 'linear':
Expand All @@ -33,10 +35,14 @@ def _fix_artifact(data, window, picks, first_samp, last_samp, mode):
if mode == 'window':
data[picks, first_samp:last_samp] = \
data[picks, first_samp:last_samp] * window[np.newaxis, :]
if mode == 'constant':
data[picks, first_samp:last_samp] = \
data[picks, base_tmin: base_tmax].mean(axis=1)[:, None]


def fix_stim_artifact(inst, events=None, event_id=None, tmin=0.,
tmax=0.01, mode='linear', stim_channel=None):
tmax=0.01, baseline=None,
mode='linear', stim_channel=None):
"""Eliminate stimulation's artifacts from instance.
.. note:: This function operates in-place, consider passing
Expand All @@ -55,10 +61,13 @@ def fix_stim_artifact(inst, events=None, event_id=None, tmin=0.,
Start time of the interpolation window in seconds.
tmax : float
End time of the interpolation window in seconds.
mode : 'linear' | 'window'
baseline: None or tuple of length 2
When mode = 'constant', baseline is required
mode : 'linear' | 'window' | 'constant'
Way to fill the artifacted time interval.
'linear' does linear interpolation
'window' applies a (1 - hanning) window.
'constant' use baseline avergae
stim_channel : str | None
Stim channel to use.
Expand All @@ -67,9 +76,16 @@ def fix_stim_artifact(inst, events=None, event_id=None, tmin=0.,
inst : instance of Raw or Evoked or Epochs
Instance with modified data
"""
_check_option('mode', mode, ['linear', 'window'])
_check_option('mode', mode, ['linear', 'window', 'constant'])
s_start = int(np.ceil(inst.info['sfreq'] * tmin))
s_end = int(np.ceil(inst.info['sfreq'] * tmax))
if (mode == "constant") and (baseline is None):
raise ValueError('Please provide the baseline')
if mode == 'constant':
b_start = int(np.ceil(inst.info['sfreq'] * baseline[0]))
b_end = int(np.ceil(inst.info['sfreq'] * baseline[1]))
else:
b_start, b_end = np.nan, np.nan
if (mode == "window") and (s_end - s_start) < 4:
raise ValueError('Time range is too short. Use a larger interval '
'or set mode to "linear".')
Expand All @@ -93,7 +109,10 @@ def fix_stim_artifact(inst, events=None, event_id=None, tmin=0.,
for event_idx in event_start:
first_samp = int(event_idx) - inst.first_samp + s_start
last_samp = int(event_idx) - inst.first_samp + s_end
_fix_artifact(data, window, picks, first_samp, last_samp, mode)
base_t1 = int(event_idx) - inst.first_samp + b_start
base_t2 = int(event_idx) - inst.first_samp + b_end
_fix_artifact(data, window, picks, first_samp, last_samp,
base_t1, base_t2, mode)

elif isinstance(inst, BaseEpochs):
if inst.reject is not None:
Expand All @@ -103,14 +122,21 @@ def fix_stim_artifact(inst, events=None, event_id=None, tmin=0.,
first_samp = s_start - e_start
last_samp = s_end - e_start
data = inst._data
base_t1 = e_start + b_start
base_t2 = e_start + b_end
for epoch in data:
_fix_artifact(epoch, window, picks, first_samp, last_samp, mode)
_fix_artifact(epoch, window, picks, first_samp, last_samp, base_t1,
base_t2, mode)

elif isinstance(inst, Evoked):
first_samp = s_start - inst.first
last_samp = s_end - inst.first
data = inst.data
_fix_artifact(data, window, picks, first_samp, last_samp, mode)
base_t1 = b_start - inst.first
base_t2 = b_end - inst.first

_fix_artifact(data, window, picks, first_samp, last_samp, base_t1,
base_t2, mode)

else:
raise TypeError('Not a Raw or Epochs or Evoked (got %s).' % type(inst))
Expand Down
16 changes: 16 additions & 0 deletions mne/preprocessing/tests/test_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def test_fix_stim_artifact():
# XXX This is a very weird check...
assert np.all(data_from_epochs_fix) == 0.

epochs = fix_stim_artifact(epochs, tmin=tmin, tmax=tmax,
baseline=(-0.1, -0.05), mode='constant')
data = epochs.get_data()[:, :, tmin_samp:tmax_samp]
assert np.all(np.diff(data[0][0])) == 0.

# use window before stimulus in raw
event_idx = np.where(events[:, 2] == 1)[0][0]
tmin, tmax = -0.045, -0.015
Expand All @@ -68,8 +73,14 @@ def test_fix_stim_artifact():
raw = fix_stim_artifact(raw, events, event_id=1, tmin=tmin,
tmax=tmax, mode='window')
data, times = raw[:, (tidx + tmin_samp):(tidx + tmax_samp)]

assert np.all(data) == 0.

raw = fix_stim_artifact(raw, events, event_id=1, tmin=tmin, tmax=tmax,
baseline=(-0.1, -0.05), mode='constant')
data, times = raw[:, (tidx + tmin_samp):(tidx + tmax_samp)]
assert np.all(np.diff(data[0])) == 0.

# get epochs from raw with fixed data
tmin, tmax, event_id = -0.2, 0.5, 1
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
Expand All @@ -95,3 +106,8 @@ def test_fix_stim_artifact():
evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax, mode='window')
data = evoked.data[:, tmin_samp:tmax_samp]
assert np.all(data) == 0.

evoked = fix_stim_artifact(evoked, tmin=tmin, tmax=tmax,
baseline=(-0.1, -0.05), mode='constant')
data = evoked.data[:, tmin_samp:tmax_samp]
assert np.all(np.diff(data[0])) == 0

0 comments on commit ec95998

Please sign in to comment.