Skip to content

Commit

Permalink
feat: add flush all option to dynamic batching configuration (#6179)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Jul 18, 2024
1 parent e0f620d commit cf7284f
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 29 deletions.
6 changes: 5 additions & 1 deletion jina/serve/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def dynamic_batching(
*,
preferred_batch_size: Optional[int] = None,
timeout: Optional[float] = 10_000,
flush_all: bool = False
):
"""
`@dynamic_batching` defines the dynamic batching behavior of an Executor.
Expand All @@ -426,11 +427,13 @@ def dynamic_batching(
:param func: the method to decorate
:param preferred_batch_size: target number of Documents in a batch. The batcher will collect requests until `preferred_batch_size` is reached,
or until `timeout` is reached. Therefore, the actual batch size can be smaller or larger than `preferred_batch_size`.
or until `timeout` is reached. Therefore, the actual batch size can be smaller or equal to `preferred_batch_size`, except if `flush_all` is set to True
:param timeout: maximum time in milliseconds to wait for a request to be assigned to a batch.
If the oldest request in the queue reaches a waiting time of `timeout`, the batch will be passed to the Executor,
even if it contains fewer than `preferred_batch_size` Documents.
Default is 10_000ms (10 seconds).
:param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not.
If this is true, `preferred_batch_size` is used as a trigger mechanism.
:return: decorated function
"""

Expand Down Expand Up @@ -476,6 +479,7 @@ def _inject_owner_attrs(self, owner, name):
'preferred_batch_size'
] = preferred_batch_size
owner.dynamic_batching[fn_name]['timeout'] = timeout
owner.dynamic_batching[fn_name]['flush_all'] = flush_all
setattr(owner, name, self.fn)

def __set_name__(self, owner, name):
Expand Down
9 changes: 7 additions & 2 deletions jina/serve/runtimes/worker/batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
response_docarray_cls,
output_array_type: Optional[str] = None,
params: Optional[Dict] = None,
flush_all: bool = False,
preferred_batch_size: int = 4,
timeout: int = 10_000,
) -> None:
Expand All @@ -35,6 +36,7 @@ def __init__(
self.params = params
self._request_docarray_cls = request_docarray_cls
self._response_docarray_cls = response_docarray_cls
self._flush_all = flush_all
self._preferred_batch_size: int = preferred_batch_size
self._timeout: int = timeout
self._reset()
Expand Down Expand Up @@ -205,7 +207,10 @@ async def _assign_results(

return num_assigned_docs

def batch(iterable_1, iterable_2, n=1):
def batch(iterable_1, iterable_2, n:Optional[int] = 1):
if n is None:
yield iterable_1, iterable_2
return
items = len(iterable_1)
for ndx in range(0, items, n):
yield iterable_1[ndx : min(ndx + n, items)], iterable_2[
Expand All @@ -229,7 +234,7 @@ def batch(iterable_1, iterable_2, n=1):
non_assigned_to_response_request_idxs = []
sum_from_previous_first_req_idx = 0
for docs_inner_batch, req_idxs in batch(
self._big_doc, self._request_idxs, self._preferred_batch_size
self._big_doc, self._request_idxs, self._preferred_batch_size if not self._flush_all else None
):
involved_requests_min_indx = req_idxs[0]
involved_requests_max_indx = req_idxs[-1]
Expand Down
4 changes: 3 additions & 1 deletion jina/serve/runtimes/worker/http_fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jina.serve.networking.sse import EventSourceResponse
from jina.types.request.data import DataRequest


if TYPE_CHECKING:
from jina.logging.logger import JinaLogger

Expand Down Expand Up @@ -88,7 +89,6 @@ def add_post_route(

@app.api_route(**app_kwargs)
async def post(body: input_model, response: Response):

req = DataRequest()
if body.header is not None:
req.header.request_id = body.header.request_id
Expand Down Expand Up @@ -122,7 +122,9 @@ async def post(body: input_model, response: Response):
docs_response = resp.docs.to_dict()
else:
docs_response = resp.docs

ret = output_model(data=docs_response, parameters=resp.parameters)

return ret

def add_streaming_routes(
Expand Down
60 changes: 56 additions & 4 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,11 +629,17 @@ def test_failure_propagation():
)


@pytest.mark.repeat(10)
def test_exception_handling_in_dynamic_batch():
@pytest.mark.parametrize(
'flush_all',
[
False,
True
],
)
def test_exception_handling_in_dynamic_batch(flush_all):
class SlowExecutorWithException(Executor):

@dynamic_batching(preferred_batch_size=3, timeout=1000)
@dynamic_batching(preferred_batch_size=3, timeout=5000, flush_all=flush_all)
@requests(on='/foo')
def foo(self, docs, **kwargs):
for doc in docs:
Expand All @@ -659,4 +665,50 @@ def foo(self, docs, **kwargs):
if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR:
num_failed_requests += 1

assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing
if not flush_all:
assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing
else:
assert 1 <= num_failed_requests <= len(da) # 3 requests in the dynamic batch failing

@pytest.mark.asyncio
@pytest.mark.parametrize(
'flush_all',
[
False,
True
],
)
async def test_num_docs_processed_in_exec(flush_all):
class DynBatchProcessor(Executor):

@dynamic_batching(preferred_batch_size=5, timeout=5000, flush_all=flush_all)
@requests(on='/foo')
def foo(self, docs, **kwargs):
for doc in docs:
doc.text = f"{len(docs)}"

depl = Deployment(uses=DynBatchProcessor, protocol='http')

with depl:
da = DocumentArray([Document(text='good') for _ in range(50)])
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
res = []
async for r in cl.post(
on='/foo',
inputs=da,
request_size=7,
continue_on_error=True,
results_in_order=True,
):
res.extend(r)
assert len(res) == 50 # 1 request per input
if not flush_all:
for d in res:
assert int(d.text) <= 5
else:
larger_than_5 = 0
for d in res:
if int(d.text) > 5:
larger_than_5 += 1
assert int(d.text) >= 5
assert larger_than_5 > 0
48 changes: 35 additions & 13 deletions tests/unit/serve/dynamic_batching/test_batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


@pytest.mark.asyncio
async def test_batch_queue_timeout():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_batch_queue_timeout(flush_all):
async def foo(docs, **kwargs):
await asyncio.sleep(0.1)
return DocumentArray([Document(text='Done') for _ in docs])
Expand All @@ -20,6 +21,7 @@ async def foo(docs, **kwargs):
response_docarray_cls=DocumentArray,
preferred_batch_size=4,
timeout=2000,
flush_all=flush_all,
)

three_data_requests = [DataRequest() for _ in range(3)]
Expand Down Expand Up @@ -59,7 +61,8 @@ async def process_request(req):


@pytest.mark.asyncio
async def test_batch_queue_timeout_does_not_wait_previous_batch():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all):
batches_lengths_computed = []

async def foo(docs, **kwargs):
Expand All @@ -73,6 +76,7 @@ async def foo(docs, **kwargs):
response_docarray_cls=DocumentArray,
preferred_batch_size=5,
timeout=3000,
flush_all=flush_all
)

data_requests = [DataRequest() for _ in range(3)]
Expand All @@ -93,19 +97,28 @@ async def process_request(req, sleep=0):
init_time = time.time()
tasks = [asyncio.create_task(process_request(req)) for req in data_requests]
tasks.append(asyncio.create_task(process_request(extra_data_request, sleep=2)))
responses = await asyncio.gather(*tasks)
_ = await asyncio.gather(*tasks)
time_spent = (time.time() - init_time) * 1000
# TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately
# BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late)
assert time_spent >= 12000
assert time_spent <= 12500
assert batches_lengths_computed == [5, 1, 2]

if flush_all is False:
# TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately
# BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late)
assert time_spent >= 12000
assert time_spent <= 12500
else:
assert time_spent >= 8000
assert time_spent <= 8500
if flush_all is False:
assert batches_lengths_computed == [5, 1, 2]
else:
assert batches_lengths_computed == [6, 2]

await bq.close()


@pytest.mark.asyncio
async def test_batch_queue_req_length_larger_than_preferred():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_batch_queue_req_length_larger_than_preferred(flush_all):
async def foo(docs, **kwargs):
await asyncio.sleep(0.1)
return DocumentArray([Document(text='Done') for _ in docs])
Expand All @@ -116,6 +129,7 @@ async def foo(docs, **kwargs):
response_docarray_cls=DocumentArray,
preferred_batch_size=4,
timeout=2000,
flush_all=flush_all,
)

