-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
107 lines (86 loc) · 4.15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Graph2Seq: Training procedure
Date:
- Jan. 28, 2023
"""
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import math
import time
import random
import numpy as np
from utils import *
from params import *
from eval import evaluate
# set random seed
random.seed(RAND_SEED)
def train(data_train, data_val, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion,
epochs, batch_size, device, output_lang,
max_length=MAX_LENGTH):
"""
Training Process with consideration of several epochs and mini-batches
:param data: the whole training data
"""
num_instance = len(data_train)
num_batch = math.ceil(num_instance / batch_size)
print("INFO: Number of training instances: {}".format(num_instance))
print("INFO: Number of batches per epoch: {}".format(num_batch))
for epoch_idx in range(epochs):
start_time_epoch = time.time()
batch_losses = []
train_batch_perf = []
for b_idx in range(num_batch):
start_idx = b_idx * batch_size
end_idx = min(num_instance - 1, start_idx + batch_size)
training_pairs = data_train[start_idx: end_idx]
# =============== training
encoder.train()
decoder.train()
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
loss = 0
for tr_pair in training_pairs:
input_graph = tr_pair[0]
target_tensor = tr_pair[1]
# encoder
node_encs, pooled_ge = encoder(input_graph.x, input_graph.edge_index, input_graph.edge_attr)
# decoder
target_length = target_tensor.size(0)
decoder_input = torch.tensor([[SOS_token]], device=device)
decoder_hidden = pooled_ge
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
if use_teacher_forcing:
# Teacher forcing: Feed the target as the next input
for di in range(target_length):
decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden,
encoder_outputs=node_encs)
loss += criterion(decoder_output, target_tensor[di])
decoder_input = target_tensor[di] # Teacher forcing
else:
# Without teacher forcing: use its own predictions as the next input
for di in range(target_length):
decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden,
encoder_outputs=node_encs,
max_length=max_length)
topv, topi = decoder_output.topk(1)
decoder_input = topi.squeeze().detach() # detach from history as input
loss += criterion(decoder_output, target_tensor[di])
if decoder_input.item() == EOS_token:
break
# back-propagate the loss
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
with torch.no_grad():
batch_losses.append(loss.item())
train_perf_metric = evaluate(encoder, decoder, data_train, output_lang, max_length)
train_batch_perf.append(train_perf_metric)
# =============== validation
val_perf_metric = evaluate(encoder, decoder, data_val, output_lang, max_length)
# ========== Print out some info...
print("INFO: Epoch: {}, Elapsed time: {}s.".format(epoch_idx, str(time.time() - start_time_epoch)))
print("INFO: \tEpoch mean loss: {}.".format(np.mean(batch_losses)))
print("INFO: \tTraining performance: {}.".format(np.mean(train_batch_perf)))
print("INFO: \tValidation performance: {}.".format(val_perf_metric))