Skip to content

Commit

Permalink
fix: fix start timer before getting data lock
Browse files Browse the repository at this point in the history
Signed-off-by: Joan Fontanals Martinez <[email protected]>
  • Loading branch information
JoanFM committed Oct 6, 2023
1 parent af7fdfe commit b022bdd
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
12 changes: 8 additions & 4 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(
self._preferred_batch_size: int = preferred_batch_size
self._timeout: int = timeout
self._reset()
self._flush_trigger: Event = Event()
self._timer_task: Optional[Task] = None

def __repr__(self) -> str:
return f'{self.__class__.__name__}(preferred_batch_size={self._preferred_batch_size}, timeout={self._timeout})'
Expand All @@ -57,9 +59,7 @@ def _reset(self) -> None:
else:
self._big_doc = self._request_docarray_cls()

self._flush_trigger: Event = Event()
self._flush_task: Optional[Task] = None
self._timer_task: Optional[Task] = None

def _cancel_timer_if_pending(self):
if (
Expand Down Expand Up @@ -97,11 +97,12 @@ async def push(self, request: DataRequest) -> asyncio.Queue:
docs = request.docs

# writes to shared data between tasks need to be mutually exclusive
if not self._timer_task:
self._start_timer()
async with self._data_lock:
if not self._flush_task:
self._flush_task = asyncio.create_task(self._await_then_flush())
if not self._timer_task:
self._start_timer()

self._big_doc.extend(docs)
next_req_idx = len(self._requests)
num_docs = len(docs)
Expand Down Expand Up @@ -218,6 +219,9 @@ def batch(iterable_1, iterable_2, n=1):
# self._requests with its lengths stored in self._requests_len. For each requests, there is a queue to
# communicate that the request has been processed properly. At this stage the data_lock is ours and
# therefore noone can add requests to this list.
self._flush_trigger: Event = Event()
self._cancel_timer_if_pending()
self._timer_task = None
try:
if not docarray_v2:
non_assigned_to_response_docs: DocumentArray = DocumentArray.empty()
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/serve/dynamic_batching/test_batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,49 @@ async def process_request(req):
await bq.close()


@pytest.mark.asyncio
async def test_batch_queue_timeout_does_not_wait_previous_batch():
batches_lengths_computed = []
async def foo(docs, **kwargs):
await asyncio.sleep(4)
batches_lengths_computed.append(len(docs))
return DocumentArray([Document(text='Done') for _ in docs])

bq: BatchQueue = BatchQueue(
foo,
request_docarray_cls=DocumentArray,
response_docarray_cls=DocumentArray,
preferred_batch_size=5,
timeout=3000,
)

data_requests = [DataRequest() for _ in range(3)]
for req in data_requests:
req.data.docs = DocumentArray([Document(text=''), Document(text='')])

extra_data_request = DataRequest()
extra_data_request.data.docs = DocumentArray([Document(text=''), Document(text='')])

async def process_request(req, sleep=0):
if sleep > 0:
await asyncio.sleep(sleep)
q = await bq.push(req)
_ = await q.get()
q.task_done()
return req
init_time = time.time()
tasks = [asyncio.create_task(process_request(req)) for req in data_requests]
tasks.append(asyncio.create_task(process_request(extra_data_request, sleep=2)))
responses = await asyncio.gather(*tasks)
time_spent = (time.time() - init_time) * 1000
# TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately
assert time_spent >= 12000
assert time_spent <= 12500
assert batches_lengths_computed == [5, 1, 2]

await bq.close()


@pytest.mark.asyncio
async def test_batch_queue_req_length_larger_than_preferred():
async def foo(docs, **kwargs):
Expand Down

0 comments on commit b022bdd

Please sign in to comment.