forked from sdemyanov/ConvNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnnclassify.m
32 lines (28 loc) · 887 Bytes
/
cnnclassify.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
function pred = cnnclassify(layers, weights_in, test_x, type)
if (length(size(test_x)) == 3)
% insert singletone maps index
test_x = permute(test_x, [1 2 4 3]);
end;
tic;
if strcmp(type, 'mexfun')
test_x = permute(test_x, [2 1 3 4]);
if (isfield(layers{1}, 'mean'))
layers{1}.mean = permute(layers{1}.mean, [2 1 3]);
end;
if (isfield(layers{1}, 'stdev'))
layers{1}.stdev = permute(layers{1}.stdev, [2 1 3]);
end;
pred = classify_mex(layers, weights_in, test_x);
pred = permute(pred, [2 1]);
%z = logsumexp(pred, 2);
%pred = exp(bsxfun(@minus, pred, z));
elseif strcmp(type, 'matlab')
pred = classify_mat(layers, weights_in, test_x);
%z = logsumexp(pred, 2);
%pred = exp(bsxfun(@minus, pred, z));
else
error('"%s" - wrong type, must be either "mexfun" or "matlab"', type);
end;
t = toc;
disp(['Total classification time: ' num2str(t)]);
end