Skip to content

Commit

Permalink
Fix client I/O and model names
Browse files Browse the repository at this point in the history
  • Loading branch information
dyastremsky committed Oct 12, 2023
1 parent a50ae8d commit edaff54
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
12 changes: 6 additions & 6 deletions samples/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def process_stream(self, prompts, sampling_parameters):
if error:
print(f"Encountered error while processing: {error}")
else:
output = result.as_numpy("TEXT")
output = result.as_numpy("text_output")
for i in output:
self._results_dict[result.get_response().id].append(i)

Expand Down Expand Up @@ -126,13 +126,13 @@ def create_request(
inputs = []
prompt_data = np.array([prompt.encode("utf-8")], dtype=np.object_)
try:
inputs.append(grpcclient.InferInput("PROMPT", [1], "BYTES"))
inputs.append(grpcclient.InferInput("text_input", [1], "BYTES"))
inputs[-1].set_data_from_numpy(prompt_data)
except Exception as error:
print(f"Encountered an error during request creation: {error}")

stream_data = np.array([stream], dtype=bool)
inputs.append(grpcclient.InferInput("STREAM", [1], "BOOL"))
inputs.append(grpcclient.InferInput("stream", [1], "BOOL"))
inputs[-1].set_data_from_numpy(stream_data)

# Request parameters are not yet supported via BLS. Provide an
Expand All @@ -143,12 +143,12 @@ def create_request(
sampling_parameters_data = np.array(
[json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_
)
inputs.append(grpcclient.InferInput("SAMPLING_PARAMETERS", [1], "BYTES"))
inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES"))
inputs[-1].set_data_from_numpy(sampling_parameters_data)

# Add requested outputs
outputs = []
outputs.append(grpcclient.InferRequestedOutput("TEXT"))
outputs.append(grpcclient.InferRequestedOutput("text_output"))

# Issue the asynchronous sequence inference.
return {
Expand All @@ -167,7 +167,7 @@ def create_request(
"--model",
type=str,
required=False,
default="vllm",
default="vllm_model",
help="Model name",
)
parser.add_argument(
Expand Down
1 change: 1 addition & 0 deletions samples/model_repository/vllm_model/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

# Note: You do not need to change any fields in this configuration.

name: "vllm_model"
backend: "vllm"

# Disabling batching in Triton, let vLLM handle the batching on its own.
Expand Down

0 comments on commit edaff54

Please sign in to comment.