diff --git a/qa/L0_request_cancellation/grpc_cancellation_test.py b/qa/L0_request_cancellation/grpc_cancellation_test.py index 55027d5b084..b65117825a8 100755 --- a/qa/L0_request_cancellation/grpc_cancellation_test.py +++ b/qa/L0_request_cancellation/grpc_cancellation_test.py @@ -50,7 +50,7 @@ def callback(user_data, result, error): user_data._completed_requests.put(result) -class GrpcCancellationTest(unittest.TestCase): +class GrpcCancellationTest(unittest.IsolatedAsyncioTestCase): _model_name = "custom_identity_int32" _model_delay = 10.0 # seconds _grpc_params = {"url": "localhost:8001", "verbose": True} @@ -61,18 +61,11 @@ def setUp(self): self._user_data = UserData() self._callback = partial(callback, self._user_data) self._prepare_request() - self._record_start_time() - - def tearDown(self): - self._record_end_time() - self._assert_max_duration() - self._assert_cancelled_by_client() - - def _record_start_time(self): self._start_time = time.time() # seconds - def _record_end_time(self): + def tearDown(self): self._end_time = time.time() # seconds + self._assert_max_duration() def _prepare_request(self): self._inputs = [] @@ -94,7 +87,7 @@ def _assert_max_duration(self): + "s", ) - def _assert_cancelled_by_client(self): + def _assert_callback_cancelled(self): self.assertFalse(self._user_data._completed_requests.empty()) data_item = self._user_data._completed_requests.get() self.assertIsInstance(data_item, InferenceServerException) @@ -110,6 +103,7 @@ def test_grpc_async_infer(self): time.sleep(2) # ensure the inference has started future.cancel() time.sleep(0.1) # context switch + self._assert_callback_cancelled() def test_grpc_stream_infer(self): self._client.start_stream(callback=self._callback) @@ -118,98 +112,33 @@ def test_grpc_stream_infer(self): ) time.sleep(2) # ensure the inference has started self._client.stop_stream(cancel_requests=True) + self._assert_callback_cancelled() - -# Disabling AsyncIO cancellation testing. Enable once -# DLIS-5476 is implemented. -# def test_aio_grpc_async_infer(self): -# # Sends a request using infer of grpc.aio to a -# # model that takes 10s to execute. Issues -# # a cancellation request after 2s. The client -# # should return with appropriate exception within -# # 5s. -# async def cancel_request(call): -# await asyncio.sleep(2) -# self.assertTrue(call.cancel()) -# -# async def handle_response(generator): -# with self.assertRaises(asyncio.exceptions.CancelledError) as cm: -# _ = await anext(generator) -# -# async def test_aio_infer(self): -# triton_client = grpcclientaio.InferenceServerClient( -# url=self._triton_grpc_url, verbose=True -# ) -# self._prepare_request() -# self._record_start_time_ms() -# -# generator = triton_client.infer( -# model_name=self.model_name_, -# inputs=self.inputs_, -# outputs=self.outputs_, -# get_call_obj=True, -# ) -# grpc_call = await anext(generator) -# -# tasks = [] -# tasks.append(asyncio.create_task(handle_response(generator))) -# tasks.append(asyncio.create_task(cancel_request(grpc_call))) -# -# for task in tasks: -# await task -# -# self._record_end_time_ms() -# self._assert_runtime_duration(5000) -# -# asyncio.run(test_aio_infer(self)) -# -# def test_aio_grpc_stream_infer(self): -# # Sends a request using stream_infer of grpc.aio -# # library model that takes 10s to execute. Issues -# # stream closure with cancel_requests=True. The client -# # should return with appropriate exception within -# # 5s. -# async def test_aio_streaming_infer(self): -# async with grpcclientaio.InferenceServerClient( -# url=self._triton_grpc_url, verbose=True -# ) as triton_client: -# -# async def async_request_iterator(): -# for i in range(1): -# await asyncio.sleep(1) -# yield { -# "model_name": self.model_name_, -# "inputs": self.inputs_, -# "outputs": self.outputs_, -# } -# -# self._prepare_request() -# self._record_start_time_ms() -# response_iterator = triton_client.stream_infer( -# inputs_iterator=async_request_iterator(), get_call_obj=True -# ) -# streaming_call = await anext(response_iterator) -# -# async def cancel_streaming(streaming_call): -# await asyncio.sleep(2) -# streaming_call.cancel() -# -# async def handle_response(response_iterator): -# with self.assertRaises(asyncio.exceptions.CancelledError) as cm: -# async for response in response_iterator: -# self.assertTrue(False, "Received an unexpected response!") -# -# tasks = [] -# tasks.append(asyncio.create_task(handle_response(response_iterator))) -# tasks.append(asyncio.create_task(cancel_streaming(streaming_call))) -# -# for task in tasks: -# await task -# -# self._record_end_time_ms() -# self._assert_runtime_duration(5000) -# -# asyncio.run(test_aio_streaming_infer(self)) + async def test_aio_grpc_async_infer(self): + infer_task = asyncio.create_task( + self._client_aio.infer( + model_name=self._model_name, inputs=self._inputs, outputs=self._outputs + ) + ) + await asyncio.sleep(2) # ensure the inference has started + self.assertTrue(infer_task.cancel()) + with self.assertRaises(asyncio.CancelledError): + response = await infer_task + + async def test_aio_grpc_stream_infer(self): + async def requests_generator(): + yield { + "model_name": self._model_name, + "inputs": self._inputs, + "outputs": self._outputs, + } + + responses_iterator = self._client_aio.stream_infer(requests_generator()) + await asyncio.sleep(2) # ensure the inference has started + self.assertTrue(responses_iterator.cancel()) + with self.assertRaises(asyncio.CancelledError): + async for result, error in responses_iterator: + self._callback(result, error) if __name__ == "__main__": diff --git a/qa/L0_request_cancellation/test.sh b/qa/L0_request_cancellation/test.sh index bc36625bce5..23917ec16f5 100755 --- a/qa/L0_request_cancellation/test.sh +++ b/qa/L0_request_cancellation/test.sh @@ -78,7 +78,7 @@ mkdir -p models/custom_identity_int32/1 && (cd models/custom_identity_int32 && \ echo 'instance_group [{ kind: KIND_CPU }]' >> config.pbtxt && \ echo -e 'parameters [{ key: "execute_delay_ms" \n value: { string_value: "10000" } }]' >> config.pbtxt) -for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer"; do +for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc_async_infer" "test_aio_grpc_stream_infer"; do TEST_LOG="./grpc_cancellation_test.$TEST_CASE.log" SERVER_LOG="grpc_cancellation_test.$TEST_CASE.server.log"