-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtrain.py
78 lines (60 loc) · 2.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Trains Wav2Letter model using speech data
TODO:
* show accuracy metrics
* add more diverse datasets
* train, val, test split
"""
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from Wav2Letter.model import Wav2Letter
from Wav2Letter.data import GoogleSpeechCommand
from Wav2Letter.decoder import GreedyDecoder
def train(batch_size, epochs):
# load saved numpy arrays for google speech command
gs = GoogleSpeechCommand()
_inputs, _targets = gs.load_vectors("./speech_data")
# paramters
batch_size = batch_size
mfcc_features = 13
grapheme_count = gs.intencode.grapheme_count
print("training google speech dataset")
print("data size", len(_inputs))
print("batch_size", batch_size)
print("epochs", epochs)
print("num_mfcc_features", mfcc_features)
print("grapheme_count", grapheme_count)
# torch tensors
inputs = torch.Tensor(_inputs)
targets = torch.IntTensor(_targets)
print("input shape", inputs.shape)
print("target shape", targets.shape)
# Initialize model, loss, optimizer
model = Wav2Letter(mfcc_features, grapheme_count)
print(model.layers)
ctc_loss = nn.CTCLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# Each mfcc feature is a channel
# https://pytorch.org/docs/stable/nn.html#torch.nn.Conv1d
# transpose (sample_size, in_frame_len, mfcc_features)
# to (sample_size, mfcc_features, in_frame_len)
inputs = inputs.transpose(1, 2)
print("transposed input", inputs.shape)
model.fit(inputs, targets, optimizer, ctc_loss, batch_size, epoch=epochs)
sample = inputs[0]
sample_target = targets[0]
log_probs = model.eval(sample)
output = GreedyDecoder(log_probs)
print("sample target", sample_target)
print("predicted", output)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Wav2Letter')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
help='total epochs (default: 100)')
args = parser.parse_args()
batch_size = args.batch_size
epochs = args.epochs
train(batch_size, epochs)