-
Notifications
You must be signed in to change notification settings - Fork 187
/
predict.py
72 lines (61 loc) · 2.92 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from tensorflow.keras.models import load_model
from clean import downsample_mono, envelope
from kapre.time_frequency import STFT, Magnitude, ApplyFilterbank, MagnitudeToDecibel
from sklearn.preprocessing import LabelEncoder
import numpy as np
from glob import glob
import argparse
import os
import pandas as pd
from tqdm import tqdm
def make_prediction(args):
model = load_model(args.model_fn,
custom_objects={'STFT':STFT,
'Magnitude':Magnitude,
'ApplyFilterbank':ApplyFilterbank,
'MagnitudeToDecibel':MagnitudeToDecibel})
wav_paths = glob('{}/**'.format(args.src_dir), recursive=True)
wav_paths = sorted([x.replace(os.sep, '/') for x in wav_paths if '.wav' in x])
classes = sorted(os.listdir(args.src_dir))
labels = [os.path.split(x)[0].split('/')[-1] for x in wav_paths]
le = LabelEncoder()
y_true = le.fit_transform(labels)
results = []
for z, wav_fn in tqdm(enumerate(wav_paths), total=len(wav_paths)):
rate, wav = downsample_mono(wav_fn, args.sr)
mask, env = envelope(wav, rate, threshold=args.threshold)
clean_wav = wav[mask]
step = int(args.sr*args.dt)
batch = []
for i in range(0, clean_wav.shape[0], step):
sample = clean_wav[i:i+step]
sample = sample.reshape(-1, 1)
if sample.shape[0] < step:
tmp = np.zeros(shape=(step, 1), dtype=np.float32)
tmp[:sample.shape[0],:] = sample.flatten().reshape(-1, 1)
sample = tmp
batch.append(sample)
X_batch = np.array(batch, dtype=np.float32)
y_pred = model.predict(X_batch)
y_mean = np.mean(y_pred, axis=0)
y_pred = np.argmax(y_mean)
real_class = os.path.dirname(wav_fn).split('/')[-1]
print('Actual class: {}, Predicted class: {}'.format(real_class, classes[y_pred]))
results.append(y_mean)
np.save(os.path.join('logs', args.pred_fn), np.array(results))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Audio Classification Training')
parser.add_argument('--model_fn', type=str, default='models/lstm.h5',
help='model file to make predictions')
parser.add_argument('--pred_fn', type=str, default='y_pred',
help='fn to write predictions in logs dir')
parser.add_argument('--src_dir', type=str, default='wavfiles',
help='directory containing wavfiles to predict')
parser.add_argument('--dt', type=float, default=1.0,
help='time in seconds to sample audio')
parser.add_argument('--sr', type=int, default=16000,
help='sample rate of clean audio')
parser.add_argument('--threshold', type=str, default=20,
help='threshold magnitude for np.int16 dtype')
args, _ = parser.parse_known_args()
make_prediction(args)