Skip to content

Commit

Permalink
adding additional output variables, "gammas" and "ll_norm"
Browse files Browse the repository at this point in the history
  • Loading branch information
sjvenditto committed Dec 14, 2021
1 parent a5ce8b3 commit 0b8ab2b
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 3 deletions.
9 changes: 7 additions & 2 deletions fitGlmHmm.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [model,ll,iter] = fitGlmHmm(y,x,w0,A0,varargin)
function [model,ll,iter,gammas,ll_norm] = fitGlmHmm(y,x,w0,A0,varargin)
%[model, ll, iter] = fitGlmHmm(y, x, w0, A0, varargin)
% Fits a GLM-HMM to data for a logistic sigmoid emissions models. Number of
% latent states is implicitly taken from the second dimension of w0.
Expand Down Expand Up @@ -36,6 +36,11 @@
% ll - (1 x NIter) log-likelihood of the model fit
% iter - number of iterations it took for the model to
% converge
% gammas - (NStates x NTrials) marginal posterior distribution
% from final fit (from runBaumWelch)
% ll_norm - (float) normalized log-likelihood of final fit
% computed as ll_norm = exp(ll(end)/NTrials) (from
% runBaumWelch)
%
% Example call:
% [model,ll,iter] = fitGlmHmm(y,x,w0,A0,'new_sess',new_sess,'maxiter',2000)
Expand Down Expand Up @@ -129,7 +134,7 @@
strcr = repmat('\b',1,length(strout));

end
[~, ~, ll(i+1)] = runBaumWelch(y,x,model,p.new_sess);
[gammas, ~, ll(i+1),ll_norm] = runBaumWelch(y,x,model,p.new_sess);
ll(isnan(ll)) = []; % get rid of nan values
iter = i;
fprintf(' done\n');
Expand Down
102 changes: 102 additions & 0 deletions runBaumWelch.asv
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
function [gammas,xis,ll] = runBaumWelch(y,x,model,new_sess)
%[gammas,xis,ll] = runBaumWelch(y, x, model, new_sess)
% Runs the Baum-Welch algorithm to compute latent state posterior
% distributions. Assumes a model of the form by fitGlmHmm.m
%
% Inputs:
% y - (1 x NTrials) obsesrvations
% x - (NFeatures x NTrials) design matrix
% model - GLM-HMM model parameters
% .w - (NFeatures x NStates) latent state GLM weights
% .pi - (NStates x 1) initial latent state probability
% .A - (NStates x NStates) latent state transition matrix
% new_sess - (1 x NTrials) logical array with 1s
% denoting the start of a new session. If unspecified
% or empty, treats the full set of trials as a single
% session.
%
% Outputs:
% gammas - (Nstates x NTrials) Marginal posterior distribution
% xis - (Nstates x Nstates x Ntrials) Joint posterior
% distribution
% ll - log-likelihood of the fit
% ll_norm - normalized log-li

%% initialize variables
nstates = size(model.w,2); % number of latent states
T = size(y,2); % number of time steps

if ~exist('new_sess','var') || isempty(new_sess)
new_sess = false(size(y));
new_sess(1) = true;
end

% data likelihood p(y|z) from emissions model
tmpy = 1./(1+exp(-model.w'*x)); %f(model.w,x(:,1));
py_z = y.*tmpy + (1-y).*(1-tmpy);


%% %%%%%%%%%%%%%%%%%%%%% Forward recursion %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% initialize variables
alphas = nan(nstates,T); % forward pass alphas
c = nan(T,1); % variable to store marginal

for t=1:T
if new_sess(t)
alphas(:,t) = model.pi.*py_z(:,t); % initial alpha, equation 13.37, reinitialize for new sessions
else
%alphas(:,t) = py_z(:,t)'.*sum(alphas(:,t-1).*model.A); % equation 13.36
alphas(:,t) = py_z(:,t).*(model.A'*alphas(:,t-1)); % equation 13.36
end

c(t) = sum(alphas(:,t)); % store marginal likelihood
if c(t) == 0, keyboard; end % this shouldn't happen, check if weights are out of control
alphas(:,t) = alphas(:,t)/c(t); % normalize 13.59
end
ll = sum(log(c)); % store log-likelihood

%% %%%%%%%%%%%%%%%%%%%%%% Backward recursion %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% initialize variables
betas = nan(nstates,T); % backward pass beta
betas(:,end) = ones(nstates,1); % initial beta (13.39)

% solve for remaining betas
for t = T-1:-1:1
if new_sess(t+1)
betas(:,t) = ones(nstates,1); % reinitialize backward pass if end of session
else
%betas(:,t) = sum(betas(:,t+1)'.*py_z(:,t+1)'.*model.A,2); % equation 13.38
betas(:,t) = model.A*(betas(:,t+1).*py_z(:,t+1)); % equation 13.38
betas(:,t) = betas(:,t)/c(t+1); % normalize 13.62
end
end


%% %%%%%%%%%%%%%%%%% Compute posterior distributions %%%%%%%%%%%%%%%%%%%%%%
% gamma - eq. 13.32, 13.64
gammas = alphas.*betas;

% trials to compute xi
% don't use trials at the start of any session to compute the
% transition matrix
Ts = 1:T;
Ts(new_sess) = [];

% xi - eq. 13.43, 13.65 (I'm pretty sure this equation is wrong, I derived
% it and it should be divided, not multiplied, by c(t))
% xis = nan(nstates,nstates,length(Ts));
% for t = 1:length(Ts)
% %xis(:,:,t) = alphas(:,Ts(t)-1).*py_z(:,Ts(t))'.*model.A.*betas(:,Ts(t))'/c(Ts(t));
% xis(:,:,t) = (alphas(:,Ts(t)-1)*(py_z(:,Ts(t)).*betas(:,Ts(t)))').*model.A/c(Ts(t));
% end

% NOTE: this isn't the "true" xi, but instead xi summed across time steps
% (the third dimention in the commented definition above). I'm using matrix
% math to compute the summed xi for speed, since we're using the lagrange
% multiplier to optimize the transition matrix in runGlmHmm.m. In a version
% that fits a transition matrix dependent on latent state, the "full" xis
% would need to be computed as commented out above
xis = ((alphas(:,Ts-1)./c(Ts)')*(py_z(:,Ts).*betas(:,Ts))').*model.A;

end

5 changes: 4 additions & 1 deletion runBaumWelch.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [gammas,xis,ll] = runBaumWelch(y,x,model,new_sess)
function [gammas,xis,ll,ll_norm] = runBaumWelch(y,x,model,new_sess)
%[gammas,xis,ll] = runBaumWelch(y, x, model, new_sess)
% Runs the Baum-Welch algorithm to compute latent state posterior
% distributions. Assumes a model of the form by fitGlmHmm.m
Expand All @@ -20,6 +20,8 @@
% xis - (Nstates x Nstates x Ntrials) Joint posterior
% distribution
% ll - log-likelihood of the fit
% ll_norm - normalized log-likelihood computed as
% ll_norm = exp(ll/NTrials)

%% initialize variables
nstates = size(model.w,2); % number of latent states
Expand Down Expand Up @@ -53,6 +55,7 @@
alphas(:,t) = alphas(:,t)/c(t); % normalize 13.59
end
ll = sum(log(c)); % store log-likelihood
ll_norm = exp(ll/T);

%% %%%%%%%%%%%%%%%%%%%%%% Backward recursion %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% initialize variables
Expand Down

0 comments on commit 0b8ab2b

Please sign in to comment.