-
Notifications
You must be signed in to change notification settings - Fork 1
/
training.py
110 lines (79 loc) · 3.09 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import math
import sys
import torch
def progress(loss, epoch, batch, batch_size, dataset_size):
"""
Print the progress of the training for each epoch
"""
batches = math.ceil(float(dataset_size) / batch_size)
count = batch * batch_size
bar_len = 40
filled_len = int(round(bar_len * count / float(dataset_size)))
bar = '=' * filled_len + '-' * (bar_len - filled_len)
status = 'Epoch {}, Loss: {:.4f}'.format(epoch, loss)
_progress_str = "\r \r [{}] ...{}".format(bar, status)
sys.stdout.write(_progress_str)
sys.stdout.flush()
if batch == batches:
print()
def train_dataset(_epoch, dataloader, model, loss_function, optimizer):
# IMPORTANT: switch to train mode
# enable regularization layers, such as Dropout
model.train()
running_loss = 0.0
# obtain the model's device ID
device = next(model.parameters()).device
for index, batch in enumerate(dataloader, 1):
# get the inputs (batch)
inputs, labels, lengths = batch
# move the batch tensors to the right device
... # EX9
# Step 1 - zero the gradients
# Remember that PyTorch accumulates gradients.
# We need to clear them out before each batch!
... # EX9
# Step 2 - forward pass: y' = model(x)
... # EX9
# Step 3 - compute loss: L = loss_function(y, y')
loss = ... # EX9
# Step 4 - backward pass: compute gradient wrt model parameters
... # EX9
# Step 5 - update weights
... # EX9
running_loss += loss.data.item()
# print statistics
progress(loss=loss.data.item(),
epoch=_epoch,
batch=index,
batch_size=dataloader.batch_size,
dataset_size=len(dataloader.dataset))
return running_loss / index
def eval_dataset(dataloader, model, loss_function):
# IMPORTANT: switch to eval mode
# disable regularization layers, such as Dropout
model.eval()
running_loss = 0.0
y_pred = [] # the predicted labels
y = [] # the gold labels
# obtain the model's device ID
device = next(model.parameters()).device
# IMPORTANT: in evaluation mode, we don't want to keep the gradients
# so we do everything under torch.no_grad()
with torch.no_grad():
for index, batch in enumerate(dataloader, 1):
# get the inputs (batch)
inputs, labels, lengths = batch
# Step 1 - move the batch tensors to the right device
... # EX9
# Step 2 - forward pass: y' = model(x)
... # EX9
# Step 3 - compute loss.
# We compute the loss only for inspection (compare train/test loss)
# because we do not actually backpropagate in test time
loss = ... # EX9
# Step 4 - make predictions (class = argmax of posteriors)
... # EX9
# Step 5 - collect the predictions, gold labels and batch loss
... # EX9
running_loss += loss.data.item()
return running_loss / index, (y_pred, y)