forked from Selmaan/DNN_Code
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conjgrad_1.m
91 lines (61 loc) · 2.19 KB
/
conjgrad_1.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
function [xs, is] = conjgrad_1( Afunc, b, x0, maxiters, miniters, Mdiag )
tolerance = 5e-4;
gapratio = 0.1;
mingap = 10;
maxtestgap = max(ceil(maxiters * gapratio), mingap) + 1;
vals = zeros(maxtestgap,1);
inext = 5;
imult = 1.3;
is = [];
xs = {};
r = Afunc(x0) - b;
y = r./Mdiag;
p = -y;
x = x0;
%val is the value of the quadratic model
val = 0.5*double((-b+r)'*x);
%disp( ['iter ' num2str(0) ': ||x|| = ' num2str(double(norm(x))) ', ||r|| = ' num2str(double(norm(r))) ', ||p|| = ' num2str(double(norm(p))) ', val = ' num2str( val ) ]);
for i = 1:maxiters
%compute the matrix-vector product. This is where 95% of the work in
%HF lies:
Ap = Afunc(p);
pAp = p'*Ap;
%the Gauss-Newton matrix should never have negative curvature. The
%Hessian easily could unless your objective is convex
if pAp <= 0
disp('Negative Curvature!');
disp('Bailing...');
break;
end
alpha = (r'*y)/pAp;
x = x + alpha*p;
r_new = r + alpha*Ap;
y_new = r_new./Mdiag;
beta = (r_new'*y_new)/(r'*y);
p = -y_new + beta*p;
r = r_new;
y = y_new;
%val = 0.5*double((-b+r)'*x);
val = gather(0.5*double((-b+r)'*x)); %modified for parallel computing toolbox
vals( mod(i-1, maxtestgap)+1 ) = val;
%disp( ['iter ' num2str(i) ': ||x|| = ' num2str(double(norm(x))) ', ||r|| = ' num2str(double(norm(r))) ', ||p|| = ' num2str(double(norm(p))) ', val = ' num2str( val ) ]);
testgap = max(ceil( i * gapratio ), mingap);
prevval = vals( mod(i-testgap-1, maxtestgap)+1 ); %testgap steps ago
if i == ceil(inext)
is(end+1) = i;
xs{end+1} = x;
inext = inext*imult;
end
%the stopping criterion here becomes largely unimportant once you
%optimize your function past a certain point, as it will almost never
%kick in before you reach i = maxiters. And if the value of maxiters
%is set so high that this never occurs, you probably have set it too
%high
if i > testgap && prevval < 0 && (val - prevval)/val < tolerance*testgap && i >= miniters
break;
end
end
if i ~= ceil(inext)
is(end+1) = i;
xs{end+1} = x;
end