-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_compareNewtonLaplaceEvidence.m
183 lines (136 loc) · 6.38 KB
/
test_compareNewtonLaplaceEvidence.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
% test_compareNewtonLaplaceEvidence.m
%
% Tests out an alternative to decoupled Laplace called (for now)
% Newton-Laplace. In this case, we update the MAP estimate "wmap" just as we
% do in decoupled Laplace. Then we evaluate the log-likelihood Hessian at
% this new value of wmap.
%
% Two discoveries:
% 1. Our closed-form update to wmap is equivalent to a single newton step.
% 2. The new method seems to fix gradient issues with decoupled Laplace.
% 1. Set up simulated example
addpath utils;
% Set dimensions and hyperparameter
varprior = 2; % prior variance of weights
nw = 20; % number of weights
nstim = 100; % number of stimuli
vlims = log10([.1, 4]); % limits of grid over sig^2 to consider
theta0 = .8; % prior variance for DLA
% Sample weights from prior
wts = randn(nw,1)*sqrt(varprior);
% Make stimuli & simulate Bernoulli GLM response
xx = randn(nstim,nw); % inputs
xproj = xx*wts; % projection of stimulus onto weights
pp = logistic(xproj); % probability of 1
yy = rand(nstim,1)<pp; % Bernoulli outputs
%% 2. Compute MAP estimate of weights given true hyperparams
% Make struct with log-likelihood and prior function pointers
mstruct.neglogli = @neglogli_bernoulliGLM; % neg-logli function handle
mstruct.logprior = @logprior_stdnormal; % log-prior function handle
mstruct.liargs = {xx,yy}; % arguments for log-likelihood
mstruct.priargs = {}; % extra arguments for prior (besides theta)
% Set optimization parameters for fminunc
opts = optimoptions('fminunc','algorithm','trust-region','SpecifyObjectiveGradient',true,'HessianFcn','objective','display','off');
% intial guess for weights (random)
w0 = randn(nw,1)*.1;
% make function handle
lfunc = @(w)(neglogpost_GLM(w,varprior,mstruct));
% % Optional: Check that analytic gradient and Hessian are correct
% HessCheck(lfunc, w0);
% Compute MAP estimate
[wmap,neglogpost] = fminunc(lfunc,zeros(nw,1),opts);
% Make Plot
subplot(211);
tt = 1:nw; % grid of coefficient indices
plot(tt,wts,tt,wmap);
title('true weights and MAP estimate'); box off;
xlabel('coefficient #'); ylabel('weight');
legend('true weights', 'MAP estim');
%% 3. Evaluate Full Laplace Evidence on a grid
% set of grid values to consider
ngrid = 25; % number of grid points
vargrid = logspace(vlims(1),vlims(2),ngrid);
% allocate storage
logLaplaceEv = zeros(ngrid,1);
for jj = 1:ngrid
logLaplaceEv(jj) = compLogLaplaceEv(vargrid(jj),mstruct,wmap,opts);
end
% Find maximum (from grid values);
[logLaplEvMax,ii]=max(logLaplaceEv);
varHat = vargrid(ii);
subplot(212);
plot(vargrid,logLaplaceEv,varHat,logLaplEvMax,'*');
xlabel('sig^2'); ylabel('log-evidence');
title('log-evidence vs. theta'); box off;
%theta0=varHat
%% 4. Now variance approximate evidence functions on the same grid
% First, compute MAP estimate given this value of theta
lfunc = @(w)(neglogpost_GLM(w,theta0,mstruct));
wmap0 = fminunc(lfunc,zeros(nw,1),opts); % get MAP estimate
% Compute gradient and Hessian of negative log-likelihood
[negL0,dnegL0,ddnegL0] = mstruct.neglogli(wmap0,mstruct.liargs{:});
% Compute Hessian of log-prior
[logpri,~,negCinv] = mstruct.logprior(wmap0,theta0,mstruct.priargs{:});
% Compute log-evidence using Laplace approximation
postHess0 = ddnegL0-negCinv; % posterior Hessian
logpost = .5*logdet(postHess0)-(nw/2)*log(2*pi); % log-posterior at wmap
logEv0 = (-negL0) + logpri - logpost; % log-evidence
% Compute Hessian of negative log-likelihood times log-likelihood mean
Lmu0 = postHess0*wmap0;
% Compute some constants
log2piconst = - nw/2*log(2*pi); % normalizing constant for log prior & posterior
norm2wmap0 = sum(wmap0.^2); % squared L2 norm of wmap0
% allocate storage for approximate Laplace Evidence (ALE)
logALE = zeros(ngrid,1);
logALE_Newton = zeros(ngrid,1);
for jj = 1:ngrid
% make inverse prior covariance
Cinv_giventheta = (1/vargrid(jj))*eye(nw); % inverse prior covariance
logdetCinv = -nw*log(vargrid(jj)); % log-determinant of inv prior cov
% Compute updated posterior Hessian
Hess_giventheta = (ddnegL0+Cinv_giventheta);
% Compute updated w_MAP
wmap_giventheta = Hess_giventheta\Lmu0;
% =======================
% compute ALE
% =======================
% Compute log prior
logp_ale = -.5*sum(wmap_giventheta.^2)/vargrid(jj)+ .5*logdetCinv + log2piconst;
% Compute negative log-likelihood
logL_ale = -mstruct.neglogli(wmap_giventheta,mstruct.liargs{:});
% Compute log posterior
logpost_ale = .5*logdet(Hess_giventheta) + log2piconst; % (note quadratic term is 0)
% Compute ALE (moving)
logALE(jj) = logL_ale + logp_ale - logpost_ale;
% ===========================================================
% Compute Newton-Laplace ALE
% ===========================================================
% Compute updated Hessian for log-likelihood using new wmap
[~,~,ddnL1] = mstruct.neglogli(wmap_giventheta,mstruct.liargs{:});
Hess_updated = (ddnL1+Cinv_giventheta);
% Compute log posterior (moving)
logpost_Newton = .5*logdet(Hess_updated) + log2piconst; % (note quadratic term is 0)
% Compute ALE (Newton-moving)
logALE_Newton(jj) = logL_ale + logp_ale - logpost_Newton;
% Compute ALE (Newton-moving)
end
%% 5. Perform numerical optimization of NewtonLaplace ALE
% Set optimization parameters for fminunc
opts2 = optimoptions('fminunc','SpecifyObjectiveGradient',false,'display','iter');
%opts2 = optimoptions('fminunc','algorithm','quasi-newton','SpecifyObjectiveGradient',true,'display','iter');
%opts2 = optimoptions('fminunc','algorithm','trust-region','SpecifyObjectiveGradient',true,'display','iter');
% make function handle
f_neglogEv = @(logtheta)(neglogEv_NewtonLaplace(logtheta,mstruct,wmap0,ddnegL0,Lmu0));
% Compute MAP estimate
[logtheta1,negALE1] = fminunc(f_neglogEv,log(theta0),opts2);
varHat_ALE = exp(logtheta1); % transform log(sigma^2) to sigma^2
logALEnewton_max = -negALE1; % flip sign to get newton-ALE at maximum
%% Make plot of LE and ALE
subplot(212);
plot(vargrid,logLaplaceEv,vargrid,logALE,vargrid,logALE_Newton,...
theta0,logEv0,'ko',varHat,logLaplEvMax,'*', varHat_ALE, logALEnewton_max,'s');
xlabel('prior variance (\sigma^2)'); ylabel('log-evidence');
title('log-evidence vs \sigma^2'); box off;
legend('Laplace Evidence', 'ALE', 'ALE-Newton','theta_0','theta max','Newton-ALE max', ...
'location', 'southeast');
set(gca,'ylim',[min(logLaplaceEv)-1,max([logALE;logLaplaceEv])+1]);