-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun.ipy
executable file
·130 lines (114 loc) · 4.86 KB
/
run.ipy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import shutil
import json
import subprocess
import timeit
import argparse
parser = argparse.ArgumentParser(description='translate.py')
parser.add_argument('--mode', type=str, required=True)
parser.add_argument('--dataset', type=str, required=True)
parser.add_argument('--start', type=int, default=15)
parser.add_argument('--end', type=int, default=50)
parser.add_argument('--beam', type=int, default=5)
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--test_dir', type=str, required=True)
parser.add_argument('--models_dir', type=str, required=True)
parser.add_argument('--best_json', type=str, default="")
parser.add_argument('--src_len', type=int, default=200)
parser.add_argument('--tgt_len', type=int, default=150)
parser.add_argument('--epochs', type=int, default=60)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--trunc', type=int, default=1)
parser.add_argument('--train_max', type=int, default=100000)
parser.add_argument('--valid_max', type=int, default=3000)
opt = parser.parse_args()
DATA_DIR = opt.data_dir
TEST_DIR = opt.test_dir
MODELS_DIR = DATA_DIR + '/' + opt.models_dir
DBHOST = 'kingboo.cs.washington.edu'
DBUSER = 'sviyer'
DBPASS = 'VerySecure'
try:
shutil.rmtree(MODELS_DIR + '/preds/')
except OSError:
pass
try:
os.makedirs(MODELS_DIR + '/preds/')
except:
pass
f_metrics = None
if opt.mode == "predict":
TESTING_FILE = TEST_DIR + '/predict.dataset'
TESTING_NL = DATA_DIR + '/predict.nl'
TESTING_GOLD = DATA_DIR + '/predict.code'
ATIS_TESTING_GOLD = 'atis/valid.sql'
ATIS_TESTING_TEM = 'atis/valid.nl.tem.map'
else:
TESTING_FILE = TEST_DIR + '/test.dataset'
TESTING_GOLD = DATA_DIR + '/test.code'
TESTING_NL = DATA_DIR + '/test.nl'
ATIS_TESTING_GOLD = 'atis/test.sql'
ATIS_TESTING_TEM = 'atis/test.nl.tem.map'
# We need a text file with the outputs to compute BLEU.
# so extract it out of the json file
test_dataset = json.loads(open(TESTING_FILE, 'r').read())
test_dataset_targets = open(TESTING_GOLD, 'w')
test_dataset_nls = open(TESTING_NL, 'w')
for example in test_dataset:
test_dataset_targets.write(' '.join(example['code']).replace('concodeclass_', '').replace('concodefunc_', '') + '\n')
test_dataset_nls.write(' '.join(example['nl']) + '\n')
test_dataset_targets.close()
test_dataset_nls.close()
if opt.dataset == "concode":
metrics = {
'bleu': {'command': ['python', 'tools/bleu.py'], 'best': -1, 'scores': {}},
'exact': {'command': ['python', 'tools/exact.py'], 'best': -1, 'scores': {}},
}
else:
metrics = {
'exact': {'command': ['python', 'tools/exact_sql.py'], 'best': -1, 'scores': {}},
'den': {'command': ['python', 'tools/getMetrics.py', '-db', 'atis', '-user', DBUSER, '-passwd', DBPASS, '-host', DBHOST], 'best': -1, 'scores': {}}
}
if opt.best_json != "":
js = json.loads(open(opt.best_json).read())
opt.start = js['exact' if opt.dataset == "concode" else "den"]['best']
opt.end = opt.start
pass
else:
f_metrics = open(MODELS_DIR + '/preds.json', 'w')
for i in range(opt.start, opt.end + 1):
start = timeit.timeit()
fname = !ls {MODELS_DIR}/*_acc_*e{i}.pt
f = os.path.basename(fname[0])
print(f)
# This is really important coz the prediction file appends to this. dev and test have the same filename
!rm {MODELS_DIR}/preds/{f}.nl.prediction*
# Prod is just a dummy here
!python translate.py -beam_size {opt.beam} -model {fname[0]} -src {TESTING_FILE} -output {MODELS_DIR}/preds/{f}.nl.prediction -max_sent_length {opt.tgt_len} -replace_unk -batch_size {opt.batch_size} -trunc {opt.trunc} -dataset {opt.dataset}
if 'den' in metrics:
!~/miniconda3/envs/p36/bin/python tools/atis_templatize.py -mapfile {ATIS_TESTING_TEM} -sqlfile {MODELS_DIR}/preds/{f}.nl.prediction -inst deanonymize -alignments atis/alignment.txt
for metric in metrics:
p = subprocess.Popen(metrics[metric]['command'] + ['-nl'] + [TESTING_NL] + ['-gold'] + [TESTING_GOLD if metric != "den" else ATIS_TESTING_GOLD] + ['-pred'] + [MODELS_DIR + '/preds/' + f + '.nl.prediction' + ('.deanon' if metric == "den" else '')], stdout=subprocess.PIPE)
score = p.stdout.read()
if metric == 'bleu':
score = float(score.decode('ascii').split(',')[0])
else:
score = float(score.decode('ascii'))
if metrics[metric]['best'] == -1 or score >= metrics[metric]['scores'][metrics[metric]['best']]:
metrics[metric]['best'] = i
metrics[metric]['scores'][i] = score
for metric in metrics:
output = ['Best']
output.append(metric)
output.append('Epoch:')
output.append(metrics[metric]['best'])
output.append(metrics[metric]['scores'][metrics[metric]['best']])
for k in metrics.keys():
output.append(k)
output.append(metrics[k]['scores'][metrics[metric]['best']])
print(output)
end = timeit.timeit()
print('Time taken: {}'.format(end - start))
if f_metrics:
f_metrics.write(json.dumps(metrics))
f_metrics.close()