-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.m
87 lines (66 loc) · 2.51 KB
/
main.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
% This script constructs a graph by interpreting the 28x28 pixel images from
% the popular MNIST digit dataset as a pointcloud in a 784-dimensional space
% and subsequently performs semi-supervised classification by solving the
% graph p-Laplacian equation on the graph.
%
% Authors: Daniel Tenbrinck, Samira Kabri, Tim Roith,
% Friedrich--Alexander-Universitaet Erlangen--Nuernberg
%% CLEAN UP
clc; clear; close all;
% add needed subdirectories
addpath('./algorithm');
addpath('./data');
addpath('./graph_construction');
%% Output configuration
sLine = strcat(repmat('^',1,50),'\n');
cLine = strcat(repmat('.',1,50),'\n');
%% LOAD SETTINGS FOR DATA AND GRAPH CONSTRUCTION AND GET LABELS
% LOAD MNIST DATA (10k images)
%%each column represents a 28x28 image
images = loadMNISTImages('t10k-images.idx3-ubyte');
[lGroups, realLabels] = labelgroups('t10k-labels.idx1-ubyte');
%% SET COORDINATES
%%each picture's gray values are treated as coordinates in 784-dim space
coordinates = images';
% SET GRAPH PROPERTIES FOR POINT CLOUDS
% define data properties
data.coordinates = coordinates;
% define parameters for k nearest neighbor search
neighborhood.type = 'kNN';
neighborhood.numberOfNeighbors = 7;
% define distance function NO INFLUENCE ON TANGENT DISTANCE
distanceFunction = 'Euclidean'; % 'Euclidean' or 'Tangent'
% define weight function
weightFunction.function = @(x)10*exp(-x.^2./10e+6);
% parameters for the diffusion algorithm
p = 2;
lamda = 10;
%% CONSTRUCT GRAPH FROM DATA
fprintf(sLine,'\n');
fprintf('Building graph from MNIST dataset.\n');
fprintf(cLine,'\n');
tic
% call constructor for kNN graphs
G = kNNGraph(data, neighborhood, distanceFunction, weightFunction);
% compute the adjacency matrix and store it
G = G.computeAdjacencyMatrix;
% symmetrize kNN graph
G = G.symmetrizeGraph;
fprintf('Finished after %f seconds.\n', toc);
%% Get labeled Data
numLabels = 100;
numInClass = [980, 1135, 1032, 1010, 982, 892, 958, 1028, 974, 1009]';
labeledIndices = labelIndices(numLabels, numInClass, lGroups);
%% Classification
fprintf(sLine,'\n');
fprintf('Performing classfication via diffusion.\n');
fprintf(cLine,'\n');
tic
labels = classifier(G, p, lamda, labeledIndices, 1e-8);
t = toc;
fprintf('Finished diffusion after %s seconds.\n', t);
fprintf(cLine,'\n');
%% Test Accuracy
accuracy = test_accuracy(labels, realLabels);
fprintf(sLine,'\n');
fprintf('The classification had an average accuracy of %f .\n', round(100*accuracy,2));