forked from var-skip/var-skip
-
Notifications
You must be signed in to change notification settings - Fork 0
/
summarize.py
executable file
·111 lines (103 loc) · 3.41 KB
/
summarize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
import numpy as np
import json
import collections
import glob
import os
import sys
REPORT_LOSS = "REPORT_LOSS" in os.environ
if len(sys.argv) < 2:
print(
"Usage: ./summarize.py <RESULTS_SUBDIR> [REQUIRE_KEY=V, [K=V, [...]]]")
exit()
data = {}
pattern = os.path.expanduser("results/{}/ray_results/*/*/result.json".format(
sys.argv[1]))
print("SEARCHING FOR", pattern)
constraints = {}
for arg in sys.argv[2:]:
k, v = arg.split("=")
constraints[k] = v
print("CONSTRAINTS", constraints)
for result in glob.glob(pattern):
key = None
ok = True
for line in open(result).readlines():
line = json.loads(line)
for k, v in constraints.items():
conf = line["config"]
if str(conf[k]) != v:
print("Ignoring", k, conf[k], "not matching", v)
ok = False
if not ok:
break
iter = line["training_iteration"]
key = (line["config"]["dataset"], line["config"]["per_row_dropout"] or
line["config"]["dropout"], line["config"]["order_seed"],
result)
l = line.get("test_loss")
line = line["results"]
line["test_loss"] = l
# break
# if iter > 10:
# break
if key and line and ok:
data[key] = line
table = []
for key in data:
dataset, dropout, seed, filename = key
for quantile in ["max", "p99", "median"]:
if dropout:
table.append(
("{}_{}-mask-ctrl-{}".format(seed, dataset,
quantile), key, False, quantile))
table.append(
("{}_{}-mask-skip-{}".format(seed, dataset,
quantile), key, True, quantile))
else:
table.append(
("{}_{}-control-{}".format(seed, dataset,
quantile), key, False, quantile))
samples = [10, 40, 100, 400, 1000, 4000, 10000]
table = sorted(table)
by_stat = collections.defaultdict(list)
by_stat_loss = collections.defaultdict(list)
by_stat_std = collections.defaultdict(float)
for row, key, skip, quantile in table:
dataset, dropout, seed, filename = key
cols = [row]
if skip:
suffix = "shortcircuit_"
else:
suffix = ""
for n in samples:
try:
loss = data[key]["test_loss"]
value = data[key]["psample_{}{}_{}".format(suffix, n, quantile)]
std = data[key]["psample_{}{}_{}_std".format(suffix, n, quantile)]
except:
continue
cols.append(str(value))
if not dropout:
tpe = "control"
elif skip:
tpe = "dropout-skip"
else:
tpe = "dropout-ctrl"
by_stat[(dataset, tpe, n, quantile)].append(value)
by_stat_std[(dataset, tpe, n, quantile)] = std
by_stat_loss[(dataset, tpe, n, quantile)].append(loss)
print(",".join(cols))
print()
for n in [10, 40, 100, 400, 1000, 4000, 10000]:
print("====", n, "=====")
for key, values in by_stat.items():
(dataset, skip, ni, quantile) = key
loss = by_stat_loss[key]
std = by_stat_std[key]
if n == ni:
if REPORT_LOSS:
print(dataset, skip, n, quantile, "loss", np.mean(loss), "std", np.std(loss))
else:
print(dataset, skip, n, quantile, sorted(values), "bstd", std)
print()