-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathfista_general.m
108 lines (105 loc) · 3.64 KB
/
fista_general.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
function [X, iter, min_cost] = fista_general(grad, proj, Xinit, L, opts, calc_F)
% function [X, iter, min_cost] = fista_general(grad,proj, Xinit, L, opts, calc_F)
% * A Fast Iterative Shrinkage-Thresholding Algorithm for
% Linear Inverse Problems.
% * Solve the problem: X = arg min_X F(X) = f(X) + lambda*g(X) where:
% - X: variable, can be a matrix.
% - f(X): a smooth convex function with continuously differentiable
% with Lipschitz continuous gradient `L(f)` (Lipschitz constant of
% the gradient of `f`).
% INPUT:
% grad : a function calculating gradient of f(X) given X.
% proj : a function calculating pL(x) -- projection
% Xinit : a matrix -- initial guess.
% L : a scalar the Lipschitz constant of the gradient of f(X).
% opts : a struct
% opts.lambda : a regularization parameter, can be either a scalar or
% a weighted matrix.
% opts.max_iter: maximum iterations of the algorithm.
% Default 300.
% opts.tol : a tolerance, the algorithm will stop if difference
% between two successive X is smaller than this value.
% Default 1e-8.
% opts.verbose : showing F(X) after each iteration or not.
% Default false.
% calc_F: optional, a function calculating value of F at X
% via feval(calc_F, X).
% OUTPUT:
% X : solution
% iter : number of run iterations
% min_cost : the achieved cost
% Modifications:
% 06/17/2016: set default value for opts.pos = false
% -------------------------------------
% Author: Tiep Vu, thv102, 4/6/2016
% (http://www.personal.psu.edu/thv102/)
% -------------------------------------
% opts = initOpts(opts);
if ~isfield(opts, 'max_iter')
opts.max_iter = 500;
end
if ~isfield(opts, 'regul')
opts.regul = 'l1';
end
if ~isfield(opts, 'pos')
opts.pos = false;
end
if ~isfield(opts, 'tol')
opts.tol = 1e-8;
end
if ~isfield(opts, 'verbose')
opts.verbose = false;
end
Linv = 1/L;
lambdaLiv = opts.lambda*Linv;
% opts_shrinkage = opts;
% opts_shrinkage.lambda = lambdaLiv;
x_old = Xinit;
y_old = Xinit;
t_old = 1;
iter = 0;
cost_old = 1e10;
%% MAIN LOOP
opts_proj = opts;
opts_proj.lambda = lambdaLiv;
while iter < opts.max_iter
iter = iter + 1;
x_new = feval(proj, y_old - Linv*feval(grad, y_old), opts_proj);
t_new = 0.5*(1 + sqrt(1 + 4*t_old^2));
y_new = x_new + (t_old - 1)/t_new * (x_new - x_old);
%% check stop criteria
e = norm1(x_new - x_old)/numel(x_new);
if e < opts.tol
break;
end
%% update
x_old = x_new;
t_old = t_new;
y_old = y_new;
%% show progress
if opts.verbose
if nargin ~= 0
cost_new = feval(calc_F, x_new);
if cost_new <= cost_old
stt = 'YES.';
else
stt = 'NO, check your code.';
end
fprintf('iter = %3d, cost = %f, cost decreases? %s\n', ...
iter, cost_new, stt);
cost_old = cost_new;
else
if mod(iter, 5) == 0
fprintf('.');
end
if mod(iter, 10) == 0
fprintf('%d', iter);
end
end
end
end
X = x_new;
if nargout == 3
min_cost = feval(calc_F, X);
end
end