Skip to content

Commit

Permalink
Switch to pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui committed Nov 6, 2024
1 parent e6e6404 commit 44edd6e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 40 deletions.
59 changes: 23 additions & 36 deletions ci/L0_additional_outputs_vllm/additional_outputs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import unittest

import numpy as np
import pytest
import tritonclient.grpc as grpcclient


class InferTest(unittest.TestCase):
class TestAdditionalOutputs:
_grpc_url = "localhost:8001"
_model_name = "vllm_opt"
_sampling_parameters = {"temperature": "0", "top_p": "1"}
Expand Down Expand Up @@ -93,51 +93,51 @@ def _llm_infer(self, inputs):
self._model_name, inputs=inputs, parameters=self._sampling_parameters
)
client.stop_stream()
self.assertGreater(len(self._responses), 0)
assert len(self._responses) > 0

def _assert_text_output_valid(self):
text_output = ""
for response in self._responses:
result, error = response["result"], response["error"]
self.assertIsNone(error)
assert error is None
text_output += result.as_numpy(name="text_output")[0].decode("utf-8")
self.assertGreater(len(text_output), 0, "output is empty")
self.assertGreater(text_output.count(" "), 4, "output is not a sentence")
assert len(text_output) > 0, "output is empty"
assert text_output.count(" ") > 4, "output is not a sentence"

def _assert_finish_reason(self, output_finish_reason):
for i in range(len(self._responses)):
result, error = self._responses[i]["result"], self._responses[i]["error"]
self.assertIsNone(error)
assert error is None
finish_reason_np = result.as_numpy(name="finish_reason")
if output_finish_reason is None or output_finish_reason == False:
self.assertIsNone(finish_reason_np)
assert finish_reason_np is None
continue
finish_reason = finish_reason_np[0].decode("utf-8")
if i < len(self._responses) - 1:
self.assertEqual(finish_reason, "None")
assert finish_reason == "None"
else:
self.assertEqual(finish_reason, "length")
assert finish_reason == "length"

def _assert_cumulative_logprob(self, output_cumulative_logprob):
prev_cumulative_logprob = 0.0
for response in self._responses:
result, error = response["result"], response["error"]
self.assertIsNone(error)
assert error is None
cumulative_logprob_np = result.as_numpy(name="cumulative_logprob")
if output_cumulative_logprob is None or output_cumulative_logprob == False:
self.assertIsNone(cumulative_logprob_np)
assert cumulative_logprob_np is None
continue
cumulative_logprob = cumulative_logprob_np[0].astype(float)
self.assertNotEqual(cumulative_logprob, prev_cumulative_logprob)
assert cumulative_logprob != prev_cumulative_logprob
prev_cumulative_logprob = cumulative_logprob

def _assert_num_token_ids(self, output_num_token_ids):
for response in self._responses:
result, error = response["result"], response["error"]
self.assertIsNone(error)
assert error is None
num_token_ids_np = result.as_numpy(name="num_token_ids")
if output_num_token_ids is None or output_num_token_ids == False:
self.assertIsNone(num_token_ids_np)
assert num_token_ids_np is None
continue
num_token_ids = num_token_ids_np[0].astype(int)
# TODO: vLLM may return token ids identical to the previous one when
Expand All @@ -156,10 +156,14 @@ def _assert_num_token_ids(self, output_num_token_ids):
# curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
#
# If this is no longer the case in a future release, change the assert
# to assertGreater().
self.assertGreaterEqual(num_token_ids, 0)

def _assert_additional_outputs_valid(
# to assert num_token_ids > 0.
assert num_token_ids >= 0

@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("output_finish_reason", [None, True, False])
@pytest.mark.parametrize("output_cumulative_logprob", [None, True, False])
@pytest.mark.parametrize("output_num_token_ids", [None, True, False])
def test_additional_outputs(
self,
stream,
output_finish_reason,
Expand All @@ -179,20 +183,3 @@ def _assert_additional_outputs_valid(
self._assert_finish_reason(output_finish_reason)
self._assert_cumulative_logprob(output_cumulative_logprob)
self._assert_num_token_ids(output_num_token_ids)

def test_additional_outputs(self):
for stream in [True, False]:
choices = [None, False, True]
for output_finish_reason in choices:
for output_cumulative_logprob in choices:
for output_num_token_ids in choices:
self._assert_additional_outputs_valid(
stream,
output_finish_reason,
output_cumulative_logprob,
output_num_token_ids,
)


if __name__ == "__main__":
unittest.main()
7 changes: 3 additions & 4 deletions ci/L0_additional_outputs_vllm/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
export CUDA_VISIBLE_DEVICES=0
source ../common/util.sh

pip3 install pytest==8.1.1
pip3 install tritonclient[grpc]

# Prepare Model
Expand All @@ -38,8 +39,7 @@ sed -i 's/"gpu_memory_utilization": 0.5/"gpu_memory_utilization": 0.3/' models/v

RET=0

# Infer Test
CLIENT_LOG="vllm_opt.log"
# Test
SERVER_LOG="vllm_opt.server.log"
SERVER_ARGS="--model-repository=models"
run_server
Expand All @@ -49,9 +49,8 @@ if [ "$SERVER_PID" == "0" ]; then
exit 1
fi
set +e
python3 additional_outputs_test.py > $CLIENT_LOG 2>&1
python3 -m pytest -s -v additional_outputs_test.py
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** additional_outputs_test FAILED. \n***"
RET=1
fi
Expand Down

0 comments on commit 44edd6e

Please sign in to comment.