-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
132 lines (111 loc) · 5.18 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import matplotlib.pyplot as plt
import mne.io
import numpy as np
import numpy.linalg
import scipy.signal
import mne.time_frequency
from multiprocessing import Pool
import viz
# >>> Parameters
# Channels to include in the analysis:
# Fp1 F3 F7 FpZ Fz Fp2 F4 F8
PICKS = ['EEG 10', 'EEG 12', 'EEG 18', 'EEG 8', 'EEG 6', 'EEG 5', 'EEG 60', 'EEG 58'] # None = all non-bad channels.
# Although EEG 8 is technically AFz, but is the closest to FpZ
START_TIME_SEC = 1 * 60
END_TIME_SEC = 4.5 * 60
# Which frequency bands we want to calculate average power for.
BAND_FREQUENCIES = {
'theta': [3.5, 7.5],
'alpha': [7.5, 13.0],
'beta': [13.0, 30.0],
}
# Take a rolling average of the last n values in a
def movingAverage(a, n=3) :
ret = np.cumsum(a, dtype=float)
ret[n:] = ret[n:] - ret[:-n]
return ret[n - 1:] / n
def calcBandPowers(raw):
data = raw._data
# Pick needed rows
if PICKS is not None:
pickIDs = mne.pick_types(raw.info, eeg=True, selection=PICKS)
data = np.take(data, pickIDs, axis=0)
# STFT for frequencies and powers
freq, t, stft = scipy.signal.stft(data, fs=int(raw.info['sfreq']))
powers = np.abs(stft)
result = {}
for bandID, bandHz in BAND_FREQUENCIES.iteritems():
# Find the average power for the frequencies in the band
fPick = np.logical_and(bandHz[0] < freq, freq < bandHz[1])
meanPower = np.mean(powers[:, fPick, :], axis=(0, 1))
result[bandID] = movingAverage(meanPower, 10)
return result
def bandStrength(pathAndBads):
"""
Given an array [path, badChannels], load the data and return power data for each channel
"""
path, bads = pathAndBads[0], pathAndBads[1]
raw = mne.io.read_raw_edf("data/" + path, preload=True)
raw = raw.crop(tmin=START_TIME_SEC, tmax=END_TIME_SEC)
raw.info['bads'] = bads
result = calcBandPowers(raw)
result['path'] = path
return result
def powerBandAnalysis(badMapping, nThreads=4):
"""
Given a mapping path -> list of bad channels for that data, load all the path
and calculate the frequency power plots for all desired bands.
Does so multi-threaded to speed things up.
"""
# Multithreaded mapping [path, bads] -> frequency powers
p = Pool(processes=nThreads)
badArray = []
for path, bads in badMapping.iteritems():
badArray.append([path, bads])
badArray = sorted(badArray, key=lambda x: x[0]) # Sort by path.
print badArray
results = p.map(bandStrength, badArray)
ax = viz.cleanSubplots(2, 3)
ax[0, 0].set_title('log(Theta power)')
ax[0, 1].set_title('log(Beta power)')
ax[0, 2].set_title('log(Theta/Beta ratio)')
ax[1, 0].set_title('Distribution of log(Theta)')
ax[1, 1].set_title('Distribution of log(Beta)')
ax[1, 2].set_title('Distribution of log(T/B)')
for i, result in enumerate(results):
dot = '-' if i % 2 == 0 else '--' # line for Focus, dash for rest
col = [(1,0,0), (0, 1, 0), (0, 0, 1), (.9, .7, 0), (.5, 0, .5), (0, .5, .5)][i // 2] # One colour per person
t, b = result['theta'], result['beta']
if i == 0:
lt = len(t)
ax[0,0].set_xlim([0, len(t)])
ax[0,1].set_xlim([0, len(t)])
ax[0,2].set_xlim([0, len(t)])
ax[0, 0].plot(np.log(t), c=col, ls=dot)
ax[0, 1].plot(np.log(b), c=col, ls=dot)
ax[0, 2].plot(np.log(t / b), c=col, ls=dot)
# TBR distribution
hist, edges = np.histogram(np.log(t), normed=True)
ax[1, 0].plot(movingAverage(edges, 2), hist, c=col, ls=dot, label=viz.shortName(result['path']))
hist, edges = np.histogram(np.log(b), normed=True)
ax[1, 1].plot(movingAverage(edges, 2), hist, c=col, ls=dot)
hist, edges = np.histogram(np.log(t / b), normed=True)
ax[1, 2].plot(movingAverage(edges, 2), hist, c=col, ls=dot)
ax[1, 0].legend()
plt.show()
if __name__ == '__main__':
# """
powerBandAnalysis({
'T013_D001_V00_2017_05_16_Emily-Resting-30Hzfilt.edf': ['STI 014', 'EEG 55', 'EEG VREF'],
'T013_D002_V00_2017_05_16_Emily-Focus-30Hzfilt.edf': ['STI 014', 'EEG 55', 'EEG VREF'],
'T013_D003_V00_2017_05_16_Giulio-Resting-State-30Hzfilt.edf': ['STI 014', 'EEG 10', 'EEG 63', 'EEG VREF'],
'T013_D004_V00_2017_05_16_Giulio-Focused-30Hzfilt.edf': ['STI 014', 'EEG 10', 'EEG 63', 'EEG VREF'],
'T013_D005_V00_2017_05_15_Patrick-Resting-State-30Hzfilt.edf': ['STI 014', 'EEG 10', 'EEG 63', 'EEG VREF'],
'T013_D006_V00_2017_05_15_Patrick-Focus-30Hzfilt.edf': ['STI 014', 'EEG 10', 'EEG 63', 'EEG VREF'],
'T013_D007_V00_2017_05_16_MichaelH-Resting-State-30Hzfilt.edf': ['STI 014', 'EEG 10', 'EEG 63', 'EEG VREF'],
'T013_D008_V00_2017_05_16_MichaelH-Focus-30Hzfilt.edf': ['STI 014', 'EEG 10', 'EEG 63', 'EEG VREF'],
'T013_D009_V00_2017_05_15_Yana-Resting-State-30Hzfilt.edf': ['STI 014', 'EEG 18', 'EEG 23', 'EEG 46', 'EEG 56', 'EEG VREF'],
'T013_D010_V00_2017_05_15_Yana-Focus-30Hzfilt.edf': ['STI 014', 'EEG 18', 'EEG 56', 'EEG VREF'],
}, nThreads=8)
# """
# viz.showEdfSignal('data/T013_D006_V00_2017_05_15_Patrick-Focus-30Hzfilt.edf') # Use this to pick bad channels above.