Skip to content

Commit

Permalink
Merge branch 'main' of github.com:triton-inference-server/vllm_backen…
Browse files Browse the repository at this point in the history
…d into jacky-vllm-additional-outputs
  • Loading branch information
kthui committed Nov 25, 2024
2 parents dae3c13 + 6c066f6 commit 2b531dd
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 13 deletions.
31 changes: 30 additions & 1 deletion src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,25 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import asyncio
import base64
import gc
import json
import os
import queue
import threading
from io import BytesIO
from typing import Dict, List

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from PIL import Image
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.version import __version__ as _VLLM_VERSION

from utils.metrics import VllmStatLogger

Expand Down Expand Up @@ -67,7 +71,7 @@ def auto_complete_config(cls, auto_complete_model_config):

@staticmethod
def _auto_complete_inputs_and_outputs(auto_complete_model_config):
# Inputs/Outputs expected by the backend.
# Inputs expected by the backend.
inputs = [
{"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
{
Expand Down Expand Up @@ -107,6 +111,16 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config):
"optional": True,
},
]
if _VLLM_VERSION >= "0.6.3.post1":
inputs.append(
{
"name": "image",
"data_type": "TYPE_STRING",
"dims": [-1], # can be multiple images as separate elements
"optional": True,
}
)
# Outputs expected by the backend.
outputs = [
{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]},
{"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]},
Expand Down Expand Up @@ -313,6 +327,21 @@ def _get_input_tensors(self, request):
if isinstance(prompt, bytes):
prompt = prompt.decode("utf-8")

# image
if _VLLM_VERSION >= "0.6.3.post1":
images = pb_utils.get_input_tensor_by_name(request, "image")
if images:
images_vllm = []
for image_np in images.as_numpy():
image_b = base64.b64decode(image_np.decode("utf-8"))
image_rgb = Image.open(BytesIO(image_b)).convert("RGB")
images_vllm.append(image_rgb)
if len(images_vllm) > 0:
prompt = {
"prompt": prompt,
"multi_modal_data": {"image": images_vllm},
}

# stream
stream = pb_utils.get_input_tensor_by_name(request, "stream")
if stream:
Expand Down
28 changes: 16 additions & 12 deletions src/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
from vllm.engine.metrics import Stats as VllmStats
from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets

from vllm.version import __version__ as _VLLM_VERSION

class TritonMetrics:
def __init__(self, labels: List[str], max_model_len: int):
Expand Down Expand Up @@ -76,11 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int):
description="Number of generation tokens processed.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_best_of_request_family = pb_utils.MetricFamily(
name="vllm:request_params_best_of",
description="Histogram of the best_of request parameter.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
# 'best_of' metric has been hidden since vllm 0.6.3
# https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005
if _VLLM_VERSION < "0.6.3":
self.histogram_best_of_request_family = pb_utils.MetricFamily(
name="vllm:request_params_best_of",
description="Histogram of the best_of request parameter.",
kind=pb_utils.MetricFamily.HISTOGRAM,
)
self.histogram_n_request_family = pb_utils.MetricFamily(
name="vllm:request_params_n",
description="Histogram of the n request parameter.",
Expand Down Expand Up @@ -159,10 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int):
buckets=build_1_2_5_buckets(max_model_len),
)
)
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
)
if _VLLM_VERSION < "0.6.3":
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self.histogram_n_request_family.Metric(
labels=labels,
buckets=[1, 2, 5, 10, 20],
Expand Down Expand Up @@ -247,10 +251,10 @@ def log(self, stats: VllmStats) -> None:
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests,
),
(self.metrics.histogram_best_of_request, stats.best_of_requests),
(self.metrics.histogram_n_request, stats.n_requests),
]

if _VLLM_VERSION < "0.6.3":
histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests))
for metric, data in counter_metrics:
self._log_counter(metric, data)
for metric, data in histogram_metrics:
Expand Down

0 comments on commit 2b531dd

Please sign in to comment.