Skip to content

Commit

Permalink
Renaming the tensors and removing tools
Browse files Browse the repository at this point in the history
  • Loading branch information
tanmayv25 committed Oct 10, 2023
1 parent 2e1fa35 commit 065f2dc
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 166 deletions.
10 changes: 5 additions & 5 deletions ci/L0_backend_vllm/vllm_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _test_vllm_model(self, send_parameters_as_tensor):
result = user_data._completed_requests.get()
self.assertIsNot(type(result), InferenceServerException)

output = result.as_numpy("TEXT")
output = result.as_numpy("text_output")
self.assertIsNotNone(output)

self.triton_client.stop_stream()
Expand Down Expand Up @@ -150,21 +150,21 @@ def _create_vllm_request_data(
inputs = []

prompt_data = np.array([prompt.encode("utf-8")], dtype=np.object_)
inputs.append(grpcclient.InferInput("PROMPT", [1], "BYTES"))
inputs.append(grpcclient.InferInput("text_input", [1], "BYTES"))
inputs[-1].set_data_from_numpy(prompt_data)

stream_data = np.array([stream], dtype=bool)
inputs.append(grpcclient.InferInput("STREAM", [1], "BOOL"))
inputs.append(grpcclient.InferInput("stream", [1], "BOOL"))
inputs[-1].set_data_from_numpy(stream_data)

if send_parameters_as_tensor:
sampling_parameters_data = np.array(
[json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_
)
inputs.append(grpcclient.InferInput("SAMPLING_PARAMETERS", [1], "BYTES"))
inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES"))
inputs[-1].set_data_from_numpy(sampling_parameters_data)

outputs = [grpcclient.InferRequestedOutput("TEXT")]
outputs = [grpcclient.InferRequestedOutput("text_output")]

return inputs, outputs

Expand Down
9 changes: 5 additions & 4 deletions ci/qa_models/vllm_opt/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,18 @@ model_transaction_policy {

input [
{
name: "PROMPT"
name: "text_input"
data_type: TYPE_STRING
dims: [ 1 ]
},
{
name: "STREAM"
name: "stream"
data_type: TYPE_BOOL
dims: [ 1 ]
optional: true
},
{
name: "SAMPLING_PARAMETERS"
name: "sampling_parameters"
data_type: TYPE_STRING
dims: [ 1 ]
optional: true
Expand All @@ -53,7 +54,7 @@ input [

output [
{
name: "TEXT"
name: "text_output"
data_type: TYPE_STRING
dims: [ -1 ]
}
Expand Down
12 changes: 6 additions & 6 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def initialize(self, args):
AsyncEngineArgs(**vllm_engine_config)
)

output_config = pb_utils.get_output_config_by_name(self.model_config, "TEXT")
output_config = pb_utils.get_output_config_by_name(self.model_config, "text_output")
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

# Counter to keep track of ongoing request counts
Expand Down Expand Up @@ -160,7 +160,7 @@ def create_response(self, vllm_output):
(prompt + output.text).encode("utf-8") for output in vllm_output.outputs
]
triton_output_tensor = pb_utils.Tensor(
"TEXT", np.asarray(text_outputs, dtype=self.output_dtype)
"text_output", np.asarray(text_outputs, dtype=self.output_dtype)
)
return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor])

Expand All @@ -172,17 +172,17 @@ async def generate(self, request):
self.ongoing_request_count += 1
try:
request_id = random_uuid()
prompt = pb_utils.get_input_tensor_by_name(request, "PROMPT").as_numpy()[0]
prompt = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy()[0]
if isinstance(prompt, bytes):
prompt = prompt.decode("utf-8")
stream = pb_utils.get_input_tensor_by_name(request, "STREAM").as_numpy()[0]
stream = pb_utils.get_input_tensor_by_name(request, "stream").as_numpy()[0]

# Request parameters are not yet supported via
# BLS. Provide an optional mechanism to receive serialized
# parameters as an input tensor until support is added

parameters_input_tensor = pb_utils.get_input_tensor_by_name(
request, "SAMPLING_PARAMETERS"
request, "sampling_parameters"
)
if parameters_input_tensor:
parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8")
Expand Down Expand Up @@ -211,7 +211,7 @@ async def generate(self, request):
self.logger.log_info(f"Error generating stream: {e}")
error = pb_utils.TritonError(f"Error generating stream: {e}")
triton_output_tensor = pb_utils.Tensor(
"TEXT", np.asarray(["N/A"], dtype=self.output_dtype)
"text_output", np.asarray(["N/A"], dtype=self.output_dtype)
)
response = pb_utils.InferenceResponse(
output_tensors=[triton_output_tensor], error=error
Expand Down
151 changes: 0 additions & 151 deletions tools/environment.yml

This file was deleted.

0 comments on commit 065f2dc

Please sign in to comment.