diff --git a/src/model.py b/src/model.py index de03fa36..7cbc2888 100644 --- a/src/model.py +++ b/src/model.py @@ -69,7 +69,9 @@ def initialize(self, args): AsyncEngineArgs(**vllm_engine_config) ) - output_config = pb_utils.get_output_config_by_name(self.model_config, "text_output") + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) # Counter to keep track of ongoing request counts @@ -172,7 +174,9 @@ async def generate(self, request): self.ongoing_request_count += 1 try: request_id = random_uuid() - prompt = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy()[0] + prompt = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy()[0] if isinstance(prompt, bytes): prompt = prompt.decode("utf-8") stream = pb_utils.get_input_tensor_by_name(request, "stream").as_numpy()[0]