-
Notifications
You must be signed in to change notification settings - Fork 0
/
predictClassification2.m
executable file
·69 lines (60 loc) · 3.13 KB
/
predictClassification2.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% %
% MASTER'S THESIS %
% %
% Student: Martin Hellwagner %
% Supervisor: Prof. Stefan Weinzierl (TU Berlin) %
% Advisor: Prof. Anders Friberg (KTH Stockholm) %
% %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% %
% Based on the code by Prof. Anders Friberg %
% Re-written and modified by Martin Hellwagner %
% %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [accuracyPLS,accuracySVM] = predictClassification2(features,groundTruth,components,speakerIndices)
% initializating values
numberSpeakers = max(speakerIndices); % number of speakers (for cross-validation)
numberIterations = 100; % number of iterations
numberCorrectPLS = 0; % number of correctly classified PLS fragments
numberTotalPLS = 0; % total number of PLS fragments
numberCorrectSVM = 0; % number of correctly classified SVM fragments
numberTotalSVM = 0; % total number of SVM fragments
% computing Z-scores of data arrays
zX = zscore(features.data);
% running cross-validation
for i = 1:numberIterations
for j = 1:numberSpeakers
trainingIndices = [];
validationIndices = [];
index1 = 1;
index2 = 1;
for k = 1:length(speakerIndices)
if (speakerIndices(k) == j)
validationIndices(index1) = k;
index1 = index1+1;
else
trainingIndices(index2) = k;
index2 = index2+1;
end
end
trainingX = zX(trainingIndices,:);
validationX = zX(validationIndices,:);
trainingY = groundTruth.data(trainingIndices,:);
validationY = groundTruth.data(validationIndices,:);
% computing results for PLS
[~,~,~,~,BETA,~,~,~] = plsregress(trainingX,trainingY,components);
resultPLS = [ones(size(validationX,1),1) validationX]*BETA;
numberCorrectPLS = numberCorrectPLS+sum(validationY == round(resultPLS));
numberTotalPLS = numberTotalPLS+length(validationY);
% computing results for SVM
model = fitcsvm(trainingX,trainingY); % MATLAB's internal SVM classification is used because
resultSVM = predict(model,validationX); % of the large overhead using the LIBSVM library
numberCorrectSVM = numberCorrectSVM+sum(validationY == resultSVM);
numberTotalSVM = numberTotalSVM+length(validationY);
end
end
% calculating overall accuracy of prediction
accuracyPLS = numberCorrectPLS/numberTotalPLS;
accuracySVM = numberCorrectSVM/numberTotalSVM;
end