-
Notifications
You must be signed in to change notification settings - Fork 0
/
inspection.py
137 lines (121 loc) · 3.67 KB
/
inspection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from collections import Counter
from statistics import mean, stdev
from dataset import DataModule, load_datasets
def inspect_data(
dataset="asdiv-a/fold0",
data_path="data",
tokenizer="roberta-base",
compress_num=False,
limit_depth=19,
ignore_over_depth=False,
multi=False,
):
datasets, *_ = load_datasets(
data_path,
dataset,
limit_depth,
compress_num,
ignore_over_depth,
tokenizer=tokenizer,
multi=multi,
label=True,
)
datamodule = DataModule(
tokenizer,
datasets,
batch_size=1,
collate_raw=True,
test_on_validation=False,
)
trd = datamodule.train_dataloader()
vd = datamodule.val_dataloader()
tsd = datamodule.test_dataloader()
trd = [{k: v[0] for k, v in d.items()} for d in trd]
vd = [{k: v[0] for k, v in d.items()} for d in vd]
tsd = [{k: v[0] for k, v in d.items()} for d in tsd]
ad = (trd + vd) if len(vd) == len(tsd) else (trd + vd + tsd)
print_values(
"n_equation",
(len(Counter(d["equation"] for d in data)) for data in (trd, vd, tsd, ad)),
)
print_statistics(
"n_operations",
(
[
sum(d["equation"].replace("**", "^").count(op) for op in ["+", "-", "*", "/", "^"])
for d in data
]
for data in (trd, vd, tsd, ad)
),
)
print_statistics(
"target depth",
([len(d["label_thoughts"]) for d in data] for data in (trd, vd, tsd, ad)),
)
print_values(
"n_thoughts per depth",
(get_n_thoughts(data) for data in (trd, vd, tsd)),
)
print_lists(
"last depth n_thoughts",
([len(d["label_thoughts"][-1]) for d in data] for data in (trd, vd, tsd, ad)),
)
print_lists(
"n_dds",
([d["n_dds"] for d in data] for data in (trd, vd, tsd, ad)),
)
print_lists(
"n_thoughts",
([d["n_thoughts"] for d in data] for data in (trd, vd, tsd, ad)),
)
print_values(
"max(n_thoughts per depth)",
(
max(
max(len(e) for e in d["label_thoughts"])
for d in data
if len(d["label_thoughts"]) > 0
)
for data in (trd, vd, tsd)
),
)
def stderr(data):
return stdev(data) / len(data) ** 0.5
def print_statistics(name, datas):
for mode, data in zip(("(train)", " (dev)", " (test)", " (all)"), datas):
print(
f"{name} {mode} : {max(data)=}",
f"{min(data)=}",
f"{mean(data)=:.2f}+-{stderr(data):.2f}" f"{stdev(data)=:.2f}" f"{Counter(data)}",
sep=", ",
)
print()
def print_lists(name, datas):
for mode, data in zip(("(train)", " (dev)", " (test)", " (all)"), datas):
print(
f"{name} {mode} : {max(data)=}"
f"{min(data)=}"
f"{mean(data)=:.2f}+-{stderr(data):.2f}"
f"{sorted(data, reverse=True)[:200]}",
sep=", ",
)
print()
def print_values(name, values):
for mode, value in zip(("(train)", " (dev)", " (test)", " (all)"), values):
print(f"{name} {mode} : {value}")
print()
def get_n_thoughts(data):
ns = [[] for _ in range(max(len(d["label_thoughts"]) for d in data))]
for d in data:
for i, es in enumerate(d["label_thoughts"]):
ns[i].append(len(es))
return "\n " + "\n ".join(
f"{i:>2}: {len(n)=:5}"
+ (
f", {max(n)=:5}, {min(n)=:2}, {mean(n)=:3.2f}+-{stderr(data):.2f}"
if len(n) > 1
else f", {n=}"
)
for i, n in enumerate(ns)
if len(n) > 0
)