-
Notifications
You must be signed in to change notification settings - Fork 11
/
frank_wolfe_optimization.m
80 lines (56 loc) · 2.1 KB
/
frank_wolfe_optimization.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
function [ z, z_p, w, b, perf_val, perf_test, obj ] = frank_wolfe_optimization( ...
X, y, clips, annot, z, sup_idx, val_idx, test_idx, val, test, ...
tau, aleph, lambda, kappa, params )
[N, ~] = size(X);
% pre-computing heavy stuff
GXTP = compute_GXTP(X, lambda);
L = sparse(N, N);
% enforcing supervision
z(sup_idx, :) = y(sup_idx, :);
% computing the gradient
grad = compute_gradient(X, z, N, tau, aleph, kappa, GXTP);
perf_val = evaluate();
perf_test = evaluate();
obj = struct('f', [], 'd', [], 't', []);
tic;
for i = 1:params.niter
% cutting the gradient into clips
l = mat2cell(grad, 17, clips);
% performing the Dynamic Programing
a = optimize_a(l, annot);
a = cell2mat(a);
% adding the supervision (can be seen as additional constraints on W)
a(sup_idx, :) = y(sup_idx, :);
% Linearization duality gap
d = trace(grad*(z-a));
% getting the optimal step size
gama = compute_gamma_optimal_FW(X, z, a, GXTP, tau, aleph, kappa);
% updating z
z = (1-gama) * z + gama * a;
% computing objective value
f = compute_objective(X, z, tau, aleph, kappa, GXTP);
% computing the gradient
grad = compute_gradient(X, z, N, tau, aleph, kappa, GXTP);
% rebuilding the classifiers
w = GXTP * z;
b = ones(1, N) * (z - X * w) / N;
% rounding
z_p = rounding(z, clips, annot);
% evaluating
perf_val(i) = evaluate(z(val_idx, :), y(val_idx, :), z_p(val_idx, :), clips(val));
perf_test(i) = evaluate(z(test_idx, :), y(test_idx, :), z_p(test_idx, :), clips(test));
% keeping track of the objective
obj(i).f = f;
obj(i).d = d;
obj(i).t = toc;
% printing the score
fprintf('iter=%3i f=%-+5.3e ', i, f);
fprintf('dgap=%-+5.3e ', d);
fprintf('acc=%5.3f ', perf_test(i).acc);
fprintf('p=%5.3f ', perf_test(i).precision);
fprintf('r=%5.3f ', perf_test(i).recall);
fprintf('jac=%5.3f ', perf_test(i).jacquard);
fprintf('jac_pred_nobg=%5.3f ', perf_test(i).jacquard_pred_nobg);
fprintf('map=%5.3f \n', perf_test(i).map);
end
end