-
Notifications
You must be signed in to change notification settings - Fork 2
/
getFoldedIdx.m
32 lines (25 loc) · 874 Bytes
/
getFoldedIdx.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
function [ foldIdxPerClass ] = getFoldedIdx( obsPerClass, nFolds )
%An internal function used to split data into folds for
%cross-validation.
%obsPerClass is a C x 1 vector containing the number of observations
%for each of C classes
%nFolds is an integer specifying the number of folds
nClasses = length(obsPerClass);
foldIdxPerClass = cell(nClasses,nFolds);
for c=1:nClasses
minPerFold = floor(obsPerClass(c)/nFolds);
remainder = obsPerClass(c)-minPerFold*nFolds;
if remainder>0
currIdx = 1:(minPerFold+1);
else
currIdx = 1:minPerFold;
end
for x=1:nFolds
foldIdxPerClass{c,x} = currIdx;
currIdx = currIdx + length(currIdx);
if x==remainder
currIdx(end)=[];
end
end
end
end