forked from lazyprogrammer/machine_learning_examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathknn.py
72 lines (60 loc) · 2.15 KB
/
knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
'''
This is an example of a K-Nearest Neighbors classifier on MNIST data.
We try k=1...5 to show how we might choose the best k.
This is not production code!
See the tutorial here:
http://lazyprogrammer.me/post/114114892404/tutorial-k-nearest-neighbor-classifier-for-mnist
'''
import pandas as pd
import numpy as np
from sortedcontainers import SortedDict
Xtest = pd.read_csv("mnist_csv/Xtest.txt", header=None).as_matrix()
Xtrain = pd.read_csv("mnist_csv/Xtrain.txt", header=None).as_matrix()
Ytest = pd.read_csv("mnist_csv/label_test.txt", header=None).as_matrix().flatten()
Ytrain = pd.read_csv("mnist_csv/label_train.txt", header=None).as_matrix().flatten()
class KNN(object):
def __init__(self, k):
self.k = k
def fit(self, X, y):
self.X = X
self.y = y
def predict(self, X):
y = np.zeros(len(X))
for i,x in enumerate(X): # test points
sd = SortedDict() # distance -> class
for j,xt in enumerate(self.X): # training points
d = np.linalg.norm(x - xt)
# print d, sd
if len(sd) < self.k:
sd[d] = self.y[j]
else:
last = sd.viewkeys()[-1]
if d < last:
del sd[last]
sd[d] = self.y[j]
# print "sd:", sd
# vote
votes = {}
# print "viewvalues:", sd.viewvalues()
for v in sd.viewvalues():
# print "v:", v
votes[v] = votes.get(v,0) + 1
# print "votes:", votes, "true:", Ytest[i]
max_votes = 0
max_votes_class = -1
for v,count in votes.iteritems():
if count > max_votes:
max_votes = count
max_votes_class = v
y[i] = max_votes_class
return y
for k in (1,2,3,4,5):
C = np.zeros((10,10), dtype=np.int)
knn = KNN(k)
knn.fit(Xtrain, Ytrain)
Ypred = knn.predict(Xtest)
for p,t in zip(Ypred, Ytest):
C[t,p] += 1
print "Confusion matrix for k = %d:" % k
print C
print "Accuracy:", np.trace(C) / 500.0