Skip to content

Commit

Permalink
Add gRPC AsyncIO cancellation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui committed Oct 11, 2023
1 parent 85be200 commit 051e3d4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 103 deletions.
133 changes: 31 additions & 102 deletions qa/L0_request_cancellation/grpc_cancellation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def callback(user_data, result, error):
user_data._completed_requests.put(result)


class GrpcCancellationTest(unittest.TestCase):
class GrpcCancellationTest(unittest.IsolatedAsyncioTestCase):
_model_name = "custom_identity_int32"
_model_delay = 10.0 # seconds
_grpc_params = {"url": "localhost:8001", "verbose": True}
Expand All @@ -61,18 +61,11 @@ def setUp(self):
self._user_data = UserData()
self._callback = partial(callback, self._user_data)
self._prepare_request()
self._record_start_time()

def tearDown(self):
self._record_end_time()
self._assert_max_duration()
self._assert_cancelled_by_client()

def _record_start_time(self):
self._start_time = time.time() # seconds

def _record_end_time(self):
def tearDown(self):
self._end_time = time.time() # seconds
self._assert_max_duration()

def _prepare_request(self):
self._inputs = []
Expand All @@ -94,7 +87,7 @@ def _assert_max_duration(self):
+ "s",
)

def _assert_cancelled_by_client(self):
def _assert_callback_cancelled(self):
self.assertFalse(self._user_data._completed_requests.empty())
data_item = self._user_data._completed_requests.get()
self.assertIsInstance(data_item, InferenceServerException)
Expand All @@ -110,6 +103,7 @@ def test_grpc_async_infer(self):
time.sleep(2) # ensure the inference has started
future.cancel()
time.sleep(0.1) # context switch
self._assert_callback_cancelled()

def test_grpc_stream_infer(self):
self._client.start_stream(callback=self._callback)
Expand All @@ -118,98 +112,33 @@ def test_grpc_stream_infer(self):
)
time.sleep(2) # ensure the inference has started
self._client.stop_stream(cancel_requests=True)
self._assert_callback_cancelled()


# Disabling AsyncIO cancellation testing. Enable once
# DLIS-5476 is implemented.
# def test_aio_grpc_async_infer(self):
# # Sends a request using infer of grpc.aio to a
# # model that takes 10s to execute. Issues
# # a cancellation request after 2s. The client
# # should return with appropriate exception within
# # 5s.
# async def cancel_request(call):
# await asyncio.sleep(2)
# self.assertTrue(call.cancel())
#
# async def handle_response(generator):
# with self.assertRaises(asyncio.exceptions.CancelledError) as cm:
# _ = await anext(generator)
#
# async def test_aio_infer(self):
# triton_client = grpcclientaio.InferenceServerClient(
# url=self._triton_grpc_url, verbose=True
# )
# self._prepare_request()
# self._record_start_time_ms()
#
# generator = triton_client.infer(
# model_name=self.model_name_,
# inputs=self.inputs_,
# outputs=self.outputs_,
# get_call_obj=True,
# )
# grpc_call = await anext(generator)
#
# tasks = []
# tasks.append(asyncio.create_task(handle_response(generator)))
# tasks.append(asyncio.create_task(cancel_request(grpc_call)))
#
# for task in tasks:
# await task
#
# self._record_end_time_ms()
# self._assert_runtime_duration(5000)
#
# asyncio.run(test_aio_infer(self))
#
# def test_aio_grpc_stream_infer(self):
# # Sends a request using stream_infer of grpc.aio
# # library model that takes 10s to execute. Issues
# # stream closure with cancel_requests=True. The client
# # should return with appropriate exception within
# # 5s.
# async def test_aio_streaming_infer(self):
# async with grpcclientaio.InferenceServerClient(
# url=self._triton_grpc_url, verbose=True
# ) as triton_client:
#
# async def async_request_iterator():
# for i in range(1):
# await asyncio.sleep(1)
# yield {
# "model_name": self.model_name_,
# "inputs": self.inputs_,
# "outputs": self.outputs_,
# }
#
# self._prepare_request()
# self._record_start_time_ms()
# response_iterator = triton_client.stream_infer(
# inputs_iterator=async_request_iterator(), get_call_obj=True
# )
# streaming_call = await anext(response_iterator)
#
# async def cancel_streaming(streaming_call):
# await asyncio.sleep(2)
# streaming_call.cancel()
#
# async def handle_response(response_iterator):
# with self.assertRaises(asyncio.exceptions.CancelledError) as cm:
# async for response in response_iterator:
# self.assertTrue(False, "Received an unexpected response!")
#
# tasks = []
# tasks.append(asyncio.create_task(handle_response(response_iterator)))
# tasks.append(asyncio.create_task(cancel_streaming(streaming_call)))
#
# for task in tasks:
# await task
#
# self._record_end_time_ms()
# self._assert_runtime_duration(5000)
#
# asyncio.run(test_aio_streaming_infer(self))
async def test_aio_grpc_async_infer(self):
infer_task = asyncio.create_task(
self._client_aio.infer(
model_name=self._model_name, inputs=self._inputs, outputs=self._outputs
)
)
await asyncio.sleep(2) # ensure the inference has started
self.assertTrue(infer_task.cancel())
with self.assertRaises(asyncio.CancelledError):
response = await infer_task

async def test_aio_grpc_stream_infer(self):
async def requests_generator():
yield {
"model_name": self._model_name,
"inputs": self._inputs,
"outputs": self._outputs,
}

responses_iterator = self._client_aio.stream_infer(requests_generator())
await asyncio.sleep(2) # ensure the inference has started
self.assertTrue(responses_iterator.cancel())
with self.assertRaises(asyncio.CancelledError):
async for result, error in responses_iterator:
self._callback(result, error)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion qa/L0_request_cancellation/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ mkdir -p models/custom_identity_int32/1 && (cd models/custom_identity_int32 && \
echo 'instance_group [{ kind: KIND_CPU }]' >> config.pbtxt && \
echo -e 'parameters [{ key: "execute_delay_ms" \n value: { string_value: "10000" } }]' >> config.pbtxt)

for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer"; do
for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc_async_infer" "test_aio_grpc_stream_infer"; do

TEST_LOG="./grpc_cancellation_test.$TEST_CASE.log"
SERVER_LOG="grpc_cancellation_test.$TEST_CASE.server.log"
Expand Down

0 comments on commit 051e3d4

Please sign in to comment.