-
Notifications
You must be signed in to change notification settings - Fork 5
/
hmmtrain_mvg.m
30 lines (25 loc) · 1.22 KB
/
hmmtrain_mvg.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
function [hmm] = hmmtrain_mvg(Xtrain, N, random_init, rand_restarts)
%HMMTRAIN_MVG trains an HMM with a multivariate gaussian emission distribution
% from the training data Xtrain with the number of states N.
% random_init is a boolean specifying whether to approximate certain parameters
% from the training data.
% preprocess the observations to be used by the EM learning function
Obs = permute(Xtrain, [1,3,2]);
Obs = cellfun(@squeeze, num2cell(Obs, [2,3]), 'UniformOutput', false);
vars = {}; % by default, hmmFit function randomizes initial model parameters
if ~random_init
% derive initial model parameters from training data
Xtrain_tempavg = squeeze(mean(Xtrain,2));
mu0 = repmat(mean(Xtrain_tempavg, 1), [N,1])';
Sigma0 = repmat(cov(Xtrain_tempavg), [1,1,N]);
emission0 = condGaussCpdCreate(mu0, Sigma0);
% equiprobable state distributions
pi0 = ones(N,1)/N;
A0 = ones(N)/N;
vars = {'pi0', pi0, 'trans0', A0, 'emission0', emission0};
end
% train HMM
hmm = hmmFit(Obs, N, 'gauss', ...
'maxIter', 100, 'convTol', 1e-7, ...
'nRandomRestarts', rand_restarts, 'verbose', false, vars{:});
end