forked from OHBA-analysis/HMM-MAR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hmmfe.m
184 lines (172 loc) · 5.87 KB
/
hmmfe.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
function [fe,ll] = hmmfe(data,T,hmm,Gamma,Xi,preproc,grouping)
% Computes the Free Energy of an HMM
%
% INPUT
% data observations, either a struct with X (time series) and C (classes, optional)
% or just a matrix containing the time series
% T length of series
% hmm hmm structure
% Gamma probability of states conditioned on data (optional)
% Xi joint probability of past and future states conditioned on data (optional)
%
% OUTPUT
% fe the variational free energy
% ll log-likelihood per time point
%
% Author: Diego Vidaurre, OHBA, University of Oxford (2017)
% to fix potential compatibility issues with previous versions
hmm = versCompatibilityFix(hmm);
mixture_model = isfield(hmm.train,'id_mixture') && hmm.train.id_mixture;
p = hmm.train.lowrank; do_HMM_pca = (p > 0);
if nargin<4, Gamma = []; end
if nargin<5, Xi = []; end
if nargin<6 || isempty(preproc), preproc = 1; end
if nargin<7 , grouping = ones(length(T),1); end
if size(grouping,1)==1, grouping = grouping'; end
if isstruct(data), data = data.X; end
options = hmm.train;
hmm.train.grouping = grouping;
stochastic_learn = isfield(options,'BIGNbatch') && ...
(options.BIGNbatch < length(T) && options.BIGNbatch > 0);
if iscell(T)
for i = 1:length(T)
if size(T{i},1)==1, T{i} = T{i}'; end
end
if size(T,1)==1, T = T'; end
end
if ~stochastic_learn
if iscell(T)
T = cell2mat(T);
end
checkdatacell;
%data = data2struct(data,T,hmm.train);
else
if ~iscell(data)
N = length(T);
dat = cell(N,1); TT = cell(N,1);
for i=1:N
t = 1:T(i);
dat{i} = data(t,:); TT{i} = T(i);
try data(t,:) = [];
catch, error('The dimension of data does not correspond to T');
end
end
if ~isempty(data)
error('The dimension of data does not correspond to T');
end
data = dat; T = TT; clear dat TT
end
end
if preproc && ~stochastic_learn
% Standardise data and control for ackward trials
data = standardisedata(data,T,options.standardise);
% Filtering
if ~isempty(options.filter)
data = filterdata(data,T,options.Fs,options.filter);
end
% Detrend data
if options.detrend
data = detrenddata(data,T);
end
% Leakage correction
if options.leakagecorr ~= 0
data = leakcorr(data,T,options.leakagecorr);
end
% Hilbert envelope
if options.onpower
data = rawsignal2power(data,T);
end
% Leading Phase Eigenvectors
if options.leida
data = leadingPhEigenvector(data,T);
end
% pre-embedded PCA transform
if length(options.pca_spatial) > 1 || (options.pca_spatial > 0 && options.pca_spatial ~= 1)
if isfield(options,'As')
data.X = bsxfun(@minus,data.X,mean(data.X));
data.X = data.X * options.As;
else
[options.As,data.X] = highdim_pca(data.X,T,options.pca_spatial);
options.pca_spatial = size(options.As,2);
end
end
% Embedding
if length(options.embeddedlags)>1
[data,T] = embeddata(data,T,options.embeddedlags);
end
% PCA transform
if isfield(options,'A')
data = bsxfun(@minus,data,mean(data)); % must center
data = data * options.A;
% Standardise principal components and control for ackward trials
data = standardisedata(data,T,options.standardise_pc);
end
% Downsampling
if options.downsample > 0
[data,T] = downsampledata(data,T,options.downsample,options.Fs);
end
end
% get residuals
if ~stochastic_learn && ~do_HMM_pca
if isfield(hmm.state(1),'W')
ndim = size(hmm.state(1).W.Mu_W,2);
else
ndim = size(hmm.state(1).Omega.Gam_rate,2);
end
if ~isfield(hmm.train,'Sind')
orders = formorders(hmm.train.order,hmm.train.orderoffset,hmm.train.timelag,hmm.train.exptimelag);
hmm.train.Sind = formindexes(orders,hmm.train.S) == 1;
if ~hmm.train.zeromean, hmm.train.Sind = [true(1,ndim); hmm.train.Sind]; end
end
residuals = getresiduals(data,T,hmm.train.S,hmm.train.maxorder,hmm.train.order,...
hmm.train.orderoffset,hmm.train.timelag,hmm.train.exptimelag,hmm.train.zeromean);
else
residuals = [];
end
% get state time courses
if isempty(Gamma) || isempty(Xi)
if ~(mixture_model && ~isempty(Gamma)) % we have Gamma and Xi is not needed
[Gamma,Xi] = hmmdecode(data,T,hmm,0,residuals,0);
end
end
if stochastic_learn
if hmm.train.downsample > 0
downs_ratio = (hmm.train.downsample/hmm.train.Fs);
else
downs_ratio = 1;
end
Tmat = downs_ratio*cell2mat(T);
if length(hmm.train.embeddedlags)>1
maxorder = 0;
L = -min(hmm.train.embeddedlags) + max(hmm.train.embeddedlags);
Tmat = Tmat - L;
else
maxorder = hmm.train.maxorder;
end
% P/Pi KL
fe = sum(evalfreeenergy([],[],[],[],hmm,[],[],[0 0 0 1 0]));
% Gamma entropy&LL
fe = fe + sum(evalfreeenergy([],Tmat,Gamma,Xi,hmm,[],[],[1 0 1 0 1]));
tacc = 0; tacc2 = 0; fell = 0; ll = [];
for i = 1:length(T)
[X,XX,residuals,Ti] = loadfile(data{i},T{i},options);
if hmm.train.lowrank > 0, XX = X; end
t = (1:(sum(Ti)-length(Ti)*maxorder)) + tacc;
t2 = (1:(sum(Ti)-length(Ti)*(maxorder+1))) + tacc2;
tacc = tacc + length(t); tacc2 = tacc2 + length(t2);
if ~isempty(Xi)
[f,l] = evalfreeenergy(X,Ti,Gamma(t,:),Xi(t2,:,:),hmm,residuals,XX,[0 1 0 0 0]);
fell = fell + sum(f); % data likelihood
ll = [l; ll];
else
[f,l] = evalfreeenergy(X,Ti,Gamma(t,:),[],hmm,residuals,XX,[0 1 0 0 0]);
fell = fell + sum(f); % data likelihood
ll = [l; ll];
end
end
fe = fe + fell;
else
[fe,ll] = evalfreeenergy(data,T,Gamma,Xi,hmm,residuals);
fe = sum(fe);
end
end