diff --git a/.gitignore b/.gitignore index 68bc17f9..9d4769c9 100644 --- a/.gitignore +++ b/.gitignore @@ -152,6 +152,15 @@ dmypy.json # Cython debug symbols cython_debug/ +# Test result files +Miniconda* +miniconda +vllm_env.tar.gz +triton_python_backend_stub +python_backend +*results.txt +*.log + # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore diff --git a/ci/L0_backend_vllm/test.sh b/ci/L0_backend_vllm/test.sh new file mode 100755 index 00000000..ad559c80 --- /dev/null +++ b/ci/L0_backend_vllm/test.sh @@ -0,0 +1,111 @@ +#!/bin/bash +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +source ../common/util.sh + +TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"} +SERVER=${TRITON_DIR}/bin/tritonserver +BACKEND_DIR=${TRITON_DIR}/backends +SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --model-control-mode=explicit --log-verbose=1" +SERVER_LOG="./vllm_backend_server.log" +CLIENT_LOG="./vllm_backend_client.log" +TEST_RESULT_FILE='test_results.txt' +CLIENT_PY="./vllm_backend_test.py" +EXPECTED_NUM_TESTS=1 + +mkdir -p models/vllm_opt/1/ +cp ../qa_models/vllm_opt/model.json models/vllm_opt/1/ +cp ../qa_models/vllm_opt/config.pbtxt models/vllm_opt + +mkdir -p models/add_sub/1/ +cp ../qa_models/add_sub/model.py models/add_sub/1/ +cp ../qa_models/add_sub/config.pbtxt models/add_sub + +pip3 install tritonclient +pip3 install grpcio + +RET=0 + +run_server +if [ "$SERVER_PID" == "0" ]; then + cat $SERVER_LOG + echo -e "\n***\n*** Failed to start $SERVER\n***" + exit 1 +fi + +set +e +python3 -m unittest -v $CLIENT_PY > $CLIENT_LOG 2>&1 + +if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***" + RET=1 +else + check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification FAILED.\n***" + RET=1 + fi +fi +set -e + +kill $SERVER_PID +wait $SERVER_PID + +# Test Python backend cmdline parameters are propagated to vllm backend +SERVER_ARGS="--model-repository=`pwd`/models --backend-directory=${BACKEND_DIR} --backend-config=python,default-max-batch-size=8" +SERVER_LOG="./vllm_test_cmdline_server.log" + +run_server +if [ "$SERVER_PID" == "0" ]; then + cat $SERVER_LOG + echo -e "\n***\n*** Failed to start $SERVER\n***" + exit 1 +fi + +kill $SERVER_PID +wait $SERVER_PID + + +COUNT=$(grep -c "default-max-batch-size\":\"8" "$SERVER_LOG") +if [[ "$COUNT" -ne 2 ]]; then + echo "Cmdline parameters verification Failed" +fi + + +rm -rf "./models" + +if [ $RET -eq 1 ]; then + cat $CLIENT_LOG + cat $SERVER_LOG + echo -e "\n***\n*** vLLM test FAILED. \n***" +else + echo -e "\n***\n*** vLLM test PASSED. \n***" +fi + +exit $RET diff --git a/ci/L0_backend_vllm/vllm_backend_test.py b/ci/L0_backend_vllm/vllm_backend_test.py new file mode 100755 index 00000000..e1839947 --- /dev/null +++ b/ci/L0_backend_vllm/vllm_backend_test.py @@ -0,0 +1,176 @@ +#!/bin/bash +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import queue +import sys +import unittest +from functools import partial + +import numpy as np +import tritonclient.grpc as grpcclient +from tritonclient.utils import * + +sys.path.append("../common") +from test_util import TestResultCollector + + +class UserData: + def __init__(self): + self._completed_requests = queue.Queue() + + +def callback(user_data, result, error): + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result) + + +class VLLMTritonBackendTest(TestResultCollector): + def setUp(self): + self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001") + self.vllm_model_name = "vllm_opt" + self.python_model_name = "add_sub" + + def test_vllm_triton_backend(self): + # Load both vllm and add_sub models + self.triton_client.load_model(self.vllm_model_name) + self.assertTrue(self.triton_client.is_model_ready(self.vllm_model_name)) + self.triton_client.load_model(self.python_model_name) + self.assertTrue(self.triton_client.is_model_ready(self.python_model_name)) + + # Unload vllm model and test add_sub model + self.triton_client.unload_model(self.vllm_model_name) + self.assertFalse(self.triton_client.is_model_ready(self.vllm_model_name)) + self._test_python_model() + + # Load vllm model and unload add_sub model + self.triton_client.load_model(self.vllm_model_name) + self.triton_client.unload_model(self.python_model_name) + self.assertFalse(self.triton_client.is_model_ready(self.python_model_name)) + + # Test vllm model and unload vllm model + self._test_vllm_model(send_parameters_as_tensor=True) + self._test_vllm_model(send_parameters_as_tensor=False) + self.triton_client.unload_model(self.vllm_model_name) + + def _test_vllm_model(self, send_parameters_as_tensor): + user_data = UserData() + stream = False + prompts = [ + "The most dangerous animal is", + "The capital of France is", + "The future of AI is", + ] + number_of_vllm_reqs = len(prompts) + sampling_parameters = {"temperature": "0.1", "top_p": "0.95"} + + self.triton_client.start_stream(callback=partial(callback, user_data)) + for i in range(number_of_vllm_reqs): + inputs, outputs = self._create_vllm_request_data( + prompts[i], stream, sampling_parameters, send_parameters_as_tensor + ) + self.triton_client.async_stream_infer( + model_name=self.vllm_model_name, + request_id=str(i), + inputs=inputs, + outputs=outputs, + parameters=sampling_parameters, + ) + + for i in range(number_of_vllm_reqs): + result = user_data._completed_requests.get() + self.assertIsNot(type(result), InferenceServerException) + + output = result.as_numpy("TEXT") + self.assertIsNotNone(output) + + self.triton_client.stop_stream() + + def _test_python_model(self): + shape = [4] + input0_data = np.random.rand(*shape).astype(np.float32) + input1_data = np.random.rand(*shape).astype(np.float32) + + inputs = [ + grpcclient.InferInput( + "INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype) + ), + grpcclient.InferInput( + "INPUT1", input1_data.shape, np_to_triton_dtype(input1_data.dtype) + ), + ] + + inputs[0].set_data_from_numpy(input0_data) + inputs[1].set_data_from_numpy(input1_data) + + outputs = [ + grpcclient.InferRequestedOutput("OUTPUT0"), + grpcclient.InferRequestedOutput("OUTPUT1"), + ] + + response = self.triton_client.infer( + self.python_model_name, inputs, request_id="10", outputs=outputs + ) + self.assertTrue( + np.allclose(input0_data + input1_data, response.as_numpy("OUTPUT0")) + ) + self.assertTrue( + np.allclose(input0_data - input1_data, response.as_numpy("OUTPUT1")) + ) + + def _create_vllm_request_data( + self, prompt, stream, sampling_parameters, send_parameters_as_tensor + ): + inputs = [] + + prompt_data = np.array([prompt.encode("utf-8")], dtype=np.object_) + inputs.append(grpcclient.InferInput("PROMPT", [1], "BYTES")) + inputs[-1].set_data_from_numpy(prompt_data) + + stream_data = np.array([stream], dtype=bool) + inputs.append(grpcclient.InferInput("STREAM", [1], "BOOL")) + inputs[-1].set_data_from_numpy(stream_data) + + if send_parameters_as_tensor: + sampling_parameters_data = np.array( + [json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_ + ) + inputs.append(grpcclient.InferInput("SAMPLING_PARAMETERS", [1], "BYTES")) + inputs[-1].set_data_from_numpy(sampling_parameters_data) + + outputs = [grpcclient.InferRequestedOutput("TEXT")] + + return inputs, outputs + + def tearDown(self): + self.triton_client.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/ci/common/test_util.py b/ci/common/test_util.py new file mode 100755 index 00000000..5ba21b37 --- /dev/null +++ b/ci/common/test_util.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +# Copyright 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import unittest + + +class TestResultCollector(unittest.TestCase): + # TestResultCollector stores test result and prints it to stdout. In order + # to use this class, unit tests must inherit this class. Use + # `check_test_results` bash function from `common/util.sh` to verify the + # expected number of tests produced by this class + + @classmethod + def setResult(cls, total, errors, failures): + cls.total, cls.errors, cls.failures = total, errors, failures + + @classmethod + def tearDownClass(cls): + # this method is called when all the unit tests in a class are + # finished. + json_res = {"total": cls.total, "errors": cls.errors, "failures": cls.failures} + with open("test_results.txt", "w+") as f: + f.write(json.dumps(json_res)) + + def run(self, result=None): + # result argument stores the accumulative test results + test_result = super().run(result) + total = test_result.testsRun + errors = len(test_result.errors) + failures = len(test_result.failures) + self.setResult(total, errors, failures) diff --git a/ci/common/util.sh b/ci/common/util.sh new file mode 100755 index 00000000..18d6a056 --- /dev/null +++ b/ci/common/util.sh @@ -0,0 +1,132 @@ +#!/bin/bash +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +SERVER_IPADDR=${TRITONSERVER_IPADDR:=localhost} +SERVER_LOG=${SERVER_LOG:=./server.log} +SERVER_TIMEOUT=${SERVER_TIMEOUT:=120} +SERVER_LD_PRELOAD=${SERVER_LD_PRELOAD:=""} + +# Run inference server. Return once server's health endpoint shows +# ready or timeout expires. Sets SERVER_PID to pid of SERVER, or 0 if +# error (including expired timeout) +function run_server () { + SERVER_PID=0 + + if [ -z "$SERVER" ]; then + echo "=== SERVER must be defined" + return + fi + + if [ ! -f "$SERVER" ]; then + echo "=== $SERVER does not exist" + return + fi + + if [ -z "$SERVER_LD_PRELOAD" ]; then + echo "=== Running $SERVER $SERVER_ARGS" + else + echo "=== Running LD_PRELOAD=$SERVER_LD_PRELOAD $SERVER $SERVER_ARGS" + fi + + LD_PRELOAD=$SERVER_LD_PRELOAD:${LD_PRELOAD} $SERVER $SERVER_ARGS > $SERVER_LOG 2>&1 & + SERVER_PID=$! + + wait_for_server_ready $SERVER_PID $SERVER_TIMEOUT + if [ "$WAIT_RET" != "0" ]; then + # Get further debug information about server startup failure + gdb_helper || true + + # Cleanup + kill $SERVER_PID > /dev/null 2>&1 || true + SERVER_PID=0 + fi +} + +# Wait until server health endpoint shows ready. Sets WAIT_RET to 0 on +# success, 1 on failure +function wait_for_server_ready() { + local spid="$1"; shift + local wait_time_secs="${1:-30}"; shift + + WAIT_RET=0 + + local wait_secs=$wait_time_secs + until test $wait_secs -eq 0 ; do + if ! kill -0 $spid > /dev/null 2>&1; then + echo "=== Server not running." + WAIT_RET=1 + return + fi + + sleep 1; + + set +e + code=`curl -s -w %{http_code} ${SERVER_IPADDR}:8000/v2/health/ready` + set -e + if [ "$code" == "200" ]; then + return + fi + + ((wait_secs--)); + done + + echo "=== Timeout $wait_time_secs secs. Server not ready." + WAIT_RET=1 +} + +# Check Python unittest results. +function check_test_results () { + local log_file=$1 + local expected_num_tests=$2 + + if [[ -z "$expected_num_tests" ]]; then + echo "=== expected number of tests must be defined" + return 1 + fi + + num_failures=`cat $log_file | grep -E ".*total.*errors.*failures.*" | tail -n 1 | jq .failures` + num_tests=`cat $log_file | grep -E ".*total.*errors.*failures.*" | tail -n 1 | jq .total` + num_errors=`cat $log_file | grep -E ".*total.*errors.*failures.*" | tail -n 1 | jq .errors` + + # Number regular expression + re='^[0-9]+$' + + if [[ $? -ne 0 ]] || ! [[ $num_failures =~ $re ]] || ! [[ $num_tests =~ $re ]] || \ + ! [[ $num_errors =~ $re ]]; then + cat $log_file + echo -e "\n***\n*** Test Failed: unable to parse test results\n***" >> $log_file + return 1 + fi + if [[ $num_errors != "0" ]] || [[ $num_failures != "0" ]] || [[ $num_tests -ne $expected_num_tests ]]; then + cat $log_file + echo -e "\n***\n*** Test Failed: Expected $expected_num_tests test(s), $num_tests test(s) executed, $num_errors test(s) had error, and $num_failures test(s) failed. \n***" >> $log_file + return 1 + fi + + return 0 +} diff --git a/ci/qa_models/add_sub/config.pbtxt b/ci/qa_models/add_sub/config.pbtxt new file mode 100644 index 00000000..b1f04387 --- /dev/null +++ b/ci/qa_models/add_sub/config.pbtxt @@ -0,0 +1,59 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "add_sub" +backend: "python" + +input [ + { + name: "INPUT0" + data_type: TYPE_FP32 + dims: [ 4 ] + } +] +input [ + { + name: "INPUT1" + data_type: TYPE_FP32 + dims: [ 4 ] + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_FP32 + dims: [ 4 ] + } +] +output [ + { + name: "OUTPUT1" + data_type: TYPE_FP32 + dims: [ 4 ] + } +] + +instance_group [{ kind: KIND_CPU }] diff --git a/ci/qa_models/add_sub/model.py b/ci/qa_models/add_sub/model.py new file mode 100644 index 00000000..882a5a73 --- /dev/null +++ b/ci/qa_models/add_sub/model.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json + +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + """This function allows the model to initialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + + self.model_config = model_config = json.loads(args["model_config"]) + + # Get OUTPUT0 configuration + output0_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT0") + + # Get OUTPUT1 configuration + output1_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT1") + + # Convert Triton types to numpy types + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config["data_type"] + ) + self.output1_dtype = pb_utils.triton_string_to_numpy( + output1_config["data_type"] + ) + + def execute(self, requests): + """This function is called when an inference request is made + for this model. + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + output0_dtype = self.output0_dtype + output1_dtype = self.output1_dtype + responses = [] + + for request in requests: + in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") + in_1 = pb_utils.get_input_tensor_by_name(request, "INPUT1") + + out_0, out_1 = ( + in_0.as_numpy() + in_1.as_numpy(), + in_0.as_numpy() - in_1.as_numpy(), + ) + + # Create output tensors. + out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0.astype(output0_dtype)) + out_tensor_1 = pb_utils.Tensor("OUTPUT1", out_1.astype(output1_dtype)) + + # Create InferenceResponse. + inference_response = pb_utils.InferenceResponse( + output_tensors=[out_tensor_0, out_tensor_1] + ) + responses.append(inference_response) + + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded.""" + print("Cleaning up...") diff --git a/ci/qa_models/vllm_opt/config.pbtxt b/ci/qa_models/vllm_opt/config.pbtxt new file mode 100644 index 00000000..0f4cedb8 --- /dev/null +++ b/ci/qa_models/vllm_opt/config.pbtxt @@ -0,0 +1,67 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "vllm_opt" +backend: "vllm" +max_batch_size: 0 + +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "PROMPT" + data_type: TYPE_STRING + dims: [ 1 ] + }, + { + name: "STREAM" + data_type: TYPE_BOOL + dims: [ 1 ] + }, + { + name: "SAMPLING_PARAMETERS" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + } +] + +output [ + { + name: "TEXT" + data_type: TYPE_STRING + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_MODEL + } +] diff --git a/ci/qa_models/vllm_opt/model.json b/ci/qa_models/vllm_opt/model.json new file mode 100644 index 00000000..083aebd5 --- /dev/null +++ b/ci/qa_models/vllm_opt/model.json @@ -0,0 +1,4 @@ +{ + "model":"facebook/opt-125m", + "disable_log_requests": "true" +} diff --git a/src/model.py b/src/model.py index 827c106d..0313da9d 100644 --- a/src/model.py +++ b/src/model.py @@ -56,7 +56,7 @@ def initialize(self, args): ), "vLLM Triton backend must be configured to use decoupled model transaction policy" engine_args_filepath = os.path.join( - pb_utils.get_model_dir(), _VLLM_ENGINE_ARGS_FILENAME + pb_utils.get_model_dir(), _VLLM_ENGINE_ARGS_FILENAME ) assert os.path.isfile( engine_args_filepath @@ -132,7 +132,13 @@ def get_sampling_params_dict(self, params_json): if k in params_dict: params_dict[k] = bool(params_dict[k]) - float_keys = ["frequency_penalty", "length_penalty", "presence_penalty", "temperature", "top_p"] + float_keys = [ + "frequency_penalty", + "length_penalty", + "presence_penalty", + "temperature", + "top_p", + ] for k in float_keys: if k in params_dict: params_dict[k] = float(params_dict[k]) @@ -175,7 +181,9 @@ async def generate(self, request): # BLS. Provide an optional mechanism to receive serialized # parameters as an input tensor until support is added - parameters_input_tensor = pb_utils.get_input_tensor_by_name(request, "SAMPLING_PARAMETERS") + parameters_input_tensor = pb_utils.get_input_tensor_by_name( + request, "SAMPLING_PARAMETERS" + ) if parameters_input_tensor: parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") else: