-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
executable file
·97 lines (80 loc) · 3.37 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from model import MemN2N
from helpers import dataloader, get_fname, get_params
def train(train_iter, model, optimizer, epochs, max_clip, valid_iter=None):
total_loss = 0
valid_data = list(valid_iter)
valid_loss = None
next_epoch_to_report = 5
pad = model.vocab.stoi['<pad>']
for _, batch in enumerate(train_iter, start=1):
story = batch.story
query = batch.query
answer = batch.answer
optimizer.zero_grad()
outputs = model(story, query)
loss = F.nll_loss(outputs, answer.view(-1), ignore_index=pad, reduction='sum')
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_clip)
optimizer.step()
total_loss += loss.item()
# linear start
if model.use_ls:
loss = 0
for k, batch in enumerate(valid_data, start=1):
story = batch.story
query = batch.query
answer = batch.answer
outputs = model(story, query)
loss += F.nll_loss(outputs, answer.view(-1), ignore_index=pad, reduction='sum').item()
loss = loss / k
if valid_loss and valid_loss <= loss:
model.use_ls = False
else:
valid_loss = loss
if train_iter.epoch == next_epoch_to_report:
print("#! epoch {:d} average batch loss: {:5.4f}".format(
int(train_iter.epoch), total_loss / len(train_iter)))
next_epoch_to_report += 5
if int(train_iter.epoch) == train_iter.epoch:
total_loss = 0
if train_iter.epoch == epochs:
break
def eval(test_iter, model):
total_error = 0
for k, batch in enumerate(test_iter, start=1):
story = batch.story
query = batch.query
answer = batch.answer
outputs = model(story, query)
_, outputs = torch.max(outputs, -1)
total_error += torch.mean((outputs != answer.view(-1)).float()).item()
print("#! average error: {:5.1f}".format(total_error / k * 100))
def run(config):
print("#! preparing data...")
train_iter, valid_iter, test_iter, vocab = dataloader(config.batch_size, config.memory_size,
config.task, config.joint, config.tenk)
print("#! instantiating model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MemN2N(get_params(config), vocab).to(device)
if config.file:
with open(os.path.join(config.save_dir, config.file), 'rb') as f:
if torch.cuda.is_available():
state_dict = torch.load(f, map_location=lambda storage, loc: storage.cuda())
else:
state_dict = torch.load(f, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict)
if config.train:
print("#! training...")
optimizer = optim.Adam(model.parameters(), config.lr)
train(train_iter, model, optimizer, config.num_epochs, config.max_clip, valid_iter)
if not os.path.isdir(config.save_dir):
os.makedirs(config.save_dir)
torch.save(model.state_dict(), os.path.join(config.save_dir, get_fname(config)))
print("#! testing...")
with torch.no_grad():
eval(test_iter, model)