From ceb59619a44eb09ca58373decacd795974197aae Mon Sep 17 00:00:00 2001 From: Jacky <18255193+kthui@users.noreply.github.com> Date: Mon, 25 Nov 2024 19:41:11 -0800 Subject: [PATCH] feat: Support sending additional outputs from vLLM inference (#70) --- README.md | 5 + .../additional_outputs_test.py | 189 +++++++++ ci/L0_additional_outputs_vllm/test.sh | 66 ++++ .../multi_lora/download.py | 0 .../multi_lora/multi_lora_test.py | 0 .../multi_lora/test.sh | 0 .../test.sh | 0 .../vllm_backend/test.sh | 0 .../vllm_backend/vllm_multi_gpu_test.py | 0 ci/common/util.sh | 4 +- docs/additional_outputs.md | 107 ++++++ src/model.py | 361 +++++++++++------- src/utils/metrics.py | 13 +- 13 files changed, 610 insertions(+), 135 deletions(-) create mode 100644 ci/L0_additional_outputs_vllm/additional_outputs_test.py create mode 100755 ci/L0_additional_outputs_vllm/test.sh rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/multi_lora/download.py (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/multi_lora/multi_lora_test.py (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/multi_lora/test.sh (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/test.sh (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/vllm_backend/test.sh (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/vllm_backend/vllm_multi_gpu_test.py (100%) create mode 100644 docs/additional_outputs.md diff --git a/README.md b/README.md index 8a993d99..a157ed61 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,11 @@ you need to specify a different `shm-region-prefix-name` for each server. See [here](https://github.com/triton-inference-server/python_backend#running-multiple-instances-of-triton-server) for more information. +## Additional vLLM outputs + +Additional vLLM outputs may be requested optionally on a per-request basis. See +[this docs](docs/additional_outputs.md) for more information. + ## Triton Metrics Starting with the 24.08 release of Triton, users can now obtain specific vLLM metrics by querying the Triton metrics endpoint (see complete vLLM metrics diff --git a/ci/L0_additional_outputs_vllm/additional_outputs_test.py b/ci/L0_additional_outputs_vllm/additional_outputs_test.py new file mode 100644 index 00000000..5a8eefbd --- /dev/null +++ b/ci/L0_additional_outputs_vllm/additional_outputs_test.py @@ -0,0 +1,189 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json + +import numpy as np +import pytest +import tritonclient.grpc as grpcclient + + +class TestAdditionalOutputs: + _grpc_url = "localhost:8001" + _model_name = "vllm_opt" + _sampling_parameters = {"temperature": "0", "top_p": "1"} + _prompt = "In this example," + + def _get_inputs( + self, + prompt, + stream=True, + sampling_parameters=None, + return_finish_reason=None, + return_cumulative_logprob=None, + return_num_output_tokens=None, + ): + inputs = [] + + inputs.append(grpcclient.InferInput("text_input", [1], "BYTES")) + inputs[-1].set_data_from_numpy( + np.array([prompt.encode("utf-8")], dtype=np.object_) + ) + + inputs.append(grpcclient.InferInput("stream", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([stream], dtype=bool)) + + if sampling_parameters is not None: + inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES")) + inputs[-1].set_data_from_numpy( + np.array( + [json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_ + ) + ) + + if return_finish_reason is not None: + inputs.append(grpcclient.InferInput("return_finish_reason", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([return_finish_reason], dtype=bool)) + + if return_cumulative_logprob is not None: + inputs.append( + grpcclient.InferInput("return_cumulative_logprob", [1], "BOOL") + ) + inputs[-1].set_data_from_numpy( + np.array([return_cumulative_logprob], dtype=bool) + ) + + if return_num_output_tokens is not None: + inputs.append( + grpcclient.InferInput("return_num_output_tokens", [1], "BOOL") + ) + inputs[-1].set_data_from_numpy( + np.array([return_num_output_tokens], dtype=bool) + ) + + return inputs + + def _callback(self, result, error): + self._responses.append({"result": result, "error": error}) + + def _llm_infer(self, inputs): + self._responses = [] + with grpcclient.InferenceServerClient(self._grpc_url) as client: + client.start_stream(self._callback) + client.async_stream_infer( + self._model_name, inputs=inputs, parameters=self._sampling_parameters + ) + client.stop_stream() + assert len(self._responses) > 0 + + def _assert_text_output_valid(self): + text_output = "" + for response in self._responses: + result, error = response["result"], response["error"] + assert error is None + text_output += result.as_numpy(name="text_output")[0].decode("utf-8") + assert len(text_output) > 0, "output is empty" + assert text_output.count(" ") > 4, "output is not a sentence" + + def _assert_finish_reason(self, return_finish_reason): + for i in range(len(self._responses)): + result, error = self._responses[i]["result"], self._responses[i]["error"] + assert error is None + finish_reason_np = result.as_numpy(name="finish_reason") + if return_finish_reason is None or return_finish_reason == False: + assert finish_reason_np is None + continue + finish_reason = finish_reason_np[0].decode("utf-8") + if i < len(self._responses) - 1: + assert finish_reason == "None" + else: + assert finish_reason == "length" + + def _assert_cumulative_logprob(self, return_cumulative_logprob): + prev_cumulative_logprob = 0.0 + for response in self._responses: + result, error = response["result"], response["error"] + assert error is None + cumulative_logprob_np = result.as_numpy(name="cumulative_logprob") + if return_cumulative_logprob is None or return_cumulative_logprob == False: + assert cumulative_logprob_np is None + continue + cumulative_logprob = cumulative_logprob_np[0].astype(float) + assert cumulative_logprob != prev_cumulative_logprob + prev_cumulative_logprob = cumulative_logprob + + def _assert_num_output_tokens(self, return_num_output_tokens): + for response in self._responses: + result, error = response["result"], response["error"] + assert error is None + num_output_tokens_np = result.as_numpy(name="num_output_tokens") + if return_num_output_tokens is None or return_num_output_tokens == False: + assert num_output_tokens_np is None + continue + num_output_tokens = num_output_tokens_np[0].astype(int) + # TODO: vLLM may return token ids identical to the previous one when + # streaming, for example: + # + # prev: None + # curr: text=' the', token_ids=array('l', [5]) + # + # prev: text=' the', token_ids=array('l', [5, 1385]) + # curr: text=' the term', token_ids=array('l', [5, 1385]) + # + # prev: text=' the term', token_ids=array('l', [5, 1385, 44]) + # curr: text=' the term', token_ids=array('l', [5, 1385, 44]) + # + # prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48]) + # curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48]) + # + # If this is no longer the case in a future release, change the assert + # to assert num_output_tokens > 0. + assert num_output_tokens >= 0 + + @pytest.mark.parametrize("stream", [True, False]) + @pytest.mark.parametrize("return_finish_reason", [None, True, False]) + @pytest.mark.parametrize("return_cumulative_logprob", [None, True, False]) + @pytest.mark.parametrize("return_num_output_tokens", [None, True, False]) + def test_additional_outputs( + self, + stream, + return_finish_reason, + return_cumulative_logprob, + return_num_output_tokens, + ): + inputs = self._get_inputs( + self._prompt, + stream=stream, + sampling_parameters=self._sampling_parameters, + return_finish_reason=return_finish_reason, + return_cumulative_logprob=return_cumulative_logprob, + return_num_output_tokens=return_num_output_tokens, + ) + self._llm_infer(inputs) + self._assert_text_output_valid() + self._assert_finish_reason(return_finish_reason) + self._assert_cumulative_logprob(return_cumulative_logprob) + self._assert_num_output_tokens(return_num_output_tokens) diff --git a/ci/L0_additional_outputs_vllm/test.sh b/ci/L0_additional_outputs_vllm/test.sh new file mode 100755 index 00000000..880f918f --- /dev/null +++ b/ci/L0_additional_outputs_vllm/test.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +export CUDA_VISIBLE_DEVICES=0 +source ../common/util.sh + +pip3 install pytest==8.1.1 +pip3 install tritonclient[grpc] + +# Prepare Model +rm -rf models vllm_baseline_output.pkl && mkdir -p models +SAMPLE_MODELS_REPO="../../samples/model_repository" +cp -r $SAMPLE_MODELS_REPO/vllm_model models/vllm_opt +sed -i 's/"gpu_memory_utilization": 0.5/"gpu_memory_utilization": 0.3/' models/vllm_opt/1/model.json + +RET=0 + +# Test +SERVER_LOG="vllm_opt.server.log" +SERVER_ARGS="--model-repository=models" +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi +set +e +python3 -m pytest --junitxml=test_additional_outputs.xml -s -v additional_outputs_test.py +if [ $? -ne 0 ]; then + echo -e "\n***\n*** additional_outputs_test FAILED. \n***" + RET=1 +fi +set -e +kill $SERVER_PID +wait $SERVER_PID + +if [ $RET -eq 0 ]; then + echo -e "\n***\n*** Test Passed\n***" +else + echo -e "\n***\n*** Test FAILED\n***" +fi +exit $RET diff --git a/ci/L0_multi_gpu/multi_lora/download.py b/ci/L0_multi_gpu_vllm/multi_lora/download.py similarity index 100% rename from ci/L0_multi_gpu/multi_lora/download.py rename to ci/L0_multi_gpu_vllm/multi_lora/download.py diff --git a/ci/L0_multi_gpu/multi_lora/multi_lora_test.py b/ci/L0_multi_gpu_vllm/multi_lora/multi_lora_test.py similarity index 100% rename from ci/L0_multi_gpu/multi_lora/multi_lora_test.py rename to ci/L0_multi_gpu_vllm/multi_lora/multi_lora_test.py diff --git a/ci/L0_multi_gpu/multi_lora/test.sh b/ci/L0_multi_gpu_vllm/multi_lora/test.sh similarity index 100% rename from ci/L0_multi_gpu/multi_lora/test.sh rename to ci/L0_multi_gpu_vllm/multi_lora/test.sh diff --git a/ci/L0_multi_gpu/test.sh b/ci/L0_multi_gpu_vllm/test.sh similarity index 100% rename from ci/L0_multi_gpu/test.sh rename to ci/L0_multi_gpu_vllm/test.sh diff --git a/ci/L0_multi_gpu/vllm_backend/test.sh b/ci/L0_multi_gpu_vllm/vllm_backend/test.sh similarity index 100% rename from ci/L0_multi_gpu/vllm_backend/test.sh rename to ci/L0_multi_gpu_vllm/vllm_backend/test.sh diff --git a/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py b/ci/L0_multi_gpu_vllm/vllm_backend/vllm_multi_gpu_test.py similarity index 100% rename from ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py rename to ci/L0_multi_gpu_vllm/vllm_backend/vllm_multi_gpu_test.py diff --git a/ci/common/util.sh b/ci/common/util.sh index 8baf4f92..0b2022ce 100755 --- a/ci/common/util.sh +++ b/ci/common/util.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -25,7 +25,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +SERVER=${SERVER:=/opt/tritonserver/bin/tritonserver} SERVER_IPADDR=${TRITONSERVER_IPADDR:=localhost} SERVER_LOG=${SERVER_LOG:=./server.log} SERVER_TIMEOUT=${SERVER_TIMEOUT:=120} diff --git a/docs/additional_outputs.md b/docs/additional_outputs.md new file mode 100644 index 00000000..5c103e89 --- /dev/null +++ b/docs/additional_outputs.md @@ -0,0 +1,107 @@ + + +# Additional Outputs from vLLM + +The vLLM backend supports sending additional outputs from vLLM on top of the +usual `text_output` when requested. + +All additional outputs are disabled by default and they need to be enabled on a +per-request basis. If enabled, the corresponding output tensor will be set for +all responses from the request. + +## Supported Additional Outputs + +### Finish Reason + +The reason why the sequence is finished. See +[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L26) +for more details. + +To enable, set `return_finish_reason` input tensor to `True`. The reason will be +sent as a string on the `finish_reason` output tensor. + +Supported since r24.12. + +### Cumulative Log Probabilities + +The cumulative log probability of the generated output text. See +[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L22) +for more details. + +To enable, set `return_cumulative_logprob` input tensor to `True`. The floating +point value will be sent on the `cumulative_logprob` output tensor. + +Supported since r24.12. + +### Number of Output Tokens + +The number of token IDs of the generated output text sent on this response. It +is the difference in length of the token IDs generated from the last response to +this response. If this is the first response, the last response length is +presumed to be zero. See +[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L21) +for more details on the token IDs of the generated output text. + +To enable, set `return_num_output_tokens` input tensor to `True`. The unsigned +integer value will be sent on the `num_output_tokens` output tensor. + +Supported since r24.12. + +## Examples + +### Add Finish Reason to Outputs + +```python +import numpy as np +import tritonclient.grpc as grpcclient + +inputs = [] + +inputs.append(grpcclient.InferInput("text_input", [1], "BYTES")) +inputs[-1].set_data_from_numpy( + np.array(["example prompt".encode("utf-8")], dtype=np.object_) +) + +inputs.append(grpcclient.InferInput("return_finish_reason", [1], "BOOL")) +inputs[-1].set_data_from_numpy(np.array([True], dtype=bool)) + +def callback(result, error): + ... + print(result.as_numpy(name="finish_reason")) + +with grpcclient.InferenceServerClient("localhost:8001") as client: + client.start_stream(callback) + client.async_stream_infer("vLLM_model_name", inputs=inputs, ...) + client.stop_stream() +``` + +## Notes + +* Enabling additional outputs may impact performance, only add additional +outputs when necessary. diff --git a/src/model.py b/src/model.py index 0fdbe0ce..d7b550c6 100644 --- a/src/model.py +++ b/src/model.py @@ -25,18 +25,19 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import asyncio +import base64 import gc import json import os import queue import threading -from typing import Dict, List -import base64 -from PIL import Image from io import BytesIO +from typing import Dict, List + import numpy as np import torch import triton_python_backend_utils as pb_utils +from PIL import Image from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.lora.request import LoRARequest @@ -51,8 +52,26 @@ class TritonPythonModel: + @classmethod + def auto_complete_config(cls, auto_complete_model_config): + # Add inputs/outputs to the model config. + cls._auto_complete_inputs_and_outputs(auto_complete_model_config) + + # We need to use decoupled transaction policy for saturating + # vLLM engine for max throughtput. + # TODO [DLIS:5233]: Allow asynchronous execution to lift this + # restriction for cases there is exactly a single response to + # a single request. + auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True)) + + # Disabling batching in Triton, let vLLM handle the batching on its own. + auto_complete_model_config.set_max_batch_size(0) + + return auto_complete_model_config + @staticmethod - def auto_complete_config(auto_complete_model_config): + def _auto_complete_inputs_and_outputs(auto_complete_model_config): + # Inputs expected by the backend. inputs = [ {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, { @@ -73,18 +92,43 @@ def auto_complete_config(auto_complete_model_config): "dims": [1], "optional": True, }, + { + "name": "return_finish_reason", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "return_cumulative_logprob", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "return_num_output_tokens", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, ] if _VLLM_VERSION >= "0.6.3.post1": - inputs.append({ - "name": "image", - "data_type": "TYPE_STRING", - "dims": [-1], # can be multiple images as separate elements - "optional": True, - }) - - outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}] + inputs.append( + { + "name": "image", + "data_type": "TYPE_STRING", + "dims": [-1], # can be multiple images as separate elements + "optional": True, + } + ) + # Outputs expected by the backend. + outputs = [ + {"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}, + {"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]}, + {"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]}, + {"name": "num_output_tokens", "data_type": "TYPE_UINT32", "dims": [-1]}, + ] - # Store the model configuration as a dictionary. + # Collect input and output names from the provided model config. config = auto_complete_model_config.as_dict() input_names = [] output_names = [] @@ -93,7 +137,7 @@ def auto_complete_config(auto_complete_model_config): for output in config["output"]: output_names.append(output["name"]) - # Add only missing inputs and output to the model configuration. + # Add missing inputs and outputs to the model config. for input in inputs: if input["name"] not in input_names: auto_complete_model_config.add_input(input) @@ -101,18 +145,6 @@ def auto_complete_config(auto_complete_model_config): if output["name"] not in output_names: auto_complete_model_config.add_output(output) - # We need to use decoupled transaction policy for saturating - # vLLM engine for max throughtput. - # TODO [DLIS:5233]: Allow asynchronous execution to lift this - # restriction for cases there is exactly a single response to - # a single request. - auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True)) - - # Disabling batching in Triton, let vLLM handle the batching on its own. - auto_complete_model_config.set_max_batch_size(0) - - return auto_complete_model_config - def initialize(self, args): self.args = args self.logger = pb_utils.Logger @@ -289,6 +321,78 @@ async def await_shutdown(self): self.logger.log_info("[vllm] Shutdown complete") + def _get_input_tensors(self, request): + # prompt + prompt = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy()[0] + if isinstance(prompt, bytes): + prompt = prompt.decode("utf-8") + + # image + if _VLLM_VERSION >= "0.6.3.post1": + images = pb_utils.get_input_tensor_by_name(request, "image") + if images: + images_vllm = [] + for image_np in images.as_numpy(): + image_b = base64.b64decode(image_np.decode("utf-8")) + image_rgb = Image.open(BytesIO(image_b)).convert("RGB") + images_vllm.append(image_rgb) + if len(images_vllm) > 0: + prompt = { + "prompt": prompt, + "multi_modal_data": {"image": images_vllm}, + } + + # stream + stream = pb_utils.get_input_tensor_by_name(request, "stream") + if stream: + stream = stream.as_numpy()[0] + else: + stream = False + + # prepend_input / exclude_input_in_output + prepend_input = pb_utils.get_input_tensor_by_name( + request, "exclude_input_in_output" + ) + if prepend_input: + # When `exclude_input_in_output` is False, we want to prepend input prompt + # to output, thus prepend_input should be True, and vice versa. + prepend_input = not prepend_input.as_numpy()[0] + elif prepend_input is None and stream: + prepend_input = False + else: + prepend_input = True + if prepend_input and stream: + raise ValueError( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) + + # parameters / sampling_parameters + # An alternative mechanism to receive serialized parameters as an input tensor, + # because request parameters are not yet supported via BLS. + sampling_parameters = pb_utils.get_input_tensor_by_name( + request, "sampling_parameters" + ) + if sampling_parameters: + parameters = sampling_parameters.as_numpy()[0].decode("utf-8") + else: + parameters = request.parameters() + + # return_finish_reason, return_cumulative_logprob, return_num_output_tokens + additional_outputs = { + "return_finish_reason": None, + "return_cumulative_logprob": None, + "return_num_output_tokens": None, + } + for tensor_name in additional_outputs.keys(): + tensor = pb_utils.get_input_tensor_by_name(request, tensor_name) + if tensor: + tensor = bool(tensor.as_numpy()[0]) + else: + tensor = False + additional_outputs[tensor_name] = tensor + + return prompt, stream, prepend_input, parameters, additional_outputs + def get_sampling_params_dict(self, params_json): """ This functions parses the dictionary values into their @@ -342,40 +446,78 @@ def response_loop(self): if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL: self.ongoing_request_count -= 1 - def create_response(self, vllm_output, prepend_input): - """ - Parses the output from the vLLM engine into Triton - response. - """ - prompt = "" - if prepend_input: - prompt = vllm_output.prompt - text_outputs = [ - (prompt + output.text).encode("utf-8") for output in vllm_output.outputs + def _create_response( + self, prev_request_output, request_output, prepend_input, additional_outputs + ): + output_tensors = [] + + # text_output + prepend_prompt = "" + if prev_request_output is None: + # this is the first response + if prepend_input: + prepend_prompt = request_output.prompt + prev_lens = [0] * len(request_output.outputs) + else: + # this is a subsequent response + prev_lens = [ + len(prev_output.text) for prev_output in prev_request_output.outputs + ] + text_output = [ + (prepend_prompt + output.text[prev_len:]).encode("utf-8") + for output, prev_len in zip(request_output.outputs, prev_lens) ] - triton_output_tensor = pb_utils.Tensor( - "text_output", np.asarray(text_outputs, dtype=self.output_dtype) + output_tensors.append( + pb_utils.Tensor( + "text_output", np.asarray(text_output, dtype=self.output_dtype) + ) ) - return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) - def create_stream_response(self, vllm_output, previous_outputs_lengths): - """ - Parses the output from the vLLM engine, extracts only newly generated - text and packs it into Triton response. - """ - if previous_outputs_lengths is None: - return self.create_response(vllm_output, prepend_input=False) + # finish_reason + if additional_outputs["return_finish_reason"]: + finish_reason = [ + str(output.finish_reason) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "finish_reason", np.asarray(finish_reason, dtype=np.object_) + ) + ) - text_outputs = [ - (output.text[prev_output_length:]).encode("utf-8") - for output, prev_output_length in zip( - vllm_output.outputs, previous_outputs_lengths + # cumulative_logprob + if additional_outputs["return_cumulative_logprob"]: + cumulative_logprob = [ + output.cumulative_logprob for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "cumulative_logprob", + np.asarray(cumulative_logprob, dtype=np.float32), + ) ) - ] - triton_output_tensor = pb_utils.Tensor( - "text_output", np.asarray(text_outputs, dtype=self.output_dtype) - ) - return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) + + # num_output_tokens + if additional_outputs["return_num_output_tokens"]: + if prev_request_output is None: + # this is the first response + prev_lens = [0] * len(request_output.outputs) + else: + # this is a subsequent response + prev_lens = [ + len(prev_output.token_ids) + for prev_output in prev_request_output.outputs + ] + num_output_tokens = [ + (len(output.token_ids) - prev_len) + for output, prev_len in zip(request_output.outputs, prev_lens) + ] + output_tensors.append( + pb_utils.Tensor( + "num_output_tokens", np.asarray(num_output_tokens, dtype=np.uint32) + ) + ) + + return pb_utils.InferenceResponse(output_tensors=output_tensors) async def generate(self, request): """ @@ -391,70 +533,17 @@ async def generate(self, request): decrement_ongoing_request_count = True try: request_id = random_uuid() - prompt = pb_utils.get_input_tensor_by_name( - request, "text_input" - ).as_numpy()[0] - if isinstance(prompt, bytes): - prompt = prompt.decode("utf-8") - - if _VLLM_VERSION >= "0.6.3.post1": - image_input_tensor = pb_utils.get_input_tensor_by_name( - request, "image" - ) - if image_input_tensor: - image_list = [] - for image_raw in image_input_tensor.as_numpy(): - image_data = base64.b64decode(image_raw.decode("utf-8")) - image = Image.open(BytesIO(image_data)).convert("RGB") - image_list.append(image) - if len(image_list) > 0: - prompt = { - "prompt": prompt, - "multi_modal_data": { - "image": image_list - } - } - - stream = pb_utils.get_input_tensor_by_name(request, "stream") - if stream: - stream = stream.as_numpy()[0] - else: - stream = False - prepend_input = pb_utils.get_input_tensor_by_name( - request, "exclude_input_in_output" - ) - if prepend_input: - # When `exclude_input_in_output` is False, we want to prepend - # input prompt to output, thus prepend_input should be True, - # and vice versa. - prepend_input = not prepend_input.as_numpy()[0] - elif prepend_input is None and stream: - prepend_input = False - else: - prepend_input = True - - if prepend_input and stream: - raise ValueError( - "When streaming, `exclude_input_in_output` = False is not allowed." - ) - - # Request parameters are not yet supported via - # BLS. Provide an optional mechanism to receive serialized - # parameters as an input tensor until support is added - - parameters_input_tensor = pb_utils.get_input_tensor_by_name( - request, "sampling_parameters" - ) - if parameters_input_tensor: - parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") - else: - parameters = request.parameters() + ( + prompt, + stream, + prepend_input, + parameters, + additional_outputs, + ) = self._get_input_tensors(request) sampling_params_dict = self.get_sampling_params_dict(parameters) lora_name = sampling_params_dict.pop("lora_name", None) sampling_params = SamplingParams(**sampling_params_dict) - last_output = None - prev_outputs = None lora_request = None if lora_name is not None: lora_id = str(self.supported_loras.index(lora_name) + 1) @@ -466,7 +555,11 @@ async def generate(self, request): request_id, prompt, sampling_params, lora_request=lora_request ) - async for output in response_iterator: + prev_request_output = None + async for request_output in response_iterator: + # Cancellation state will be checked by the response loop and written to + # the response state if streaming. If not streaming, cancellation state + # needs to be checked here. is_cancelled = response_state["is_cancelled"] if not stream: is_cancelled = response_sender.is_cancelled() @@ -474,7 +567,9 @@ async def generate(self, request): self.logger.log_info("[vllm] Cancelling the request") await self.llm_engine.abort(request_id) self.logger.log_info("[vllm] Successfully cancelled the request") + if stream: + # Add cancelled final response to response loop. response_state["last_response_generated"] = True response = pb_utils.InferenceResponse( error=pb_utils.TritonError( @@ -487,44 +582,52 @@ async def generate(self, request): self._response_queue.put_nowait( (response_state, response, flags) ) + break + + # Send each response if streaming. if stream: - prev_outputs_lengths = None - if prev_outputs is not None: - prev_outputs_lengths = [ - len(prev_output.text) - for prev_output in prev_outputs.outputs - ] - response = self.create_stream_response(output, prev_outputs_lengths) + response = self._create_response( + prev_request_output, + request_output, + prepend_input=False, + additional_outputs=additional_outputs, + ) flags = 0 - if output.finished: + if request_output.finished: response_state["last_response_generated"] = True flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL decrement_ongoing_request_count = False self._response_queue.put_nowait((response_state, response, flags)) - prev_outputs = output - last_output = output + prev_request_output = request_output + # Send the last response which contains all the outputs if not streaming. if not stream: response_sender.send( - self.create_response(last_output, prepend_input), + self._create_response( + prev_request_output=None, + request_output=request_output, + prepend_input=prepend_input, + additional_outputs=additional_outputs, + ), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, ) except Exception as e: self.logger.log_error(f"[vllm] Error generating stream: {e}") error = pb_utils.TritonError(f"Error generating stream: {e}") - triton_output_tensor = pb_utils.Tensor( + text_output_tensor = pb_utils.Tensor( "text_output", np.asarray(["N/A"], dtype=self.output_dtype) ) response = pb_utils.InferenceResponse( - output_tensors=[triton_output_tensor], error=error + output_tensors=[text_output_tensor], error=error ) response_sender.send( response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL ) raise e + finally: if decrement_ongoing_request_count: self.ongoing_request_count -= 1 diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 0504eef9..48b77a2c 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -34,6 +34,7 @@ from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets from vllm.version import __version__ as _VLLM_VERSION + class TritonMetrics: def __init__(self, labels: List[str], max_model_len: int): # Initialize metric families @@ -163,9 +164,11 @@ def __init__(self, labels: List[str], max_model_len: int): ) ) if _VLLM_VERSION < "0.6.3": - self.histogram_best_of_request = self.histogram_best_of_request_family.Metric( - labels=labels, - buckets=[1, 2, 5, 10, 20], + self.histogram_best_of_request = ( + self.histogram_best_of_request_family.Metric( + labels=labels, + buckets=[1, 2, 5, 10, 20], + ) ) self.histogram_n_request = self.histogram_n_request_family.Metric( labels=labels, @@ -254,7 +257,9 @@ def log(self, stats: VllmStats) -> None: (self.metrics.histogram_n_request, stats.n_requests), ] if _VLLM_VERSION < "0.6.3": - histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests)) + histogram_metrics.append( + (self.metrics.histogram_best_of_request, stats.best_of_requests) + ) for metric, data in counter_metrics: self._log_counter(metric, data) for metric, data in histogram_metrics: