Skip to content

Commit

Permalink
#DONE fixing GLMbi code (bilinear filter); demo 2 works!!
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Pillow committed Jan 22, 2016
1 parent a2b7d46 commit 3ce6716
Show file tree
Hide file tree
Showing 18 changed files with 350 additions and 451 deletions.
54 changes: 27 additions & 27 deletions demos/demo2_GLM_spatialStim.m
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
ttk = dtStim*(-nkt+1:0)'; % time bins for filter
kx = 1./sqrt(2*pi*4).*exp(-(xxk-nkx/2).^2/5);
k = kt*kx'; % Make space-time separable filter
k = k./norm(k(:))*2;
k = k./norm(k(:))/1.5;

% Insert into glm structure (created with default history filter)
ggsim = makeSimStruct_GLM(nkt,dtStim,dtSp); % Create GLM structure with default params
ggsim.k = k./norm(k(:))*3; % Insert into simulation struct
ggsim.k = k; % Insert into simulation struct
ggsim.dc = 2;

% === Make Fig: model params =======================
clf; subplot(3,3,[1 4]); % ------------------------------------------
Expand All @@ -46,7 +47,7 @@

%% 2. Generate some training data ========================================

slen = 5000; % Stimulus length (frames); More samples gives better fit
slen = 10000; % Stimulus length (frames); More samples gives better fit
Stim = round(rand(slen,nkx))*2-1; % Run model on long, binary stimulus
[tsp,sps,Itot,Isp] = simGLM(ggsim,Stim); % run model
nsp = length(tsp);
Expand Down Expand Up @@ -86,7 +87,7 @@
nkbasis = 8; % number of basis vectors for representing k
nhbasis = 8; % number of basis vectors for representing h
hpeakFinal = .1; % time of peak of last basis vector for h
gg0 = makeFittingStruct_GLM(dtStim,dtSp,nkt,nkbasis,nhbasis,hpeakFinal,sta);
gg0 = makeFittingStruct_GLM(dtStim,dtSp,nkt,nkbasis,sta,nhbasis,hpeakFinal);
gg0.sps = sps; % Insert binned spike train into fitting struct
gg0.mask = exptmask; % insert mask (optional)
gg0.ihw = randn(size(gg0.ihw))*1; % initialize spike-history weights randomly
Expand All @@ -103,16 +104,16 @@
%% 5. Fit GLM ("bilinear stim filter version") via max likelihood

% Initialize params for fitting --------------
Filter_rank = 1; % Number of column/row vector pairs to use
gg0 = makeFittingStruct_GLMbi(sta,dtSp,Filter_rank);
gg0.tsp = tsp;
gg0.mask = exptmask;
[logli1,rr1,tt] = neglogli_GLM(gg0,Stim); % Compute logli of initial params
fprintf('Initial value of negative log-li (GLMbi): %.3f\n', logli1);
k_rank = 1; % Number of column/row vector pairs to use
gg0b = makeFittingStruct_GLMbi(k_rank,dtStim,dtSp,nkt,nkbasis,sta,nhbasis,hpeakFinal);
gg0b.sps = sps;
gg0b.mask = exptmask;
logli0b = neglogli_GLM(gg0b,Stim); % Compute logli of initial params
fprintf('Initial value of negative log-li (GLMbi): %.3f\n', logli0b);

% Do ML estimation of model params
opts = {'display', 'iter', 'maxiter', 500};
[gg2, negloglival2] = MLfit_GLMbi(gg0,Stim,opts); % do ML (requires optimization toolbox)
opts = {'display', 'iter'};
[gg2, negloglival2] = MLfit_GLMbi(gg0b,Stim,opts); % do ML (requires optimization toolbox)


%% 6. Plot results ====================
Expand All @@ -121,35 +122,34 @@
subplot(231); % True filter % ---------------
imagesc(ggsim.k); colormap gray;
title('True Filter');ylabel('time');

subplot(232); % sta % ------------------------
imagesc(sta);
title('raw STA');
ylabel('time');
title('raw STA'); ylabel('time');

subplot(233); % sta-projection % ---------------
imagesc(gg0.k)
title('low-rank STA');
imagesc(gg0.k); title('low-rank STA');

subplot(234); % estimated filter % ---------------
imagesc(gg1.k)
title('ML estimate: full filter'); xlabel('space'); ylabel('time');
imagesc(gg1.k); title('ML estimate: full filter'); xlabel('space'); ylabel('time');

subplot(235); % estimated filter % ---------------
imagesc(gg2.k)
title('ML estimate: bilinear filter'); xlabel('space');
imagesc(gg2.k); title('ML estimate: bilinear filter'); xlabel('space');

subplot(236); % ----------------------------------
plot(ggsim.iht,exp(ggsim.ih),'k', gg1.iht,exp(gg1.ihbas*gg1.ihw),'b',...
gg2.iht, exp(gg2.ihbas*gg2.ihw), 'r');
title('post-spike kernel');
axis tight;
title('post-spike kernel'); axis tight;
xlabel('time after spike (s)');

% Errors in STA and ML estimate
ktmu = normalizecols([mean(ggsim.k,2),mean(gg1.k,2),mean(gg2.k,2)]);
kxmu = normalizecols([mean(ggsim.k)',mean(gg1.k)',mean(gg2.k)']);
Errs_T = [subspace(ktmu(:,1),ktmu(:,2)), subspace(ktmu(:,1),ktmu(:,3))]
Errs_X = [subspace(kxmu(:,1),kxmu(:,2)), subspace(kxmu(:,1),kxmu(:,3))]

errfun = @(x,y)(sum((x(:)-y(:)).^2));
Errs_Total = [errfun(ggsim.k,gg1.k), errfun(ggsim.k, gg2.k)]
msefun = @(k)(sum((k(:)-ggsim.k(:)).^2)); % error function
fprintf(['K-filter errors (GLM vs. GLMbilinear):\n', ...
'Temporal error: %.3f %.3f\n', ...
' Spatial error: %.3f %.3f\n', ...
' Total error: %.3f %.3f\n'], ...
subspace(ktmu(:,1),ktmu(:,2)), subspace(ktmu(:,1),ktmu(:,3)), ...
subspace(kxmu(:,1),kxmu(:,2)), subspace(kxmu(:,1),kxmu(:,3)), ...
msefun(gg1.k), msefun(gg2.k));
10 changes: 5 additions & 5 deletions glmtools_fitting/Loss_GLM_logli.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@
% Inputs:
% prs = [kprs - weights for stimulus kernel
% dc - dc current injection
% ihprs - weights on post-spike current];
% ihprs - weights on post-spike current]
%
% Outputs:
% logli = negative log likelihood of spike train
% dL = gradient with respect to prs
% H = hessian

% Extract some vals from Xstruct (Opt Prs);
nktot = Xstruct.nkx*Xstruct.nkt; % total # params for k
dt = Xstruct.dtSp; % absolute bin size for spike train (in sec)

% Unpack GLM prs;
nktot = Xstruct.nkx*Xstruct.nkt; % total # params for k
kprs = prs(1:nktot);
dc = prs(nktot+1);
ihprs = prs(nktot+2:end);
Expand All @@ -32,6 +31,7 @@
ihflag = Xstruct.ihflag; % flag
rlen = Xstruct.rlen; % number of bins in spike train vector
nsp = (sum(bsps)); % number of spikes
dt = Xstruct.dtSp; % absolute bin size for spike train (in sec)

% -------- Compute sum of filter reponses -----------------------
if Xstruct.ihflag
Expand Down Expand Up @@ -65,7 +65,7 @@
if ihflag, dLdh1 = (frac1'*XXsp(bsps,:))';
end

% Combine terms
% Combine Term 1 and Term 2
dLdk = dLdk0*dt - dLdk1;
dLdb = dLdb0*dt - dLdb1;
if ihflag, dLdh = dLdh0*dt - dLdh1;
Expand Down
Loading

0 comments on commit 3ce6716

Please sign in to comment.