diff --git a/samples/client.py b/samples/client.py index 390a3657..61dc5938 100755 --- a/samples/client.py +++ b/samples/client.py @@ -38,13 +38,21 @@ class LLMClient: def __init__(self, flags: argparse.Namespace): - self._client = grpcclient.InferenceServerClient( - url=flags.url, verbose=flags.verbose - ) self._flags = flags - self._loop = asyncio.get_event_loop() self._results_dict = {} + def get_triton_client(self): + try: + triton_client = grpcclient.InferenceServerClient( + url=self._flags.url, + verbose=self._flags.verbose, + ) + except Exception as e: + print("channel creation failed: " + str(e)) + sys.exit() + + return triton_client + async def async_request_iterator( self, prompts, sampling_parameters, exclude_input_in_output ): @@ -65,8 +73,9 @@ async def async_request_iterator( async def stream_infer(self, prompts, sampling_parameters, exclude_input_in_output): try: + triton_client = self.get_triton_client() # Start streaming - response_iterator = self._client.stream_infer( + response_iterator = triton_client.stream_infer( inputs_iterator=self.async_request_iterator( prompts, sampling_parameters, exclude_input_in_output ), @@ -138,7 +147,7 @@ async def run(self): print("FAIL: vLLM example") def run_async(self): - self._loop.run_until_complete(self.run()) + asyncio.run(self.run()) def create_request( self,