-
Notifications
You must be signed in to change notification settings - Fork 5
/
create_train_test.py
executable file
·86 lines (74 loc) · 2.29 KB
/
create_train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/python
#----------------------------------------------------------------------------------------
#Original Author: Nicolas Pinto, MIT ([email protected])
#
#Modifier: Abhijit Bendale
# Vision and Security Technology lab
# University of Colorado, Colorado Springs
#
#Date: May 22,2009
#
#Usage: python create_train_test.py <path to caltech 101 directory> <no of training images>
#This file creates train.lst and test.lst in the current directory which are list of training
#and testing files.
#
#This code is heavily influenced by Niclas Pinto's (MIT) code on V1 model
#For original code license information refer http://pinto.scripts.mit.edu/Main/HomePage
#
#---------------------------------------------------------------------------------------
import sys
import os
import random
EXTENSIONS = ['.png', '.jpg','.pgm']
img_path = sys.argv[1]
img_path = os.path.abspath(img_path)
ntrain = int(sys.argv[2])
if not os.path.isdir(img_path):
raise ValueError, "%s is not a directory" % (img_path)
tree = os.walk(img_path)
filelist = []
categories = tree.next()[1]
for root, dirs, files in tree:
if dirs != []:
msgs = ["invalid image tree structure:"]
for d in dirs:
msgs += [" "+"/".join([root, d])]
msg = "\n".join(msgs)
raise Exception, msg
filelist += [ root+'/'+f for f in files if os.path.splitext(f)[-1] in EXTENSIONS ]
filelist.sort()
kwargs = {}
kwargs['filelist'] = filelist
cats = {}
for f in filelist:
cat = "/".join(f.split('/')[:-1])
name = f.split('/')[-1]
if cat not in cats:
cats[cat] = [name]
else:
cats[cat] += [name]
# -- Shuffle the images into a new random order
# -- Organize into training and testing sets
filelists_dict = {}
train = {}
test = {}
for cat in cats:
filelist = cats[cat]
random.seed()
random.shuffle(filelist)
filelist = [ cat + '/' + f for f in filelist ]
filelists_dict[cat] = filelist
train[cat] = filelist[0 : ntrain]
test[cat] = filelist[ntrain+1:]
trfile = open('train.lst','w')
tsfile = open('test.lst','w')
for i in train.keys():
for j in train[i]:
trfile.write(j)
trfile.write("\n")
for k in test[i]:
tsfile.write(k)
tsfile.write("\n")
trfile.close()
tsfile.close()