diff --git a/tests/test_torchserve.sh b/tests/test_torchserve.sh index 9ee81358..4aff7fc3 100755 --- a/tests/test_torchserve.sh +++ b/tests/test_torchserve.sh @@ -2,11 +2,10 @@ set -e TMPDIR=/tmp -sudo apt-get update && sudo apt-get install -y curl wget +sudo apt-get update && sudo apt-get install -y curl jq wget sudo pip3 install torch-model-archiver +cp torchserve/custom_handler.py $TMPDIR/ cd $TMPDIR -git clone https://github.com/pytorch/serve -b v0.8.2 -cd serve/examples/object_detector/yolo/yolov8 # TODO: use gamutRF weights here. wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt torch-model-archiver --force --model-name yolov8n --version 1.0 --serialized-file yolov8n.pt --handler custom_handler.py @@ -15,4 +14,6 @@ mv yolov8n.mar model_store/ docker run -v $(pwd)/model_store:/model_store --net host -d iqtlabs/gamutrf-torchserve timeout 60s torchserve --start --model-store /model_store --ncs --foreground sleep 5 curl -X POST "localhost:8081/models?model_name=yolov8n&url=yolov8n.mar&initial_workers=4&batch_size=2" -curl http://127.0.0.1:8080/predictions/yolov8n -T persons.jpg +# TODO: use gamutRF test spectogram image +wget https://github.com/pytorch/serve/raw/master/examples/object_detector/yolo/yolov8/persons.jpg +curl http://127.0.0.1:8080/predictions/yolov8n -T persons.jpg | jq diff --git a/torchserve/custom_handler.py b/torchserve/custom_handler.py new file mode 100644 index 00000000..c2d35f85 --- /dev/null +++ b/torchserve/custom_handler.py @@ -0,0 +1,74 @@ +# based on pytorch's yolov8n example. + +from collections import defaultdict +import os + +import torch +from torchvision import transforms +from ultralytics import YOLO + +from ts.torch_handler.object_detector import ObjectDetector + +IMG_SIZE = 640 + +try: + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +except ImportError as error: + XLA_AVAILABLE = False + + +class Yolov8Handler(ObjectDetector): + image_processing = transforms.Compose( + [ + transforms.Resize(IMG_SIZE), + transforms.CenterCrop(IMG_SIZE), + transforms.ToTensor(), + ] + ) + + def initialize(self, context): + if torch.cuda.is_available(): + self.device = torch.device("cuda") + elif XLA_AVAILABLE: + self.device = xm.xla_device() + else: + self.device = torch.device("cpu") + + properties = context.system_properties + self.manifest = context.manifest + model_dir = properties.get("model_dir") + self.model_pt_path = None + if "serializedFile" in self.manifest["model"]: + serialized_file = self.manifest["model"]["serializedFile"] + self.model_pt_path = os.path.join(model_dir, serialized_file) + self.model = self._load_torchscript_model(self.model_pt_path) + self.initialized = True + + def _load_torchscript_model(self, model_pt_path): + """Loads the PyTorch model and returns the NN model object. + + Args: + model_pt_path (str): denotes the path of the model file. + + Returns: + (NN Model Object) : Loads the model object. + """ + # TODO: remove this method if https://github.com/pytorch/text/issues/1793 gets resolved + + model = YOLO(model_pt_path) + model.to(self.device) + return model + + def postprocess(self, res): + output = [] + for data in res: + result_dict = defaultdict(list) + for cls, conf, xywh in zip( + data.boxes.cls.tolist(), data.boxes.conf, data.boxes.xywh + ): + name = data.names[int(cls)] + result_dict[name].append({"conf": conf.item(), "xywh": xywh.tolist()}) + output.append(result_dict) + return output