-
Notifications
You must be signed in to change notification settings - Fork 187
/
train.py
121 lines (100 loc) · 4.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import tensorflow as tf
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
import os
from scipy.io import wavfile
import pandas as pd
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from models import Conv1D, Conv2D, LSTM
from tqdm import tqdm
from glob import glob
import argparse
import warnings
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, wav_paths, labels, sr, dt, n_classes,
batch_size=32, shuffle=True):
self.wav_paths = wav_paths
self.labels = labels
self.sr = sr
self.dt = dt
self.n_classes = n_classes
self.batch_size = batch_size
self.shuffle = True
self.on_epoch_end()
def __len__(self):
return int(np.floor(len(self.wav_paths) / self.batch_size))
def __getitem__(self, index):
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
wav_paths = [self.wav_paths[k] for k in indexes]
labels = [self.labels[k] for k in indexes]
# generate a batch of time data
X = np.empty((self.batch_size, int(self.sr*self.dt), 1), dtype=np.float32)
Y = np.empty((self.batch_size, self.n_classes), dtype=np.float32)
for i, (path, label) in enumerate(zip(wav_paths, labels)):
rate, wav = wavfile.read(path)
X[i,] = wav.reshape(-1, 1)
Y[i,] = to_categorical(label, num_classes=self.n_classes)
return X, Y
def on_epoch_end(self):
self.indexes = np.arange(len(self.wav_paths))
if self.shuffle:
np.random.shuffle(self.indexes)
def train(args):
src_root = args.src_root
sr = args.sample_rate
dt = args.delta_time
batch_size = args.batch_size
model_type = args.model_type
params = {'N_CLASSES':len(os.listdir(args.src_root)),
'SR':sr,
'DT':dt}
models = {'conv1d':Conv1D(**params),
'conv2d':Conv2D(**params),
'lstm': LSTM(**params)}
assert model_type in models.keys(), '{} not an available model'.format(model_type)
csv_path = os.path.join('logs', '{}_history.csv'.format(model_type))
wav_paths = glob('{}/**'.format(src_root), recursive=True)
wav_paths = [x.replace(os.sep, '/') for x in wav_paths if '.wav' in x]
classes = sorted(os.listdir(args.src_root))
le = LabelEncoder()
le.fit(classes)
labels = [os.path.split(x)[0].split('/')[-1] for x in wav_paths]
labels = le.transform(labels)
wav_train, wav_val, label_train, label_val = train_test_split(wav_paths,
labels,
test_size=0.1,
random_state=0)
assert len(label_train) >= args.batch_size, 'Number of train samples must be >= batch_size'
if len(set(label_train)) != params['N_CLASSES']:
warnings.warn('Found {}/{} classes in training data. Increase data size or change random_state.'.format(len(set(label_train)), params['N_CLASSES']))
if len(set(label_val)) != params['N_CLASSES']:
warnings.warn('Found {}/{} classes in validation data. Increase data size or change random_state.'.format(len(set(label_val)), params['N_CLASSES']))
tg = DataGenerator(wav_train, label_train, sr, dt,
params['N_CLASSES'], batch_size=batch_size)
vg = DataGenerator(wav_val, label_val, sr, dt,
params['N_CLASSES'], batch_size=batch_size)
model = models[model_type]
cp = ModelCheckpoint('models/{}.h5'.format(model_type), monitor='val_loss',
save_best_only=True, save_weights_only=False,
mode='auto', save_freq='epoch', verbose=1)
csv_logger = CSVLogger(csv_path, append=False)
model.fit(tg, validation_data=vg,
epochs=30, verbose=1,
callbacks=[csv_logger, cp])
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Audio Classification Training')
parser.add_argument('--model_type', type=str, default='lstm',
help='model to run. i.e. conv1d, conv2d, lstm')
parser.add_argument('--src_root', type=str, default='clean',
help='directory of audio files in total duration')
parser.add_argument('--batch_size', type=int, default=16,
help='batch size')
parser.add_argument('--delta_time', '-dt', type=float, default=1.0,
help='time in seconds to sample audio')
parser.add_argument('--sample_rate', '-sr', type=int, default=16000,
help='sample rate of clean audio')
args, _ = parser.parse_known_args()
train(args)