forked from spro/char-rnn.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
executable file
·110 lines (90 loc) · 3.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python
# https://github.com/spro/char-rnn.pytorch
import torch
import torch.nn as nn
from torch.autograd import Variable
import argparse
import os
from tqdm import tqdm
from helpers import *
from model import *
from generate import *
# Parse command line arguments
argparser = argparse.ArgumentParser()
argparser.add_argument('filename', type=str)
argparser.add_argument('--model', type=str, default="gru")
argparser.add_argument('--n_epochs', type=int, default=2000)
argparser.add_argument('--print_every', type=int, default=100)
argparser.add_argument('--hidden_size', type=int, default=100)
argparser.add_argument('--n_layers', type=int, default=2)
argparser.add_argument('--learning_rate', type=float, default=0.01)
argparser.add_argument('--chunk_len', type=int, default=200)
argparser.add_argument('--batch_size', type=int, default=100)
argparser.add_argument('--shuffle', action='store_true')
argparser.add_argument('--cuda', action='store_true')
args = argparser.parse_args()
if args.cuda:
print("Using CUDA")
file, file_len = read_file(args.filename)
def random_training_set(chunk_len, batch_size):
inp = torch.LongTensor(batch_size, chunk_len)
target = torch.LongTensor(batch_size, chunk_len)
for bi in range(batch_size):
start_index = random.randint(0, file_len - chunk_len -1)
end_index = start_index + chunk_len + 1
chunk = file[start_index:end_index]
inp[bi] = char_tensor(chunk[:-1])
target[bi] = char_tensor(chunk[1:])
inp = Variable(inp)
target = Variable(target)
if args.cuda:
inp = inp.cuda()
target = target.cuda()
return inp, target
def train(inp, target):
hidden = decoder.init_hidden(args.batch_size)
if args.cuda:
if args.model == "gru":
hidden = hidden.cuda()
else:
hidden = (hidden[0].cuda(), hidden[1].cuda())
decoder.zero_grad()
loss = 0
for c in range(args.chunk_len):
output, hidden = decoder(inp[:,c], hidden)
loss += criterion(output.view(args.batch_size, -1), target[:,c])
loss.backward()
decoder_optimizer.step()
return loss.item() / args.chunk_len
def save():
save_filename = os.path.splitext(os.path.basename(args.filename))[0] + '.pt'
torch.save(decoder, save_filename)
print('Saved as %s' % save_filename)
# Initialize models and start training
decoder = CharRNN(
n_characters,
args.hidden_size,
n_characters,
model=args.model,
n_layers=args.n_layers,
)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=args.learning_rate)
criterion = nn.CrossEntropyLoss()
if args.cuda:
decoder.cuda()
start = time.time()
all_losses = []
loss_avg = 0
try:
print("Training for %d epochs..." % args.n_epochs)
for epoch in tqdm(range(1, args.n_epochs + 1)):
loss = train(*random_training_set(args.chunk_len, args.batch_size))
loss_avg += loss
if epoch % args.print_every == 0:
print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / args.n_epochs * 100, loss))
print(generate(decoder, 'Wh', 100, cuda=args.cuda), '\n')
print("Saving...")
save()
except KeyboardInterrupt:
print("Saving before quit...")
save()