Skip to content

Commit

Permalink
Add exclude_input_in_output option to vllm backend (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
oandreeva-nv authored Mar 1, 2024
1 parent 6f0afff commit 5c03411
Show file tree
Hide file tree
Showing 7 changed files with 300 additions and 38 deletions.
98 changes: 87 additions & 11 deletions ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,37 +34,113 @@
sys.path.append("../../common")
from test_util import AsyncTestResultCollector, create_vllm_request

PROMPTS = ["The most dangerous animal is"]
SAMPLING_PARAMETERS = {"temperature": "0", "top_p": "1"}


class VLLMTritonStreamTest(AsyncTestResultCollector):
async def test_vllm_model_enabled_stream(self):
async def _test_vllm_model(
self,
prompts=PROMPTS,
sampling_parameters=SAMPLING_PARAMETERS,
stream=True,
exclude_input_in_output=None,
expected_output=None,
expect_error=False,
):
async with grpcclient.InferenceServerClient(
url="localhost:8001"
) as triton_client:
model_name = "vllm_opt"
stream = True
prompts = [
"The most dangerous animal is",
"The future of AI is",
]
sampling_parameters = {"temperature": "0", "top_p": "1"}

async def request_iterator():
for i, prompt in enumerate(prompts):
yield create_vllm_request(
prompt, i, stream, sampling_parameters, model_name
prompt,
i,
stream,
sampling_parameters,
model_name,
exclude_input_in_output=exclude_input_in_output,
)

response_iterator = triton_client.stream_infer(
inputs_iterator=request_iterator()
)

final_response = []
async for response in response_iterator:
result, error = response
self.assertIsNone(error, str(error))
self.assertIsNotNone(result, str(result))
if expect_error:
self.assertIsInstance(error, InferenceServerException)
self.assertEquals(
error.message(),
"Error generating stream: When streaming, `exclude_input_in_output` = False is not allowed.",
error,
)
return

self.assertIsNone(error, error)
self.assertIsNotNone(result, result)
output = result.as_numpy("text_output")
self.assertIsNotNone(output, "`text_output` should not be None")
final_response.append(str(output[0], encoding="utf-8"))
if expected_output is not None:
self.assertEqual(
final_response,
expected_output,
'Expected to receive the following response: "{}",\
but received "{}".'.format(
expected_output, final_response
),
)

async def test_vllm_model_enabled_stream(self):
"""
Verifying that request with multiple prompts runs successfully.
"""
prompts = [
"The most dangerous animal is",
"The future of AI is",
]

await self._test_vllm_model(prompts=prompts)

async def test_vllm_model_enabled_stream_exclude_input_in_output_default(self):
"""
Verifying that streaming request returns only generated diffs, which
is default behaviour for `stream=True`.
"""
expected_output = [
" the",
" one",
" that",
" is",
" most",
" likely",
" to",
" be",
" killed",
" by",
" a",
" car",
".",
"\n",
"I",
"'m",
]
await self._test_vllm_model(expected_output=expected_output)

async def test_vllm_model_enabled_stream_exclude_input_in_output_false(self):
"""
Verifying that streaming request returns only generated diffs even if
`exclude_input_in_output` is set to False explicitly.
"""
expected_output = "Error generating stream: When streaming, `exclude_input_in_output` = False is not allowed."
await self._test_vllm_model(
exclude_input_in_output=False,
expected_output=expected_output,
expect_error=True,
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion ci/L0_backend_vllm/enabled_stream/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ CLIENT_LOG="./enabled_stream_client.log"
TEST_RESULT_FILE='test_results.txt'
CLIENT_PY="./enabled_stream_test.py"
SAMPLE_MODELS_REPO="../../../samples/model_repository"
EXPECTED_NUM_TESTS=1
EXPECTED_NUM_TESTS=3

rm -rf models && mkdir -p models
cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt
Expand Down
2 changes: 1 addition & 1 deletion ci/L0_backend_vllm/vllm_backend/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ CLIENT_LOG="./vllm_backend_client.log"
TEST_RESULT_FILE='test_results.txt'
CLIENT_PY="./vllm_backend_test.py"
SAMPLE_MODELS_REPO="../../../samples/model_repository"
EXPECTED_NUM_TESTS=3
EXPECTED_NUM_TESTS=6

# Helpers =======================================
function assert_curl_success {
Expand Down
117 changes: 109 additions & 8 deletions ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
sys.path.append("../../common")
from test_util import TestResultCollector, UserData, callback, create_vllm_request

PROMPTS = [
"The most dangerous animal is",
"The capital of France is",
"The future of AI is",
]
SAMPLING_PARAMETERS = {"temperature": "0", "top_p": "1"}


class VLLMTritonBackendTest(TestResultCollector):
def setUp(self):
Expand All @@ -60,8 +67,18 @@ def test_vllm_triton_backend(self):
self.assertFalse(self.triton_client.is_model_ready(self.python_model_name))

# Test vllm model and unload vllm model
self._test_vllm_model(send_parameters_as_tensor=True)
self._test_vllm_model(send_parameters_as_tensor=False)
self._test_vllm_model(
prompts=PROMPTS,
sampling_parameters=SAMPLING_PARAMETERS,
stream=False,
send_parameters_as_tensor=True,
)
self._test_vllm_model(
prompts=PROMPTS,
sampling_parameters=SAMPLING_PARAMETERS,
stream=False,
send_parameters_as_tensor=False,
)
self.triton_client.unload_model(self.vllm_model_name)

def test_model_with_invalid_attributes(self):
Expand All @@ -74,16 +91,90 @@ def test_vllm_invalid_model_name(self):
with self.assertRaises(InferenceServerException):
self.triton_client.load_model(model_name)

def _test_vllm_model(self, send_parameters_as_tensor):
user_data = UserData()
stream = False
def test_exclude_input_in_output_default(self):
"""
Verifying default behavior for `exclude_input_in_output`
in non-streaming mode.
Expected result: prompt is returned with diffs.
"""
self.triton_client.load_model(self.vllm_model_name)
prompts = [
"The most dangerous animal is",
"The capital of France is",
"The future of AI is",
]
number_of_vllm_reqs = len(prompts)
expected_output = [
b"The capital of France is the capital of the French Republic.\n\nThe capital of France is the capital"
]
sampling_parameters = {"temperature": "0", "top_p": "1"}
self._test_vllm_model(
prompts,
sampling_parameters,
stream=False,
send_parameters_as_tensor=True,
expected_output=expected_output,
)
self.triton_client.unload_model(self.vllm_model_name)

def test_exclude_input_in_output_false(self):
"""
Verifying behavior for `exclude_input_in_output` = False
in non-streaming mode.
Expected result: prompt is returned with diffs.
"""
self.triton_client.load_model(self.vllm_model_name)
# Test vllm model and unload vllm model
prompts = [
"The capital of France is",
]
expected_output = [
b"The capital of France is the capital of the French Republic.\n\nThe capital of France is the capital"
]
sampling_parameters = {"temperature": "0", "top_p": "1"}
self._test_vllm_model(
prompts,
sampling_parameters,
stream=False,
send_parameters_as_tensor=True,
exclude_input_in_output=False,
expected_output=expected_output,
)
self.triton_client.unload_model(self.vllm_model_name)

def test_exclude_input_in_output_true(self):
"""
Verifying behavior for `exclude_input_in_output` = True
in non-streaming mode.
Expected result: only diffs are returned.
"""
self.triton_client.load_model(self.vllm_model_name)
# Test vllm model and unload vllm model
prompts = [
"The capital of France is",
]
expected_output = [
b" the capital of the French Republic.\n\nThe capital of France is the capital"
]
sampling_parameters = {"temperature": "0", "top_p": "1"}
self._test_vllm_model(
prompts,
sampling_parameters,
stream=False,
send_parameters_as_tensor=True,
exclude_input_in_output=True,
expected_output=expected_output,
)
self.triton_client.unload_model(self.vllm_model_name)

def _test_vllm_model(
self,
prompts,
sampling_parameters,
stream,
send_parameters_as_tensor,
exclude_input_in_output=None,
expected_output=None,
):
user_data = UserData()
number_of_vllm_reqs = len(prompts)

self.triton_client.start_stream(callback=partial(callback, user_data))
for i in range(number_of_vllm_reqs):
Expand All @@ -94,6 +185,7 @@ def _test_vllm_model(self, send_parameters_as_tensor):
sampling_parameters,
self.vllm_model_name,
send_parameters_as_tensor,
exclude_input_in_output=exclude_input_in_output,
)
self.triton_client.async_stream_infer(
model_name=self.vllm_model_name,
Expand All @@ -111,6 +203,15 @@ def _test_vllm_model(self, send_parameters_as_tensor):

output = result.as_numpy("text_output")
self.assertIsNotNone(output, "`text_output` should not be None")
if expected_output is not None:
self.assertEqual(
output,
expected_output[i],
'Actual and expected outputs do not match.\n \
Expected "{}" \n Actual:"{}"'.format(
output, expected_output[i]
),
)

self.triton_client.stop_stream()

Expand Down
5 changes: 5 additions & 0 deletions ci/common/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def create_vllm_request(
sampling_parameters,
model_name,
send_parameters_as_tensor=True,
exclude_input_in_output=None,
):
inputs = []

Expand All @@ -111,6 +112,10 @@ def create_vllm_request(
inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES"))
inputs[-1].set_data_from_numpy(sampling_parameters_data)

if exclude_input_in_output is not None:
inputs.append(grpcclient.InferInput("exclude_input_in_output", [1], "BOOL"))
inputs[-1].set_data_from_numpy(np.array([exclude_input_in_output], dtype=bool))

outputs = [grpcclient.InferRequestedOutput("text_output")]

return {
Expand Down
Loading

0 comments on commit 5c03411

Please sign in to comment.