diff --git a/README.md b/README.md index 13953f58..802f4f4c 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,9 @@ container with the following commands: ``` mkdir -p /opt/tritonserver/backends/vllm -wget -P /opt/tritonserver/backends/vllm https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/src/model.py +git clone https://github.com/triton-inference-server/vllm_backend.git /opt/tritonserver/backends/vllm/vllm_backend +cp -r /opt/tritonserver/backends/vllm/vllm_backend/src/* /opt/tritonserver/backends/vllm +rm -rf /opt/tritonserver/backends/vllm/vllm_backend ``` ## Using the vLLM Backend diff --git a/ci/L0_backend_vllm/metrics_test/test.sh b/ci/L0_backend_vllm/metrics_test/test.sh new file mode 100755 index 00000000..6509b13c --- /dev/null +++ b/ci/L0_backend_vllm/metrics_test/test.sh @@ -0,0 +1,98 @@ +#!/bin/bash +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +source ../../common/util.sh + +TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"} +SERVER=${TRITON_DIR}/bin/tritonserver +BACKEND_DIR=${TRITON_DIR}/backends +SERVER_ARGS="--model-repository=$(pwd)/models --backend-directory=${BACKEND_DIR} --model-control-mode=explicit --load-model=vllm_opt --log-verbose=1" +SERVER_LOG="./vllm_metrics_server.log" +CLIENT_LOG="./vllm_metrics_client.log" +TEST_RESULT_FILE='test_results.txt' +CLIENT_PY="./vllm_metrics_test.py" +SAMPLE_MODELS_REPO="../../../samples/model_repository" +EXPECTED_NUM_TESTS=1 + +# Helpers ======================================= +function assert_curl_success { + message="${1}" + if [ "$code" != "200" ]; then + cat ./curl.out + echo -e "\n***\n*** ${message} : line ${BASH_LINENO}\n***" + RET=1 + fi +} + +rm -rf models && mkdir -p models +cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt +# `vllm_opt`` model will be loaded on server start and stay loaded throughout +# unittesting. To ensure that vllm's memory profiler will not error out +# on `vllm_load_test` load, we reduce "gpu_memory_utilization" for `vllm_opt`, +# so that at least 60% of GPU memory was available for other models. +sed -i 's/"gpu_memory_utilization": 0.5/"gpu_memory_utilization": 0.4/' models/vllm_opt/1/model.json + +RET=0 + +run_server +if [ "$SERVER_PID" == "0" ]; then + cat $SERVER_LOG + echo -e "\n***\n*** Failed to start $SERVER\n***" + exit 1 +fi + +set +e +python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1 + +if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***" + RET=1 +else + check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification FAILED.\n***" + RET=1 + fi +fi +set -e + +kill $SERVER_PID +wait $SERVER_PID +rm -rf "./models" + +if [ $RET -eq 1 ]; then + cat $CLIENT_LOG + cat $SERVER_LOG + echo -e "\n***\n*** vLLM test FAILED. \n***" +else + echo -e "\n***\n*** vLLM test PASSED. \n***" +fi + +collect_artifacts_from_subdir +exit $RET diff --git a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py new file mode 100644 index 00000000..f6df2340 --- /dev/null +++ b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py @@ -0,0 +1,141 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import re +import sys +import unittest +from functools import partial + +import requests +import tritonclient.grpc as grpcclient +from tritonclient.utils import * + +sys.path.append("../../common") +from test_util import TestResultCollector, UserData, callback, create_vllm_request + + +class VLLMTritonMetricsTest(TestResultCollector): + def setUp(self): + self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001") + self.tritonserver_ipaddr = os.environ.get("TRITONSERVER_IPADDR", "localhost") + self.vllm_model_name = "vllm_opt" + self.prompts = [ + "The most dangerous animal is", + "The capital of France is", + "The future of AI is", + ] + self.sampling_parameters = {"temperature": "0", "top_p": "1"} + + def get_metrics(self): + """ + Store vllm metrics in a dictionary. + """ + r = requests.get(f"http://{self.tritonserver_ipaddr}:8002/metrics") + r.raise_for_status() + + # Regular expression to match the pattern + pattern = r"^(vllm:[^ {]+)(?:{.*})? ([0-9.-]+)$" + vllm_dict = {} + + # Find all matches in the text + matches = re.findall(pattern, r.text, re.MULTILINE) + + for match in matches: + key, value = match + vllm_dict[key] = float(value) if "." in value else int(value) + + return vllm_dict + + def vllm_async_stream_infer( + self, + prompts, + sampling_parameters, + stream, + send_parameters_as_tensor, + model_name, + ): + """ + Helper function to send async stream infer requests to vLLM. + """ + user_data = UserData() + number_of_vllm_reqs = len(prompts) + + self.triton_client.start_stream(callback=partial(callback, user_data)) + for i in range(number_of_vllm_reqs): + request_data = create_vllm_request( + prompts[i], + i, + stream, + sampling_parameters, + model_name, + send_parameters_as_tensor, + ) + self.triton_client.async_stream_infer( + model_name=model_name, + request_id=request_data["request_id"], + inputs=request_data["inputs"], + outputs=request_data["outputs"], + parameters=sampling_parameters, + ) + + for _ in range(number_of_vllm_reqs): + result = user_data._completed_requests.get() + if type(result) is InferenceServerException: + print(result.message()) + self.assertIsNot(type(result), InferenceServerException, str(result)) + + output = result.as_numpy("text_output") + self.assertIsNotNone(output, "`text_output` should not be None") + + self.triton_client.stop_stream() + + def test_vllm_metrics(self): + # Test vLLM metrics + self.vllm_async_stream_infer( + prompts=self.prompts, + sampling_parameters=self.sampling_parameters, + stream=False, + send_parameters_as_tensor=True, + model_name=self.vllm_model_name, + ) + metrics_dict = self.get_metrics() + + self.assertEqual(metrics_dict["vllm:time_to_first_token_seconds_count"], 3) + self.assertTrue( + 0.0001 < metrics_dict["vllm:time_to_first_token_seconds_sum"] < 0.0003 + ) + self.assertEqual(metrics_dict["vllm:time_per_output_token_seconds_count"], 45) + self.assertTrue( + 0.001 <= metrics_dict["vllm:time_per_output_token_seconds_sum"] <= 0.003 + ) + + def tearDown(self): + self.triton_client.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/ci/L0_backend_vllm/test.sh b/ci/L0_backend_vllm/test.sh index 93d065c8..a9f89894 100755 --- a/ci/L0_backend_vllm/test.sh +++ b/ci/L0_backend_vllm/test.sh @@ -26,7 +26,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. RET=0 -SUBTESTS="accuracy_test request_cancellation enabled_stream vllm_backend" +SUBTESTS="accuracy_test request_cancellation enabled_stream vllm_backend metrics_test" python3 -m pip install --upgrade pip && pip3 install tritonclient[grpc] diff --git a/src/model.py b/src/model.py index 3fe7cd1e..f250f86c 100644 --- a/src/model.py +++ b/src/model.py @@ -39,6 +39,8 @@ from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from utils.metrics import VllmStatLogger + _VLLM_ENGINE_ARGS_FILENAME = "model.json" _MULTI_LORA_ARGS_FILENAME = "multi_lora.json" @@ -151,6 +153,14 @@ def init_engine(self): AsyncEngineArgs(**self.vllm_engine_config) ) + # Create vLLM custom Metrics + labels = { + "model": self.args["model_name"], + "version": self.args["model_version"], + } + logger = VllmStatLogger(labels=labels) + self.llm_engine.add_logger("triton", logger) + def setup_lora(self): self.enable_lora = False diff --git a/src/utils/metrics.py b/src/utils/metrics.py new file mode 100644 index 00000000..cb7b4ff4 --- /dev/null +++ b/src/utils/metrics.py @@ -0,0 +1,137 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Dict, List, Union + +import triton_python_backend_utils as pb_utils +from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase +from vllm.engine.metrics import Stats as VllmStats +from vllm.engine.metrics import SupportsMetricsInfo + + +class TritonMetrics: + def __init__(self, labels): + # Initialize metric families + # Iteration stats + self.histogram_time_to_first_token_family = pb_utils.MetricFamily( + name="vllm:time_to_first_token_seconds", + description="Histogram of time to first token in seconds.", + kind=pb_utils.MetricFamily.HISTOGRAM, + ) + self.histogram_time_per_output_token_family = pb_utils.MetricFamily( + name="vllm:time_per_output_token_seconds", + description="Histogram of time per output token in seconds.", + kind=pb_utils.MetricFamily.HISTOGRAM, + ) + + # Initialize metrics + # Iteration stats + self.histogram_time_to_first_token = ( + self.histogram_time_to_first_token_family.Metric( + labels=labels, + buckets=[ + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + ], + ) + ) + self.histogram_time_per_output_token = ( + self.histogram_time_per_output_token_family.Metric( + labels=labels, + buckets=[ + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + ], + ) + ) + + +class VllmStatLogger(VllmStatLoggerBase): + """StatLogger is used as an adapter between vLLM stats collector and Triton metrics provider.""" + + # local_interval not used here. It's for vLLM logs to stdout. + def __init__(self, labels: Dict, local_interval: float = 0) -> None: + # Tracked stats over current local logging interval. + super().__init__(local_interval) + self.metrics = TritonMetrics(labels=labels) + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None: + """Convenience function for logging list to histogram. + + Args: + histogram: A histogram metric instance. + data: A list of int or float data to observe into the histogram metric. + + Returns: + None + """ + for datum in data: + histogram.observe(datum) + + def log(self, stats: VllmStats) -> None: + """Logs tracked stats to triton metrics server every iteration. + + Args: + stats: Created by LLMEngine for use by VllmStatLogger. + + Returns: + None + """ + self._log_histogram( + self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter + ) + self._log_histogram( + self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter, + )