Skip to content

Commit

Permalink
custom, custom handler that outputs confidence and bbox co-ords.
Browse files Browse the repository at this point in the history
  • Loading branch information
anarkiwi committed Oct 17, 2023
1 parent a2f6021 commit aac8fc9
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/test_torchserve.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

set -e
TMPDIR=/tmp
sudo apt-get update && sudo apt-get install -y curl wget
sudo apt-get update && sudo apt-get install -y curl jq wget
sudo pip3 install torch-model-archiver
cp torchserve/custom_handler.py $TMPDIR/
cd $TMPDIR
git clone https://github.com/pytorch/serve -b v0.8.2
cd serve/examples/object_detector/yolo/yolov8
# TODO: use gamutRF weights here.
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt
torch-model-archiver --force --model-name yolov8n --version 1.0 --serialized-file yolov8n.pt --handler custom_handler.py
Expand All @@ -15,4 +14,6 @@ mv yolov8n.mar model_store/
docker run -v $(pwd)/model_store:/model_store --net host -d iqtlabs/gamutrf-torchserve timeout 60s torchserve --start --model-store /model_store --ncs --foreground
sleep 5
curl -X POST "localhost:8081/models?model_name=yolov8n&url=yolov8n.mar&initial_workers=4&batch_size=2"
curl http://127.0.0.1:8080/predictions/yolov8n -T persons.jpg
# TODO: use gamutRF test spectogram image
wget https://github.com/pytorch/serve/raw/master/examples/object_detector/yolo/yolov8/persons.jpg
curl http://127.0.0.1:8080/predictions/yolov8n -T persons.jpg | jq
74 changes: 74 additions & 0 deletions torchserve/custom_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# based on pytorch's yolov8n example.

from collections import defaultdict
import os

import torch
from torchvision import transforms
from ultralytics import YOLO

from ts.torch_handler.object_detector import ObjectDetector

IMG_SIZE = 640

try:
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
except ImportError as error:
XLA_AVAILABLE = False


class Yolov8Handler(ObjectDetector):
image_processing = transforms.Compose(
[
transforms.Resize(IMG_SIZE),
transforms.CenterCrop(IMG_SIZE),
transforms.ToTensor(),
]
)

def initialize(self, context):
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif XLA_AVAILABLE:
self.device = xm.xla_device()
else:
self.device = torch.device("cpu")

properties = context.system_properties
self.manifest = context.manifest
model_dir = properties.get("model_dir")
self.model_pt_path = None
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
self.model_pt_path = os.path.join(model_dir, serialized_file)
self.model = self._load_torchscript_model(self.model_pt_path)
self.initialized = True

def _load_torchscript_model(self, model_pt_path):
"""Loads the PyTorch model and returns the NN model object.
Args:
model_pt_path (str): denotes the path of the model file.
Returns:
(NN Model Object) : Loads the model object.
"""
# TODO: remove this method if https://github.com/pytorch/text/issues/1793 gets resolved

model = YOLO(model_pt_path)
model.to(self.device)
return model

def postprocess(self, res):
output = []
for data in res:
result_dict = defaultdict(list)
for cls, conf, xywh in zip(
data.boxes.cls.tolist(), data.boxes.conf, data.boxes.xywh
):
name = data.names[int(cls)]
result_dict[name].append({"conf": conf.item(), "xywh": xywh.tolist()})
output.append(result_dict)
return output

0 comments on commit aac8fc9

Please sign in to comment.