From 10a5b944a17fabf031eca006fd357b87e7e44662 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Tue, 29 Oct 2024 18:47:15 -0700 Subject: [PATCH 01/14] Add additional outputs and their input switches to auto complete * [WIP] Add additional outputs to auto complete * [WIP] Use individual input tensor to control per additional output * [WIP] Parse additional output flags from request --- src/model.py | 182 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 122 insertions(+), 60 deletions(-) diff --git a/src/model.py b/src/model.py index 3f6e23bb..d44688ac 100644 --- a/src/model.py +++ b/src/model.py @@ -48,8 +48,26 @@ class TritonPythonModel: + @classmethod + def auto_complete_config(cls, auto_complete_model_config): + # Add inputs/outputs to the model config. + cls._auto_complete_inputs_and_outputs(auto_complete_model_config) + + # We need to use decoupled transaction policy for saturating + # vLLM engine for max throughtput. + # TODO [DLIS:5233]: Allow asynchronous execution to lift this + # restriction for cases there is exactly a single response to + # a single request. + auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True)) + + # Disabling batching in Triton, let vLLM handle the batching on its own. + auto_complete_model_config.set_max_batch_size(0) + + return auto_complete_model_config + @staticmethod - def auto_complete_config(auto_complete_model_config): + def _auto_complete_inputs_and_outputs(auto_complete_model_config): + # Inputs/Outputs expected by the backend. inputs = [ {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, { @@ -70,10 +88,33 @@ def auto_complete_config(auto_complete_model_config): "dims": [1], "optional": True, }, + { + "name": "output_finish_reason", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "output_cumulative_logprob", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + { + "name": "output_num_token_ids", + "data_type": "TYPE_BOOL", + "dims": [1], + "optional": True, + }, + ] + outputs = [ + {"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}, + {"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]}, + {"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]}, + {"name": "num_token_ids", "data_type": "TYPE_UINT32", "dims": [-1]}, ] - outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}] - # Store the model configuration as a dictionary. + # Collect input and output names from the provided model config. config = auto_complete_model_config.as_dict() input_names = [] output_names = [] @@ -82,7 +123,7 @@ def auto_complete_config(auto_complete_model_config): for output in config["output"]: output_names.append(output["name"]) - # Add only missing inputs and output to the model configuration. + # Add missing inputs and outputs to the model config. for input in inputs: if input["name"] not in input_names: auto_complete_model_config.add_input(input) @@ -90,18 +131,6 @@ def auto_complete_config(auto_complete_model_config): if output["name"] not in output_names: auto_complete_model_config.add_output(output) - # We need to use decoupled transaction policy for saturating - # vLLM engine for max throughtput. - # TODO [DLIS:5233]: Allow asynchronous execution to lift this - # restriction for cases there is exactly a single response to - # a single request. - auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True)) - - # Disabling batching in Triton, let vLLM handle the batching on its own. - auto_complete_model_config.set_max_batch_size(0) - - return auto_complete_model_config - def initialize(self, args): self.args = args self.logger = pb_utils.Logger @@ -278,6 +307,63 @@ async def await_shutdown(self): self.logger.log_info("[vllm] Shutdown complete") + def _get_input_tensors(self, request): + # prompt + prompt = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy()[0] + if isinstance(prompt, bytes): + prompt = prompt.decode("utf-8") + + # stream + stream = pb_utils.get_input_tensor_by_name(request, "stream") + if stream: + stream = stream.as_numpy()[0] + else: + stream = False + + # prepend_input / exclude_input_in_output + prepend_input = pb_utils.get_input_tensor_by_name( + request, "exclude_input_in_output" + ) + if prepend_input: + # When `exclude_input_in_output` is False, we want to prepend input prompt + # to output, thus prepend_input should be True, and vice versa. + prepend_input = not prepend_input.as_numpy()[0] + elif prepend_input is None and stream: + prepend_input = False + else: + prepend_input = True + if prepend_input and stream: + raise ValueError( + "When streaming, `exclude_input_in_output` = False is not allowed." + ) + + # parameters / sampling_parameters + # An alternative mechanism to receive serialized parameters as an input tensor, + # because request parameters are not yet supported via BLS. + sampling_parameters = pb_utils.get_input_tensor_by_name( + request, "sampling_parameters" + ) + if sampling_parameters: + parameters = sampling_parameters.as_numpy()[0].decode("utf-8") + else: + parameters = request.parameters() + + # output_finish_reason, output_cumulative_logprob, output_num_token_ids + additional_outputs = { + "output_finish_reason": None, + "output_cumulative_logprob": None, + "output_num_token_ids": None, + } + for tensor_name in additional_outputs.keys(): + tensor = pb_utils.get_input_tensor_by_name(request, tensor_name) + if tensor: + tensor = bool(tensor.as_numpy()[0]) + else: + tensor = False + additional_outputs[tensor_name] = tensor + + return prompt, stream, prepend_input, parameters, additional_outputs + def get_sampling_params_dict(self, params_json): """ This functions parses the dictionary values into their @@ -331,7 +417,7 @@ def response_loop(self): if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL: self.ongoing_request_count -= 1 - def create_response(self, vllm_output, prepend_input): + def create_response(self, vllm_output, prepend_input, additional_outputs): """ Parses the output from the vLLM engine into Triton response. @@ -347,13 +433,17 @@ def create_response(self, vllm_output, prepend_input): ) return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) - def create_stream_response(self, vllm_output, previous_outputs_lengths): + def create_stream_response( + self, vllm_output, previous_outputs_lengths, additional_outputs + ): """ Parses the output from the vLLM engine, extracts only newly generated text and packs it into Triton response. """ if previous_outputs_lengths is None: - return self.create_response(vllm_output, prepend_input=False) + return self.create_response( + vllm_output, prepend_input=False, additional_outputs=additional_outputs + ) text_outputs = [ (output.text[prev_output_length:]).encode("utf-8") @@ -380,45 +470,13 @@ async def generate(self, request): decrement_ongoing_request_count = True try: request_id = random_uuid() - prompt = pb_utils.get_input_tensor_by_name( - request, "text_input" - ).as_numpy()[0] - if isinstance(prompt, bytes): - prompt = prompt.decode("utf-8") - stream = pb_utils.get_input_tensor_by_name(request, "stream") - if stream: - stream = stream.as_numpy()[0] - else: - stream = False - prepend_input = pb_utils.get_input_tensor_by_name( - request, "exclude_input_in_output" - ) - if prepend_input: - # When `exclude_input_in_output` is False, we want to prepend - # input prompt to output, thus prepend_input should be True, - # and vice versa. - prepend_input = not prepend_input.as_numpy()[0] - elif prepend_input is None and stream: - prepend_input = False - else: - prepend_input = True - - if prepend_input and stream: - raise ValueError( - "When streaming, `exclude_input_in_output` = False is not allowed." - ) - - # Request parameters are not yet supported via - # BLS. Provide an optional mechanism to receive serialized - # parameters as an input tensor until support is added - - parameters_input_tensor = pb_utils.get_input_tensor_by_name( - request, "sampling_parameters" - ) - if parameters_input_tensor: - parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") - else: - parameters = request.parameters() + ( + prompt, + stream, + prepend_input, + parameters, + additional_outputs, + ) = self._get_input_tensors(request) sampling_params_dict = self.get_sampling_params_dict(parameters) lora_name = sampling_params_dict.pop("lora_name", None) @@ -465,7 +523,9 @@ async def generate(self, request): len(prev_output.text) for prev_output in prev_outputs.outputs ] - response = self.create_stream_response(output, prev_outputs_lengths) + response = self.create_stream_response( + output, prev_outputs_lengths, additional_outputs + ) flags = 0 if output.finished: response_state["last_response_generated"] = True @@ -478,7 +538,9 @@ async def generate(self, request): if not stream: response_sender.send( - self.create_response(last_output, prepend_input), + self.create_response( + last_output, prepend_input, additional_outputs + ), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, ) From 892f0d0eb9af8b26e5502759e355ca50e201a4bd Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Thu, 31 Oct 2024 12:21:36 -0700 Subject: [PATCH 02/14] chore: Refactor generate function --- src/model.py | 94 ++++++++++++++++++++++------------------------------ 1 file changed, 40 insertions(+), 54 deletions(-) diff --git a/src/model.py b/src/model.py index d44688ac..5525edca 100644 --- a/src/model.py +++ b/src/model.py @@ -417,44 +417,30 @@ def response_loop(self): if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL: self.ongoing_request_count -= 1 - def create_response(self, vllm_output, prepend_input, additional_outputs): - """ - Parses the output from the vLLM engine into Triton - response. - """ - prompt = "" - if prepend_input: - prompt = vllm_output.prompt - text_outputs = [ - (prompt + output.text).encode("utf-8") for output in vllm_output.outputs - ] - triton_output_tensor = pb_utils.Tensor( - "text_output", np.asarray(text_outputs, dtype=self.output_dtype) - ) - return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) - - def create_stream_response( - self, vllm_output, previous_outputs_lengths, additional_outputs + def _create_response( + self, prev_request_output, request_output, prepend_input=False ): - """ - Parses the output from the vLLM engine, extracts only newly generated - text and packs it into Triton response. - """ - if previous_outputs_lengths is None: - return self.create_response( - vllm_output, prepend_input=False, additional_outputs=additional_outputs - ) - - text_outputs = [ - (output.text[prev_output_length:]).encode("utf-8") - for output, prev_output_length in zip( - vllm_output.outputs, previous_outputs_lengths - ) + # text_output + prepend_prompt = "" + if prev_request_output is None: + # this is the first response + if prepend_input: + prepend_prompt = request_output.prompt + prev_lens = [0] * len(request_output.outputs) + else: + # this is a subsequent response + prev_lens = [ + len(prev_output.text) for prev_output in prev_request_output.outputs + ] + text_output = [ + (prepend_prompt + output.text[prev_len:]).encode("utf-8") + for output, prev_len in zip(request_output.outputs, prev_lens) ] - triton_output_tensor = pb_utils.Tensor( - "text_output", np.asarray(text_outputs, dtype=self.output_dtype) + text_output_tensor = pb_utils.Tensor( + "text_output", np.asarray(text_output, dtype=self.output_dtype) ) - return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) + + return pb_utils.InferenceResponse(output_tensors=[text_output_tensor]) async def generate(self, request): """ @@ -481,8 +467,6 @@ async def generate(self, request): sampling_params_dict = self.get_sampling_params_dict(parameters) lora_name = sampling_params_dict.pop("lora_name", None) sampling_params = SamplingParams(**sampling_params_dict) - last_output = None - prev_outputs = None lora_request = None if lora_name is not None: lora_id = str(self.supported_loras.index(lora_name) + 1) @@ -494,7 +478,11 @@ async def generate(self, request): request_id, prompt, sampling_params, lora_request=lora_request ) - async for output in response_iterator: + prev_request_output = None + async for request_output in response_iterator: + # Cancellation state will be checked by the response loop and written to + # the response state if streaming. If not streaming, cancellation state + # needs to be checked here. is_cancelled = response_state["is_cancelled"] if not stream: is_cancelled = response_sender.is_cancelled() @@ -502,7 +490,9 @@ async def generate(self, request): self.logger.log_info("[vllm] Cancelling the request") await self.llm_engine.abort(request_id) self.logger.log_info("[vllm] Successfully cancelled the request") + if stream: + # Add cancelled final response to response loop. response_state["last_response_generated"] = True response = pb_utils.InferenceResponse( error=pb_utils.TritonError( @@ -515,48 +505,44 @@ async def generate(self, request): self._response_queue.put_nowait( (response_state, response, flags) ) + break + + # Send each response if streaming. if stream: - prev_outputs_lengths = None - if prev_outputs is not None: - prev_outputs_lengths = [ - len(prev_output.text) - for prev_output in prev_outputs.outputs - ] - response = self.create_stream_response( - output, prev_outputs_lengths, additional_outputs + response = self._create_response( + prev_request_output, request_output ) flags = 0 - if output.finished: + if request_output.finished: response_state["last_response_generated"] = True flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL decrement_ongoing_request_count = False self._response_queue.put_nowait((response_state, response, flags)) - prev_outputs = output - last_output = output + prev_request_output = request_output + # Send the last response which contains all the outputs if not streaming. if not stream: response_sender.send( - self.create_response( - last_output, prepend_input, additional_outputs - ), + self._create_response(None, request_output, prepend_input), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, ) except Exception as e: self.logger.log_error(f"[vllm] Error generating stream: {e}") error = pb_utils.TritonError(f"Error generating stream: {e}") - triton_output_tensor = pb_utils.Tensor( + text_output_tensor = pb_utils.Tensor( "text_output", np.asarray(["N/A"], dtype=self.output_dtype) ) response = pb_utils.InferenceResponse( - output_tensors=[triton_output_tensor], error=error + output_tensors=[text_output_tensor], error=error ) response_sender.send( response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL ) raise e + finally: if decrement_ongoing_request_count: self.ongoing_request_count -= 1 From 58ee481d6c4a7c76c2f4bb272b440709a02d3e54 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Thu, 31 Oct 2024 18:03:22 -0700 Subject: [PATCH 03/14] Add additional outputs to response --- src/model.py | 68 +++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 6 deletions(-) diff --git a/src/model.py b/src/model.py index 5525edca..0b4a6759 100644 --- a/src/model.py +++ b/src/model.py @@ -418,8 +418,10 @@ def response_loop(self): self.ongoing_request_count -= 1 def _create_response( - self, prev_request_output, request_output, prepend_input=False + self, prev_request_output, request_output, prepend_input, additional_outputs ): + output_tensors = [] + # text_output prepend_prompt = "" if prev_request_output is None: @@ -436,11 +438,57 @@ def _create_response( (prepend_prompt + output.text[prev_len:]).encode("utf-8") for output, prev_len in zip(request_output.outputs, prev_lens) ] - text_output_tensor = pb_utils.Tensor( - "text_output", np.asarray(text_output, dtype=self.output_dtype) + output_tensors.append( + pb_utils.Tensor( + "text_output", np.asarray(text_output, dtype=self.output_dtype) + ) ) - return pb_utils.InferenceResponse(output_tensors=[text_output_tensor]) + # finish_reason + if additional_outputs["output_finish_reason"]: + finish_reason = [ + str(output.finish_reason) for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "finish_reason", np.asarray(finish_reason, dtype=np.object_) + ) + ) + + # cumulative_logprob + if additional_outputs["output_cumulative_logprob"]: + cumulative_logprob = [ + output.cumulative_logprob for output in request_output.outputs + ] + output_tensors.append( + pb_utils.Tensor( + "cumulative_logprob", + np.asarray(cumulative_logprob, dtype=np.float32), + ) + ) + + # num_token_ids + if additional_outputs["output_num_token_ids"]: + if prev_request_output is None: + # this is the first response + prev_lens = [0] * len(request_output.outputs) + else: + # this is a subsequent response + prev_lens = [ + len(prev_output.token_ids) + for prev_output in prev_request_output.outputs + ] + num_token_ids = [ + (len(output.token_ids) - prev_len) + for output, prev_len in zip(request_output.outputs, prev_lens) + ] + output_tensors.append( + pb_utils.Tensor( + "num_token_ids", np.asarray(num_token_ids, dtype=np.uint32) + ) + ) + + return pb_utils.InferenceResponse(output_tensors=output_tensors) async def generate(self, request): """ @@ -511,7 +559,10 @@ async def generate(self, request): # Send each response if streaming. if stream: response = self._create_response( - prev_request_output, request_output + prev_request_output, + request_output, + prepend_input=False, + additional_outputs=additional_outputs, ) flags = 0 if request_output.finished: @@ -525,7 +576,12 @@ async def generate(self, request): # Send the last response which contains all the outputs if not streaming. if not stream: response_sender.send( - self._create_response(None, request_output, prepend_input), + self._create_response( + prev_request_output=None, + request_output=request_output, + prepend_input=prepend_input, + additional_outputs=additional_outputs, + ), flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, ) From 5e605ca9199235ad3715b0ba2dc56b4e6108f646 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Fri, 1 Nov 2024 17:14:01 -0700 Subject: [PATCH 04/14] Add test for additional outputs * Add additional outputs test * Update copyright * Some test enhancement and notes --- .../additional_outputs_test.py | 198 ++++++++++++++++++ ci/L0_vllm_additional_outputs/test.sh | 67 ++++++ ci/common/util.sh | 4 +- 3 files changed, 267 insertions(+), 2 deletions(-) create mode 100644 ci/L0_vllm_additional_outputs/additional_outputs_test.py create mode 100755 ci/L0_vllm_additional_outputs/test.sh diff --git a/ci/L0_vllm_additional_outputs/additional_outputs_test.py b/ci/L0_vllm_additional_outputs/additional_outputs_test.py new file mode 100644 index 00000000..0c2b3cda --- /dev/null +++ b/ci/L0_vllm_additional_outputs/additional_outputs_test.py @@ -0,0 +1,198 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import unittest + +import numpy as np +import tritonclient.grpc as grpcclient + + +class InferTest(unittest.TestCase): + _grpc_url = "localhost:8001" + _model_name = "vllm_opt" + _sampling_parameters = {"temperature": "0", "top_p": "1"} + _prompt = "In this example," + + def _get_inputs( + self, + prompt, + stream=True, + sampling_parameters=None, + output_finish_reason=None, + output_cumulative_logprob=None, + output_num_token_ids=None, + ): + inputs = [] + + inputs.append(grpcclient.InferInput("text_input", [1], "BYTES")) + inputs[-1].set_data_from_numpy( + np.array([prompt.encode("utf-8")], dtype=np.object_) + ) + + inputs.append(grpcclient.InferInput("stream", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([stream], dtype=bool)) + + if sampling_parameters is not None: + inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES")) + inputs[-1].set_data_from_numpy( + np.array( + [json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_ + ) + ) + + if output_finish_reason is not None: + inputs.append(grpcclient.InferInput("output_finish_reason", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([output_finish_reason], dtype=bool)) + + if output_cumulative_logprob is not None: + inputs.append( + grpcclient.InferInput("output_cumulative_logprob", [1], "BOOL") + ) + inputs[-1].set_data_from_numpy( + np.array([output_cumulative_logprob], dtype=bool) + ) + + if output_num_token_ids is not None: + inputs.append(grpcclient.InferInput("output_num_token_ids", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([output_num_token_ids], dtype=bool)) + + return inputs + + def _callback(self, result, error): + self._responses.append({"result": result, "error": error}) + + def _llm_infer(self, inputs): + self._responses = [] + with grpcclient.InferenceServerClient(self._grpc_url) as client: + client.start_stream(self._callback) + client.async_stream_infer( + self._model_name, inputs=inputs, parameters=self._sampling_parameters + ) + client.stop_stream() + self.assertGreater(len(self._responses), 0) + + def _assert_text_output_valid(self): + text_output = "" + for response in self._responses: + result, error = response["result"], response["error"] + self.assertIsNone(error) + text_output += result.as_numpy(name="text_output")[0].decode("utf-8") + self.assertGreater(len(text_output), 0, "output is empty") + self.assertGreater(text_output.count(" "), 4, "output is not a sentence") + + def _assert_finish_reason(self, output_finish_reason): + for i in range(len(self._responses)): + result, error = self._responses[i]["result"], self._responses[i]["error"] + self.assertIsNone(error) + finish_reason_np = result.as_numpy(name="finish_reason") + if output_finish_reason is None or output_finish_reason == False: + self.assertIsNone(finish_reason_np) + continue + finish_reason = finish_reason_np[0].decode("utf-8") + if i < len(self._responses) - 1: + self.assertEqual(finish_reason, "None") + else: + self.assertEqual(finish_reason, "length") + + def _assert_cumulative_logprob(self, output_cumulative_logprob): + prev_cumulative_logprob = 0.0 + for response in self._responses: + result, error = response["result"], response["error"] + self.assertIsNone(error) + cumulative_logprob_np = result.as_numpy(name="cumulative_logprob") + if output_cumulative_logprob is None or output_cumulative_logprob == False: + self.assertIsNone(cumulative_logprob_np) + continue + cumulative_logprob = cumulative_logprob_np[0].astype(float) + self.assertNotEqual(cumulative_logprob, prev_cumulative_logprob) + prev_cumulative_logprob = cumulative_logprob + + def _assert_num_token_ids(self, output_num_token_ids): + for response in self._responses: + result, error = response["result"], response["error"] + self.assertIsNone(error) + num_token_ids_np = result.as_numpy(name="num_token_ids") + if output_num_token_ids is None or output_num_token_ids == False: + self.assertIsNone(num_token_ids_np) + continue + num_token_ids = num_token_ids_np[0].astype(int) + # TODO: vLLM may return token ids identical to the previous one when + # streaming, for example: + # + # prev: None + # curr: text=' the', token_ids=array('l', [5]) + # + # prev: text=' the', token_ids=array('l', [5, 1385]) + # curr: text=' the term', token_ids=array('l', [5, 1385]) + # + # prev: text=' the term', token_ids=array('l', [5, 1385, 44]) + # curr: text=' the term', token_ids=array('l', [5, 1385, 44]) + # + # prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48]) + # curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48]) + # + # If this is no longer the case in a future release, change the assert + # to assertGreater(). + self.assertGreaterEqual(num_token_ids, 0) + + def _assert_additional_outputs_valid( + self, + stream, + output_finish_reason, + output_cumulative_logprob, + output_num_token_ids, + ): + inputs = self._get_inputs( + self._prompt, + stream=stream, + sampling_parameters=self._sampling_parameters, + output_finish_reason=output_finish_reason, + output_cumulative_logprob=output_cumulative_logprob, + output_num_token_ids=output_num_token_ids, + ) + self._llm_infer(inputs) + self._assert_text_output_valid() + self._assert_finish_reason(output_finish_reason) + self._assert_cumulative_logprob(output_cumulative_logprob) + self._assert_num_token_ids(output_num_token_ids) + + def test_additional_outputs(self): + for stream in [True, False]: + choices = [None, False, True] + for output_finish_reason in choices: + for output_cumulative_logprob in choices: + for output_num_token_ids in choices: + self._assert_additional_outputs_valid( + stream, + output_finish_reason, + output_cumulative_logprob, + output_num_token_ids, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/ci/L0_vllm_additional_outputs/test.sh b/ci/L0_vllm_additional_outputs/test.sh new file mode 100755 index 00000000..fffa2ec6 --- /dev/null +++ b/ci/L0_vllm_additional_outputs/test.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +export CUDA_VISIBLE_DEVICES=0 +source ../common/util.sh + +pip3 install tritonclient[grpc] + +# Prepare Model +rm -rf models vllm_baseline_output.pkl && mkdir -p models +SAMPLE_MODELS_REPO="../../samples/model_repository" +cp -r $SAMPLE_MODELS_REPO/vllm_model models/vllm_opt +sed -i 's/"gpu_memory_utilization": 0.5/"gpu_memory_utilization": 0.3/' models/vllm_opt/1/model.json + +RET=0 + +# Infer Test +CLIENT_LOG="vllm_opt.log" +SERVER_LOG="vllm_opt.server.log" +SERVER_ARGS="--model-repository=models" +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi +set +e +python3 additional_outputs_test.py > $CLIENT_LOG 2>&1 +if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** additional_outputs_test FAILED. \n***" + RET=1 +fi +set -e +kill $SERVER_PID +wait $SERVER_PID + +if [ $RET -eq 0 ]; then + echo -e "\n***\n*** Test Passed\n***" +else + echo -e "\n***\n*** Test FAILED\n***" +fi +exit $RET diff --git a/ci/common/util.sh b/ci/common/util.sh index 8baf4f92..0b2022ce 100755 --- a/ci/common/util.sh +++ b/ci/common/util.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -25,7 +25,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +SERVER=${SERVER:=/opt/tritonserver/bin/tritonserver} SERVER_IPADDR=${TRITONSERVER_IPADDR:=localhost} SERVER_LOG=${SERVER_LOG:=./server.log} SERVER_TIMEOUT=${SERVER_TIMEOUT:=120} From f35e9c40d1e09a91f6d2f621c6634a1294d49d2f Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:21:35 -0800 Subject: [PATCH 05/14] Add docs for additonal outputs --- README.md | 5 ++ docs/additional_outputs.md | 107 +++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 docs/additional_outputs.md diff --git a/README.md b/README.md index 8a993d99..a157ed61 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,11 @@ you need to specify a different `shm-region-prefix-name` for each server. See [here](https://github.com/triton-inference-server/python_backend#running-multiple-instances-of-triton-server) for more information. +## Additional vLLM outputs + +Additional vLLM outputs may be requested optionally on a per-request basis. See +[this docs](docs/additional_outputs.md) for more information. + ## Triton Metrics Starting with the 24.08 release of Triton, users can now obtain specific vLLM metrics by querying the Triton metrics endpoint (see complete vLLM metrics diff --git a/docs/additional_outputs.md b/docs/additional_outputs.md new file mode 100644 index 00000000..c874435e --- /dev/null +++ b/docs/additional_outputs.md @@ -0,0 +1,107 @@ + + +# Additional Outputs from vLLM + +The vLLM backend supports sending additional outputs from vLLM on top of the +usual `text_output` when requested. + +All additional outputs are disabled by default and they need to be enabled on a +per-request basis. If enabled, the corresponding output tensor will be set for +all responses from the request. + +## Supported Additional Outputs + +### Finish Reason + +The reason why the sequence is finished. See +[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L26) +for more details. + +To enable, set `output_finish_reason` input tensor to `True`. The reason will be +sent as a string on the `finish_reason` output tensor. + +Supported since r24.11. + +### Cumulative Log Probabilities + +The cumulative log probability of the generated output text. See +[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L22) +for more details. + +To enable, set `output_cumulative_logprob` input tensor to `True`. The floating +point value will be sent on the `cumulative_logprob` output tensor. + +Supported since r24.11. + +### Number of token IDs + +The number of token IDs of the generated output text sent on this response. It +is the difference in length of the token IDs generated from the last response to +this response. If this is the first response, the last response length is +presumed to be zero. See +[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L21) +for more details on the token IDs of the generated output text. + +To enable, set `output_num_token_ids` input tensor to `True`. The unsigned +integer value will be sent on the `num_token_ids` output tensor. + +Supported since r24.11. + +## Examples + +### Add Finish Reason to Outputs + +```python +import numpy as np +import tritonclient.grpc as grpcclient + +inputs = [] + +inputs.append(grpcclient.InferInput("text_input", [1], "BYTES")) +inputs[-1].set_data_from_numpy( + np.array(["example prompt".encode("utf-8")], dtype=np.object_) +) + +inputs.append(grpcclient.InferInput("output_finish_reason", [1], "BOOL")) +inputs[-1].set_data_from_numpy(np.array([True], dtype=bool)) + +def callback(result, error): + ... + print(result.as_numpy(name="finish_reason")) + +with grpcclient.InferenceServerClient("localhost:8001") as client: + client.start_stream(callback) + client.async_stream_infer("vLLM_model_name", inputs=inputs, ...) + client.stop_stream() +``` + +## Notes + +* Enabling additional outputs may impact performance, only add additional +outputs when necessary. From e6e64048f4c448818f5f66c55fab7040f63d39d6 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:34:36 -0800 Subject: [PATCH 06/14] chore: Unify vLLM test names --- .../additional_outputs_test.py | 0 .../test.sh | 0 ci/{L0_multi_gpu => L0_multi_gpu_vllm}/multi_lora/download.py | 0 .../multi_lora/multi_lora_test.py | 0 ci/{L0_multi_gpu => L0_multi_gpu_vllm}/multi_lora/test.sh | 0 ci/{L0_multi_gpu => L0_multi_gpu_vllm}/test.sh | 0 ci/{L0_multi_gpu => L0_multi_gpu_vllm}/vllm_backend/test.sh | 0 .../vllm_backend/vllm_multi_gpu_test.py | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename ci/{L0_vllm_additional_outputs => L0_additional_outputs_vllm}/additional_outputs_test.py (100%) rename ci/{L0_vllm_additional_outputs => L0_additional_outputs_vllm}/test.sh (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/multi_lora/download.py (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/multi_lora/multi_lora_test.py (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/multi_lora/test.sh (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/test.sh (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/vllm_backend/test.sh (100%) rename ci/{L0_multi_gpu => L0_multi_gpu_vllm}/vllm_backend/vllm_multi_gpu_test.py (100%) diff --git a/ci/L0_vllm_additional_outputs/additional_outputs_test.py b/ci/L0_additional_outputs_vllm/additional_outputs_test.py similarity index 100% rename from ci/L0_vllm_additional_outputs/additional_outputs_test.py rename to ci/L0_additional_outputs_vllm/additional_outputs_test.py diff --git a/ci/L0_vllm_additional_outputs/test.sh b/ci/L0_additional_outputs_vllm/test.sh similarity index 100% rename from ci/L0_vllm_additional_outputs/test.sh rename to ci/L0_additional_outputs_vllm/test.sh diff --git a/ci/L0_multi_gpu/multi_lora/download.py b/ci/L0_multi_gpu_vllm/multi_lora/download.py similarity index 100% rename from ci/L0_multi_gpu/multi_lora/download.py rename to ci/L0_multi_gpu_vllm/multi_lora/download.py diff --git a/ci/L0_multi_gpu/multi_lora/multi_lora_test.py b/ci/L0_multi_gpu_vllm/multi_lora/multi_lora_test.py similarity index 100% rename from ci/L0_multi_gpu/multi_lora/multi_lora_test.py rename to ci/L0_multi_gpu_vllm/multi_lora/multi_lora_test.py diff --git a/ci/L0_multi_gpu/multi_lora/test.sh b/ci/L0_multi_gpu_vllm/multi_lora/test.sh similarity index 100% rename from ci/L0_multi_gpu/multi_lora/test.sh rename to ci/L0_multi_gpu_vllm/multi_lora/test.sh diff --git a/ci/L0_multi_gpu/test.sh b/ci/L0_multi_gpu_vllm/test.sh similarity index 100% rename from ci/L0_multi_gpu/test.sh rename to ci/L0_multi_gpu_vllm/test.sh diff --git a/ci/L0_multi_gpu/vllm_backend/test.sh b/ci/L0_multi_gpu_vllm/vllm_backend/test.sh similarity index 100% rename from ci/L0_multi_gpu/vllm_backend/test.sh rename to ci/L0_multi_gpu_vllm/vllm_backend/test.sh diff --git a/ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py b/ci/L0_multi_gpu_vllm/vllm_backend/vllm_multi_gpu_test.py similarity index 100% rename from ci/L0_multi_gpu/vllm_backend/vllm_multi_gpu_test.py rename to ci/L0_multi_gpu_vllm/vllm_backend/vllm_multi_gpu_test.py From 44edd6e3cc1b8393d13a5b9256697e6609386259 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Wed, 6 Nov 2024 15:36:38 -0800 Subject: [PATCH 07/14] Switch to pytest --- .../additional_outputs_test.py | 59 ++++++++----------- ci/L0_additional_outputs_vllm/test.sh | 7 +-- 2 files changed, 26 insertions(+), 40 deletions(-) diff --git a/ci/L0_additional_outputs_vllm/additional_outputs_test.py b/ci/L0_additional_outputs_vllm/additional_outputs_test.py index 0c2b3cda..08cfc91e 100644 --- a/ci/L0_additional_outputs_vllm/additional_outputs_test.py +++ b/ci/L0_additional_outputs_vllm/additional_outputs_test.py @@ -25,13 +25,13 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json -import unittest import numpy as np +import pytest import tritonclient.grpc as grpcclient -class InferTest(unittest.TestCase): +class TestAdditionalOutputs: _grpc_url = "localhost:8001" _model_name = "vllm_opt" _sampling_parameters = {"temperature": "0", "top_p": "1"} @@ -93,51 +93,51 @@ def _llm_infer(self, inputs): self._model_name, inputs=inputs, parameters=self._sampling_parameters ) client.stop_stream() - self.assertGreater(len(self._responses), 0) + assert len(self._responses) > 0 def _assert_text_output_valid(self): text_output = "" for response in self._responses: result, error = response["result"], response["error"] - self.assertIsNone(error) + assert error is None text_output += result.as_numpy(name="text_output")[0].decode("utf-8") - self.assertGreater(len(text_output), 0, "output is empty") - self.assertGreater(text_output.count(" "), 4, "output is not a sentence") + assert len(text_output) > 0, "output is empty" + assert text_output.count(" ") > 4, "output is not a sentence" def _assert_finish_reason(self, output_finish_reason): for i in range(len(self._responses)): result, error = self._responses[i]["result"], self._responses[i]["error"] - self.assertIsNone(error) + assert error is None finish_reason_np = result.as_numpy(name="finish_reason") if output_finish_reason is None or output_finish_reason == False: - self.assertIsNone(finish_reason_np) + assert finish_reason_np is None continue finish_reason = finish_reason_np[0].decode("utf-8") if i < len(self._responses) - 1: - self.assertEqual(finish_reason, "None") + assert finish_reason == "None" else: - self.assertEqual(finish_reason, "length") + assert finish_reason == "length" def _assert_cumulative_logprob(self, output_cumulative_logprob): prev_cumulative_logprob = 0.0 for response in self._responses: result, error = response["result"], response["error"] - self.assertIsNone(error) + assert error is None cumulative_logprob_np = result.as_numpy(name="cumulative_logprob") if output_cumulative_logprob is None or output_cumulative_logprob == False: - self.assertIsNone(cumulative_logprob_np) + assert cumulative_logprob_np is None continue cumulative_logprob = cumulative_logprob_np[0].astype(float) - self.assertNotEqual(cumulative_logprob, prev_cumulative_logprob) + assert cumulative_logprob != prev_cumulative_logprob prev_cumulative_logprob = cumulative_logprob def _assert_num_token_ids(self, output_num_token_ids): for response in self._responses: result, error = response["result"], response["error"] - self.assertIsNone(error) + assert error is None num_token_ids_np = result.as_numpy(name="num_token_ids") if output_num_token_ids is None or output_num_token_ids == False: - self.assertIsNone(num_token_ids_np) + assert num_token_ids_np is None continue num_token_ids = num_token_ids_np[0].astype(int) # TODO: vLLM may return token ids identical to the previous one when @@ -156,10 +156,14 @@ def _assert_num_token_ids(self, output_num_token_ids): # curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48]) # # If this is no longer the case in a future release, change the assert - # to assertGreater(). - self.assertGreaterEqual(num_token_ids, 0) - - def _assert_additional_outputs_valid( + # to assert num_token_ids > 0. + assert num_token_ids >= 0 + + @pytest.mark.parametrize("stream", [True, False]) + @pytest.mark.parametrize("output_finish_reason", [None, True, False]) + @pytest.mark.parametrize("output_cumulative_logprob", [None, True, False]) + @pytest.mark.parametrize("output_num_token_ids", [None, True, False]) + def test_additional_outputs( self, stream, output_finish_reason, @@ -179,20 +183,3 @@ def _assert_additional_outputs_valid( self._assert_finish_reason(output_finish_reason) self._assert_cumulative_logprob(output_cumulative_logprob) self._assert_num_token_ids(output_num_token_ids) - - def test_additional_outputs(self): - for stream in [True, False]: - choices = [None, False, True] - for output_finish_reason in choices: - for output_cumulative_logprob in choices: - for output_num_token_ids in choices: - self._assert_additional_outputs_valid( - stream, - output_finish_reason, - output_cumulative_logprob, - output_num_token_ids, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/ci/L0_additional_outputs_vllm/test.sh b/ci/L0_additional_outputs_vllm/test.sh index fffa2ec6..9800e4d3 100755 --- a/ci/L0_additional_outputs_vllm/test.sh +++ b/ci/L0_additional_outputs_vllm/test.sh @@ -28,6 +28,7 @@ export CUDA_VISIBLE_DEVICES=0 source ../common/util.sh +pip3 install pytest==8.1.1 pip3 install tritonclient[grpc] # Prepare Model @@ -38,8 +39,7 @@ sed -i 's/"gpu_memory_utilization": 0.5/"gpu_memory_utilization": 0.3/' models/v RET=0 -# Infer Test -CLIENT_LOG="vllm_opt.log" +# Test SERVER_LOG="vllm_opt.server.log" SERVER_ARGS="--model-repository=models" run_server @@ -49,9 +49,8 @@ if [ "$SERVER_PID" == "0" ]; then exit 1 fi set +e -python3 additional_outputs_test.py > $CLIENT_LOG 2>&1 +python3 -m pytest -s -v additional_outputs_test.py if [ $? -ne 0 ]; then - cat $CLIENT_LOG echo -e "\n***\n*** additional_outputs_test FAILED. \n***" RET=1 fi From 1773deafb683c463cde7dc6efc79f10d0eca381f Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Wed, 6 Nov 2024 15:51:16 -0800 Subject: [PATCH 08/14] pytest to dump additional outputs --- ci/L0_additional_outputs_vllm/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/L0_additional_outputs_vllm/test.sh b/ci/L0_additional_outputs_vllm/test.sh index 9800e4d3..880f918f 100755 --- a/ci/L0_additional_outputs_vllm/test.sh +++ b/ci/L0_additional_outputs_vllm/test.sh @@ -49,7 +49,7 @@ if [ "$SERVER_PID" == "0" ]; then exit 1 fi set +e -python3 -m pytest -s -v additional_outputs_test.py +python3 -m pytest --junitxml=test_additional_outputs.xml -s -v additional_outputs_test.py if [ $? -ne 0 ]; then echo -e "\n***\n*** additional_outputs_test FAILED. \n***" RET=1 From 29099dfe119a283319711313cdb597915c05994c Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Wed, 6 Nov 2024 16:48:52 -0800 Subject: [PATCH 09/14] Rename output_* to return_* --- .../additional_outputs_test.py | 60 +++++++++---------- docs/additional_outputs.md | 8 +-- src/model.py | 20 +++---- 3 files changed, 44 insertions(+), 44 deletions(-) diff --git a/ci/L0_additional_outputs_vllm/additional_outputs_test.py b/ci/L0_additional_outputs_vllm/additional_outputs_test.py index 08cfc91e..a8dfb24d 100644 --- a/ci/L0_additional_outputs_vllm/additional_outputs_test.py +++ b/ci/L0_additional_outputs_vllm/additional_outputs_test.py @@ -42,9 +42,9 @@ def _get_inputs( prompt, stream=True, sampling_parameters=None, - output_finish_reason=None, - output_cumulative_logprob=None, - output_num_token_ids=None, + return_finish_reason=None, + return_cumulative_logprob=None, + return_num_token_ids=None, ): inputs = [] @@ -64,21 +64,21 @@ def _get_inputs( ) ) - if output_finish_reason is not None: - inputs.append(grpcclient.InferInput("output_finish_reason", [1], "BOOL")) - inputs[-1].set_data_from_numpy(np.array([output_finish_reason], dtype=bool)) + if return_finish_reason is not None: + inputs.append(grpcclient.InferInput("return_finish_reason", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([return_finish_reason], dtype=bool)) - if output_cumulative_logprob is not None: + if return_cumulative_logprob is not None: inputs.append( - grpcclient.InferInput("output_cumulative_logprob", [1], "BOOL") + grpcclient.InferInput("return_cumulative_logprob", [1], "BOOL") ) inputs[-1].set_data_from_numpy( - np.array([output_cumulative_logprob], dtype=bool) + np.array([return_cumulative_logprob], dtype=bool) ) - if output_num_token_ids is not None: - inputs.append(grpcclient.InferInput("output_num_token_ids", [1], "BOOL")) - inputs[-1].set_data_from_numpy(np.array([output_num_token_ids], dtype=bool)) + if return_num_token_ids is not None: + inputs.append(grpcclient.InferInput("return_num_token_ids", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([return_num_token_ids], dtype=bool)) return inputs @@ -104,12 +104,12 @@ def _assert_text_output_valid(self): assert len(text_output) > 0, "output is empty" assert text_output.count(" ") > 4, "output is not a sentence" - def _assert_finish_reason(self, output_finish_reason): + def _assert_finish_reason(self, return_finish_reason): for i in range(len(self._responses)): result, error = self._responses[i]["result"], self._responses[i]["error"] assert error is None finish_reason_np = result.as_numpy(name="finish_reason") - if output_finish_reason is None or output_finish_reason == False: + if return_finish_reason is None or return_finish_reason == False: assert finish_reason_np is None continue finish_reason = finish_reason_np[0].decode("utf-8") @@ -118,25 +118,25 @@ def _assert_finish_reason(self, output_finish_reason): else: assert finish_reason == "length" - def _assert_cumulative_logprob(self, output_cumulative_logprob): + def _assert_cumulative_logprob(self, return_cumulative_logprob): prev_cumulative_logprob = 0.0 for response in self._responses: result, error = response["result"], response["error"] assert error is None cumulative_logprob_np = result.as_numpy(name="cumulative_logprob") - if output_cumulative_logprob is None or output_cumulative_logprob == False: + if return_cumulative_logprob is None or return_cumulative_logprob == False: assert cumulative_logprob_np is None continue cumulative_logprob = cumulative_logprob_np[0].astype(float) assert cumulative_logprob != prev_cumulative_logprob prev_cumulative_logprob = cumulative_logprob - def _assert_num_token_ids(self, output_num_token_ids): + def _assert_num_token_ids(self, return_num_token_ids): for response in self._responses: result, error = response["result"], response["error"] assert error is None num_token_ids_np = result.as_numpy(name="num_token_ids") - if output_num_token_ids is None or output_num_token_ids == False: + if return_num_token_ids is None or return_num_token_ids == False: assert num_token_ids_np is None continue num_token_ids = num_token_ids_np[0].astype(int) @@ -160,26 +160,26 @@ def _assert_num_token_ids(self, output_num_token_ids): assert num_token_ids >= 0 @pytest.mark.parametrize("stream", [True, False]) - @pytest.mark.parametrize("output_finish_reason", [None, True, False]) - @pytest.mark.parametrize("output_cumulative_logprob", [None, True, False]) - @pytest.mark.parametrize("output_num_token_ids", [None, True, False]) + @pytest.mark.parametrize("return_finish_reason", [None, True, False]) + @pytest.mark.parametrize("return_cumulative_logprob", [None, True, False]) + @pytest.mark.parametrize("return_num_token_ids", [None, True, False]) def test_additional_outputs( self, stream, - output_finish_reason, - output_cumulative_logprob, - output_num_token_ids, + return_finish_reason, + return_cumulative_logprob, + return_num_token_ids, ): inputs = self._get_inputs( self._prompt, stream=stream, sampling_parameters=self._sampling_parameters, - output_finish_reason=output_finish_reason, - output_cumulative_logprob=output_cumulative_logprob, - output_num_token_ids=output_num_token_ids, + return_finish_reason=return_finish_reason, + return_cumulative_logprob=return_cumulative_logprob, + return_num_token_ids=return_num_token_ids, ) self._llm_infer(inputs) self._assert_text_output_valid() - self._assert_finish_reason(output_finish_reason) - self._assert_cumulative_logprob(output_cumulative_logprob) - self._assert_num_token_ids(output_num_token_ids) + self._assert_finish_reason(return_finish_reason) + self._assert_cumulative_logprob(return_cumulative_logprob) + self._assert_num_token_ids(return_num_token_ids) diff --git a/docs/additional_outputs.md b/docs/additional_outputs.md index c874435e..dcca0dc4 100644 --- a/docs/additional_outputs.md +++ b/docs/additional_outputs.md @@ -43,7 +43,7 @@ The reason why the sequence is finished. See [here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L26) for more details. -To enable, set `output_finish_reason` input tensor to `True`. The reason will be +To enable, set `return_finish_reason` input tensor to `True`. The reason will be sent as a string on the `finish_reason` output tensor. Supported since r24.11. @@ -54,7 +54,7 @@ The cumulative log probability of the generated output text. See [here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L22) for more details. -To enable, set `output_cumulative_logprob` input tensor to `True`. The floating +To enable, set `return_cumulative_logprob` input tensor to `True`. The floating point value will be sent on the `cumulative_logprob` output tensor. Supported since r24.11. @@ -68,7 +68,7 @@ presumed to be zero. See [here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L21) for more details on the token IDs of the generated output text. -To enable, set `output_num_token_ids` input tensor to `True`. The unsigned +To enable, set `return_num_token_ids` input tensor to `True`. The unsigned integer value will be sent on the `num_token_ids` output tensor. Supported since r24.11. @@ -88,7 +88,7 @@ inputs[-1].set_data_from_numpy( np.array(["example prompt".encode("utf-8")], dtype=np.object_) ) -inputs.append(grpcclient.InferInput("output_finish_reason", [1], "BOOL")) +inputs.append(grpcclient.InferInput("return_finish_reason", [1], "BOOL")) inputs[-1].set_data_from_numpy(np.array([True], dtype=bool)) def callback(result, error): diff --git a/src/model.py b/src/model.py index 0b4a6759..dfaebf61 100644 --- a/src/model.py +++ b/src/model.py @@ -89,19 +89,19 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): "optional": True, }, { - "name": "output_finish_reason", + "name": "return_finish_reason", "data_type": "TYPE_BOOL", "dims": [1], "optional": True, }, { - "name": "output_cumulative_logprob", + "name": "return_cumulative_logprob", "data_type": "TYPE_BOOL", "dims": [1], "optional": True, }, { - "name": "output_num_token_ids", + "name": "return_num_token_ids", "data_type": "TYPE_BOOL", "dims": [1], "optional": True, @@ -348,11 +348,11 @@ def _get_input_tensors(self, request): else: parameters = request.parameters() - # output_finish_reason, output_cumulative_logprob, output_num_token_ids + # return_finish_reason, return_cumulative_logprob, return_num_token_ids additional_outputs = { - "output_finish_reason": None, - "output_cumulative_logprob": None, - "output_num_token_ids": None, + "return_finish_reason": None, + "return_cumulative_logprob": None, + "return_num_token_ids": None, } for tensor_name in additional_outputs.keys(): tensor = pb_utils.get_input_tensor_by_name(request, tensor_name) @@ -445,7 +445,7 @@ def _create_response( ) # finish_reason - if additional_outputs["output_finish_reason"]: + if additional_outputs["return_finish_reason"]: finish_reason = [ str(output.finish_reason) for output in request_output.outputs ] @@ -456,7 +456,7 @@ def _create_response( ) # cumulative_logprob - if additional_outputs["output_cumulative_logprob"]: + if additional_outputs["return_cumulative_logprob"]: cumulative_logprob = [ output.cumulative_logprob for output in request_output.outputs ] @@ -468,7 +468,7 @@ def _create_response( ) # num_token_ids - if additional_outputs["output_num_token_ids"]: + if additional_outputs["return_num_token_ids"]: if prev_request_output is None: # this is the first response prev_lens = [0] * len(request_output.outputs) From 457eeaa6f23fb480ba38a84b1426e4d67879e58a Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:16:34 -0800 Subject: [PATCH 10/14] Return token ids instead of number of token ids --- .../additional_outputs_test.py | 32 +++++++++---------- docs/additional_outputs.md | 13 +++----- src/model.py | 20 ++++++------ 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/ci/L0_additional_outputs_vllm/additional_outputs_test.py b/ci/L0_additional_outputs_vllm/additional_outputs_test.py index a8dfb24d..2826a4ca 100644 --- a/ci/L0_additional_outputs_vllm/additional_outputs_test.py +++ b/ci/L0_additional_outputs_vllm/additional_outputs_test.py @@ -44,7 +44,7 @@ def _get_inputs( sampling_parameters=None, return_finish_reason=None, return_cumulative_logprob=None, - return_num_token_ids=None, + return_token_ids=None, ): inputs = [] @@ -76,9 +76,9 @@ def _get_inputs( np.array([return_cumulative_logprob], dtype=bool) ) - if return_num_token_ids is not None: - inputs.append(grpcclient.InferInput("return_num_token_ids", [1], "BOOL")) - inputs[-1].set_data_from_numpy(np.array([return_num_token_ids], dtype=bool)) + if return_token_ids is not None: + inputs.append(grpcclient.InferInput("return_token_ids", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([return_token_ids], dtype=bool)) return inputs @@ -131,15 +131,15 @@ def _assert_cumulative_logprob(self, return_cumulative_logprob): assert cumulative_logprob != prev_cumulative_logprob prev_cumulative_logprob = cumulative_logprob - def _assert_num_token_ids(self, return_num_token_ids): + def _assert_token_ids(self, return_token_ids): for response in self._responses: result, error = response["result"], response["error"] assert error is None - num_token_ids_np = result.as_numpy(name="num_token_ids") - if return_num_token_ids is None or return_num_token_ids == False: - assert num_token_ids_np is None + token_ids_np = result.as_numpy(name="token_ids") + if return_token_ids is None or return_token_ids == False: + assert token_ids_np is None continue - num_token_ids = num_token_ids_np[0].astype(int) + token_ids = token_ids_np[0].astype(int) # TODO: vLLM may return token ids identical to the previous one when # streaming, for example: # @@ -155,20 +155,20 @@ def _assert_num_token_ids(self, return_num_token_ids): # prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48]) # curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48]) # - # If this is no longer the case in a future release, change the assert - # to assert num_token_ids > 0. - assert num_token_ids >= 0 + # If this is no longer the case in a future release, change to + # assert len(token_ids) > 0. + assert len(token_ids) >= 0 @pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("return_finish_reason", [None, True, False]) @pytest.mark.parametrize("return_cumulative_logprob", [None, True, False]) - @pytest.mark.parametrize("return_num_token_ids", [None, True, False]) + @pytest.mark.parametrize("return_token_ids", [None, True, False]) def test_additional_outputs( self, stream, return_finish_reason, return_cumulative_logprob, - return_num_token_ids, + return_token_ids, ): inputs = self._get_inputs( self._prompt, @@ -176,10 +176,10 @@ def test_additional_outputs( sampling_parameters=self._sampling_parameters, return_finish_reason=return_finish_reason, return_cumulative_logprob=return_cumulative_logprob, - return_num_token_ids=return_num_token_ids, + return_token_ids=return_token_ids, ) self._llm_infer(inputs) self._assert_text_output_valid() self._assert_finish_reason(return_finish_reason) self._assert_cumulative_logprob(return_cumulative_logprob) - self._assert_num_token_ids(return_num_token_ids) + self._assert_token_ids(return_token_ids) diff --git a/docs/additional_outputs.md b/docs/additional_outputs.md index dcca0dc4..fdc631dd 100644 --- a/docs/additional_outputs.md +++ b/docs/additional_outputs.md @@ -59,17 +59,14 @@ point value will be sent on the `cumulative_logprob` output tensor. Supported since r24.11. -### Number of token IDs +### Token IDs -The number of token IDs of the generated output text sent on this response. It -is the difference in length of the token IDs generated from the last response to -this response. If this is the first response, the last response length is -presumed to be zero. See +The token IDs of the generated output text sent on this response. See [here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L21) -for more details on the token IDs of the generated output text. +for more details. -To enable, set `return_num_token_ids` input tensor to `True`. The unsigned -integer value will be sent on the `num_token_ids` output tensor. +To enable, set `return_token_ids` input tensor to `True`. The array of integer +value will be sent on the `token_ids` output tensor. Supported since r24.11. diff --git a/src/model.py b/src/model.py index dfaebf61..bd073156 100644 --- a/src/model.py +++ b/src/model.py @@ -101,7 +101,7 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): "optional": True, }, { - "name": "return_num_token_ids", + "name": "return_token_ids", "data_type": "TYPE_BOOL", "dims": [1], "optional": True, @@ -111,7 +111,7 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): {"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}, {"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]}, {"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]}, - {"name": "num_token_ids", "data_type": "TYPE_UINT32", "dims": [-1]}, + {"name": "token_ids", "data_type": "TYPE_INT64", "dims": [-1, -1]}, ] # Collect input and output names from the provided model config. @@ -348,11 +348,11 @@ def _get_input_tensors(self, request): else: parameters = request.parameters() - # return_finish_reason, return_cumulative_logprob, return_num_token_ids + # return_finish_reason, return_cumulative_logprob, return_token_ids additional_outputs = { "return_finish_reason": None, "return_cumulative_logprob": None, - "return_num_token_ids": None, + "return_token_ids": None, } for tensor_name in additional_outputs.keys(): tensor = pb_utils.get_input_tensor_by_name(request, tensor_name) @@ -467,8 +467,8 @@ def _create_response( ) ) - # num_token_ids - if additional_outputs["return_num_token_ids"]: + # token_ids + if additional_outputs["return_token_ids"]: if prev_request_output is None: # this is the first response prev_lens = [0] * len(request_output.outputs) @@ -478,14 +478,12 @@ def _create_response( len(prev_output.token_ids) for prev_output in prev_request_output.outputs ] - num_token_ids = [ - (len(output.token_ids) - prev_len) + token_ids = [ + output.token_ids[prev_len:] for output, prev_len in zip(request_output.outputs, prev_lens) ] output_tensors.append( - pb_utils.Tensor( - "num_token_ids", np.asarray(num_token_ids, dtype=np.uint32) - ) + pb_utils.Tensor("token_ids", np.asarray(token_ids, dtype=np.int64)) ) return pb_utils.InferenceResponse(output_tensors=output_tensors) From 5e9b09f21e212e4cc3362cae893d8a25f017a957 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:24:51 -0800 Subject: [PATCH 11/14] Revert "Return token ids instead of number of token ids" This reverts commit 457eeaa6f23fb480ba38a84b1426e4d67879e58a. --- .../additional_outputs_test.py | 32 +++++++++---------- docs/additional_outputs.md | 13 +++++--- src/model.py | 20 ++++++------ 3 files changed, 35 insertions(+), 30 deletions(-) diff --git a/ci/L0_additional_outputs_vllm/additional_outputs_test.py b/ci/L0_additional_outputs_vllm/additional_outputs_test.py index 2826a4ca..a8dfb24d 100644 --- a/ci/L0_additional_outputs_vllm/additional_outputs_test.py +++ b/ci/L0_additional_outputs_vllm/additional_outputs_test.py @@ -44,7 +44,7 @@ def _get_inputs( sampling_parameters=None, return_finish_reason=None, return_cumulative_logprob=None, - return_token_ids=None, + return_num_token_ids=None, ): inputs = [] @@ -76,9 +76,9 @@ def _get_inputs( np.array([return_cumulative_logprob], dtype=bool) ) - if return_token_ids is not None: - inputs.append(grpcclient.InferInput("return_token_ids", [1], "BOOL")) - inputs[-1].set_data_from_numpy(np.array([return_token_ids], dtype=bool)) + if return_num_token_ids is not None: + inputs.append(grpcclient.InferInput("return_num_token_ids", [1], "BOOL")) + inputs[-1].set_data_from_numpy(np.array([return_num_token_ids], dtype=bool)) return inputs @@ -131,15 +131,15 @@ def _assert_cumulative_logprob(self, return_cumulative_logprob): assert cumulative_logprob != prev_cumulative_logprob prev_cumulative_logprob = cumulative_logprob - def _assert_token_ids(self, return_token_ids): + def _assert_num_token_ids(self, return_num_token_ids): for response in self._responses: result, error = response["result"], response["error"] assert error is None - token_ids_np = result.as_numpy(name="token_ids") - if return_token_ids is None or return_token_ids == False: - assert token_ids_np is None + num_token_ids_np = result.as_numpy(name="num_token_ids") + if return_num_token_ids is None or return_num_token_ids == False: + assert num_token_ids_np is None continue - token_ids = token_ids_np[0].astype(int) + num_token_ids = num_token_ids_np[0].astype(int) # TODO: vLLM may return token ids identical to the previous one when # streaming, for example: # @@ -155,20 +155,20 @@ def _assert_token_ids(self, return_token_ids): # prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48]) # curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48]) # - # If this is no longer the case in a future release, change to - # assert len(token_ids) > 0. - assert len(token_ids) >= 0 + # If this is no longer the case in a future release, change the assert + # to assert num_token_ids > 0. + assert num_token_ids >= 0 @pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("return_finish_reason", [None, True, False]) @pytest.mark.parametrize("return_cumulative_logprob", [None, True, False]) - @pytest.mark.parametrize("return_token_ids", [None, True, False]) + @pytest.mark.parametrize("return_num_token_ids", [None, True, False]) def test_additional_outputs( self, stream, return_finish_reason, return_cumulative_logprob, - return_token_ids, + return_num_token_ids, ): inputs = self._get_inputs( self._prompt, @@ -176,10 +176,10 @@ def test_additional_outputs( sampling_parameters=self._sampling_parameters, return_finish_reason=return_finish_reason, return_cumulative_logprob=return_cumulative_logprob, - return_token_ids=return_token_ids, + return_num_token_ids=return_num_token_ids, ) self._llm_infer(inputs) self._assert_text_output_valid() self._assert_finish_reason(return_finish_reason) self._assert_cumulative_logprob(return_cumulative_logprob) - self._assert_token_ids(return_token_ids) + self._assert_num_token_ids(return_num_token_ids) diff --git a/docs/additional_outputs.md b/docs/additional_outputs.md index fdc631dd..dcca0dc4 100644 --- a/docs/additional_outputs.md +++ b/docs/additional_outputs.md @@ -59,14 +59,17 @@ point value will be sent on the `cumulative_logprob` output tensor. Supported since r24.11. -### Token IDs +### Number of token IDs -The token IDs of the generated output text sent on this response. See +The number of token IDs of the generated output text sent on this response. It +is the difference in length of the token IDs generated from the last response to +this response. If this is the first response, the last response length is +presumed to be zero. See [here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L21) -for more details. +for more details on the token IDs of the generated output text. -To enable, set `return_token_ids` input tensor to `True`. The array of integer -value will be sent on the `token_ids` output tensor. +To enable, set `return_num_token_ids` input tensor to `True`. The unsigned +integer value will be sent on the `num_token_ids` output tensor. Supported since r24.11. diff --git a/src/model.py b/src/model.py index bd073156..dfaebf61 100644 --- a/src/model.py +++ b/src/model.py @@ -101,7 +101,7 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): "optional": True, }, { - "name": "return_token_ids", + "name": "return_num_token_ids", "data_type": "TYPE_BOOL", "dims": [1], "optional": True, @@ -111,7 +111,7 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): {"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}, {"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]}, {"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]}, - {"name": "token_ids", "data_type": "TYPE_INT64", "dims": [-1, -1]}, + {"name": "num_token_ids", "data_type": "TYPE_UINT32", "dims": [-1]}, ] # Collect input and output names from the provided model config. @@ -348,11 +348,11 @@ def _get_input_tensors(self, request): else: parameters = request.parameters() - # return_finish_reason, return_cumulative_logprob, return_token_ids + # return_finish_reason, return_cumulative_logprob, return_num_token_ids additional_outputs = { "return_finish_reason": None, "return_cumulative_logprob": None, - "return_token_ids": None, + "return_num_token_ids": None, } for tensor_name in additional_outputs.keys(): tensor = pb_utils.get_input_tensor_by_name(request, tensor_name) @@ -467,8 +467,8 @@ def _create_response( ) ) - # token_ids - if additional_outputs["return_token_ids"]: + # num_token_ids + if additional_outputs["return_num_token_ids"]: if prev_request_output is None: # this is the first response prev_lens = [0] * len(request_output.outputs) @@ -478,12 +478,14 @@ def _create_response( len(prev_output.token_ids) for prev_output in prev_request_output.outputs ] - token_ids = [ - output.token_ids[prev_len:] + num_token_ids = [ + (len(output.token_ids) - prev_len) for output, prev_len in zip(request_output.outputs, prev_lens) ] output_tensors.append( - pb_utils.Tensor("token_ids", np.asarray(token_ids, dtype=np.int64)) + pb_utils.Tensor( + "num_token_ids", np.asarray(num_token_ids, dtype=np.uint32) + ) ) return pb_utils.InferenceResponse(output_tensors=output_tensors) From dae3c132b25541f82ffe37e5756d9cfb72c4fa22 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:30:45 -0800 Subject: [PATCH 12/14] Rename num_token_ids to num_output_tokens --- .../additional_outputs_test.py | 34 +++++++++++-------- docs/additional_outputs.md | 6 ++-- src/model.py | 16 ++++----- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/ci/L0_additional_outputs_vllm/additional_outputs_test.py b/ci/L0_additional_outputs_vllm/additional_outputs_test.py index a8dfb24d..5a8eefbd 100644 --- a/ci/L0_additional_outputs_vllm/additional_outputs_test.py +++ b/ci/L0_additional_outputs_vllm/additional_outputs_test.py @@ -44,7 +44,7 @@ def _get_inputs( sampling_parameters=None, return_finish_reason=None, return_cumulative_logprob=None, - return_num_token_ids=None, + return_num_output_tokens=None, ): inputs = [] @@ -76,9 +76,13 @@ def _get_inputs( np.array([return_cumulative_logprob], dtype=bool) ) - if return_num_token_ids is not None: - inputs.append(grpcclient.InferInput("return_num_token_ids", [1], "BOOL")) - inputs[-1].set_data_from_numpy(np.array([return_num_token_ids], dtype=bool)) + if return_num_output_tokens is not None: + inputs.append( + grpcclient.InferInput("return_num_output_tokens", [1], "BOOL") + ) + inputs[-1].set_data_from_numpy( + np.array([return_num_output_tokens], dtype=bool) + ) return inputs @@ -131,15 +135,15 @@ def _assert_cumulative_logprob(self, return_cumulative_logprob): assert cumulative_logprob != prev_cumulative_logprob prev_cumulative_logprob = cumulative_logprob - def _assert_num_token_ids(self, return_num_token_ids): + def _assert_num_output_tokens(self, return_num_output_tokens): for response in self._responses: result, error = response["result"], response["error"] assert error is None - num_token_ids_np = result.as_numpy(name="num_token_ids") - if return_num_token_ids is None or return_num_token_ids == False: - assert num_token_ids_np is None + num_output_tokens_np = result.as_numpy(name="num_output_tokens") + if return_num_output_tokens is None or return_num_output_tokens == False: + assert num_output_tokens_np is None continue - num_token_ids = num_token_ids_np[0].astype(int) + num_output_tokens = num_output_tokens_np[0].astype(int) # TODO: vLLM may return token ids identical to the previous one when # streaming, for example: # @@ -156,19 +160,19 @@ def _assert_num_token_ids(self, return_num_token_ids): # curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48]) # # If this is no longer the case in a future release, change the assert - # to assert num_token_ids > 0. - assert num_token_ids >= 0 + # to assert num_output_tokens > 0. + assert num_output_tokens >= 0 @pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("return_finish_reason", [None, True, False]) @pytest.mark.parametrize("return_cumulative_logprob", [None, True, False]) - @pytest.mark.parametrize("return_num_token_ids", [None, True, False]) + @pytest.mark.parametrize("return_num_output_tokens", [None, True, False]) def test_additional_outputs( self, stream, return_finish_reason, return_cumulative_logprob, - return_num_token_ids, + return_num_output_tokens, ): inputs = self._get_inputs( self._prompt, @@ -176,10 +180,10 @@ def test_additional_outputs( sampling_parameters=self._sampling_parameters, return_finish_reason=return_finish_reason, return_cumulative_logprob=return_cumulative_logprob, - return_num_token_ids=return_num_token_ids, + return_num_output_tokens=return_num_output_tokens, ) self._llm_infer(inputs) self._assert_text_output_valid() self._assert_finish_reason(return_finish_reason) self._assert_cumulative_logprob(return_cumulative_logprob) - self._assert_num_token_ids(return_num_token_ids) + self._assert_num_output_tokens(return_num_output_tokens) diff --git a/docs/additional_outputs.md b/docs/additional_outputs.md index dcca0dc4..2deb6a9d 100644 --- a/docs/additional_outputs.md +++ b/docs/additional_outputs.md @@ -59,7 +59,7 @@ point value will be sent on the `cumulative_logprob` output tensor. Supported since r24.11. -### Number of token IDs +### Number of Output Tokens The number of token IDs of the generated output text sent on this response. It is the difference in length of the token IDs generated from the last response to @@ -68,8 +68,8 @@ presumed to be zero. See [here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L21) for more details on the token IDs of the generated output text. -To enable, set `return_num_token_ids` input tensor to `True`. The unsigned -integer value will be sent on the `num_token_ids` output tensor. +To enable, set `return_num_output_tokens` input tensor to `True`. The unsigned +integer value will be sent on the `num_output_tokens` output tensor. Supported since r24.11. diff --git a/src/model.py b/src/model.py index dfaebf61..d37ab689 100644 --- a/src/model.py +++ b/src/model.py @@ -101,7 +101,7 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): "optional": True, }, { - "name": "return_num_token_ids", + "name": "return_num_output_tokens", "data_type": "TYPE_BOOL", "dims": [1], "optional": True, @@ -111,7 +111,7 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): {"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}, {"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]}, {"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]}, - {"name": "num_token_ids", "data_type": "TYPE_UINT32", "dims": [-1]}, + {"name": "num_output_tokens", "data_type": "TYPE_UINT32", "dims": [-1]}, ] # Collect input and output names from the provided model config. @@ -348,11 +348,11 @@ def _get_input_tensors(self, request): else: parameters = request.parameters() - # return_finish_reason, return_cumulative_logprob, return_num_token_ids + # return_finish_reason, return_cumulative_logprob, return_num_output_tokens additional_outputs = { "return_finish_reason": None, "return_cumulative_logprob": None, - "return_num_token_ids": None, + "return_num_output_tokens": None, } for tensor_name in additional_outputs.keys(): tensor = pb_utils.get_input_tensor_by_name(request, tensor_name) @@ -467,8 +467,8 @@ def _create_response( ) ) - # num_token_ids - if additional_outputs["return_num_token_ids"]: + # num_output_tokens + if additional_outputs["return_num_output_tokens"]: if prev_request_output is None: # this is the first response prev_lens = [0] * len(request_output.outputs) @@ -478,13 +478,13 @@ def _create_response( len(prev_output.token_ids) for prev_output in prev_request_output.outputs ] - num_token_ids = [ + num_output_tokens = [ (len(output.token_ids) - prev_len) for output, prev_len in zip(request_output.outputs, prev_lens) ] output_tensors.append( pb_utils.Tensor( - "num_token_ids", np.asarray(num_token_ids, dtype=np.uint32) + "num_output_tokens", np.asarray(num_output_tokens, dtype=np.uint32) ) ) From ccb3323d8fc25db300f7b606f6ca1cdf2955b06d Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:47:46 -0800 Subject: [PATCH 13/14] [chore] Fix pre-commit on utils/metrics.py --- src/utils/metrics.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 0504eef9..48b77a2c 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -34,6 +34,7 @@ from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets from vllm.version import __version__ as _VLLM_VERSION + class TritonMetrics: def __init__(self, labels: List[str], max_model_len: int): # Initialize metric families @@ -163,9 +164,11 @@ def __init__(self, labels: List[str], max_model_len: int): ) ) if _VLLM_VERSION < "0.6.3": - self.histogram_best_of_request = self.histogram_best_of_request_family.Metric( - labels=labels, - buckets=[1, 2, 5, 10, 20], + self.histogram_best_of_request = ( + self.histogram_best_of_request_family.Metric( + labels=labels, + buckets=[1, 2, 5, 10, 20], + ) ) self.histogram_n_request = self.histogram_n_request_family.Metric( labels=labels, @@ -254,7 +257,9 @@ def log(self, stats: VllmStats) -> None: (self.metrics.histogram_n_request, stats.n_requests), ] if _VLLM_VERSION < "0.6.3": - histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests)) + histogram_metrics.append( + (self.metrics.histogram_best_of_request, stats.best_of_requests) + ) for metric, data in counter_metrics: self._log_counter(metric, data) for metric, data in histogram_metrics: From 00aa413b19d7ed3fedab1df50bdb198a1e2e5c94 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Mon, 25 Nov 2024 16:40:32 -0800 Subject: [PATCH 14/14] [docs] Update targeted release version --- docs/additional_outputs.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/additional_outputs.md b/docs/additional_outputs.md index 2deb6a9d..5c103e89 100644 --- a/docs/additional_outputs.md +++ b/docs/additional_outputs.md @@ -46,7 +46,7 @@ for more details. To enable, set `return_finish_reason` input tensor to `True`. The reason will be sent as a string on the `finish_reason` output tensor. -Supported since r24.11. +Supported since r24.12. ### Cumulative Log Probabilities @@ -57,7 +57,7 @@ for more details. To enable, set `return_cumulative_logprob` input tensor to `True`. The floating point value will be sent on the `cumulative_logprob` output tensor. -Supported since r24.11. +Supported since r24.12. ### Number of Output Tokens @@ -71,7 +71,7 @@ for more details on the token IDs of the generated output text. To enable, set `return_num_output_tokens` input tensor to `True`. The unsigned integer value will be sent on the `num_output_tokens` output tensor. -Supported since r24.11. +Supported since r24.12. ## Examples