-
Notifications
You must be signed in to change notification settings - Fork 5
/
hmmtrain_gmm.m
23 lines (19 loc) · 987 Bytes
/
hmmtrain_gmm.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
function [hmm] = hmmtrain_gmm(Xtrain, N, K, random_init, rand_restarts)
%HMMTRAIN_GMM trains an HMM with a Gaussian mixture model emission distribution
% from the training data Xtrain with the number of states N and number of mixtures K.
% random_init is a boolean specifying whether to randomly initialize certain model parameters.
% preprocess the observations to be used by the EM learning function
Obs = permute(Xtrain, [1,3,2]);
Obs = cellfun(@squeeze, num2cell(Obs, [2,3]), 'UniformOutput', false);
vars = {}; % by default, hmmFit function randomizes initial model parameters
if ~random_init
% equiprobable state distributions
pi0 = ones(N,1)/N;
A0 = ones(N)/N;
vars = {'pi0', pi0, 'trans0', A0};
end
% train HMM
hmm = hmmFit(Obs, N, 'mixGaussTied', 'nmix', K, ...
'maxIter', 100, 'convTol', 1e-7, ...
'nRandomRestarts', rand_restarts, 'verbose', false, vars{:});
end