Skip to content

Commit

Permalink
Add logprobs additional output
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui committed Dec 3, 2024
1 parent 2e1a223 commit 1e2675e
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 44 deletions.
92 changes: 70 additions & 22 deletions ci/L0_additional_outputs_vllm/additional_outputs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,20 @@ class TestAdditionalOutputs:
_sampling_parameters = {"temperature": "0", "top_p": "1"}
_prompt = "In this example,"

def _get_sampling_parameters(self, logprobs=None):
sampling_parameters = self._sampling_parameters.copy()
if logprobs is not None:
sampling_parameters["logprobs"] = logprobs
return sampling_parameters

def _get_inputs(
self,
prompt,
stream=True,
sampling_parameters=None,
return_finish_reason=None,
return_cumulative_logprob=None,
return_logprobs=None,
return_num_input_tokens=None,
return_num_output_tokens=None,
):
Expand Down Expand Up @@ -77,6 +84,10 @@ def _get_inputs(
np.array([return_cumulative_logprob], dtype=bool)
)

if return_logprobs is not None:
inputs.append(grpcclient.InferInput("return_logprobs", [1], "BOOL"))
inputs[-1].set_data_from_numpy(np.array([return_logprobs], dtype=bool))

if return_num_input_tokens is not None:
inputs.append(grpcclient.InferInput("return_num_input_tokens", [1], "BOOL"))
inputs[-1].set_data_from_numpy(
Expand All @@ -96,12 +107,12 @@ def _get_inputs(
def _callback(self, result, error):
self._responses.append({"result": result, "error": error})

def _llm_infer(self, inputs):
def _llm_infer(self, inputs, sampling_parameters):
self._responses = []
with grpcclient.InferenceServerClient(self._grpc_url) as client:
client.start_stream(self._callback)
client.async_stream_infer(
self._model_name, inputs=inputs, parameters=self._sampling_parameters
self._model_name, inputs=inputs, parameters=sampling_parameters
)
client.stop_stream()
assert len(self._responses) > 0
Expand Down Expand Up @@ -142,6 +153,51 @@ def _assert_cumulative_logprob(self, return_cumulative_logprob):
assert cumulative_logprob != prev_cumulative_logprob
prev_cumulative_logprob = cumulative_logprob

def _assert_logprobs(
self, stream, sampling_parameters, return_logprobs, return_num_output_tokens
):
for response in self._responses:
result, error = response["result"], response["error"]
assert error is None
logprobs_np = result.as_numpy(name="logprobs")
if return_logprobs is None or return_logprobs == False:
assert logprobs_np is None
continue
logprobs = json.loads(logprobs_np[0].decode("utf-8"))
if "logprobs" not in sampling_parameters:
assert logprobs is None
continue
assert isinstance(logprobs, list)
assert len(logprobs) >= 1
if return_num_output_tokens == True:
num_output_tokens = result.as_numpy(name="num_output_tokens")[0].astype(
int
)
assert len(logprobs) == num_output_tokens
text_output_logprobs = ""
for logprobs_d in logprobs:
assert isinstance(logprobs_d, dict)
assert len(logprobs_d) >= 1
assert len(logprobs_d) <= sampling_parameters["logprobs"] + 1
rank_one_found = False
for token_id, logprob_d in logprobs_d.items():
assert isinstance(token_id, str)
assert len(logprob_d) == 3
assert isinstance(logprob_d["logprob"], float)
assert isinstance(logprob_d["rank"], int)
assert isinstance(logprob_d["decoded_token"], str)
if logprob_d["rank"] == 1:
assert not rank_one_found
rank_one_found = True
text_output_logprobs += logprob_d["decoded_token"]
assert rank_one_found
text_output = result.as_numpy(name="text_output")[0].decode("utf-8")
if not stream:
# given exclude_input_in_output is not set, prepend_input is True if not
# streaming and False if streaming
text_output_logprobs = self._prompt + text_output_logprobs
assert text_output_logprobs == text_output

def _assert_num_input_tokens(self, return_num_input_tokens):
for response in self._responses:
result, error = response["result"], response["error"]
Expand All @@ -163,50 +219,42 @@ def _assert_num_output_tokens(self, return_num_output_tokens):
assert num_output_tokens_np is None
continue
num_output_tokens = num_output_tokens_np[0].astype(int)
# TODO: vLLM may return token ids identical to the previous one when
# streaming, for example:
#
# prev: None
# curr: text=' the', token_ids=array('l', [5])
#
# prev: text=' the', token_ids=array('l', [5, 1385])
# curr: text=' the term', token_ids=array('l', [5, 1385])
#
# prev: text=' the term', token_ids=array('l', [5, 1385, 44])
# curr: text=' the term', token_ids=array('l', [5, 1385, 44])
#
# prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48])
# curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
#
# If this is no longer the case in a future release, change the assert
# to assert num_output_tokens > 0.
assert num_output_tokens >= 0
assert num_output_tokens > 0

@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("return_finish_reason", [None, True, False])
@pytest.mark.parametrize("return_cumulative_logprob", [None, True, False])
@pytest.mark.parametrize("logprobs", [None, 0, 2])
@pytest.mark.parametrize("return_logprobs", [None, True, False])
@pytest.mark.parametrize("return_num_input_tokens", [None, True, False])
@pytest.mark.parametrize("return_num_output_tokens", [None, True, False])
def test_additional_outputs(
self,
stream,
return_finish_reason,
return_cumulative_logprob,
logprobs,
return_logprobs,
return_num_input_tokens,
return_num_output_tokens,
):
sampling_parameters = self._get_sampling_parameters(logprobs=logprobs)
inputs = self._get_inputs(
self._prompt,
stream=stream,
sampling_parameters=self._sampling_parameters,
sampling_parameters=sampling_parameters,
return_finish_reason=return_finish_reason,
return_cumulative_logprob=return_cumulative_logprob,
return_logprobs=return_logprobs,
return_num_input_tokens=return_num_input_tokens,
return_num_output_tokens=return_num_output_tokens,
)
self._llm_infer(inputs)
self._llm_infer(inputs, sampling_parameters)
self._assert_text_output_valid()
self._assert_finish_reason(return_finish_reason)
self._assert_cumulative_logprob(return_cumulative_logprob)
self._assert_logprobs(
stream, sampling_parameters, return_logprobs, return_num_output_tokens
)
self._assert_num_input_tokens(return_num_input_tokens)
self._assert_num_output_tokens(return_num_output_tokens)
79 changes: 57 additions & 22 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config):
"dims": [1],
"optional": True,
},
{
"name": "return_logprobs",
"data_type": "TYPE_BOOL",
"dims": [1],
"optional": True,
},
{
"name": "return_num_input_tokens",
"data_type": "TYPE_BOOL",
Expand Down Expand Up @@ -131,6 +137,7 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config):
{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]},
{"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]},
{"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]},
{"name": "logprobs", "data_type": "TYPE_STRING", "dims": [-1]},
{"name": "num_input_tokens", "data_type": "TYPE_UINT32", "dims": [1]},
{"name": "num_output_tokens", "data_type": "TYPE_UINT32", "dims": [-1]},
]
Expand Down Expand Up @@ -388,6 +395,7 @@ def _get_input_tensors(self, request):
additional_outputs = {
"return_finish_reason": None,
"return_cumulative_logprob": None,
"return_logprobs": None,
"return_num_input_tokens": None,
"return_num_output_tokens": None,
}
Expand Down Expand Up @@ -455,26 +463,27 @@ def response_loop(self):
self.ongoing_request_count -= 1

def _create_response(
self, prev_request_output, request_output, prepend_input, additional_outputs
self, request_output_state, request_output, prepend_input, additional_outputs
):
output_tensors = []

# text_output
prepend_prompt = ""
if prev_request_output is None:
if "prev_lens_text_output" not in request_output_state:
# this is the first response
if prepend_input:
prepend_prompt = request_output.prompt
prev_lens = [0] * len(request_output.outputs)
else:
# this is a subsequent response
prev_lens = [
len(prev_output.text) for prev_output in prev_request_output.outputs
]
request_output_state["prev_lens_text_output"] = [0] * len(
request_output.outputs
)
prev_lens = request_output_state["prev_lens_text_output"]
text_output = [
(prepend_prompt + output.text[prev_len:]).encode("utf-8")
for output, prev_len in zip(request_output.outputs, prev_lens)
]
request_output_state["prev_lens_text_output"] = [
len(output.text) for output in request_output.outputs
]
output_tensors.append(
pb_utils.Tensor(
"text_output", np.asarray(text_output, dtype=self.output_dtype)
Expand Down Expand Up @@ -504,6 +513,35 @@ def _create_response(
)
)

# logprobs
if additional_outputs["return_logprobs"]:
if "prev_lens_logprobs" not in request_output_state:
request_output_state["prev_lens_logprobs"] = [0] * len(
request_output.outputs
)
logprobs = []
for i in range(len(request_output.outputs)):
output = request_output.outputs[i]
if output.logprobs is None:
logprobs.append("null".encode("utf-8"))
continue
prev_len = request_output_state["prev_lens_logprobs"][i]
request_output_state["prev_lens_logprobs"][i] = len(output.logprobs)
logprobs_py = []
for logprob_d_vllm in output.logprobs[prev_len:]:
logprob_d_py = {}
for token_id, logprob_vllm in logprob_d_vllm.items():
logprob_d_py[token_id] = {
"logprob": logprob_vllm.logprob,
"rank": logprob_vllm.rank,
"decoded_token": logprob_vllm.decoded_token,
}
logprobs_py.append(logprob_d_py)
logprobs.append(json.dumps(logprobs_py).encode("utf-8"))
output_tensors.append(
pb_utils.Tensor("logprobs", np.asarray(logprobs, dtype=np.object_))
)

# num_input_tokens
if additional_outputs["return_num_input_tokens"]:
num_input_tokens = len(request_output.prompt_token_ids)
Expand All @@ -515,19 +553,18 @@ def _create_response(

# num_output_tokens
if additional_outputs["return_num_output_tokens"]:
if prev_request_output is None:
# this is the first response
prev_lens = [0] * len(request_output.outputs)
else:
# this is a subsequent response
prev_lens = [
len(prev_output.token_ids)
for prev_output in prev_request_output.outputs
]
if "prev_lens_num_output_tokens" not in request_output_state:
request_output_state["prev_lens_num_output_tokens"] = [0] * len(
request_output.outputs
)
prev_lens = request_output_state["prev_lens_num_output_tokens"]
num_output_tokens = [
(len(output.token_ids) - prev_len)
for output, prev_len in zip(request_output.outputs, prev_lens)
]
request_output_state["prev_lens_num_output_tokens"] = [
len(output.token_ids) for output in request_output.outputs
]
output_tensors.append(
pb_utils.Tensor(
"num_output_tokens", np.asarray(num_output_tokens, dtype=np.uint32)
Expand Down Expand Up @@ -572,7 +609,7 @@ async def generate(self, request):
request_id, prompt, sampling_params, lora_request=lora_request
)

prev_request_output = None
request_output_state = {}
async for request_output in response_iterator:
# Cancellation state will be checked by the response loop and written to
# the response state if streaming. If not streaming, cancellation state
Expand Down Expand Up @@ -605,7 +642,7 @@ async def generate(self, request):
# Send each response if streaming.
if stream:
response = self._create_response(
prev_request_output,
request_output_state,
request_output,
prepend_input=False,
additional_outputs=additional_outputs,
Expand All @@ -617,13 +654,11 @@ async def generate(self, request):
decrement_ongoing_request_count = False
self._response_queue.put_nowait((response_state, response, flags))

prev_request_output = request_output

# Send the last response which contains all the outputs if not streaming.
if not stream:
response_sender.send(
self._create_response(
prev_request_output=None,
request_output_state={},
request_output=request_output,
prepend_input=prepend_input,
additional_outputs=additional_outputs,
Expand Down

0 comments on commit 1e2675e

Please sign in to comment.