From c11a12a5221eb4fdeb6ecc2466e19c4e12bbd3e2 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 20 Dec 2024 17:01:59 -0800 Subject: [PATCH] Move the gRPC client initialization into aio loop --- samples/client.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/samples/client.py b/samples/client.py index 390a3657..61dc5938 100755 --- a/samples/client.py +++ b/samples/client.py @@ -38,13 +38,21 @@ class LLMClient: def __init__(self, flags: argparse.Namespace): - self._client = grpcclient.InferenceServerClient( - url=flags.url, verbose=flags.verbose - ) self._flags = flags - self._loop = asyncio.get_event_loop() self._results_dict = {} + def get_triton_client(self): + try: + triton_client = grpcclient.InferenceServerClient( + url=self._flags.url, + verbose=self._flags.verbose, + ) + except Exception as e: + print("channel creation failed: " + str(e)) + sys.exit() + + return triton_client + async def async_request_iterator( self, prompts, sampling_parameters, exclude_input_in_output ): @@ -65,8 +73,9 @@ async def async_request_iterator( async def stream_infer(self, prompts, sampling_parameters, exclude_input_in_output): try: + triton_client = self.get_triton_client() # Start streaming - response_iterator = self._client.stream_infer( + response_iterator = triton_client.stream_infer( inputs_iterator=self.async_request_iterator( prompts, sampling_parameters, exclude_input_in_output ), @@ -138,7 +147,7 @@ async def run(self): print("FAIL: vLLM example") def run_async(self): - self._loop.run_until_complete(self.run()) + asyncio.run(self.run()) def create_request( self,