diff --git a/CNN/cnntest.m b/CNN/cnntest.m index 64ba271..5726b28 100644 --- a/CNN/cnntest.m +++ b/CNN/cnntest.m @@ -1,9 +1,9 @@ -function [er, bad] = cnntest(net, x, y) +function [predicted_label, er, bad] = cnntest(net, x, y) % feedforward net = cnnff(net, x); [~, h] = max(net.o); [~, a] = max(y); bad = find(h ~= a); - + predicted_label = h-1; er = numel(bad) / size(y, 2); end diff --git a/README.md b/README.md index be277e4..94be00e 100644 --- a/README.md +++ b/README.md @@ -180,7 +180,7 @@ opts.numepochs = 1; cnn = cnntrain(cnn, train_x, train_y, opts); -[er, bad] = cnntest(cnn, test_x, test_y); +[predicted_label, er, bad] = cnntest(cnn, test_x, test_y); %plot mean squared error figure; plot(cnn.rL); diff --git a/tests/test_example_CNN.m b/tests/test_example_CNN.m index 05ce9d9..b502ae3 100644 --- a/tests/test_example_CNN.m +++ b/tests/test_example_CNN.m @@ -28,7 +28,7 @@ cnn = cnnsetup(cnn, train_x, train_y); cnn = cnntrain(cnn, train_x, train_y, opts); -[er, bad] = cnntest(cnn, test_x, test_y); +[predicted_label, er, bad] = cnntest(cnn, test_x, test_y); %plot mean squared error figure; plot(cnn.rL);