diff --git a/ci/L0_accuracy_test/test.sh b/ci/L0_accuracy_test/test.sh index 13c804c7..e72a0859 100644 --- a/ci/L0_accuracy_test/test.sh +++ b/ci/L0_accuracy_test/test.sh @@ -77,9 +77,9 @@ rm -rf models/ if [ $RET -eq 1 ]; then cat $CLIENT_LOG cat $SERVER_LOG - echo -e "\n***\n*** vLLM test FAILED. \n***" + echo -e "\n***\n*** Accuracy test FAILED. \n***" else - echo -e "\n***\n*** vLLM test PASSED. \n***" + echo -e "\n***\n*** Accuracy test PASSED. \n***" fi exit $RET diff --git a/ci/L0_accuracy_test/vllm_accuracy_test.py b/ci/L0_accuracy_test/vllm_accuracy_test.py index f2d69a4e..398d7ad3 100644 --- a/ci/L0_accuracy_test/vllm_accuracy_test.py +++ b/ci/L0_accuracy_test/vllm_accuracy_test.py @@ -24,7 +24,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import queue import sys import unittest from functools import partial @@ -39,19 +38,7 @@ import asyncio sys.path.append("../common") -from test_util import TestResultCollector, create_vllm_request - - -class UserData: - def __init__(self): - self._completed_requests = queue.Queue() - - -def callback(user_data, result, error): - if error: - user_data._completed_requests.put(error) - else: - user_data._completed_requests.put(result) +from test_util import TestResultCollector, create_vllm_request, UserData, callback async def generate_python_vllm_output(prompt, llm_engine): @@ -78,7 +65,7 @@ def setUp(self): self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001") vllm_engine_config = { "model": "facebook/opt-125m", - "gpu_memory_utilization": 0.25, + "gpu_memory_utilization": 0.3, } self.llm_engine = AsyncLLMEngine.from_engine_args( diff --git a/ci/L0_backend_vllm/vllm_backend_test.py b/ci/L0_backend_vllm/vllm_backend_test.py index cdc6d43c..f666d4b7 100755 --- a/ci/L0_backend_vllm/vllm_backend_test.py +++ b/ci/L0_backend_vllm/vllm_backend_test.py @@ -24,7 +24,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import queue import sys import unittest from functools import partial @@ -34,19 +33,7 @@ from tritonclient.utils import * sys.path.append("../common") -from test_util import TestResultCollector, create_vllm_request - - -class UserData: - def __init__(self): - self._completed_requests = queue.Queue() - - -def callback(user_data, result, error): - if error: - user_data._completed_requests.put(error) - else: - user_data._completed_requests.put(result) +from test_util import TestResultCollector, create_vllm_request, UserData, callback class VLLMTritonBackendTest(TestResultCollector): diff --git a/ci/L0_multi_gpu/test.sh b/ci/L0_multi_gpu/test.sh new file mode 100644 index 00000000..5879da41 --- /dev/null +++ b/ci/L0_multi_gpu/test.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +source ../common/util.sh + +TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"} +SERVER=${TRITON_DIR}/bin/tritonserver +BACKEND_DIR=${TRITON_DIR}/backends +SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --log-verbose=1" +SERVER_LOG="./vllm_multi_gpu_test_server.log" +CLIENT_LOG="./vllm_multi_gpu_test_client.log" +TEST_RESULT_FILE='test_results.txt' +CLIENT_PY="./vllm_multi_gpu_test.py" +EXPECTED_NUM_TESTS=1 + +mkdir -p models/vllm_opt/1/ +echo '{"model":"facebook/opt-125m", "disable_log_requests": "true", "gpu_memory_utilization":0.5, "tensor_parallel_size":2}' > models/vllm_opt/1/model.json +cp ../qa_models/vllm_opt/config.pbtxt models/vllm_opt + +pip3 install tritonclient +pip3 install grpcio +pip3 install nvidia-ml-py3 + +RET=0 + +run_server +if [ "$SERVER_PID" == "0" ]; then + cat $SERVER_LOG + echo -e "\n***\n*** Failed to start $SERVER\n***" + exit 1 +fi + +set +e +python3 -m unittest -v $CLIENT_PY > $CLIENT_LOG 2>&1 + +if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***" + RET=1 +else + check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification FAILED.\n***" + RET=1 + fi +fi +set -e + +kill $SERVER_PID +wait $SERVER_PID +rm -rf models/ + +if [ $RET -eq 1 ]; then + cat $CLIENT_LOG + cat $SERVER_LOG + echo -e "\n***\n*** Multi GPU Utilization test FAILED. \n***" +else + echo -e "\n***\n*** Multi GPU Utilization test PASSED. \n***" +fi + +exit $RET diff --git a/ci/L0_multi_gpu/vllm_multi_gpu_test.py b/ci/L0_multi_gpu/vllm_multi_gpu_test.py new file mode 100644 index 00000000..ae4944c7 --- /dev/null +++ b/ci/L0_multi_gpu/vllm_multi_gpu_test.py @@ -0,0 +1,126 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import nvidia_smi +import tritonclient.grpc as grpcclient +from functools import partial +import unittest +from tritonclient.utils import * + +sys.path.append("../common") +from test_util import TestResultCollector, create_vllm_request, UserData, callback + + +class VLLMMultiGPUTest(TestResultCollector): + def setUp(self): + nvidia_smi.nvmlInit() + self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001") + self.vllm_model_name = "vllm_opt" + + def get_gpu_memory_utilization(self, gpu_id): + handle = nvidia_smi.nvmlDeviceGetHandleByIndex(gpu_id) + info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) + return info.used + + def get_available_gpu_ids(self): + device_count = nvidia_smi.nvmlDeviceGetCount() + available_gpus = [] + for gpu_id in range(device_count): + handle = nvidia_smi.nvmlDeviceGetHandleByIndex(gpu_id) + if handle: + available_gpus.append(gpu_id) + return available_gpus + + def test_vllm_multi_gpu_utilization(self): + gpu_ids = self.get_available_gpu_ids() + self.assertGreaterEqual(len(gpu_ids), 2, "Error: Detected single GPU") + + print("=============== Before Loading vLLM Model ===============") + mem_util_before_loading_model = {} + for gpu_id in gpu_ids: + memory_utilization = self.get_gpu_memory_utilization(gpu_id) + print(f"GPU {gpu_id} Memory Utilization: {memory_utilization} bytes") + mem_util_before_loading_model[gpu_id] = memory_utilization + + self.triton_client.load_model(self.vllm_model_name) + self._test_vllm_model() + + print("=============== After Loading vLLM Model ===============") + vllm_model_used_gpus = 0 + for gpu_id in gpu_ids: + memory_utilization = self.get_gpu_memory_utilization(gpu_id) + print(f"GPU {gpu_id} Memory Utilization: {memory_utilization} bytes") + if memory_utilization > mem_util_before_loading_model[gpu_id]: + vllm_model_used_gpus += 1 + + self.assertGreaterEqual(vllm_model_used_gpus, 2) + + def _test_vllm_model(self, send_parameters_as_tensor=True): + user_data = UserData() + stream = False + prompts = [ + "The most dangerous animal is", + "The capital of France is", + "The future of AI is", + ] + number_of_vllm_reqs = len(prompts) + sampling_parameters = {"temperature": "0.1", "top_p": "0.95"} + + self.triton_client.start_stream(callback=partial(callback, user_data)) + for i in range(number_of_vllm_reqs): + request_data = create_vllm_request( + prompts[i], + i, + stream, + sampling_parameters, + self.vllm_model_name, + send_parameters_as_tensor, + ) + self.triton_client.async_stream_infer( + model_name=self.vllm_model_name, + request_id=request_data["request_id"], + inputs=request_data["inputs"], + outputs=request_data["outputs"], + parameters=sampling_parameters, + ) + + for i in range(number_of_vllm_reqs): + result = user_data._completed_requests.get() + self.assertIsNot(type(result), InferenceServerException) + + output = result.as_numpy("text_output") + self.assertIsNotNone(output) + + self.triton_client.stop_stream() + + def tearDown(self): + nvidia_smi.nvmlShutdown() + self.triton_client.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/ci/L0_stream_enabled/test.sh b/ci/L0_stream_enabled/test.sh index 4eaed7d3..83513cf4 100755 --- a/ci/L0_stream_enabled/test.sh +++ b/ci/L0_stream_enabled/test.sh @@ -77,9 +77,9 @@ rm -rf models/ if [ $RET -eq 1 ]; then cat $CLIENT_LOG cat $SERVER_LOG - echo -e "\n***\n*** vLLM test FAILED. \n***" + echo -e "\n***\n*** Strem Enabled test FAILED. \n***" else - echo -e "\n***\n*** vLLM test PASSED. \n***" + echo -e "\n***\n*** Strem Enabled test PASSED. \n***" fi exit $RET diff --git a/ci/common/test_util.py b/ci/common/test_util.py index a5bb6025..2f760884 100755 --- a/ci/common/test_util.py +++ b/ci/common/test_util.py @@ -28,7 +28,7 @@ import json import unittest - +import queue import numpy as np import tritonclient.grpc as grpcclient @@ -119,3 +119,15 @@ def create_vllm_request( "request_id": str(request_id), "parameters": sampling_parameters, } + + +class UserData: + def __init__(self): + self._completed_requests = queue.Queue() + + +def callback(user_data, result, error): + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result)