Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support input for llama3.2 multi-modal model #69

Merged
merged 6 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import queue
import threading
from typing import Dict, List

import base64
from PIL import Image
from io import BytesIO
import numpy as np
import torch
import triton_python_backend_utils as pb_utils
Expand All @@ -52,6 +54,12 @@ class TritonPythonModel:
def auto_complete_config(auto_complete_model_config):
inputs = [
{"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
{
"name": "multi_modal_data",
"data_type": "TYPE_STRING",
"dims": [1],
"optional": True,
},
Copy link
Contributor

@rmccorm4 rmccorm4 Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GuanLuo @krishung5 @kthui any concerns with passing a serialized JSON input vs. individual input tensors for "image", "audio", etc?

Looks like this is currently mimicing the style of inputs vllm itself expects, so it would be pretty intuitive to vllm users:

Current serialized JSON form:

            {
                "name": "multi_modal_data",
                "data_type": "TYPE_STRING",
                "dims": [1], # 1 "element" to Triton, arbitrary structure/size inside the JSON, validated by backend
                "optional": True,
            },

Example tensor form:

            {
                "name": "image",
                "data_type": "TYPE_STRING",
                "dims": [-1], # can be multiple images as separate elements
                "optional": True,
            },
            {
                "name": "audio",
                "data_type": "TYPE_STRING",
                "dims": [-1], # can be multiple audios as separate elements
                "optional": True,
            },

Copy link
Contributor

@kthui kthui Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the individual input tensors is cleaner in terms of what inputs are expected, and less prone to user error as it does not involve the additional JSON layer.

Given that we need to teardown the JSON and convert each Base64 into bytes, there are actually some work on the backend to verify the JSON is well-formed for the conversion to happen. I think it is easier to supply the image/audio as individual tensors knowing they are already well-formed, and then convert each Base64 into bytes and format them correctly for vLLM.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No actual concerns off the top of my head. Agree with Jacky that the tensor form looks cleaner and could simplify some checks. I think aligning the format with vLLM could slightly improve usability for vLLM backend users in my opinion. However, since the required input changes seem minimal, the impact on vLLM users should be limited.

{
"name": "stream",
"data_type": "TYPE_BOOL",
Expand Down Expand Up @@ -385,6 +393,28 @@ async def generate(self, request):
).as_numpy()[0]
if isinstance(prompt, bytes):
prompt = prompt.decode("utf-8")

multi_modal_data_input_tensor = pb_utils.get_input_tensor_by_name(
request, "multi_modal_data"
)
if multi_modal_data_input_tensor:
multi_modal_data = multi_modal_data_input_tensor.as_numpy()[0].decode("utf-8")
multi_modal_data = json.loads(multi_modal_data)
if "image" in multi_modal_data:
image_list = []
for image_base64_string in multi_modal_data["image"]:
if "base64," in image_base64_string:
image_base64_string = image_base64_string.split("base64,")[-1]
image_data = base64.b64decode(image_base64_string)
image = Image.open(BytesIO(image_data)).convert("RGB")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: May need to expose image formats other than RGB in the future, but seems like a sensible default / first support for now. We can probably defer exposing it until we have a use case requiring other formats.

ex: https://github.com/vllm-project/vllm/blob/1dd4cb2935fc3fff9c156b5772d18e0a0d1861f0/vllm/multimodal/utils.py#L33

image_list.append(image)
prompt = {
"prompt": prompt,
"multi_modal_data": {
"image": image_list
}
}

stream = pb_utils.get_input_tensor_by_name(request, "stream")
if stream:
stream = stream.as_numpy()[0]
Expand Down
10 changes: 0 additions & 10 deletions src/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ def __init__(self, labels: List[str], max_model_len: int):
description="Number of generation tokens processed.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_best_of_request_family = pb_utils.MetricFamily(
name="vllm:request_params_best_of",
description="Histogram of the best_of request parameter.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_n_request_family = pb_utils.MetricFamily(
name="vllm:request_params_n",
description="Histogram of the n request parameter.",
Expand Down Expand Up @@ -159,10 +154,6 @@ def __init__(self, labels: List[str], max_model_len: int):
buckets=build_1_2_5_buckets(max_model_len),
)
)
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self.histogram_n_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
Expand Down Expand Up @@ -247,7 +238,6 @@ def log(self, stats: VllmStats) -> None:
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests,
),
(self.metrics.histogram_best_of_request, stats.best_of_requests),
(self.metrics.histogram_n_request, stats.n_requests),
]

Expand Down
Loading