forked from Machine-Learning-Tokyo/Poetry-GAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lang_model.py
93 lines (79 loc) · 3.43 KB
/
lang_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#################################################
### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
#################################################
# file to edit: notebooks/lang_model.ipynb
from fastai import *
from fastai.text import *
from fastai.callbacks.tracker import SaveModelCallback, EarlyStoppingCallback
models = {'AWD':AWD_LSTM, 'XL':TransformerXL}
#train language model with either AWD_LSTM or TransformerXL archs and generate preds
def train_lm(path,filename,model='AWD_LSTM',
epochs=8,pretrained_fnames=None,preds=True):
#get data after running preprocess
print(f'loading data from {path}/{filename};')
data_lm = load_data(path,filename, bs=64,bptt=70)
#change config if XL
if model == 'XL':
config = tfmerXL_lm_config.copy()
config['mem_len'] = 150
config['output_p'] = 0.1
config['embed_p'] = 0.1
config['ff_p'] = 0.1
config['resid_p'] = 0.1
config['d_inner'] = 1024
config['d_model'] = 128
else: config=None
#load pretrained weights
if pretrained_fnames: pretrained_fnames = pretrained_fnames.split(',')
learn = language_model_learner(data_lm,models[model],
config=config,pretrained=False,
pretrained_fnames=pretrained_fnames)
print(f'training lm model {model}; pretrained from {pretrained_fnames};')
#early stopping and saving at every epoch
cb = [SaveModelCallback(learn),EarlyStoppingCallback(learn)]
if pretrained_fnames:
#layered training
print(f'training lm model head;')
learn.fit_one_cycle(1, 3e-3, moms=(0.8,0.7))
print(f'saving lm model head to {path}/{filename}_head;')
learn.save(filename+'_head')
learn.unfreeze()
print(f'training for {epochs} epochs')
learn.fit_one_cycle(epochs, 3e-4, moms=(0.8,0.7),callbacks=cb)
print(f'saving model to {path}/{filename}_finetuned')
learn.save(filename+'_finetuned')
#generate outputs from validation set
if preds:
print(f'generating predictions and saving to {path}/{filename}_preds.txt;')
get_valid_preds(learn,data_lm,filename+'_'+model+'_preds.txt')
import string
def post_process(text):
text = text.replace('\n','\\n')
text = text.split()
new_text = ''
for i in range(len(text)):
if text[i] in ['xxmaj','xxup'] and i+1<len(text): text[i+1] = text[i+1].capitalize()
if text[i] == 'i': text[i] = text[i].capitalize()
for tok in text:
if tok in string.punctuation or tok in ["\'m","n\'t","\'ll","\'s","\'ve"]:new_text+=tok
elif tok not in ['xxmaj','xxup','xxbos','xxrep']:new_text+=' '+tok
return new_text.replace('\\n','\n').replace('" ','"').replace('&&','')
def get_valid_preds(learn,data,file_name,test=None):
if not test:
test = []
for i in range(len(data.valid_ds)):
test.append(data.valid_ds.x.get(i).text)
tot = ''
with open(path/file_name) as f:
for text in test:
tot+=text
text = text.replace('\n','\\n').split(' ')
num_words = len(text)
seed = text[:num_words//4]
seed = ' '.join(t for t in seed)
tot+='----------$----------'+seed
pred = post_process(learn.predict(seed,num_words-num_words//4))
tot+='----------$----------'+pred
f.write(tot+'\n\n\n\n\n\n\n')
import fire
if __name__ == '__main__': fire.Fire(train_lm)