-
Notifications
You must be signed in to change notification settings - Fork 5
/
partitiondata.m
50 lines (43 loc) · 2.12 KB
/
partitiondata.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
function [Xtrain, Xtest, ytrain, ytest] = partitiondata(X, y, train_percent, classes)
%PARTITIONDATA partitions the dataset (X,y) into training and testing
% datasets. Splits the data to have equal portions of class labels in
% each set.
%
% Parameters:
% X [NumSubjects, NumTimeSlices, NumFeatures]: fMRI scan data
% y [NumSubjects, 1]: class labels
% train_percent: percentage of dataset to allocate for training
% classes: structure of classes to consider. For example, to disclude (2, impulsive)
% types from the dataset, let classes = {0, 1, 3}; To combine classes,
% for example to do binary ADHD classification, let classes = {0, 1:3};
[numS, numT, numF] = size(X);
% calculate class counts
class_counts = zeros(1, length(classes));
for cidx = 1:length(classes)
class_counts(cidx) = sum(y == cidx-1);
end
% partition data
% preallocate matrices for speed
num_train_subjects = sum(ceil(train_percent .* class_counts));
num_test_subjects = numS - num_train_subjects;
Xtrain = zeros(num_train_subjects, numT, numF);
Xtest = zeros(num_test_subjects, numT, numF);
ytrain = zeros(num_train_subjects, 1);
ytest = zeros(num_test_subjects, 1);
train_samples = 0;
test_samples = 0;
for c = 0:length(classes)-1
sample_inds = find(y == c);
num_train = ceil(train_percent * length(sample_inds));
num_test = length(sample_inds) - num_train;
Xtrain(train_samples+1:train_samples + num_train,:,:) = X(sample_inds(1:num_train),:,:);
Xtest(test_samples+1:test_samples + num_test,:,:) = X(sample_inds(num_train+1:length(sample_inds)),:,:);
ytrain(train_samples+1:train_samples + num_train) = y(sample_inds(1:num_train));
ytest(test_samples+1:test_samples + num_test) = y(sample_inds(num_train+1:length(sample_inds)));
train_samples = train_samples + num_train;
test_samples = test_samples + num_test;
end
% sanity checks
assert(num_train_subjects+num_test_subjects == numS, 'The number of train+test subjects should equal the original number of subjects');
assert(all(0:length(classes)-1 == unique([ytrain; ytest])'), 'The classes in ytrain and ytest should be the same as the classes parameter');
end