forked from ilarinieminen/SOM-Toolbox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
som_kmeans.m
145 lines (123 loc) · 4.15 KB
/
som_kmeans.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
function [codes,clusters,err] = som_kmeans(method, D, k, epochs, verbose)
% SOM_KMEANS K-means algorithm.
%
% [codes,clusters,err] = som_kmeans(method, D, k, [epochs], [verbose])
%
% Input and output arguments ([]'s are optional):
% method (string) k-means algorithm type: 'batch' or 'seq'
% D (matrix) data matrix
% (struct) data or map struct
% k (scalar) number of centroids
% [epochs] (scalar) number of training epochs
% [verbose] (scalar) if <> 0 display additonal information
%
% codes (matrix) codebook vectors
% clusters (vector) cluster number for each sample
% err (scalar) total quantization error for the data set
%
% See also KMEANS_CLUSTERS, SOM_MAKE, SOM_BATCHTRAIN, SOM_SEQTRAIN.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Function has been renamed by Kimmo Raivio, because matlab65 also have
% kmeans function 1.10.02
%% input arguments
if isstruct(D),
switch D.type,
case 'som_map', data = D.codebook;
case 'som_data', data = D.data;
end
else
data = D;
end
[l dim] = size(data);
if nargin < 4 || isempty(epochs) || isnan(epochs), epochs = 100; end
if nargin < 5, verbose = 0; end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% action
rand('state', sum(100*clock)); % init rand generator
lr = 0.5; % learning rate for sequential k-means
temp = randperm(l);
centroids = data(temp(1:k),:);
res = zeros(k,l);
clusters = zeros(1, l);
if dim==1,
[codes,clusters,err] = scalar_kmeans(data,k,epochs);
return;
end
switch method
case 'seq',
len = epochs * l;
l_rate = linspace(lr,0,len);
order = randperm(l);
for iter = 1:len
x = D(order(rem(iter,l)+1),:);
dx = x(ones(k,1),:) - centroids;
[dist nearest] = min(sum(dx.^2,2));
centroids(nearest,:) = centroids(nearest,:) + l_rate(iter)*dx(nearest,:);
end
[dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ...
ones(l, 1) * sum((centroids.^2)',1) - ...
2.*(data*(centroids')))');
case 'batch',
iter = 0;
old_clusters = zeros(k, 1);
while iter<epochs
[dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ...
ones(l, 1) * sum((centroids.^2)',1) - ...
2.*(data*(centroids')))');
for i = 1:k
f = find(clusters==i);
s = length(f);
if s, centroids(i,:) = sum(data(f,:)) / s; end
end
if iter
if sum(old_clusters==clusters)==0
if verbose, fprintf(1, 'Convergence in %d iterations\n', iter); end
break;
end
end
old_clusters = clusters;
iter = iter + 1;
end
[dummy clusters] = min(((ones(k, 1) * sum((data.^2)', 1))' + ...
ones(l, 1) * sum((centroids.^2)',1) - ...
2.*(data*(centroids')))');
otherwise,
fprintf(2, 'Unknown method\n');
end
err = 0;
for i = 1:k
f = find(clusters==i);
s = length(f);
if s, err = err + sum(sum((data(f,:)-ones(s,1)*centroids(i,:)).^2,2)); end
end
codes = centroids;
return;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [y,bm,qe] = scalar_kmeans(x,k,maxepochs)
nans = ~isfinite(x);
x(nans) = [];
n = length(x);
mi = min(x); ma = max(x)
y = linspace(mi,ma,k)';
bm = ones(n,1);
bmold = zeros(n,1);
i = 0;
while ~all(bm==bmold) && i<maxepochs,
bmold = bm;
[c bm] = histc(x,[-Inf; (y(2:end)+y(1:end-1))/2; Inf]);
y = full(sum(sparse(bm,1:n,x,k,n),2));
zh = (c(1:end-1)==0);
y(~zh) = y(~zh)./c(~zh);
inds = find(zh)';
for j=inds, if j==1, y(j) = mi; else y(j) = y(j-1) + eps; end, end
i=i+1;
end
if i==maxepochs, [c bm] = histc(x,[-Inf; (y(2:end)+y(1:end-1))/2; Inf]); end
if nargout>2, qe = sum(abs(x-y(bm)))/n; end
if any(nans),
notnan = find(~nans); n = length(nans);
y = full(sparse(notnan,1,y ,n,1)); y(nans) = NaN;
bm = full(sparse(notnan,1,bm,n,1)); bm(nans) = NaN;
if nargout>2, qe = full(sparse(notnan,1,qe,n,1)); qe(nans) = NaN; end
end
return;