diff --git a/src/model.py b/src/model.py index 33fb7622..0f09c3d3 100644 --- a/src/model.py +++ b/src/model.py @@ -31,7 +31,9 @@ import queue import threading from typing import Dict, List - +import base64 +from PIL import Image +from io import BytesIO import numpy as np import torch import triton_python_backend_utils as pb_utils @@ -397,12 +399,21 @@ async def generate(self, request): ) if multi_modal_data_input_tensor: multi_modal_data = multi_modal_data_input_tensor.as_numpy()[0].decode("utf-8") - # Build TextPrompt format prompt for multi modal models multi_modal_data = json.loads(multi_modal_data) - prompt = { - "prompt": prompt, - "multi_modal_data": multi_modal_data - } + if "image" in multi_modal_data: + image_list = [] + for image_base64_string in multi_modal_data["image"]: + if "base64," in image_base64_string: + image_base64_string = image_base64_string.split("base64,")[-1] + image_data = base64.b64decode(image_base64_string) + image = Image.open(BytesIO(image_data)).convert("RGB") + image_list.append(image) + prompt = { + "prompt": prompt, + "multi_modal_data": { + "image": image_list + } + } stream = pb_utils.get_input_tensor_by_name(request, "stream") if stream: