Skip to content

Commit

Permalink
Support input for llama3.2 multi-modal model (#69)
Browse files Browse the repository at this point in the history
Co-authored-by: jibxie <[email protected]>
  • Loading branch information
xiejibing and jibxie authored Nov 25, 2024
1 parent b71088a commit 6c066f6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
32 changes: 31 additions & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import queue
import threading
from typing import Dict, List

import base64
from PIL import Image
from io import BytesIO
import numpy as np
import torch
import triton_python_backend_utils as pb_utils
Expand All @@ -40,6 +42,7 @@
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.version import __version__ as _VLLM_VERSION

from utils.metrics import VllmStatLogger

Expand Down Expand Up @@ -71,6 +74,14 @@ def auto_complete_config(auto_complete_model_config):
"optional": True,
},
]
if _VLLM_VERSION >= "0.6.3.post1":
inputs.append({
"name": "image",
"data_type": "TYPE_STRING",
"dims": [-1], # can be multiple images as separate elements
"optional": True,
})

outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}]

# Store the model configuration as a dictionary.
Expand Down Expand Up @@ -385,6 +396,25 @@ async def generate(self, request):
).as_numpy()[0]
if isinstance(prompt, bytes):
prompt = prompt.decode("utf-8")

if _VLLM_VERSION >= "0.6.3.post1":
image_input_tensor = pb_utils.get_input_tensor_by_name(
request, "image"
)
if image_input_tensor:
image_list = []
for image_raw in image_input_tensor.as_numpy():
image_data = base64.b64decode(image_raw.decode("utf-8"))
image = Image.open(BytesIO(image_data)).convert("RGB")
image_list.append(image)
if len(image_list) > 0:
prompt = {
"prompt": prompt,
"multi_modal_data": {
"image": image_list
}
}

stream = pb_utils.get_input_tensor_by_name(request, "stream")
if stream:
stream = stream.as_numpy()[0]
Expand Down
28 changes: 16 additions & 12 deletions src/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
from vllm.engine.metrics import Stats as VllmStats
from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets

from vllm.version import __version__ as _VLLM_VERSION

class TritonMetrics:
def __init__(self, labels: List[str], max_model_len: int):
Expand Down Expand Up @@ -76,11 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int):
description="Number of generation tokens processed.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_best_of_request_family = pb_utils.MetricFamily(
name="vllm:request_params_best_of",
description="Histogram of the best_of request parameter.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
# 'best_of' metric has been hidden since vllm 0.6.3
# https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005
if _VLLM_VERSION < "0.6.3":
self.histogram_best_of_request_family = pb_utils.MetricFamily(
name="vllm:request_params_best_of",
description="Histogram of the best_of request parameter.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_n_request_family = pb_utils.MetricFamily(
name="vllm:request_params_n",
description="Histogram of the n request parameter.",
Expand Down Expand Up @@ -159,10 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int):
buckets=build_1_2_5_buckets(max_model_len),
)
)
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
)
if _VLLM_VERSION < "0.6.3":
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self.histogram_n_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
Expand Down Expand Up @@ -247,10 +251,10 @@ def log(self, stats: VllmStats) -> None:
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests,
),
(self.metrics.histogram_best_of_request, stats.best_of_requests),
(self.metrics.histogram_n_request, stats.n_requests),
]

if _VLLM_VERSION < "0.6.3":
histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests))
for metric, data in counter_metrics:
self._log_counter(metric, data)
for metric, data in histogram_metrics:
Expand Down

0 comments on commit 6c066f6

Please sign in to comment.