-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathserver.py
126 lines (98 loc) · 3.89 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import SocketServer, time, socket, re, os
import matplotlib.pyplot as plt
import numpy as np
import caffe
# Set your GPU mode
# If you want to use GPU for parellel calculation, set use_gpu as True
use_gpu = False
# Set IP Address and Port
ip_addr = '127.0.0.1'
port = 10000
caffe_root = '/Users/EdwardDing/caffe/'
MODEL_FILE = caffe_root + 'examples/imagenet/imagenet_deploy.prototxt'
PRETRAINED = caffe_root + 'examples/imagenet/caffe_reference_imagenet_model'
LABEL_FILE = caffe_root + 'data/ilsvrc12/synset_words.txt'
PATTERN = '<id>(.*?)</id>' #regular expression, used to get id
# Load all labels, return a list of labels
def loadLabel():
labels=[]
for line in open(LABEL_FILE):
l = line.split(',')
l[0] = ' '.join( l[0].split()[1:] )
labels.append(l)
return labels
# CNN_Classify, return a list of possible result labels
def CNN_Classify(imageFile, should_oversample):
labels = loadLabel()
# Initialize a CNN
net = caffe.Classifier(MODEL_FILE, PRETRAINED,
mean = np.load(caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy'),
channel_swap=(2,1,0),
raw_scale=255)
net.set_phase_test()
# Set GPU Mode
if use_gpu:
net.set_mode_gpu()
else:
net.set_mode_cpu()
input_image = caffe.io.load_image(imageFile) # defined in io.py
# Set oversample Mode
if should_oversample:
prediction = net.predict([input_image])
else:
prediction = net.predict([input_image], oversample = False)
# Print label in the terminal
label = labels[prediction[0].argmax(axis=0)]
print label
return label
class MyServer(SocketServer.BaseRequestHandler):
def handle(self):
print 'Connected from', self.client_address
should_oversample = True;
while True:
# Receive data from the client
receivedData = self.request.recv(1024)
if not receivedData:
continue
# Set oversample according to the picture uploaded by the client
# if the picture has been cropped to interested spot, set oversample as False
# else set it as True
elif receivedData.startswith('OVERSAMPLE'):
if receivedData.endswith('TRUE'):
should_oversample = True;
else:
should_oversample = False;
# Things to do with receiving pic from the client
elif receivedData.startswith('<id>'):
usrID = re.match(PATTERN, receivedData)
fileName = usrID.group(1) + '.png'
f = open(fileName, 'wb')
count = 0;
while True:
data = self.request.recv(8192)
print 'package: ', count
count = count + 1
if data.find('<END OF FILE>') >= 0:
data = data[:-13]
f.write(data)
print 'finished'
break
f.write(data)
f.flush()
f.close()
# Get the result from CNN
result = CNN_Classify(fileName, should_oversample)
# Send the best reuslt (only one) to the client
self.request.sendall(result[0])
# Delete the temp pic been uploaded to the server
os.remove(fileName)
elif receivedData.startswith('bye'):
break
self.request.close()
print 'Disconnected from', self.client_address
print
if __name__ == '__main__':
print 'Server is started\nwaiting for connection...\n'
addr = (ip_addr,port)
srv = SocketServer.ThreadingTCPServer(addr, MyServer)
srv.serve_forever()