Skip to content

Commit

Permalink
feat: Support sending additional outputs from vLLM inference (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui authored Nov 26, 2024
1 parent 6c066f6 commit ceb5961
Show file tree
Hide file tree
Showing 13 changed files with 610 additions and 135 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ you need to specify a different `shm-region-prefix-name` for each server. See
[here](https://github.com/triton-inference-server/python_backend#running-multiple-instances-of-triton-server)
for more information.

## Additional vLLM outputs

Additional vLLM outputs may be requested optionally on a per-request basis. See
[this docs](docs/additional_outputs.md) for more information.

## Triton Metrics
Starting with the 24.08 release of Triton, users can now obtain specific
vLLM metrics by querying the Triton metrics endpoint (see complete vLLM metrics
Expand Down
189 changes: 189 additions & 0 deletions ci/L0_additional_outputs_vllm/additional_outputs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json

import numpy as np
import pytest
import tritonclient.grpc as grpcclient


class TestAdditionalOutputs:
_grpc_url = "localhost:8001"
_model_name = "vllm_opt"
_sampling_parameters = {"temperature": "0", "top_p": "1"}
_prompt = "In this example,"

def _get_inputs(
self,
prompt,
stream=True,
sampling_parameters=None,
return_finish_reason=None,
return_cumulative_logprob=None,
return_num_output_tokens=None,
):
inputs = []

inputs.append(grpcclient.InferInput("text_input", [1], "BYTES"))
inputs[-1].set_data_from_numpy(
np.array([prompt.encode("utf-8")], dtype=np.object_)
)

inputs.append(grpcclient.InferInput("stream", [1], "BOOL"))
inputs[-1].set_data_from_numpy(np.array([stream], dtype=bool))

if sampling_parameters is not None:
inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES"))
inputs[-1].set_data_from_numpy(
np.array(
[json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_
)
)

if return_finish_reason is not None:
inputs.append(grpcclient.InferInput("return_finish_reason", [1], "BOOL"))
inputs[-1].set_data_from_numpy(np.array([return_finish_reason], dtype=bool))

if return_cumulative_logprob is not None:
inputs.append(
grpcclient.InferInput("return_cumulative_logprob", [1], "BOOL")
)
inputs[-1].set_data_from_numpy(
np.array([return_cumulative_logprob], dtype=bool)
)

if return_num_output_tokens is not None:
inputs.append(
grpcclient.InferInput("return_num_output_tokens", [1], "BOOL")
)
inputs[-1].set_data_from_numpy(
np.array([return_num_output_tokens], dtype=bool)
)

return inputs

def _callback(self, result, error):
self._responses.append({"result": result, "error": error})

def _llm_infer(self, inputs):
self._responses = []
with grpcclient.InferenceServerClient(self._grpc_url) as client:
client.start_stream(self._callback)
client.async_stream_infer(
self._model_name, inputs=inputs, parameters=self._sampling_parameters
)
client.stop_stream()
assert len(self._responses) > 0

def _assert_text_output_valid(self):
text_output = ""
for response in self._responses:
result, error = response["result"], response["error"]
assert error is None
text_output += result.as_numpy(name="text_output")[0].decode("utf-8")
assert len(text_output) > 0, "output is empty"
assert text_output.count(" ") > 4, "output is not a sentence"

def _assert_finish_reason(self, return_finish_reason):
for i in range(len(self._responses)):
result, error = self._responses[i]["result"], self._responses[i]["error"]
assert error is None
finish_reason_np = result.as_numpy(name="finish_reason")
if return_finish_reason is None or return_finish_reason == False:
assert finish_reason_np is None
continue
finish_reason = finish_reason_np[0].decode("utf-8")
if i < len(self._responses) - 1:
assert finish_reason == "None"
else:
assert finish_reason == "length"

def _assert_cumulative_logprob(self, return_cumulative_logprob):
prev_cumulative_logprob = 0.0
for response in self._responses:
result, error = response["result"], response["error"]
assert error is None
cumulative_logprob_np = result.as_numpy(name="cumulative_logprob")
if return_cumulative_logprob is None or return_cumulative_logprob == False:
assert cumulative_logprob_np is None
continue
cumulative_logprob = cumulative_logprob_np[0].astype(float)
assert cumulative_logprob != prev_cumulative_logprob
prev_cumulative_logprob = cumulative_logprob

def _assert_num_output_tokens(self, return_num_output_tokens):
for response in self._responses:
result, error = response["result"], response["error"]
assert error is None
num_output_tokens_np = result.as_numpy(name="num_output_tokens")
if return_num_output_tokens is None or return_num_output_tokens == False:
assert num_output_tokens_np is None
continue
num_output_tokens = num_output_tokens_np[0].astype(int)
# TODO: vLLM may return token ids identical to the previous one when
# streaming, for example:
#
# prev: None
# curr: text=' the', token_ids=array('l', [5])
#
# prev: text=' the', token_ids=array('l', [5, 1385])
# curr: text=' the term', token_ids=array('l', [5, 1385])
#
# prev: text=' the term', token_ids=array('l', [5, 1385, 44])
# curr: text=' the term', token_ids=array('l', [5, 1385, 44])
#
# prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48])
# curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
#
# If this is no longer the case in a future release, change the assert
# to assert num_output_tokens > 0.
assert num_output_tokens >= 0

@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("return_finish_reason", [None, True, False])
@pytest.mark.parametrize("return_cumulative_logprob", [None, True, False])
@pytest.mark.parametrize("return_num_output_tokens", [None, True, False])
def test_additional_outputs(
self,
stream,
return_finish_reason,
return_cumulative_logprob,
return_num_output_tokens,
):
inputs = self._get_inputs(
self._prompt,
stream=stream,
sampling_parameters=self._sampling_parameters,
return_finish_reason=return_finish_reason,
return_cumulative_logprob=return_cumulative_logprob,
return_num_output_tokens=return_num_output_tokens,
)
self._llm_infer(inputs)
self._assert_text_output_valid()
self._assert_finish_reason(return_finish_reason)
self._assert_cumulative_logprob(return_cumulative_logprob)
self._assert_num_output_tokens(return_num_output_tokens)
66 changes: 66 additions & 0 deletions ci/L0_additional_outputs_vllm/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

export CUDA_VISIBLE_DEVICES=0
source ../common/util.sh

pip3 install pytest==8.1.1
pip3 install tritonclient[grpc]

# Prepare Model
rm -rf models vllm_baseline_output.pkl && mkdir -p models
SAMPLE_MODELS_REPO="../../samples/model_repository"
cp -r $SAMPLE_MODELS_REPO/vllm_model models/vllm_opt
sed -i 's/"gpu_memory_utilization": 0.5/"gpu_memory_utilization": 0.3/' models/vllm_opt/1/model.json

RET=0

# Test
SERVER_LOG="vllm_opt.server.log"
SERVER_ARGS="--model-repository=models"
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi
set +e
python3 -m pytest --junitxml=test_additional_outputs.xml -s -v additional_outputs_test.py
if [ $? -ne 0 ]; then
echo -e "\n***\n*** additional_outputs_test FAILED. \n***"
RET=1
fi
set -e
kill $SERVER_PID
wait $SERVER_PID

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
echo -e "\n***\n*** Test FAILED\n***"
fi
exit $RET
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions ci/common/util.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand All @@ -25,7 +25,7 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


SERVER=${SERVER:=/opt/tritonserver/bin/tritonserver}
SERVER_IPADDR=${TRITONSERVER_IPADDR:=localhost}
SERVER_LOG=${SERVER_LOG:=./server.log}
SERVER_TIMEOUT=${SERVER_TIMEOUT:=120}
Expand Down
107 changes: 107 additions & 0 deletions docs/additional_outputs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
<!--
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->

# Additional Outputs from vLLM

The vLLM backend supports sending additional outputs from vLLM on top of the
usual `text_output` when requested.

All additional outputs are disabled by default and they need to be enabled on a
per-request basis. If enabled, the corresponding output tensor will be set for
all responses from the request.

## Supported Additional Outputs

### Finish Reason

The reason why the sequence is finished. See
[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L26)
for more details.

To enable, set `return_finish_reason` input tensor to `True`. The reason will be
sent as a string on the `finish_reason` output tensor.

Supported since r24.12.

### Cumulative Log Probabilities

The cumulative log probability of the generated output text. See
[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L22)
for more details.

To enable, set `return_cumulative_logprob` input tensor to `True`. The floating
point value will be sent on the `cumulative_logprob` output tensor.

Supported since r24.12.

### Number of Output Tokens

The number of token IDs of the generated output text sent on this response. It
is the difference in length of the token IDs generated from the last response to
this response. If this is the first response, the last response length is
presumed to be zero. See
[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L21)
for more details on the token IDs of the generated output text.

To enable, set `return_num_output_tokens` input tensor to `True`. The unsigned
integer value will be sent on the `num_output_tokens` output tensor.

Supported since r24.12.

## Examples

### Add Finish Reason to Outputs

```python
import numpy as np
import tritonclient.grpc as grpcclient

inputs = []

inputs.append(grpcclient.InferInput("text_input", [1], "BYTES"))
inputs[-1].set_data_from_numpy(
np.array(["example prompt".encode("utf-8")], dtype=np.object_)
)

inputs.append(grpcclient.InferInput("return_finish_reason", [1], "BOOL"))
inputs[-1].set_data_from_numpy(np.array([True], dtype=bool))

def callback(result, error):
...
print(result.as_numpy(name="finish_reason"))

with grpcclient.InferenceServerClient("localhost:8001") as client:
client.start_stream(callback)
client.async_stream_infer("vLLM_model_name", inputs=inputs, ...)
client.stop_stream()
```

## Notes

* Enabling additional outputs may impact performance, only add additional
outputs when necessary.
Loading

0 comments on commit ceb5961

Please sign in to comment.