-
Notifications
You must be signed in to change notification settings - Fork 2
/
s6_eval.py
75 lines (61 loc) · 2.3 KB
/
s6_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sqlite3
from os.path import join
import numpy as np
import pandas as pd
from parlai.core.metrics import BleuMetric, F1Metric, RougeMetric
from core.utils import create_dir_if_not_exist
from utils import DATA_DIR
def calc(compute_func):
results = []
for r_id, response in responses.items():
if r_id in labels:
label = labels[r_id]
computed_result = compute_func(response, label)
results.append(float(computed_result)
if not isinstance(computed_result, tuple)
else [float(r) for r in computed_result])
return np.average(results, axis=0)
if __name__ == '__main__':
limit = 3136
table = "original"
# Create dir for experiments.
experiment_dir = join(DATA_DIR, "eval_llm")
create_dir_if_not_exist(experiment_dir)
con = sqlite3.connect(join(experiment_dir, "answers.sqlite3"))
responses = {}
# Extract responses.
cursor = con.cursor()
cursor.execute(f'SELECT * FROM {table}')
for row in cursor:
r_id, response = row
responses[int(r_id)] = response
# Extract labels.
df = pd.read_csv(join(experiment_dir, "valid_original_no-cand-labeled.csv"), sep="\t")
labels = {}
for i, r in list(df.iterrows())[:limit]:
d = r.to_dict()
labels[d["id"]] = d["label"]
r = {
"bleu-1": calc(lambda r, l: BleuMetric.compute(guess=r, answers=[l], k=1)),
"bleu-2": calc(lambda r, l: BleuMetric.compute(guess=r, answers=[l], k=2)),
"bleu-3": calc(lambda r, l: BleuMetric.compute(guess=r, answers=[l], k=3)),
"bleu-4": calc(lambda r, l: BleuMetric.compute(guess=r, answers=[l], k=4)),
"f1": calc(lambda r, l: F1Metric.compute(guess=r, answers=[l], expose_p_and_r=True)),
"rouge": calc(lambda r, l: RougeMetric.compute_many(guess=r, answers=[l])),
}
content = [
r["bleu-1"],
r["bleu-2"],
r["bleu-3"],
r["bleu-4"],
"", # ppl
"", # en-acc
r["f1"][2],
r["f1"][1],
r["f1"][0],
r["rouge"][0],
r["rouge"][1],
r["rouge"][2]
]
print(table)
print("\t".join([str(v) for v in content]))