Skip to content

Commit

Permalink
Add ability to cancel AsyncIO gRPC stream requests (#417)
Browse files Browse the repository at this point in the history
* Add ability to cancel AsyncIO gRPC stream or non-stream requests

* Add docs on AsyncIO request cancellation

* Improve AsyncIO docs

* Improve example code styling

Co-authored-by: Ryan McCormick <[email protected]>

* Skip await on documentation

---------

Co-authored-by: Ryan McCormick <[email protected]>
  • Loading branch information
kthui and rmccorm4 committed Oct 13, 2023
1 parent de762b6 commit aac4b27
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,25 @@ sent via this stream.
See more details about these APIs in
[grpc/\_client.py](src/python/library/tritonclient/grpc/_client.py).

For gRPC AsyncIO requests, an AsyncIO task wrapping an `infer()` coroutine can
be safely cancelled.

```python
infer_task = asyncio.create_task(aio_client.infer(...))
infer_task.cancel()
```

For gRPC AsyncIO streaming requests, `cancel()` can be called on the
asynchronous iterator returned by `stream_infer()` API.

```python
responses_iterator = aio_client.stream_infer(...)
responses_iterator.cancel()
```

See more details about these APIs in
[grpc/aio/\__init__.py](src/python/library/tritonclient/grpc/aio/__init__.py).

See [request_cancellation](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/request_cancellation.md)
in the server user-guide to learn about how this is handled on the
server side.
Expand Down
44 changes: 32 additions & 12 deletions src/python/library/tritonclient/grpc/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ async def infer(
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

async def stream_infer(
def stream_infer(
self,
inputs_iterator,
stream_timeout=None,
Expand All @@ -636,7 +636,7 @@ async def stream_infer(
Parameters
----------
inputs_iterator : async_generator
inputs_iterator : asynchronous iterator
Async iterator that yields a dict(s) consists of the input
parameters to the async_stream_infer function defined in
tritonclient.grpc.InferenceServerClient.
Expand All @@ -653,9 +653,15 @@ async def stream_infer(
Returns
-------
async_generator
asynchronous iterator
Yield tuple holding (InferResult, InferenceServerException) objects.
This object can be used to cancel the inference request like below:
----------
it = stream_infer(...)
ret = it.cancel()
----------
Raises
------
InferenceServerException
Expand Down Expand Up @@ -708,21 +714,35 @@ async def _request_iterator(inputs_iterator):
parameters=inputs["parameters"],
)

try:
response_iterator = self._client_stub.ModelStreamInfer(
_request_iterator(inputs_iterator),
metadata=metadata,
timeout=stream_timeout,
compression=_grpc_compression_type(compression_algorithm),
)
async for response in response_iterator:
class _ResponseIterator:
def __init__(self, grpc_call, verbose):
self._grpc_call = grpc_call
self._verbose = verbose

def __aiter__(self):
return self

async def __anext__(self):
response = await self._grpc_call.__aiter__().__anext__()
if self._verbose:
print(response)
result = error = None
if response.error_message != "":
error = InferenceServerException(msg=response.error_message)
else:
result = InferResult(response.infer_response)
yield (result, error)
return result, error

def cancel(self):
return self._grpc_call.cancel()

try:
grpc_call = self._client_stub.ModelStreamInfer(
_request_iterator(inputs_iterator),
metadata=metadata,
timeout=stream_timeout,
compression=_grpc_compression_type(compression_algorithm),
)
return _ResponseIterator(grpc_call, self._verbose)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

0 comments on commit aac4b27

Please sign in to comment.