diff --git a/jina/serve/runtimes/worker/batch_queue.py b/jina/serve/runtimes/worker/batch_queue.py index b0ad4beeda05e..e3d7747ac49f7 100644 --- a/jina/serve/runtimes/worker/batch_queue.py +++ b/jina/serve/runtimes/worker/batch_queue.py @@ -38,6 +38,8 @@ def __init__( self._preferred_batch_size: int = preferred_batch_size self._timeout: int = timeout self._reset() + self._flush_trigger: Event = Event() + self._timer_task: Optional[Task] = None def __repr__(self) -> str: return f'{self.__class__.__name__}(preferred_batch_size={self._preferred_batch_size}, timeout={self._timeout})' @@ -57,9 +59,7 @@ def _reset(self) -> None: else: self._big_doc = self._request_docarray_cls() - self._flush_trigger: Event = Event() self._flush_task: Optional[Task] = None - self._timer_task: Optional[Task] = None def _cancel_timer_if_pending(self): if ( @@ -97,11 +97,12 @@ async def push(self, request: DataRequest) -> asyncio.Queue: docs = request.docs # writes to shared data between tasks need to be mutually exclusive + if not self._timer_task: + self._start_timer() async with self._data_lock: if not self._flush_task: self._flush_task = asyncio.create_task(self._await_then_flush()) - if not self._timer_task: - self._start_timer() + self._big_doc.extend(docs) next_req_idx = len(self._requests) num_docs = len(docs) @@ -218,6 +219,9 @@ def batch(iterable_1, iterable_2, n=1): # self._requests with its lengths stored in self._requests_len. For each requests, there is a queue to # communicate that the request has been processed properly. At this stage the data_lock is ours and # therefore noone can add requests to this list. + self._flush_trigger: Event = Event() + self._cancel_timer_if_pending() + self._timer_task = None try: if not docarray_v2: non_assigned_to_response_docs: DocumentArray = DocumentArray.empty() diff --git a/tests/unit/serve/dynamic_batching/test_batch_queue.py b/tests/unit/serve/dynamic_batching/test_batch_queue.py index 1fdd1a79cf123..057ae40dceaf1 100644 --- a/tests/unit/serve/dynamic_batching/test_batch_queue.py +++ b/tests/unit/serve/dynamic_batching/test_batch_queue.py @@ -58,6 +58,49 @@ async def process_request(req): await bq.close() +@pytest.mark.asyncio +async def test_batch_queue_timeout_does_not_wait_previous_batch(): + batches_lengths_computed = [] + async def foo(docs, **kwargs): + await asyncio.sleep(4) + batches_lengths_computed.append(len(docs)) + return DocumentArray([Document(text='Done') for _ in docs]) + + bq: BatchQueue = BatchQueue( + foo, + request_docarray_cls=DocumentArray, + response_docarray_cls=DocumentArray, + preferred_batch_size=5, + timeout=3000, + ) + + data_requests = [DataRequest() for _ in range(3)] + for req in data_requests: + req.data.docs = DocumentArray([Document(text=''), Document(text='')]) + + extra_data_request = DataRequest() + extra_data_request.data.docs = DocumentArray([Document(text=''), Document(text='')]) + + async def process_request(req, sleep=0): + if sleep > 0: + await asyncio.sleep(sleep) + q = await bq.push(req) + _ = await q.get() + q.task_done() + return req + init_time = time.time() + tasks = [asyncio.create_task(process_request(req)) for req in data_requests] + tasks.append(asyncio.create_task(process_request(extra_data_request, sleep=2))) + responses = await asyncio.gather(*tasks) + time_spent = (time.time() - init_time) * 1000 + # TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately + assert time_spent >= 12000 + assert time_spent <= 12500 + assert batches_lengths_computed == [5, 1, 2] + + await bq.close() + + @pytest.mark.asyncio async def test_batch_queue_req_length_larger_than_preferred(): async def foo(docs, **kwargs):