diff --git a/src/python/library/tritonclient/grpc/aio/__init__.py b/src/python/library/tritonclient/grpc/aio/__init__.py index 0c19eadff..fc5eaccdb 100755 --- a/src/python/library/tritonclient/grpc/aio/__init__.py +++ b/src/python/library/tritonclient/grpc/aio/__init__.py @@ -586,35 +586,8 @@ async def infer( headers=None, compression_algorithm=None, parameters=None, - get_call_obj=False, ): - """Refer to tritonclient.grpc.InferenceServerClient - The additional parameters for this functions are - described below: - - Parameters - ---------- - get_call_obj : bool - If set True, then this function will yield - grpc.aio.call object first bfore the - InferResult. - This object can be used to issue request - cancellation if required. This can be attained - by following: - ------- - generator = client.infer(..., get_call_obj=True) - grpc_call = await anext(generator) - grpc_call.cancel() - ------- - - Returns - ------- - async_generator - If get_call_obj is set True, then it generates the - streaming_call object before generating the inference - results. - - """ + """Refer to tritonclient.grpc.InferenceServerClient""" metadata = self._get_metadata(headers) @@ -636,20 +609,18 @@ async def infer( ) if self._verbose: print("infer, metadata {}\n{}".format(metadata, request)) + try: - call = self._client_stub.ModelInfer( + response = await self._client_stub.ModelInfer( request=request, metadata=metadata, timeout=client_timeout, compression=_grpc_compression_type(compression_algorithm), ) - if get_call_obj: - yield call - response = await call if self._verbose: print(response) result = InferResult(response) - yield result + return result except grpc.RpcError as rpc_error: raise_error_grpc(rpc_error) @@ -659,7 +630,6 @@ async def stream_infer( stream_timeout=None, headers=None, compression_algorithm=None, - get_call_obj=False, ): """Runs an asynchronous inference over gRPC bi-directional streaming API. @@ -680,23 +650,11 @@ async def stream_infer( Optional grpc compression algorithm to be used on client side. Currently supports "deflate", "gzip" and None. By default, no compression is used. - get_call_obj : bool - If set True, then the async_generator will first generate - grpc.aio.call object and then generate rest of the results. - The call object can be used to cancel the execution of the - ongoing stream and exit. This can be done like below: - ------- - async_generator = client.stream_infer(..., get_call_obj=True) - streaming_call = await anext(response_iterator) - streaming_call.cancel() - ------- Returns ------- async_generator Yield tuple holding (InferResult, InferenceServerException) objects. - If get_call_obj is set True, then it first generates streaming_call - object associated with the call before generating these tuples. Raises ------ @@ -751,17 +709,13 @@ async def _request_iterator(inputs_iterator): ) try: - streaming_call = self._client_stub.ModelStreamInfer( + response_iterator = self._client_stub.ModelStreamInfer( _request_iterator(inputs_iterator), metadata=metadata, timeout=stream_timeout, compression=_grpc_compression_type(compression_algorithm), ) - - if get_call_obj: - yield streaming_call - - async for response in streaming_call: + async for response in response_iterator: if self._verbose: print(response) result = error = None