diff --git a/jina/serve/executors/decorators.py b/jina/serve/executors/decorators.py index 4034439102478..7c7a6e4031bcf 100644 --- a/jina/serve/executors/decorators.py +++ b/jina/serve/executors/decorators.py @@ -416,6 +416,7 @@ def dynamic_batching( *, preferred_batch_size: Optional[int] = None, timeout: Optional[float] = 10_000, + flush_all: bool = False ): """ `@dynamic_batching` defines the dynamic batching behavior of an Executor. @@ -426,11 +427,13 @@ def dynamic_batching( :param func: the method to decorate :param preferred_batch_size: target number of Documents in a batch. The batcher will collect requests until `preferred_batch_size` is reached, - or until `timeout` is reached. Therefore, the actual batch size can be smaller or larger than `preferred_batch_size`. + or until `timeout` is reached. Therefore, the actual batch size can be smaller or equal to `preferred_batch_size`, except if `flush_all` is set to True :param timeout: maximum time in milliseconds to wait for a request to be assigned to a batch. If the oldest request in the queue reaches a waiting time of `timeout`, the batch will be passed to the Executor, even if it contains fewer than `preferred_batch_size` Documents. Default is 10_000ms (10 seconds). + :param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not. + If this is true, `preferred_batch_size` is used as a trigger mechanism. :return: decorated function """ @@ -476,6 +479,7 @@ def _inject_owner_attrs(self, owner, name): 'preferred_batch_size' ] = preferred_batch_size owner.dynamic_batching[fn_name]['timeout'] = timeout + owner.dynamic_batching[fn_name]['flush_all'] = flush_all setattr(owner, name, self.fn) def __set_name__(self, owner, name): diff --git a/jina/serve/runtimes/worker/batch_queue.py b/jina/serve/runtimes/worker/batch_queue.py index 530f5f58d3a81..8f7e0d283b413 100644 --- a/jina/serve/runtimes/worker/batch_queue.py +++ b/jina/serve/runtimes/worker/batch_queue.py @@ -23,6 +23,7 @@ def __init__( response_docarray_cls, output_array_type: Optional[str] = None, params: Optional[Dict] = None, + flush_all: bool = False, preferred_batch_size: int = 4, timeout: int = 10_000, ) -> None: @@ -35,6 +36,7 @@ def __init__( self.params = params self._request_docarray_cls = request_docarray_cls self._response_docarray_cls = response_docarray_cls + self._flush_all = flush_all self._preferred_batch_size: int = preferred_batch_size self._timeout: int = timeout self._reset() @@ -205,7 +207,10 @@ async def _assign_results( return num_assigned_docs - def batch(iterable_1, iterable_2, n=1): + def batch(iterable_1, iterable_2, n:Optional[int] = 1): + if n is None: + yield iterable_1, iterable_2 + return items = len(iterable_1) for ndx in range(0, items, n): yield iterable_1[ndx : min(ndx + n, items)], iterable_2[ @@ -229,7 +234,7 @@ def batch(iterable_1, iterable_2, n=1): non_assigned_to_response_request_idxs = [] sum_from_previous_first_req_idx = 0 for docs_inner_batch, req_idxs in batch( - self._big_doc, self._request_idxs, self._preferred_batch_size + self._big_doc, self._request_idxs, self._preferred_batch_size if not self._flush_all else None ): involved_requests_min_indx = req_idxs[0] involved_requests_max_indx = req_idxs[-1] diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index fd4aeabf8c79c..b45b94f7c62cf 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -7,6 +7,7 @@ from jina.serve.networking.sse import EventSourceResponse from jina.types.request.data import DataRequest + if TYPE_CHECKING: from jina.logging.logger import JinaLogger @@ -88,7 +89,6 @@ def add_post_route( @app.api_route(**app_kwargs) async def post(body: input_model, response: Response): - req = DataRequest() if body.header is not None: req.header.request_id = body.header.request_id @@ -122,7 +122,9 @@ async def post(body: input_model, response: Response): docs_response = resp.docs.to_dict() else: docs_response = resp.docs + ret = output_model(data=docs_response, parameters=resp.parameters) + return ret def add_streaming_routes( diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index 90126a82700f5..355e771c52fc7 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -629,11 +629,17 @@ def test_failure_propagation(): ) -@pytest.mark.repeat(10) -def test_exception_handling_in_dynamic_batch(): +@pytest.mark.parametrize( + 'flush_all', + [ + False, + True + ], +) +def test_exception_handling_in_dynamic_batch(flush_all): class SlowExecutorWithException(Executor): - @dynamic_batching(preferred_batch_size=3, timeout=1000) + @dynamic_batching(preferred_batch_size=3, timeout=5000, flush_all=flush_all) @requests(on='/foo') def foo(self, docs, **kwargs): for doc in docs: @@ -659,4 +665,50 @@ def foo(self, docs, **kwargs): if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR: num_failed_requests += 1 - assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing + if not flush_all: + assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing + else: + assert 1 <= num_failed_requests <= len(da) # 3 requests in the dynamic batch failing + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'flush_all', + [ + False, + True + ], +) +async def test_num_docs_processed_in_exec(flush_all): + class DynBatchProcessor(Executor): + + @dynamic_batching(preferred_batch_size=5, timeout=5000, flush_all=flush_all) + @requests(on='/foo') + def foo(self, docs, **kwargs): + for doc in docs: + doc.text = f"{len(docs)}" + + depl = Deployment(uses=DynBatchProcessor, protocol='http') + + with depl: + da = DocumentArray([Document(text='good') for _ in range(50)]) + cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True) + res = [] + async for r in cl.post( + on='/foo', + inputs=da, + request_size=7, + continue_on_error=True, + results_in_order=True, + ): + res.extend(r) + assert len(res) == 50 # 1 request per input + if not flush_all: + for d in res: + assert int(d.text) <= 5 + else: + larger_than_5 = 0 + for d in res: + if int(d.text) > 5: + larger_than_5 += 1 + assert int(d.text) >= 5 + assert larger_than_5 > 0 diff --git a/tests/unit/serve/dynamic_batching/test_batch_queue.py b/tests/unit/serve/dynamic_batching/test_batch_queue.py index 22758995d7270..2d0a172ca5a27 100644 --- a/tests/unit/serve/dynamic_batching/test_batch_queue.py +++ b/tests/unit/serve/dynamic_batching/test_batch_queue.py @@ -9,7 +9,8 @@ @pytest.mark.asyncio -async def test_batch_queue_timeout(): +@pytest.mark.parametrize('flush_all', [False, True]) +async def test_batch_queue_timeout(flush_all): async def foo(docs, **kwargs): await asyncio.sleep(0.1) return DocumentArray([Document(text='Done') for _ in docs]) @@ -20,6 +21,7 @@ async def foo(docs, **kwargs): response_docarray_cls=DocumentArray, preferred_batch_size=4, timeout=2000, + flush_all=flush_all, ) three_data_requests = [DataRequest() for _ in range(3)] @@ -59,7 +61,8 @@ async def process_request(req): @pytest.mark.asyncio -async def test_batch_queue_timeout_does_not_wait_previous_batch(): +@pytest.mark.parametrize('flush_all', [False, True]) +async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all): batches_lengths_computed = [] async def foo(docs, **kwargs): @@ -73,6 +76,7 @@ async def foo(docs, **kwargs): response_docarray_cls=DocumentArray, preferred_batch_size=5, timeout=3000, + flush_all=flush_all ) data_requests = [DataRequest() for _ in range(3)] @@ -93,19 +97,28 @@ async def process_request(req, sleep=0): init_time = time.time() tasks = [asyncio.create_task(process_request(req)) for req in data_requests] tasks.append(asyncio.create_task(process_request(extra_data_request, sleep=2))) - responses = await asyncio.gather(*tasks) + _ = await asyncio.gather(*tasks) time_spent = (time.time() - init_time) * 1000 - # TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately - # BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late) - assert time_spent >= 12000 - assert time_spent <= 12500 - assert batches_lengths_computed == [5, 1, 2] + + if flush_all is False: + # TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately + # BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late) + assert time_spent >= 12000 + assert time_spent <= 12500 + else: + assert time_spent >= 8000 + assert time_spent <= 8500 + if flush_all is False: + assert batches_lengths_computed == [5, 1, 2] + else: + assert batches_lengths_computed == [6, 2] await bq.close() @pytest.mark.asyncio -async def test_batch_queue_req_length_larger_than_preferred(): +@pytest.mark.parametrize('flush_all', [False, True]) +async def test_batch_queue_req_length_larger_than_preferred(flush_all): async def foo(docs, **kwargs): await asyncio.sleep(0.1) return DocumentArray([Document(text='Done') for _ in docs]) @@ -116,6 +129,7 @@ async def foo(docs, **kwargs): response_docarray_cls=DocumentArray, preferred_batch_size=4, timeout=2000, + flush_all=flush_all, ) data_requests = [DataRequest() for _ in range(3)] @@ -240,7 +254,8 @@ async def process_request(req): @pytest.mark.asyncio -async def test_exception_all(): +@pytest.mark.parametrize('flush_all', [False, True]) +async def test_exception_all(flush_all): async def foo(docs, **kwargs): raise AssertionError @@ -249,6 +264,7 @@ async def foo(docs, **kwargs): request_docarray_cls=DocumentArray, response_docarray_cls=DocumentArray, preferred_batch_size=2, + flush_all=flush_all, timeout=500, ) @@ -284,14 +300,19 @@ async def foo(docs, **kwargs): assert repr(bq) == str(bq) -@pytest.mark.parametrize('num_requests', [61, 127, 100]) -@pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100]) +@pytest.mark.parametrize('num_requests', [33, 127, 100]) +@pytest.mark.parametrize('preferred_batch_size', [7, 61, 100]) @pytest.mark.parametrize('timeout', [0.3, 500]) +@pytest.mark.parametrize('flush_all', [False, True]) @pytest.mark.asyncio -async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout): +async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all): import random async def foo(docs, **kwargs): + if not flush_all: + assert len(docs) <= preferred_batch_size + else: + assert len(docs) >= preferred_batch_size await asyncio.sleep(0.1) for doc in docs: doc.text += ' Processed' @@ -301,6 +322,7 @@ async def foo(docs, **kwargs): request_docarray_cls=DocumentArray, response_docarray_cls=DocumentArray, preferred_batch_size=preferred_batch_size, + flush_all=flush_all, timeout=timeout, ) diff --git a/tests/unit/serve/executors/test_executor.py b/tests/unit/serve/executors/test_executor.py index 3bb6f1769ceff..a6d902421ae83 100644 --- a/tests/unit/serve/executors/test_executor.py +++ b/tests/unit/serve/executors/test_executor.py @@ -614,15 +614,15 @@ class C(B): [ ( dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000), + dict(preferred_batch_size=4, timeout=5_000, flush_all=False), ), ( - dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True), ), ( dict(preferred_batch_size=4), - dict(preferred_batch_size=4, timeout=10_000), + dict(preferred_batch_size=4, timeout=10_000, flush_all=False), ), ], ) @@ -641,15 +641,15 @@ def foo(self, docs, **kwargs): [ ( dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000), + dict(preferred_batch_size=4, timeout=5_000, flush_all=False), ), ( - dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True), ), ( dict(preferred_batch_size=4), - dict(preferred_batch_size=4, timeout=10_000), + dict(preferred_batch_size=4, timeout=10_000, flush_all=False), ), ], )