-
Notifications
You must be signed in to change notification settings - Fork 3
/
training_phobert.py
158 lines (137 loc) · 4.62 KB
/
training_phobert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
from transformers import AutoModel, AutoTokenizer, AdamW
from torch import nn as nn
import json
from torch.utils.data import DataLoader, TensorDataset
from phobert_finetuned import PhoBERT_finetuned
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# Load train and validation dataset
with open('content.json', 'r', encoding="utf-8") as c:
contents = json.load(c)
with open('val_content.json', 'r', encoding="utf-8") as v:
val_contents = json.load(v)
# Load model PhoBERT and its tokenizer
phobert = AutoModel.from_pretrained('vinai/phobert-base')
tokenizer = AutoTokenizer.from_pretrained('vinai/phobert-base')
# Process the train dataset:
tags = []
X = []
y = []
for content in contents['intents']:
tag = content['tag']
for pattern in content['patterns']:
X.append(pattern)
tags.append(tag)
tags_set = sorted(set(tags))
for tag in tags:
label = tags_set.index(tag)
y.append(label)
token_train = {}
token_train = tokenizer.batch_encode_plus(
X,
max_length=13,
padding='max_length',
truncation=True
)
X_train_mask = torch.tensor(token_train['attention_mask'])
X_train = torch.tensor(token_train['input_ids'])
y_train = torch.tensor(y)
# Process the validation dataset:
tags_val = []
X_val = []
y_val = []
for val_content in val_contents['intents']:
tag_val = val_content['tag']
for val_pattern in val_content['patterns']:
X_val.append(val_pattern)
tags_val.append(tag_val)
for tag_val in tags_val:
label = tags_set.index(tag_val)
y_val.append(label)
token_val = {}
token_val = tokenizer.batch_encode_plus(
X_val,
max_length=13,
padding='max_length',
truncation=True
)
X_val_mask = torch.tensor(token_val['attention_mask'])
X_val = torch.tensor(token_val['input_ids'])
y_val = torch.tensor(y_val)
# Hyperparameter:
batch_size = 8
hidden_size = 512
num_class = len(tags_set)
lr = 7.5e-5
num_epoch = 100
dataset = TensorDataset(X_train, X_train_mask, y_train)
train_data = DataLoader(dataset=dataset, batch_size=batch_size,
shuffle=True)
val_dataset = TensorDataset(X_val, X_val_mask, y_val)
val_data = DataLoader(dataset=val_dataset, batch_size=batch_size,
shuffle=True)
# Model:
for param in phobert.parameters():
param.requires_grad = False
model = PhoBERT_finetuned(phobert, hidden_size=hidden_size,
num_class=num_class)
model = model.to(device)
optimizer = AdamW(model.parameters(), lr=lr)
loss_f = nn.NLLLoss()
def train():
print("Training...")
model.train()
total_loss = 0
for (step, batch) in enumerate(train_data):
# push the batch to gpu
batch = [r.to(device) for r in batch]
sent_id, mask, labels = batch
# clear previously calculated gradients
model.zero_grad()
pred = model(sent_id, mask)
loss = loss_f(pred, labels)
total_loss += loss.item()
loss.backward()
# clip the gradients to 1.0.
# It helps in preventing the exploding gradient problem
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
avg_loss = total_loss/len(train_data)
return avg_loss
def evaluate():
print("Evaluating...")
# deactivate dropout layers
model.eval()
total_loss = 0
# iterate over batches
for (step, batch) in enumerate(val_data):
# push the batch to gpu
batch = [t.to(device) for t in batch]
sent_id_val, mask_val, labels_val = batch
# deactivate autograd
with torch.no_grad():
# model predictions
preds = model(sent_id_val, mask_val)
# compute the validation loss between actual and predicted values
loss = loss_f(preds, labels_val)
total_loss = total_loss + loss.item()
preds = preds.detach().cpu().numpy()
# compute the validation loss of the epoch
avg_loss = total_loss / len(val_data)
return avg_loss
best_valid_loss = float('inf')
train_losses = []
valid_losses = []
for epoch in range(num_epoch):
print('\n Epoch {:}/{:}'.format(epoch + 1, num_epoch))
train_loss = train()
# evaluate model
valid_loss = evaluate()
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'saved_weights.pth')
train_losses.append(train_loss)
valid_losses.append(valid_loss)
print(f'\nTraining Loss: {train_loss:.3f}')
print(f'Validation Loss: {valid_loss:.3f}')