forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
create-lmdb.py
executable file
·135 lines (112 loc) · 4.49 KB
/
create-lmdb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: create-lmdb.py
# Author: Yuxin Wu
import argparse
import numpy as np
import os
import string
import bob.ap
import scipy.io.wavfile as wavfile
from tensorpack.dataflow import DataFlow, LMDBSerializer
from tensorpack.utils import fs, logger, serialize, get_tqdm
from tensorpack.utils.argtools import memoized
from tensorpack.utils.stats import OnlineMoments
CHARSET = set(string.ascii_lowercase + ' ')
PHONEME_LIST = [
'aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', 'b', 'bcl', 'ch', 'd', 'dcl', 'dh',
'dx', 'eh', 'el', 'em', 'en', 'eng', 'epi', 'er', 'ey', 'f', 'g', 'gcl', 'h#', 'hh', 'hv', 'ih',
'ix', 'iy', 'jh', 'k', 'kcl', 'l', 'm', 'n', 'ng', 'nx', 'ow', 'oy', 'p', 'pau', 'pcl', 'q', 'r',
's', 'sh', 't', 'tcl', 'th', 'uh', 'uw', 'ux', 'v', 'w', 'y', 'z', 'zh']
PHONEME_DIC = {v: k for k, v in enumerate(PHONEME_LIST)}
WORD_DIC = {v: k for k, v in enumerate(string.ascii_lowercase + ' ')}
def read_timit_txt(f):
f = open(f)
line = f.readlines()[0].strip().split(' ')
line = line[2:]
line = ' '.join(line)
line = line.replace('.', '').lower()
line = filter(lambda c: c in CHARSET, line)
f.close()
ret = []
for c in line:
ret.append(WORD_DIC[c])
return np.asarray(ret)
def read_timit_phoneme(f):
f = open(f)
pho = []
for line in f:
line = line.strip().split(' ')[-1]
pho.append(PHONEME_DIC[line])
f.close()
return np.asarray(pho)
@memoized
def get_bob_extractor(fs, win_length_ms=10, win_shift_ms=5,
n_filters=55, n_ceps=15, f_min=0., f_max=6000,
delta_win=2, pre_emphasis_coef=0.95, dct_norm=True,
mel_scale=True):
ret = bob.ap.Ceps(fs, win_length_ms, win_shift_ms, n_filters, n_ceps, f_min,
f_max, delta_win, pre_emphasis_coef, mel_scale, dct_norm)
return ret
def diff_feature(feat, nd=1):
diff = feat[1:] - feat[:-1]
feat = feat[1:]
if nd == 1:
return np.concatenate((feat, diff), axis=1)
elif nd == 2:
d2 = diff[1:] - diff[:-1]
return np.concatenate((feat[1:], diff[1:], d2), axis=1)
def get_feature(f):
fs, signal = wavfile.read(f)
signal = signal.astype('float64')
feat = get_bob_extractor(fs, n_filters=26, n_ceps=13)(signal)
feat = diff_feature(feat, nd=2)
return feat
class RawTIMIT(DataFlow):
def __init__(self, dirname, label='phoneme'):
self.dirname = dirname
assert os.path.isdir(dirname), dirname
self.filelists = [k for k in fs.recursive_walk(self.dirname)
if k.endswith('.wav')]
logger.info("Found {} wav files ...".format(len(self.filelists)))
assert len(self.filelists), "Found no '.wav' files!"
assert label in ['phoneme', 'letter'], label
self.label = label
def __len__(self):
return len(self.filelists)
def __iter__(self):
for f in self.filelists:
feat = get_feature(f)
if self.label == 'phoneme':
label = read_timit_phoneme(f[:-4] + '.PHN')
elif self.label == 'letter':
label = read_timit_txt(f[:-4] + '.TXT')
yield [feat, label]
def compute_mean_std(db, fname):
ds = LMDBSerializer.load(db, shuffle=False)
ds.reset_state()
o = OnlineMoments()
for dp in get_tqdm(ds):
feat = dp[0] # len x dim
for f in feat:
o.feed(f)
logger.info("Writing to {} ...".format(fname))
with open(fname, 'wb') as f:
f.write(serialize.dumps([o.mean, o.std]))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(title='command', dest='command')
parser_db = subparsers.add_parser('build', help='build a LMDB database')
parser_db.add_argument('--dataset',
help='path to TIMIT TRAIN or TEST directory', required=True)
parser_db.add_argument('--db', help='output lmdb file', required=True)
parser_stat = subparsers.add_parser('stat', help='compute statistics (mean/std) of dataset')
parser_stat.add_argument('--db', help='input lmdb file', required=True)
parser_stat.add_argument('-o', '--output',
help='output statistics file', default='stats.data')
args = parser.parse_args()
if args.command == 'build':
ds = RawTIMIT(args.dataset)
LMDBSerializer.save(ds, args.db)
elif args.command == 'stat':
compute_mean_std(args.db, args.output)