Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support sending additional outputs from vLLM inference #70

Merged
merged 15 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 23 additions & 36 deletions ci/L0_additional_outputs_vllm/additional_outputs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import unittest

import numpy as np
import pytest
import tritonclient.grpc as grpcclient


class InferTest(unittest.TestCase):
class TestAdditionalOutputs:
_grpc_url = "localhost:8001"
_model_name = "vllm_opt"
_sampling_parameters = {"temperature": "0", "top_p": "1"}
Expand Down Expand Up @@ -93,51 +93,51 @@ def _llm_infer(self, inputs):
self._model_name, inputs=inputs, parameters=self._sampling_parameters
)
client.stop_stream()
self.assertGreater(len(self._responses), 0)
assert len(self._responses) > 0

def _assert_text_output_valid(self):
text_output = ""
for response in self._responses:
result, error = response["result"], response["error"]
self.assertIsNone(error)
assert error is None
text_output += result.as_numpy(name="text_output")[0].decode("utf-8")
self.assertGreater(len(text_output), 0, "output is empty")
self.assertGreater(text_output.count(" "), 4, "output is not a sentence")
assert len(text_output) > 0, "output is empty"
assert text_output.count(" ") > 4, "output is not a sentence"

def _assert_finish_reason(self, output_finish_reason):
for i in range(len(self._responses)):
result, error = self._responses[i]["result"], self._responses[i]["error"]
self.assertIsNone(error)
assert error is None
finish_reason_np = result.as_numpy(name="finish_reason")
if output_finish_reason is None or output_finish_reason == False:
self.assertIsNone(finish_reason_np)
assert finish_reason_np is None
continue
finish_reason = finish_reason_np[0].decode("utf-8")
if i < len(self._responses) - 1:
self.assertEqual(finish_reason, "None")
assert finish_reason == "None"
else:
self.assertEqual(finish_reason, "length")
assert finish_reason == "length"

def _assert_cumulative_logprob(self, output_cumulative_logprob):
prev_cumulative_logprob = 0.0
for response in self._responses:
result, error = response["result"], response["error"]
self.assertIsNone(error)
assert error is None
cumulative_logprob_np = result.as_numpy(name="cumulative_logprob")
if output_cumulative_logprob is None or output_cumulative_logprob == False:
self.assertIsNone(cumulative_logprob_np)
assert cumulative_logprob_np is None
continue
cumulative_logprob = cumulative_logprob_np[0].astype(float)
self.assertNotEqual(cumulative_logprob, prev_cumulative_logprob)
assert cumulative_logprob != prev_cumulative_logprob
prev_cumulative_logprob = cumulative_logprob

def _assert_num_token_ids(self, output_num_token_ids):
for response in self._responses:
result, error = response["result"], response["error"]
self.assertIsNone(error)
assert error is None
num_token_ids_np = result.as_numpy(name="num_token_ids")
if output_num_token_ids is None or output_num_token_ids == False:
self.assertIsNone(num_token_ids_np)
assert num_token_ids_np is None
continue
num_token_ids = num_token_ids_np[0].astype(int)
# TODO: vLLM may return token ids identical to the previous one when
Expand All @@ -156,10 +156,14 @@ def _assert_num_token_ids(self, output_num_token_ids):
# curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
#
# If this is no longer the case in a future release, change the assert
# to assertGreater().
self.assertGreaterEqual(num_token_ids, 0)

def _assert_additional_outputs_valid(
# to assert num_token_ids > 0.
assert num_token_ids >= 0

@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("output_finish_reason", [None, True, False])
@pytest.mark.parametrize("output_cumulative_logprob", [None, True, False])
@pytest.mark.parametrize("output_num_token_ids", [None, True, False])
def test_additional_outputs(
self,
stream,
output_finish_reason,
Expand All @@ -179,20 +183,3 @@ def _assert_additional_outputs_valid(
self._assert_finish_reason(output_finish_reason)
self._assert_cumulative_logprob(output_cumulative_logprob)
self._assert_num_token_ids(output_num_token_ids)

def test_additional_outputs(self):
for stream in [True, False]:
choices = [None, False, True]
for output_finish_reason in choices:
for output_cumulative_logprob in choices:
for output_num_token_ids in choices:
self._assert_additional_outputs_valid(
stream,
output_finish_reason,
output_cumulative_logprob,
output_num_token_ids,
)


if __name__ == "__main__":
unittest.main()
7 changes: 3 additions & 4 deletions ci/L0_additional_outputs_vllm/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
export CUDA_VISIBLE_DEVICES=0
source ../common/util.sh

pip3 install pytest==8.1.1
pip3 install tritonclient[grpc]

# Prepare Model
Expand All @@ -38,8 +39,7 @@ sed -i 's/"gpu_memory_utilization": 0.5/"gpu_memory_utilization": 0.3/' models/v

RET=0

# Infer Test
CLIENT_LOG="vllm_opt.log"
# Test
SERVER_LOG="vllm_opt.server.log"
SERVER_ARGS="--model-repository=models"
run_server
Expand All @@ -49,9 +49,8 @@ if [ "$SERVER_PID" == "0" ]; then
exit 1
fi
set +e
python3 additional_outputs_test.py > $CLIENT_LOG 2>&1
python3 -m pytest -s -v additional_outputs_test.py
kthui marked this conversation as resolved.
Show resolved Hide resolved
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** additional_outputs_test FAILED. \n***"
RET=1
fi
Expand Down
Loading