data_requests = [DataRequest() for _ in range(3)]
Expand Down Expand Up @@ -240,7 +254,8 @@ async def process_request(req):


@pytest.mark.asyncio
async def test_exception_all():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_exception_all(flush_all):
async def foo(docs, **kwargs):
raise AssertionError

Expand All @@ -249,6 +264,7 @@ async def foo(docs, **kwargs):
request_docarray_cls=DocumentArray,
response_docarray_cls=DocumentArray,
preferred_batch_size=2,
flush_all=flush_all,
timeout=500,
)

Expand Down Expand Up @@ -284,14 +300,19 @@ async def foo(docs, **kwargs):
assert repr(bq) == str(bq)


@pytest.mark.parametrize('num_requests', [61, 127, 100])
@pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100])
@pytest.mark.parametrize('num_requests', [33, 127, 100])
@pytest.mark.parametrize('preferred_batch_size', [7, 61, 100])
@pytest.mark.parametrize('timeout', [0.3, 500])
@pytest.mark.parametrize('flush_all', [False, True])
@pytest.mark.asyncio
async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout):
async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all):
import random

async def foo(docs, **kwargs):
if not flush_all:
assert len(docs) <= preferred_batch_size
else:
assert len(docs) >= preferred_batch_size
await asyncio.sleep(0.1)
for doc in docs:
doc.text += ' Processed'
Expand All @@ -301,6 +322,7 @@ async def foo(docs, **kwargs):
request_docarray_cls=DocumentArray,
response_docarray_cls=DocumentArray,
preferred_batch_size=preferred_batch_size,
flush_all=flush_all,
timeout=timeout,
)

Expand Down
16 changes: 8 additions & 8 deletions tests/unit/serve/executors/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,15 +614,15 @@ class C(B):
[
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False),
),
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
),
(
dict(preferred_batch_size=4),
dict(preferred_batch_size=4, timeout=10_000),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False),
),
],
)
Expand All @@ -641,15 +641,15 @@ def foo(self, docs, **kwargs):
[
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False),
),
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
),
(
dict(preferred_batch_size=4),
dict(preferred_batch_size=4, timeout=10_000),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False),
),
],
)
Expand Down

0 comments on commit cf7284f

Please sign in to comment.