forked from gkakogeorgiou/spot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotting.py
188 lines (145 loc) · 6.3 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator
# Constant of size guidance: When reading tensorboard event files, we only want to read all scalars
SIZE_GUIDANCE = {
event_accumulator.COMPRESSED_HISTOGRAMS: 1,
event_accumulator.IMAGES: 1,
event_accumulator.AUDIO: 1,
event_accumulator.HISTOGRAMS: 1,
event_accumulator.TENSORS: 1,
event_accumulator.SCALARS: 0
}
# Some pretty colors, with assignment to proper fusion methods
COLORS = {
"baseline": "#7F7F91", # grey
"residual": "#FFA600", # orange
"mean": "#3D1D9C", # blue
"max": "#C7007C", # violet
"sum": "#FF4647" # salmon
}
# Flag whether the (slots) should be removed from labels in plots
REMOVE_SLOTS_FROM_LABELS = True
# Constant base path to logging directory and plot path
BASE_PATH = "logs/"
PLOT_PATH = "plots/"
def results_for_run(run_name):
# Get contents in that directory
p_log_dir = os.path.join(BASE_PATH, run_name)
log_dir_content = os.listdir(p_log_dir)
# If the last element is a directory again, the timestamp of log path was not removed
if os.path.isdir(os.path.join(p_log_dir, log_dir_content[-1])):
print(f"Found subdirectories for log path {p_log_dir} - taking last subdirectory (last timestamp)")
p_log_dir = os.path.join(p_log_dir, log_dir_content[-1])
log_dir_content = os.listdir(p_log_dir)
# Find tensorboard file
tb_file_name_list = [x for x in log_dir_content if x.startswith("events.out.tfevents")]
assert len(tb_file_name_list) == 1, "The provided directory must contain exactly one tensorboard event file"
tb_file_name = tb_file_name_list[0]
# Open the file and read the scalars for the validation
p_tb_file = os.path.join(p_log_dir, tb_file_name)
ev_acc = event_accumulator.EventAccumulator(p_tb_file, SIZE_GUIDANCE)
ev_acc.Reload()
scalars = ev_acc.Tags()["scalars"]
val_scalars = [s for s in scalars if s.startswith("VAL")]
# Create one joint pd.DataFrame from it
scalar_df = pd.DataFrame()
scalar_df.index.name = "Epoch"
for scalar in val_scalars:
scalar_read = ev_acc.Scalars(scalar)
scalar_vals = [e.value for e in scalar_read]
scalar_df[scalar[4:]] = scalar_vals # Scalar starting from 4: to remove VAL/
# Delete from memory to avoid OOM, then return
del ev_acc
return scalar_df
def get_color_for_run(run):
for fusion, col in COLORS.items():
if fusion in run:
return col
raise Exception(f"No respective color for run {run} could be found, must include fusion method")
def make_metrics_pretty(metrics):
if REMOVE_SLOTS_FROM_LABELS:
metrics = [m.replace(" (slots)", "") for m in metrics]
return [r"${" + m.replace("mbo", "mBO").replace("miou", "mIoU") + "}$" for m in metrics]
def plot_one_metric(runs, run_dfs, pretty_labels, metric = "miou (slots)"):
"""
Plots one metric over time.
"""
# Set modern style
plt.style.use("ggplot")
# Create the plot
fig, ax = plt.subplots(1, 1)
ax.grid(True, linestyle="--", linewidth=0.8, alpha=0.5, zorder=-1)
# Plot the results
for idx, run in enumerate(runs):
label = run.split("/")[-1] if len(pretty_labels) == 0 else pretty_labels[idx]
ax.plot(run_dfs[run].index , run_dfs[run][metric], label=label, color=get_color_for_run(run))
# Aesthetics and saving
plt.xlabel("Epochs")
plt.ylabel(make_metrics_pretty([metric])[0])
plt.legend()
plt.savefig(os.path.join(PLOT_PATH, f"comparison_{metric}"))
def plot_comparison_in_multiple_metrics(
runs, run_dfs, pretty_labels, metrics = np.array(["mbo_c (slots)", "mbo_i (slots)", "miou (slots)"])
):
# Set modern style
plt.style.use("ggplot")
# Create the plot
fig, ax = plt.subplots(1, 1)
ax.grid(True, linestyle='--', linewidth=0.9, alpha=0.75, zorder=-1)
# Create the ticks with the metrics as labels
diff_between_metrics = int(len(runs) * 2)
tick_lst = [diff_between_metrics * x for x in range(len(metrics))]
tick_label_lst = make_metrics_pretty(metrics)
ax.set_xticks(tick_lst, labels=tick_label_lst)
# Create offset list for aligning
offsets = list(range(-int(len(runs)/2), int((len(runs) + 1)/2)))
if len(runs) % 2 == 0:
offsets[int(len(runs)/2):] = [x+1 for x in offsets[int(len(runs)/2):]]
offsets = [x + 0.5 if x < 0 else x - 0.5 for x in offsets]
# Create the bars with respective colors, labels and values
for idx in range(len(run_dfs)):
run = runs[idx]
data = run_dfs[run].iloc[-1][metrics].values
label = run.split("/")[-1] if len(pretty_labels) == 0 else pretty_labels[idx]
ax.bar(np.array(list(range(len(metrics)))) * diff_between_metrics + offsets[idx], data, color=get_color_for_run(run), label=label, zorder=2, edgecolor="black")
# Aesthetics and saving
ax.set_ylabel("Score")
ax.set_ylim(0, 70)
plt.legend()
plt.savefig(os.path.join(PLOT_PATH, f"comparison_all_metrics"), bbox_inches="tight")
def main():
# INSERT ALL RELEVANT RUNS FOR PLOTITNG HERE
runs = [
"baseline/dinosaur_baseline",
"ablations/mean_9_10_11",
"ablations/residual_5_8_11",
"ablations/sum_9_10_11",
"ablations/max_9_10_11"
]
# Optional: Pretty labels, name them corresponding runs above. If none are to be used: leave as empty list
pretty_labels = [
"Baseline",
"Mean",
"Residual",
"Sum",
"Max"
]
# Sanity check
assert len(runs) == len(pretty_labels) or len(pretty_labels) == 0, "Either supply as many labels as there are runs, or none at all"
# Reads in all dfs and stores them in a dictionary
run_dfs = {}
for run in runs:
run_dfs[run] = results_for_run(run)
# Basic plot for one metric
# WHICH METRIC IS TO BE PLOTTED
# Possible metrics:
# 'mse', 'ari (slots)', 'ari (decoder)', 'mbo_c', 'mbo_i', 'miou',
# 'mbo_c (slots)', 'mbo_i (slots)', 'miou (slots)', 'best_loss'
plot_one_metric(runs, run_dfs, pretty_labels, "miou (slots)")
# Plot comparison in all metrics
plot_comparison_in_multiple_metrics(runs, run_dfs, pretty_labels, np.array(["mbo_c (slots)", "mbo_i (slots)", "miou (slots)"]))
if __name__ == "__main__":
main()