-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathfista_lasso_backtracking.m
51 lines (46 loc) · 1.39 KB
/
fista_lasso_backtracking.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
function X = fista_lasso_backtracking(Y, D, Xinit, opts)
if ~isfield(opts, 'backtracking')
opts.backtracking = false;
end
opts.regul = 'l1';
opts = initOpts(opts);
lambda = opts.lambda;
% if numel(lambda) > 1 && size(lambda, 2) == 1
% lambda = repmat(opts.lambda, 1, size(Y, 2));
% end
if numel(Xinit) == 0
Xinit = zeros(size(D,2), size(Y,2));
end
%% cost f
function cost = calc_f(X)
cost = 1/2 *normF2(Y(:, i) - D*X);
end
%% cost function
function cost = calc_F(X)
if numel(lambda) == 1 % scalar
cost = calc_f(X) + lambda*norm1(X);
elseif numel(lambda) == numel(X)
cost = calc_f(X) + norm1(lambda.*X);
elseif numel(lambda) == size(X, 1)
lambda1 = repmat(lambda, 1, size(size(X, 2)));
cost = calc_f(X) + norm1(lambda1.*X);
end
end
%% gradient
DtD = D'*D;
DtY = D'*Y;
function res = grad(X)
res = DtD*X - DtY(:, i);
end
% Checking gradient
if opts.check_grad
check_grad(@calc_f, @grad, Xinit);
end
opts.max_iter = 500;
% for backtracking, we need to optimize one by one
X = zeros(size(Xinit));
for i = 1:size(X, 2)
X(:, i) = fista_backtracking(@calc_f, @grad, Xinit(:, i), opts, ...
@calc_F);
end
end