-
Notifications
You must be signed in to change notification settings - Fork 11
/
evaluate.m
98 lines (74 loc) · 2.24 KB
/
evaluate.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
function [res] = evaluate(z, y, z_p, clips)
if nargin == 0
res = [];
res.acc = [];
res.recall = [];
res.precision = [];
res.r_per_class = [];
res.p_per_class = [];
res.f1_per_class = [];
res.jacquard = [];
res.jacquard_nobg = [];
res.jacquard_pred = [];
res.jacquard_pred_nobg = [];
res.jac_per_class = [];
res.jac_per_clip = [];
res.mean_jac_per_clip = [];
res.ap = [];
res.map = [];
return
elseif nargin == 2
[~, K] = size(y);
res = [];
res.ap = zeros(1, K);
for i = 1:K
[~, ~, info] = vl_pr(2*y(:, i)-1, z(:, i));
res.ap(i) = info.auc_pa08;
end
res.map = mean(res.ap);
return
else
[~, K] = size(y);
[~, gt] = max(y, [], 2);
res = [];
[~, pred] = max(z_p, [], 2);
res.acc = sum(pred==gt) / length(gt);
res.recall = sum(pred==gt & gt~=K) / sum(gt~=K);
res.precision = sum(pred==gt & gt~=K) / sum(pred~=K);
res.r_per_class = zeros(1, K);
res.p_per_class = zeros(1, K);
for i = 1:K
res.r_per_class(i) = sum(pred==gt & gt==i) / sum(gt==i);
res.p_per_class(i) = sum(pred==gt & gt==i) / sum(pred==i);
end
r = res.r_per_class;
p = res.p_per_class;
idxr = isnan(r) | r==0;
idxp = isnan(p) | p==0;
res.r_per_class(idxr) = 0;
res.p_per_class(idxp) = 0;
res.f1_per_class = 2 * r.*p ./ (r+p);
res.f1_per_class(idxr|idxp) = 0;
[gt, j, ~, j_pred, j_perclip, gt_perclip] = all_jacquards(z_p, y, clips);
res.jacquard = mean(j);
res.jacquard_nobg = mean(j(gt~=17));
res.jacquard_pred = mean(j_pred);
res.jacquard_pred_nobg = mean(j_pred(gt~=17));
res.jac_per_class = zeros(1, K-1);
for i = 1:K-1
res.jac_per_class(i) = mean(j_pred(gt==i));
end
res.jac_per_clip = j_perclip;
res.mean_jac_per_clip = zeros(length(j_perclip), 1);
for i = 1:length(j_perclip)
res.mean_jac_per_clip(i) = mean(j_perclip{i}(gt_perclip{i} ~= 17));
end
res.ap = zeros(1, K);
for i = 1:K
[~, ~, info] = vl_pr(2*y(:, i)-1, z(:, i));
res.ap(i) = info.auc_pa08;
end
res.map = mean(res.ap);
return
end
end