-
Notifications
You must be signed in to change notification settings - Fork 5
/
classify.py
executable file
·58 lines (51 loc) · 1.84 KB
/
classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/python
#----------------------------------------------------------------------------------------
#This code follows GPL liecense
#
#Author: Abhijit Bendale
# Vision and Security Technology lab
# University of Colorado, Colorado Springs
#
#Date: May 22,2009
#
#Usage: python classify.py <list of testing image> <list of training images> <sift binary Path>
#This file creates test.clslbl file with following entries:
#<filename> <category label from Caltech 101> <category label assigned by SVM-KNN>
#---------------------------------------------------------------------------------------
import sys
import os
from utility_functions import *
from PyML import *
from PyML.classifiers import multi
testLst = sys.argv[1]
trainLst = sys.argv[2]
siftPath = sys.argv[3]
resultFile = open("test.clslbl", "w")
testFile = open(testLst,"r")
for query in testFile:
query = query[:-1]
print query
KNN, queryDescriptor = getNeighbours(query , trainLst, siftPath)
sameClass, sameCat = checkSameClass(KNN)
if(sameClass):
resultFile.write(str(query + " " + sameCat + "\n"))
else:
#training of SVM
trainData, trainLabels = getTrainingData(KNN)
multiSvmTrainData = VectorDataSet(trainData, L = trainLabels)
mclass = multi.OneAgainstRest (svm.SVM())
mclass.train(multiSvmTrainData)
#testing of SVM
testLabels = []
lbl = query.split("/")[-2]
for i in range(len(queryDescriptor)):
testLabels += [lbl]
testData = VectorDataSet(queryDescriptor, L = testLabels)
result = mclass.test(testData)
predictedLabels = result.getPredictedLabels()
#check this
svmCat = max_occuring_cat(predictedLabels)
resultFile.write(str(query + " " + svmCat + "\n"))
testFile.close()
resultFile.close()