Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support input for llama3.2 multi-modal model #69

Merged
merged 6 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import queue
import threading
from typing import Dict, List

import base64
from PIL import Image
from io import BytesIO
import numpy as np
import torch
import triton_python_backend_utils as pb_utils
Expand All @@ -40,6 +42,7 @@
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.version import __version__ as _VLLM_VERSION

from utils.metrics import VllmStatLogger

Expand Down Expand Up @@ -71,6 +74,14 @@ def auto_complete_config(auto_complete_model_config):
"optional": True,
},
]
if _VLLM_VERSION >= "0.6.3.post1":
inputs.append({
"name": "image",
"data_type": "TYPE_STRING",
"dims": [-1], # can be multiple images as separate elements
"optional": True,
})

outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}]

# Store the model configuration as a dictionary.
Expand Down Expand Up @@ -385,6 +396,25 @@ async def generate(self, request):
).as_numpy()[0]
if isinstance(prompt, bytes):
prompt = prompt.decode("utf-8")

if _VLLM_VERSION >= "0.6.3.post1":
image_input_tensor = pb_utils.get_input_tensor_by_name(
request, "image"
)
if image_input_tensor:
image_list = []
for image_raw in image_input_tensor.as_numpy():
image_data = base64.b64decode(image_raw.decode("utf-8"))
image = Image.open(BytesIO(image_data)).convert("RGB")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: May need to expose image formats other than RGB in the future, but seems like a sensible default / first support for now. We can probably defer exposing it until we have a use case requiring other formats.

ex: https://github.com/vllm-project/vllm/blob/1dd4cb2935fc3fff9c156b5772d18e0a0d1861f0/vllm/multimodal/utils.py#L33

image_list.append(image)
if len(image_list) > 0:
prompt = {
"prompt": prompt,
"multi_modal_data": {
"image": image_list
}
}

stream = pb_utils.get_input_tensor_by_name(request, "stream")
if stream:
stream = stream.as_numpy()[0]
Expand Down
28 changes: 16 additions & 12 deletions src/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
from vllm.engine.metrics import Stats as VllmStats
from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets

from vllm.version import __version__ as _VLLM_VERSION

class TritonMetrics:
def __init__(self, labels: List[str], max_model_len: int):
Expand Down Expand Up @@ -76,11 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int):
description="Number of generation tokens processed.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_best_of_request_family = pb_utils.MetricFamily(
name="vllm:request_params_best_of",
description="Histogram of the best_of request parameter.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
# 'best_of' metric has been hidden since vllm 0.6.3
# https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005
if _VLLM_VERSION < "0.6.3":
self.histogram_best_of_request_family = pb_utils.MetricFamily(
name="vllm:request_params_best_of",
description="Histogram of the best_of request parameter.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_n_request_family = pb_utils.MetricFamily(
name="vllm:request_params_n",
description="Histogram of the n request parameter.",
Expand Down Expand Up @@ -159,10 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int):
buckets=build_1_2_5_buckets(max_model_len),
)
)
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
)
if _VLLM_VERSION < "0.6.3":
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self.histogram_n_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
Expand Down Expand Up @@ -247,10 +251,10 @@ def log(self, stats: VllmStats) -> None:
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests,
),
(self.metrics.histogram_best_of_request, stats.best_of_requests),
(self.metrics.histogram_n_request, stats.n_requests),
]

if _VLLM_VERSION < "0.6.3":
histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests))
for metric, data in counter_metrics:
self._log_counter(metric, data)
for metric, data in histogram_metrics:
Expand Down
Loading