-
Notifications
You must be signed in to change notification settings - Fork 0
/
service.py
82 lines (71 loc) · 2.95 KB
/
service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from flask import request, jsonify
from config.config import MODEL_STORAGE_PATH, IMAGE_BATCH_PARAM_NAME, SINGLE_IMAGE_PARAM_NAME
from model.model import ImageClassifier
from service.app_factory import create_app
from service.exceptions.generic_error import GenericError
from service.image_predictor import ImagePredictor
from service.image_trainer import ImageTrainer
app = create_app()
@app.route('/')
def hello():
return 'Welcome to the hand written digits predictor service.'
@app.errorhandler(GenericError)
def invalid_request(error):
"""
Converts an error to the error handler class to be rendered.
:param error: Error description.
:return: Json with response.
"""
response = jsonify(error.to_dict())
response.status_code = error.status_code
return response
# TODO: Use API routes and versions with reverse server.
@app.route('/digits/train', methods=['GET', 'POST'])
def train_digits():
"""
Service that calls the digits training method using the images path.
:return: Response: Either valid model (True) or invalid (False) as json.
"""
try:
if request.json:
request_data = request.json
image_path = request_data.get(IMAGE_BATCH_PARAM_NAME)
classifier_model = ImageClassifier(training_path=image_path)
image_trainer = ImageTrainer(classifier_model)
training_response = image_trainer.train_digits()
if training_response:
return jsonify(training_response), 201
else:
return invalid_request(message="Wrong training set.",
status_code=400)
else:
return invalid_request(message="Source with wrong format.",
status_code=400)
except ValueError:
return invalid_request(message="Unexpected error during request.",
status_code=400)
@app.route('/digits/classify', methods=['GET', 'POST'])
def process_digit():
"""
Service that calls the digits prediction method for a single image.
:return: Predicted digit response.
"""
try:
if request.json:
request_data = request.json
image_path = request_data.get(SINGLE_IMAGE_PARAM_NAME)
image_predictor = ImagePredictor(MODEL_STORAGE_PATH)
predicted_digit = image_predictor.classify_digit(image_path)
if predicted_digit:
return jsonify(predicted_digit), 201
else:
return invalid_request(message="Wrong Image or not able to predict.",
status_code=400)
else:
return invalid_request(message="Source with wrong format",
status_code=400)
except ValueError:
return invalid_request(message="Unexpected error during request.",
status_code=400)
if __name__ == "__main__":
app.run(host="0.0.0.0")