forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdog_breed_classification_handler.py
33 lines (30 loc) · 1.24 KB
/
dog_breed_classification_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from ts.torch_handler.image_classifier import ImageClassifier
import json
class DogBreedClassifier(ImageClassifier):
def preprocess(self, data):
self.is_dogs = [False] * len(data)
inp_imgs = []
for idx, row in enumerate(data):
cat_dog_response = row.get("cat_dog_classification").decode()
input_data = row.get("pre_processing").decode()
if cat_dog_response == "dog":
self.is_dogs[idx] = True
# Wrap the input data into a format that is expected by the parent
# preprocessing method
inp_imgs.append({"body": input_data})
if len(inp_imgs) > 0:
return ImageClassifier.preprocess(self, inp_imgs)
def inference(self, data, *args, **kwargs):
if data is not None:
return ImageClassifier.inference(self, data, *args, **kwargs)
def postprocess(self, data):
response = ["It's a cat!"] * len(self.is_dogs)
if data is None:
return response
post_resp = ImageClassifier.postprocess(self, data)
idx2 = 0
for idx, is_dog in enumerate(self.is_dogs):
if is_dog:
response[idx] = post_resp[idx2]
idx2+=1
return response