-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
133 lines (100 loc) · 4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
## global import ---------------------------------------------------------------
import numpy as np
import time
import os
## local import ----------------------------------------------------------------
from network import NetworkBuilder as NetworkBuilder
from network import save
from network import load
from usps import read as read
# from usps import save as save
import network
import usps
import config
import debug
## location of usps data sets --------------------------------------------------
USPS_TEST_SET = "data_sets/test"
USPS_TRAIN_SET = "data_sets/train"
USPS_TRAIN100_SET = "data_sets/train100"
USPS_TRAIN1000_SET = "data_sets/train1000"
if __name__ == "__main__":
## print arrays in full
np.set_printoptions(threshold='nan')
print "*** Select the training set: ***"
print "1. Train HTM on USPS train100 training set"
print "2. Train HTM on USPS train1000 training set"
print "3. Train HTM on USPS full training set (over 7000 elements)"
print "4. Load HTM from file"
print "5. Quit"
choice = int(raw_input())
if choice == 1 or choice == 2 or choice == 3:
builder = NetworkBuilder(config.usps_net)
htm = builder.build()
htm.start()
t0 = time.time()
print
print "*** Training HTM **"
seq_count = {}
if choice == 1: directory = "train100"
elif choice == 2: directory = "train1000"
else: directory = "train"
sequences = usps.get_training_sequences(directory, uSeqCount=seq_count)
print "Starting training..."
# import profile
# profile.runctx('htm.train(sequences)', globals(), {'htm':htm,
# 'sequences':sequences})
htm.train(sequences)
print "Saving network on file..."
try: os.mkdir("usps/" + directory)
except: pass
save(htm, "usps/" + directory + "/")
print "*** Summary **"
print "Number of training sequences generated:"
print " * Entry layer: ", seq_count[network.ENTRY]
print " * Intermediate layer: ", seq_count[network.INTERMEDIATE]
print " * Output layer: ", seq_count[network.OUTPUT]
print "Training completed in ", time.time() - t0, "seconds"
elif choice == 4:
print "Enter the directory:"
directory = raw_input()
htm = load(directory)
else:
exit(0)
print "*** HTM Testing ***"
print "1. Test HTM on single input"
print "2. Test HTM on USPS train100 training set"
print "3. Test HTM on USPS train1000 training set"
print "4. Test HTM on USPS full test set (over 2000 elements)"
print "5. Quit"
choice = int(raw_input())
t0 = time.time()
print
print "*** Testing HTM ***"
if choice == 1:
print htm.inference(read('data_sets/test/0/1.bmp'))
# import profile
# profile.runctx("htm.inference(read('data_sets/test/0/1.bmp'))", globals(),
# {'htm': htm,
# 'read' : read})
print "Completed in ", time.time() - t0, "seconds"
elif choice == 2 or choice == 3 or choice == 4:
if choice == 2: directory = USPS_TRAIN100_SET
elif choice == 3: directory = USPS_TRAIN1000_SET
elif choice == 4: directory = USPS_TEST_SET
classes = os.listdir(directory)
total = 0
correct = 0
for c in classes:
current_class = int(c)
for i in os.listdir(directory + '/' + c):
total += 1
res = np.array(htm.inference(read(directory + '/' + c + '/' + i)))
res = np.argmax(res)
if res == current_class:
correct += 1
print "Total:", total
print "Correct:", correct
print "Correctness ratio:", correct/float(total)
print "Completed in ", time.time() - t0, "seconds"
else:
exit(0)