-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpretrain.py
92 lines (68 loc) · 3.44 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2018 Dong-Hyun Lee, Kakao Brain.
# (Strongly inspired by original Google BERT code and Hugging Face's code)
""" Pretrain transformer with Masked LM and Sentence Classification """
from random import randint, shuffle
from random import random as rand
import fire
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from dataloaders import tokenization
from models import modules
from bin import optimizer
from bin import train
from dataloaders import Pipeline, Preprocess4Pretrain, SentPairDataLoader
from utils.utils import set_seeds, get_device, get_random_word, truncate_tokens_pair
from models import BertModel4Pretrain
def main(train_cfg='configs/pretrain.json',
model_cfg='configs/bert_base.json',
data_file='/media/newhd/BookCorpus/books_full.txt',
model_file='/home/krishna/Krishna/Speech/PnGBERT/PnG-BERT/exp/bert/pretrain/model_steps_20000.pt',
data_parallel=True,
vocab='/media/newhd/BookCorpus/vocab.txt',
save_dir='./exp/bert/pretrain',
log_dir='./exp/bert/pretrain/runs',
max_len=512,
max_pred=20,
mask_prob=0.15):
cfg = train.Config.from_json(train_cfg)
model_cfg = modules.Config.from_json(model_cfg)
set_seeds(cfg.seed)
tokenizer = tokenization.FullTokenizer(vocab_file=vocab, do_lower_case=True)
tokenize = lambda x: tokenizer.tokenize(tokenizer.convert_to_unicode(x))
pipeline = [Preprocess4Pretrain(max_pred,
mask_prob,
list(tokenizer.vocab.keys()),
tokenizer.convert_tokens_to_ids,
max_len)]
data_iter = SentPairDataLoader(data_file,
cfg.batch_size,
tokenize,
tokenizer,
max_len,
pipeline=pipeline)
model = BertModel4Pretrain(model_cfg)
criterion = nn.CrossEntropyLoss(reduction='none')
optim = optimizer.optim4GPU(cfg, model)
trainer = train.Trainer(cfg, model, data_iter, optim, save_dir, get_device())
writer = SummaryWriter(log_dir=log_dir) # for tensorboardX
def evaluate(model, batch):
input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, word_ids = batch
logits_lm = model(input_ids, segment_ids, input_mask, masked_pos,word_ids)
acc = torch.sum(masked_ids.view(-1) == logits_lm.argmax(dim=-1).view(-1)) / masked_ids.view(-1).size(0)
return acc, acc
def get_loss(model, batch, global_step): # make sure loss is tensor
input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, word_ids = batch
logits_lm = model(input_ids, segment_ids, input_mask, masked_pos,word_ids)
loss_lm = criterion(logits_lm.transpose(1, 2), masked_ids) # for masked LM
loss_lm = (loss_lm*masked_weights.float()).mean()
writer.add_scalars('data/scalar_group',
{'loss_lm': loss_lm.item(),
'loss_total': loss_lm.item(),
'lr': optim.get_lr()[0],
},
global_step)
return loss_lm
trainer.train(get_loss, model_file, None, data_parallel)
if __name__ == '__main__':
fire.Fire(main)