Skip to content

Commit

Permalink
Minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
yinggeh committed Aug 6, 2024
1 parent 6f97f6f commit bf7669e
Showing 1 changed file with 36 additions and 45 deletions.
81 changes: 36 additions & 45 deletions ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,46 +70,17 @@ def get_metrics(self):

return vllm_dict

def test_vllm_metrics(self):
# All vLLM metrics from tritonserver
expected_metrics_dict = {
"vllm:prompt_tokens_total": 0,
"vllm:generation_tokens_total": 0,
}

# Test vLLM metrics
self._test_vllm_model(
prompts=self.prompts,
sampling_parameters=self.sampling_parameters,
stream=False,
send_parameters_as_tensor=True,
model_name=self.vllm_model_name,
)
expected_metrics_dict["vllm:prompt_tokens_total"] = 18
expected_metrics_dict["vllm:generation_tokens_total"] = 48
self.assertEqual(self.get_metrics(), expected_metrics_dict)

self._test_vllm_model(
prompts=self.prompts,
sampling_parameters=self.sampling_parameters,
stream=False,
send_parameters_as_tensor=False,
model_name=self.vllm_model_name,
)
expected_metrics_dict["vllm:prompt_tokens_total"] = 36
expected_metrics_dict["vllm:generation_tokens_total"] = 96
self.assertEqual(self.get_metrics(), expected_metrics_dict)

def _test_vllm_model(
def vllm_async_stream_infer(
self,
prompts,
sampling_parameters,
stream,
send_parameters_as_tensor,
exclude_input_in_output=None,
expected_output=None,
model_name="vllm_opt",
model_name,
):
"""
Helper function to send async stream infer requests to vLLM.
"""
user_data = UserData()
number_of_vllm_reqs = len(prompts)

Expand All @@ -122,7 +93,6 @@ def _test_vllm_model(
sampling_parameters,
model_name,
send_parameters_as_tensor,
exclude_input_in_output=exclude_input_in_output,
)
self.triton_client.async_stream_infer(
model_name=model_name,
Expand All @@ -132,26 +102,47 @@ def _test_vllm_model(
parameters=sampling_parameters,
)

for i in range(number_of_vllm_reqs):
for _ in range(number_of_vllm_reqs):
result = user_data._completed_requests.get()
if type(result) is InferenceServerException:
print(result.message())
self.assertIsNot(type(result), InferenceServerException, str(result))

output = result.as_numpy("text_output")
self.assertIsNotNone(output, "`text_output` should not be None")
if expected_output is not None:
self.assertEqual(
output,
expected_output[i],
'Actual and expected outputs do not match.\n \
Expected "{}" \n Actual:"{}"'.format(
output, expected_output[i]
),
)

self.triton_client.stop_stream()

def test_vllm_metrics(self):
# All vLLM metrics from tritonserver
expected_metrics_dict = {
"vllm:prompt_tokens_total": 0,
"vllm:generation_tokens_total": 0,
}

# Test vLLM metrics
self.vllm_async_stream_infer(
prompts=self.prompts,
sampling_parameters=self.sampling_parameters,
stream=False,
send_parameters_as_tensor=True,
model_name=self.vllm_model_name,
)
expected_metrics_dict["vllm:prompt_tokens_total"] = 18
expected_metrics_dict["vllm:generation_tokens_total"] = 48
self.assertEqual(self.get_metrics(), expected_metrics_dict)

self.vllm_async_stream_infer(
prompts=self.prompts,
sampling_parameters=self.sampling_parameters,
stream=False,
send_parameters_as_tensor=False,
model_name=self.vllm_model_name,
)
expected_metrics_dict["vllm:prompt_tokens_total"] = 36
expected_metrics_dict["vllm:generation_tokens_total"] = 96
self.assertEqual(self.get_metrics(), expected_metrics_dict)

def tearDown(self):
self.triton_client.close()

Expand Down

0 comments on commit bf7669e

Please sign in to comment.