-
Notifications
You must be signed in to change notification settings - Fork 10
/
dataset.py.exit
227 lines (187 loc) · 8.03 KB
/
dataset.py.exit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os
import random
import pickle
import numpy as np
import torch as th
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pack_sequence, pad_sequence
from utils import parse_scps, stft, apply_cmvn, EPSILON, get_logger
logger = get_logger(__name__)
class SpectrogramReader(object):
"""
Wrapper for short-time fourier transform of dataset
"""
def __init__(self, wave_scp, **kwargs):
if not os.path.exists(wave_scp):
raise FileNotFoundError("Could not find file {}".format(wave_scp))
self.stft_kwargs = kwargs
self.wave_dict = parse_scps(wave_scp)
self.wave_keys = [key for key in self.wave_dict.keys()]
logger.info(
"Create SpectrogramReader for {} with {} utterances".format(
wave_scp, len(self.wave_dict)))
def __len__(self):
return len(self.wave_dict)
def __contains__(self, key):
return key in self.wave_dict
# stft
def _load(self, key):
return stft(self.wave_dict[key], **self.stft_kwargs)
'''
# sequential index
def __iter__(self):
for key in self.wave_dict:
yield key, self._load(key)
'''
# random index
def __getitem__(self, key):
if key not in self.wave_dict:
raise KeyError("Could not find utterance {}".format(key))
return self._load(key)
class Datasets(object):
def __init__(self, mix_reader, target_reader_list):
self.mix_reader = mix_reader
self.target_reader_list = target_reader_list
self.key_list = mix_reader.wave_keys
self.num_spks = len(target_reader_list)
def __len__(self):
return len(self.mix_reader)
def _has_target(self, key):
for target in self.target_reader_list:
if key not in target:
return False
return True
def __getitem__(self, index):
key = self.key_list[index]
mix = self.mix_reader[key]
if self._has_target(key):
ref = [reader[key] for reader in self.target_reader_list]
else:
raise ValueError('Not have Target Data')
return {
'mix': mix.astype(np.float32),
"ref": [spk.astype(np.float32) for spk in ref]
}
class DataLoaders(object):
def __init__(self,
dataset,
num_workers=1,
batch_size=20,
shuffle=True,
drop_last=False,
apply_log=True,
mvn_dict=None):
#if batch_size == 1:
# raise ValueError("Now do not support perutt training")
self.shuffle = shuffle
self.batch_size = batch_size
self.mvn_dict = mvn_dict
self.drop_last = drop_last
self.num_workers = num_workers
self.apply_log = apply_log
self.num_spks = dataset.num_spks
self.eg_loader = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=shuffle,
collate_fn=self._collate)
if mvn_dict:
logger.info("Using cmvn dictionary from {}".format(mvn_dict))
with open(mvn_dict, "rb") as f:
self.mvn_dict = pickle.load(f)
def _collate(self, batch):
chunk = []
for eg in batch:
chunk.append(eg)
return chunk
def __len__(self):
return len(self.eg_loader)
def _transform(self, eg):
"""
Transform original spectrogram
If mixture_specs is a complex object, it means PAM will be used for training
It can be configured in .yaml, egs: apply_abs=false to produce complex results
If mixture_specs is real, we will using AM(ratio mask)
Arguments:
mixture_specs: non-log complex/real spectrogram
targets_specs_list: list of non-log complex/real spectrogram for each target speakers
Returns:
python dictionary with four attributes:
num_frames: length of current utterance
feature: input feature for networks, egs: log spectrogram + cmvn
source_attr: a dictionary with at most 2 keys: spectrogram and phase(for PSM), each contains a tensor
target_attr: same keys like source_attr, each keys correspond to a tensor list
"""
mix_specs = eg['mix']
targets_specs_list = eg['ref']
# NOTE: mixture_specs may be complex or real
input_spectra = np.abs(mix_specs) if np.iscomplexobj(
mix_specs) else mix_specs
# apply_log and cmvn, for nnet input
if self.apply_log:
input_spectra = np.log(np.maximum(input_spectra, EPSILON))
if self.mvn_dict:
input_spectra = apply_cmvn(input_spectra, self.mvn_dict)
source_attr = {}
target_attr = {}
if np.iscomplexobj(mix_specs):
source_attr["spectrogram"] = th.tensor(
np.abs(mix_specs), dtype=th.float32)
target_attr["spectrogram"] = [
th.tensor(np.abs(t), dtype=th.float32) for t in targets_specs_list]
source_attr["phase"] = th.tensor(
np.angle(mix_specs), dtype=th.float32)
target_attr["phase"] = [
th.tensor(np.angle(t), dtype=th.float32) for t in targets_specs_list]
else:
source_attr["spectrogram"] = th.tensor(mix_specs, dtype=th.float32)
target_attr["spectrogram"] = [
th.tensor(t, dtype=th.float32) for t in targets_specs_list]
return {
"num_frames": mix_specs.shape[0],
"feature": th.tensor(input_spectra, dtype=th.float32),
"source_attr": source_attr,
"target_attr": target_attr
}
def _process(self, egs):
"""
Transform utterance index into a minbatch
Arguments:
index: a list type [{},{},{}]
Returns:
input_sizes: a tensor correspond to utterance length
input_feats: packed sequence to feed networks
source_attr/target_attr: dictionary contains spectrogram/phase needed in loss computation
"""
if type(egs) is not list:
raise ValueError("Unsupported index type({})".format(type(egs)))
def prepare_target(dict_lsit, index, key):
return pad_sequence([d["target_attr"][key][index] for d in dict_lsit], batch_first=True)
dict_list = sorted([self._transform(eg) for eg in egs],
key=lambda x: x['num_frames'], reverse=True)
#input_feats = pack_sequence([d['feature'] for d in dict_list])
input_feats = pad_sequence([d['feature'] for d in dict_list],batch_first=True)
input_sizes = th.tensor([d['num_frames']
for d in dict_list], dtype=th.float32)
source_attr = {}
target_attr = {}
source_attr['spectrogram'] = pad_sequence(
[d['source_attr']["spectrogram"] for d in dict_list], batch_first=True)
target_attr['spectrogram'] = [prepare_target(
dict_list, index, 'spectrogram') for index in range(self.num_spks)]
if 'phase' in dict_list[0]['source_attr'] and 'phase' in dict_list[0]['target_attr']:
source_attr['phase'] = pad_sequence(
[d['source_attr']['phase'] for d in dict_list], batch_first=True)
target_attr['phase'] = [prepare_target(
dict_list, index, 'phase') for index in range(self.num_spks)]
return input_sizes, input_feats, source_attr, target_attr
def __iter__(self):
num_utts = 0
log_period = 2000 // self.batch_size
for e, eg in enumerate(self.eg_loader):
num_utts += (len(eg) if type(eg) is list else 1)
if not (e+1) % log_period:
logger.info(
"Processed {} batches, {} utterances".format(e + 1, num_utts))
yield self._process(eg)
logger.info("Processed {} utterances in total".format(num_utts))