forked from JusperLee/UtterancePIT-Speech-Separation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
169 lines (149 loc) · 4.95 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import warnings
import yaml
import logging
import librosa as audio_lib
import scipy.io.wavfile as wf
import numpy as np
MAX_INT16 = np.iinfo(np.int16).max
EPSILON = np.finfo(np.float32).eps
config_keys = [
"trainer", "model", "spectrogram_reader", "dataloader", "train_scp_conf",
"valid_scp_conf", "debug_scp_conf"
]
def nfft(window_size):
return int(2**np.ceil(int(np.log2(window_size))))
# return F x T or T x F
def stft(file,
frame_length=1024,
frame_shift=256,
center=False,
window="hann",
return_samps=False,
apply_abs=False,
apply_log=False,
apply_pow=False,
transpose=True):
if not os.path.exists(file):
raise FileNotFoundError("Input file {} do not exists!".format(file))
if apply_log and not apply_abs:
apply_abs = True
warnings.warn(
"Ignore apply_abs=False cause function return real values")
samps, _ = audio_lib.load(file, sr=None)
stft_mat = audio_lib.stft(
samps,
nfft(frame_length),
frame_shift,
frame_length,
window=window,
center=center)
if apply_abs:
stft_mat = np.abs(stft_mat)
if apply_pow:
stft_mat = np.power(stft_mat, 2)
if apply_log:
stft_mat = np.log(np.maximum(stft_mat, EPSILON))
if transpose:
stft_mat = np.transpose(stft_mat)
return stft_mat if not return_samps else (samps, stft_mat)
def istft(file,
stft_mat,
frame_length=1024,
frame_shift=256,
center=False,
window="hann",
transpose=True,
norm=None,
fs=16000,
nsamps=None):
if transpose:
stft_mat = np.transpose(stft_mat)
samps = audio_lib.istft(
stft_mat,
frame_shift,
frame_length,
window=window,
center=center,
length=nsamps)
# renorm if needed
if norm:
samps_norm = np.linalg.norm(samps, np.inf)
samps = samps * norm / samps_norm
# same as MATLAB and kaldi
samps_int16 = (samps * MAX_INT16).astype(np.int16)
fdir = os.path.dirname(file)
if fdir and not os.path.exists(fdir):
os.makedirs(fdir)
# NOTE: librosa 0.6.0 seems could not write non-float narray
# so use scipy.io.wavfile instead
wf.write(file, fs, samps_int16)
def apply_cmvn(feats, cmvn_dict):
if type(cmvn_dict) != dict:
raise TypeError("Input must be a python dictionary")
if 'mean' in cmvn_dict:
feats = feats - cmvn_dict['mean']
if 'std' in cmvn_dict:
feats = feats / cmvn_dict['std']
return feats
def parse_scps(scp_path):
assert os.path.exists(scp_path)
scp_dict = dict()
with open(scp_path, 'r') as f:
for scp in f:
scp_tokens = scp.strip().split()
if len(scp_tokens) != 2:
raise RuntimeError(
"Error format of context \'{}\'".format(scp))
key, addr = scp_tokens
if key in scp_dict:
raise ValueError("Duplicate key \'{}\' exists!".format(key))
scp_dict[key] = addr
return scp_dict
def filekey(path):
fname = os.path.basename(path)
if not fname:
raise ValueError("{}(Is directory path?)".format(path))
token = fname.split(".")
if len(token) == 1:
return token[0]
else:
return '.'.join(token[:-1])
def parse_yaml(yaml_conf):
if not os.path.exists(yaml_conf):
raise FileNotFoundError(
"Could not find configure files...{}".format(yaml_conf))
with open(yaml_conf, 'r') as f:
config_dict = yaml.load(f)
for key in config_keys:
if key not in config_dict:
raise KeyError("Missing {} configs in yaml".format(key))
batch_size = config_dict["dataloader"]["batch_size"]
if batch_size <= 0:
raise ValueError("Invalid batch_size: {}".format(batch_size))
num_frames = config_dict["spectrogram_reader"]["frame_length"]
num_bins = nfft(num_frames) // 2 + 1
if len(config_dict["train_scp_conf"]) != len(
config_dict["valid_scp_conf"]):
raise ValueError("Check configures in train_scp_conf/valid_scp_conf")
num_spks = 0
for key in config_dict["train_scp_conf"]:
if key[:3] == "spk":
num_spks += 1
if num_spks != config_dict["model"]["num_spks"]:
warnings.warn(
"Number of speakers configured in trainer do not match *_scp_conf, "
" correct to {}".format(num_spks))
config_dict["model"]["num_spks"] = num_spks
return num_bins, config_dict
def get_logger(
name,
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s"):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter(format_str)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger