forked from wtiandong/tide
-
Notifications
You must be signed in to change notification settings - Fork 1
/
show_errors.py
88 lines (75 loc) · 3.08 KB
/
show_errors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from tidecv import TIDE, datasets
import argparse
import seaborn as sns
import matplotlib.pyplot as plt
import json
import pandas as pd
import numpy as np
import math
from matplotlib.ticker import MultipleLocator
def arg():
parser = argparse.ArgumentParser()
parser.add_argument('--annotation', '-a', type=str, default='examples/instances_val2017.json')
parser.add_argument('--result', '-r', default='examples/coco_instances_results.json', type=str)
parser.add_argument('--name', '-n', default='', type=str)
parser.add_argument('--show', type=bool, default=True)
parser.add_argument('--normalize', type=bool, default=False)
args = parser.parse_args()
return args
def main():
args = arg()
tide = TIDE()
gt = datasets.COCO(args.annotation)
bbox_results = datasets.COCOResult(args.result)
#tide.evaluate(datasets.COCO(), datasets.COCOResult('path/to/your/results/file'), mode=TIDE.BOX) # Use TIDE.MASK for masks
run = tide.evaluate(gt, bbox_results, mode=TIDE.BOX, name=args.name) # Use TIDE.MASK for masks
tide.summarize()
tide.plot('./result')
errors = tide.get_main_per_class_errors()
error_names = ['Cls', 'Loc', 'Both', 'Dupe', 'Bkg', 'Miss']
total_table = []
for idx in sorted(errors[''][error_names[0]].keys()):
rows_items = [gt.classes[idx], ]
rows_items += [run.ap_data.objs[idx].get_ap()]
rows_items += [ errors[''][name][idx] for name in error_names]
total_table.append(rows_items)
df = pd.DataFrame(total_table, columns=['Name', 'AP']+error_names)
maximum_dap = math.ceil(df[error_names].max().max())
sns.set(font_scale=0.4)
g = sns.PairGrid(df.sort_values("AP", ascending=False),
x_vars=['AP']+error_names,
y_vars=["Name"],
height=8, aspect=.25)
g.map(sns.stripplot, size=8, orient="h", jitter=False,
palette="flare_r", linewidth=1, edgecolor="w")
g.axes[0,0].set_xlim(0,100)
for idx in range(1, g.axes.shape[1]):
g.axes[0,idx].set_xlim(0,maximum_dap)
g.axes[0,idx].xaxis.set_major_locator(MultipleLocator(5))
titles = ['AP'] + error_names
for ax, title in zip(g.axes.flat, titles):
ax.set(title=title)
ax.xaxis.grid(False)
ax.yaxis.grid(True)
sns.despine(left=True, bottom=True)
plt.subplots_adjust(left=0.05, top=0.98)
plt.savefig('per_accuracy_and_errors.png')
ret = tide.get_confusion_matrix()
dat = ret[''].T
if args.normalize:
dat = dat / dat.astype(np.float).sum(axis=0)
cm = pd.DataFrame(data=dat,
index=gt.classes.values(),
columns=gt.classes.values())
sns.set(font_scale=0.4)
fig, axes = plt.subplots(figsize=(10,8))
sns.heatmap(cm, square=True, cbar=True, annot=False, cmap='Blues',
xticklabels=True, yticklabels=True,
linewidths=.5
)
plt.xlabel("Predict", fontsize=13)
plt.ylabel("GT", fontsize=13)
fig.subplots_adjust(bottom=0.15)
plt.savefig('class_error_confusion_matrix.png')
if __name__ == '__main__':
main()