forked from annypan/ilm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
acl20_repro.py
120 lines (112 loc) · 6.56 KB
/
acl20_repro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
PREMASKED_DATA = {
'train': {
'sto_mixture': 'https://drive.google.com/open?id=1LxlyPqz3OvAZsYRRC8yRdSoaCKGB0Ucg',
'abs_mixture': 'https://drive.google.com/open?id=1rw45GKP4iRJLzXnRtX-rnk_NeGXOqWkU',
'lyr_mixture': 'https://drive.google.com/open?id=1jGCjboxlFUF0jqvB0_-L0eeylhKWfZJV',
},
'valid': {
'sto_mixture': 'https://drive.google.com/open?id=1Y4HRYrBnqwtdbziF5Q6b5WaIFxJd1v7m',
'abs_mixture': 'https://drive.google.com/open?id=1hHdXX43qkkm-zpUCJz_iuv1vRRpyfbaP',
'lyr_mixture': 'https://drive.google.com/open?id=1xR0LC5WHV1UQDPjWTN0HcOQ9C5jsXYef',
},
'test': {
# Table 1/6
'sto_sentence': 'https://drive.google.com/open?id=1w02hGewoBk_Pq-thrtbOcU1JPGRdGL_U',
'abs_sentence': 'https://drive.google.com/open?id=18aNMfcqC1wyC8wWJHbCfMxLY49Dbg-du',
'lyr_sentence': 'https://drive.google.com/open?id=18Szj-HYwh3sjLmmfF8TNwAmB2oEddool',
# Table 3
'sto_document': 'https://drive.google.com/open?id=1ydEjL0SMbX8p-1w6XeLWNrzeVn8TleGT',
'abs_document': 'https://drive.google.com/open?id=1UjPh51URE8hvK-yTw3xkVwkBwEcCo6Uz',
'lyr_document': 'https://drive.google.com/open?id=1KNvdzn1xhpw0Xdh0pMWN3CtrDKEK8d4N',
# Table 4
'sto_mixture': 'https://drive.google.com/open?id=1Zsuj8Plrcs49f-5rV6dvJ5W2kIz_C30u',
'abs_mixture': 'https://drive.google.com/open?id=1TA3ySrvcWxaNtoDPpN8Jk7uqjGKdPqda',
'lyr_mixture': 'https://drive.google.com/open?id=1FGEL3CGzLvnWpgvUYWsHsOUgW65DVxgw',
# Table 5
'sto_paragraph': 'https://drive.google.com/open?id=1MBM96hfN2cGJidG-mi_4bE0K07xgWAxT',
'abs_paragraph': 'https://drive.google.com/open?id=1xXJfjCNzRLXYZgHgUrNimP4CtUW0Ziph',
'lyr_paragraph': 'https://drive.google.com/open?id=10ScpFR8sG3Ur0WpWdkPYxAsT94jNNmZh',
# Table 7
'sto_ngram': 'https://drive.google.com/open?id=1x8RBys_jbreSFO1zMdmwiT2ref2F8q_C',
'abs_ngram': 'https://drive.google.com/open?id=1JJyh7clJjyPF-rm4rHFLyX7Y-l_doD0K',
'lyr_ngram': 'https://drive.google.com/open?id=1dbCCc68TvY6segwTrrxYS1ukVbdC7zgJ',
# Table 8
'sto_word': 'https://drive.google.com/open?id=178joxkympgzDwZoExnalWujRq2Jv_37P',
'abs_word': 'https://drive.google.com/open?id=1PdVg-TnG5VQt8GCQOQA841AGw1GR44yl',
'lyr_word': 'https://drive.google.com/open?id=1Td-yr6g5cTxW4yoz_Wv4gSi-wbu1376R',
}
}
PRETRAINED_MODELS = {
# Trained on stories
'sto_lm': 'https://drive.google.com/open?id=1-FGKu-bodqOsCGrFCYY6Yyp2rTk2rRpc',
'sto_lmrev': 'https://drive.google.com/open?id=1_uCgugc57tPGfFofKbU8doJN23cf4lEY',
'sto_lmall': 'https://drive.google.com/open?id=1dPOLkggPbe-Pzn8VVkcrinuGJv2yRieR',
'sto_ilm': 'https://drive.google.com/open?id=1oYFLxkX6mWbmpEwQH8BmgE7iKix2W0zU',
'sto_lmscratch': 'https://drive.google.com/open?id=1vGxdfZUWtOB5ajpDgSGUXuHK5_BGY9GA',
'sto_lmrevscratch': 'https://drive.google.com/open?id=1xbyQ5bMJpTxlsPtL1YsH2jmUUh_49gOI',
'sto_lmallscratch': 'https://drive.google.com/open?id=1Qy13Dw60Jd5HqN89q8WvCMtwvTXJw7tj',
'sto_ilmscratch': 'https://drive.google.com/open?id=14BFLWSaPi2JSsKsa68lcTSnCOnYV9jPm',
# Trained on abstracts
'abs_lm': 'https://drive.google.com/open?id=1BSIFfuSTznmHIKa4R-AnwIxN93b1Ap-b',
'abs_lmrev': 'https://drive.google.com/open?id=1yl36oZq9R_d3IhlFWLlMGq46n8F9Lq1q',
'abs_lmall': 'https://drive.google.com/open?id=1qyM0OCL8pI5dL7sfAag-y9X_bnlTS_1Z',
'abs_ilm': 'https://drive.google.com/open?id=1FBY9DR60WWX05orILaFHuyZYlB4ChTpS',
'abs_lmscratch': 'https://drive.google.com/open?id=103Cw2ZSb5g5PlTKslmbmhqCaxn3N65OO',
'abs_lmrevscratch': 'https://drive.google.com/open?id=1HeuxA2A6iEs6SW26jlCom3x_tFQHnIGu',
'abs_lmallscratch': 'https://drive.google.com/open?id=1XU61GMduqJeCzYqDk8BQ7S4M8tbzqF9g',
'abs_ilmscratch': 'https://drive.google.com/open?id=1ZTZOO5fVTlnPBw7EC_4OOEzHmcs6tAFO',
# Trained on lyrics
'lyr_lm': 'https://drive.google.com/open?id=1FJBgz26lZPcanZTEf0iWxZCXIEM6esu6',
'lyr_lmrev': 'https://drive.google.com/open?id=1XAug1jhm7sa5lksDV6GMyF8sFQLwk1Y6',
'lyr_lmall': 'https://drive.google.com/open?id=1nrNkd4cBsdZS0eajA3wD1i5b6t6R6bow',
'lyr_ilm': 'https://drive.google.com/open?id=1nYuYCS5fDP2_vB7A92guk0PWh5CC2I5x',
'lyr_lmscratch': 'https://drive.google.com/open?id=1JzDRUSWVeyGnNaWKVYM8t1BPAs58t6uB',
'lyr_lmrevscratch': 'https://drive.google.com/open?id=1Kkli5Brmc3D6qE0b5ww5daZdZroaN1YB',
'lyr_lmallscratch': 'https://drive.google.com/open?id=18JYIBOtDfnksZPl4TW9cOzjOh_qDBCJP',
'lyr_ilmscratch': 'https://drive.google.com/open?id=1RObPpSttNtMw4UQ1bGiVzEM-94QqkwHT',
}
PRETRAINED_MODEL_CONFIG_JSON = 'https://drive.google.com/open?id=15JnXi7L6LeEB2fq4dFK2WRvDKyX46hVi'
PRETRAINED_SPECIAL_VOCAB_PKL = 'https://drive.google.com/open?id=1nTQVe2tfkWV8dumbrLIHzMgPwpLIbYUd'
PAPER_TASK_TO_INTERNAL = {
'lm': 'lm',
'lmrev': 'reverse_lm',
'lmall': 'naive',
'ilm': 'ilm',
}
_DOWNLOAD_TEMPLATE = """
wget -nc --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={gdrive_id}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id={gdrive_id}" -O {local_path} && rm -rf /tmp/cookies.txt
""".strip()
if __name__ == '__main__':
import os
import sys
try:
out_dir = os.environ['ILM_DIR']
except:
out_dir = '/tmp/ilm'
if sys.argv[1] == 'model':
data_tag, model_type = sys.argv[2:]
model_tag = '{}_{}'.format(data_tag[:3], model_type)
out_dir = os.path.join(out_dir, 'models', model_tag)
gdrive_urls = [
PRETRAINED_MODELS[model_tag],
PRETRAINED_MODEL_CONFIG_JSON,
PRETRAINED_SPECIAL_VOCAB_PKL]
local_fns = [
'pytorch_model.bin',
'config.json',
'additional_ids_to_tokens.pkl']
elif sys.argv[1] == 'data_train':
data_tag = sys.argv[2][:3]
out_dir = os.path.join(out_dir, 'data')
gdrive_urls = [PREMASKED_DATA[s]['{}_mixture'.format(data_tag)] for s in ['train', 'valid']]
local_fns = ['{}_mixture_{}.pkl'.format(data_tag, s) for s in ['train', 'valid']]
elif sys.argv[1] == 'data_eval':
data_tag = sys.argv[2][:3]
out_dir = os.path.join(out_dir, 'data')
gdrive_urls = [PREMASKED_DATA['test']['{}_{}'.format(data_tag, g)] for g in ['mixture', 'document', 'paragraph', 'sentence', 'ngram', 'word']]
local_fns = ['{}_{}_test.pkl'.format(data_tag, g) for g in ['mixture', 'document', 'paragraph', 'sentence', 'ngram', 'word']]
print('mkdir -p {}'.format(out_dir))
for gdrive_url, local_fn in zip(gdrive_urls, local_fns):
print(_DOWNLOAD_TEMPLATE.format(
gdrive_id=gdrive_url.split('=')[1],
local_path=os.path.join(out_dir, local_fn)))