forked from gist-ailab/SleePyCo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloader.py
121 lines (98 loc) · 4.8 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import glob
import torch
import numpy as np
from transform import *
from torch.utils.data import Dataset
class EEGDataLoader(Dataset):
def __init__(self, config, fold, set='train'):
self.set = set
self.fold = fold
self.sr = 100
self.dset_cfg = config['dataset']
self.root_dir = self.dset_cfg['root_dir']
self.dset_name = self.dset_cfg['name']
self.num_splits = self.dset_cfg['num_splits']
self.eeg_channel = self.dset_cfg['eeg_channel']
self.seq_len = self.dset_cfg['seq_len']
self.target_idx = self.dset_cfg['target_idx']
self.training_mode = config['training_params']['mode']
self.dataset_path = os.path.join(self.root_dir, 'dset', self.dset_name, 'npz')
self.inputs, self.labels, self.epochs = self.split_dataset()
if self.training_mode == 'pretrain':
self.transform = Compose(
transforms=[
RandomAmplitudeScale(),
RandomTimeShift(),
RandomDCShift(),
RandomZeroMasking(),
RandomAdditiveGaussianNoise(),
RandomBandStopFilter(),
]
)
self.two_transform = TwoTransform(self.transform)
def __len__(self):
return len(self.epochs)
def __getitem__(self, idx):
n_sample = 30 * self.sr * self.seq_len
file_idx, idx, seq_len = self.epochs[idx]
inputs = self.inputs[file_idx][idx:idx+seq_len]
if self.set == 'train':
if self.training_mode == 'pretrain':
assert seq_len == 1
input_a, input_b = self.two_transform(inputs)
input_a = torch.from_numpy(input_a).float()
input_b = torch.from_numpy(input_b).float()
inputs = [input_a, input_b]
elif self.training_mode in ['scratch', 'fullyfinetune', 'freezefinetune']:
inputs = inputs.reshape(1, n_sample)
inputs = torch.from_numpy(inputs).float()
else:
raise NotImplementedError
else:
if not self.training_mode == 'pretrain':
inputs = inputs.reshape(1, n_sample)
inputs = torch.from_numpy(inputs).float()
labels = self.labels[file_idx][idx:idx+seq_len]
labels = torch.from_numpy(labels).long()
labels = labels[self.target_idx]
return inputs, labels
def split_dataset(self):
file_idx = 0
inputs, labels, epochs = [], [], []
data_root = os.path.join(self.dataset_path, self.eeg_channel)
data_fname_list = [os.path.basename(x) for x in sorted(glob.glob(os.path.join(data_root, '*.npz')))]
data_fname_dict = {'train': [], 'test': [], 'val': []}
split_idx_list = np.load(os.path.join('./split_idx', 'idx_{}.npy'.format(self.dset_name)), allow_pickle=True)
assert len(split_idx_list) == self.num_splits
if self.dset_name == 'Sleep-EDF-2013':
for i in range(len(data_fname_list)):
subject_idx = int(data_fname_list[i][3:5])
if subject_idx == self.fold - 1:
data_fname_dict['test'].append(data_fname_list[i])
elif subject_idx in split_idx_list[self.fold - 1]:
data_fname_dict['val'].append(data_fname_list[i])
else:
data_fname_dict['train'].append(data_fname_list[i])
elif self.dset_name == 'Sleep-EDF-2018':
for i in range(len(data_fname_list)):
subject_idx = int(data_fname_list[i][3:5])
if subject_idx in split_idx_list[self.fold - 1][self.set]:
data_fname_dict[self.set].append(data_fname_list[i])
elif self.dset_name == 'MASS' or self.dset_name == 'Physio2018' or self.dset_name == 'SHHS':
for i in range(len(data_fname_list)):
if i in split_idx_list[self.fold - 1][self.set]:
data_fname_dict[self.set].append(data_fname_list[i])
else:
raise NameError("dataset '{}' cannot be found.".format(self.dataset))
for data_fname in data_fname_dict[self.set]:
npz_file = np.load(os.path.join(data_root, data_fname))
inputs.append(npz_file['x'])
labels.append(npz_file['y'])
seq_len = self.seq_len
if self.dset_name== 'MASS' and ('-02-' in data_fname or '-04-' in data_fname or '-05-' in data_fname):
seq_len = int(self.seq_len * 1.5)
for i in range(len(npz_file['y']) - seq_len + 1):
epochs.append([file_idx, i, seq_len])
file_idx += 1
return inputs, labels, epochs