-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
97 lines (80 loc) · 3.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split
from data import TreebankDataset, tree_collate
from model import treeLoss, Score
def log(string):
"""
Log the string into logs/run.log
:param: string
"""
with open('logs/run.log', 'a') as logFile:
print(string, file=logFile)
def train(args):
"""
Train the score function over treebank
:param: batches (default = 50)
:param: epochs (default = 5)
:param: split (default = 0.8) - train test split ratio
:param: samples (default = 2000) - number of samples to train on in treebank
:param: cuda (default = True)
"""
batches = args.batches
epochs = args.epochs
split = args.split
samples = args.samples
cuda = args.cuda
# Create train test split
treebank = TreebankDataset(train=True, samples=samples)
trainSize = int(split * len(treebank))
testSize = len(treebank) - trainSize
print("Generating train test split")
treebankTrain, treebankTest = random_split(treebank, [trainSize, testSize])
trainLoader = DataLoader(treebankTrain, batch_size=batches, collate_fn=tree_collate)
testLoader = DataLoader(treebankTest, batch_size=batches, collate_fn=tree_collate)
if cuda and torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
model = Score(device).to(device)
optimizer = torch.optim.Adadelta(model.parameters(), rho=0.95)
for epoch in range(epochs):
print("Epoch step {} of {}".format(epoch+1, epochs))
# Train
print("Training..")
epochLoss = 0.0
loss = torch.zeros(1)
for i, batch in enumerate(tqdm(trainLoader)):
scores = model(batch.to(device, dtype=torch.float))
loss = treeLoss(scores)
log("[%d, %d] loss: %.3f" %
(epoch+1, i+1, loss.item()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
epochLoss += loss.item()
print("[%d, %d] loss: %.3f" %
(epoch+1, epochs, epochLoss / trainSize))
#Test
print("Validation..")
testLoss = 0.0
for batch in tqdm(testLoader):
with torch.no_grad():
scores = model(batch.to(device, dtype=torch.float))
loss = treeLoss(scores)
testLoss += loss.item()
log("[%d %d] val loss: %.3f" %
(epoch+1, epochs, testLoss / testSize))
print("[%d %d] val loss: %.3f" %
(epoch+1, epochs, testLoss / testSize))
torch.save(model.state_dict(), 'models/run-' + str(epoch+1) + '.pt')
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Train score function")
parser.add_argument('--batches', action="store", dest="batches", type=int, default=50)
parser.add_argument('--epochs', action="store", dest="epochs", type=int, default=5)
parser.add_argument('--split', action="store", dest="split", type=float, default=0.8)
parser.add_argument('--samples', action="store", dest="samples", type=int, default=2000)
parser.add_argument('--cuda', action="store_true", dest="cuda", default=True)
args = parser.parse_args()
train(args)