diff --git a/jina/serve/executors/decorators.py b/jina/serve/executors/decorators.py index 4034439102478..7c7a6e4031bcf 100644 --- a/jina/serve/executors/decorators.py +++ b/jina/serve/executors/decorators.py @@ -416,6 +416,7 @@ def dynamic_batching( *, preferred_batch_size: Optional[int] = None, timeout: Optional[float] = 10_000, + flush_all: bool = False ): """ `@dynamic_batching` defines the dynamic batching behavior of an Executor. @@ -426,11 +427,13 @@ def dynamic_batching( :param func: the method to decorate :param preferred_batch_size: target number of Documents in a batch. The batcher will collect requests until `preferred_batch_size` is reached, - or until `timeout` is reached. Therefore, the actual batch size can be smaller or larger than `preferred_batch_size`. + or until `timeout` is reached. Therefore, the actual batch size can be smaller or equal to `preferred_batch_size`, except if `flush_all` is set to True :param timeout: maximum time in milliseconds to wait for a request to be assigned to a batch. If the oldest request in the queue reaches a waiting time of `timeout`, the batch will be passed to the Executor, even if it contains fewer than `preferred_batch_size` Documents. Default is 10_000ms (10 seconds). + :param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not. + If this is true, `preferred_batch_size` is used as a trigger mechanism. :return: decorated function """ @@ -476,6 +479,7 @@ def _inject_owner_attrs(self, owner, name): 'preferred_batch_size' ] = preferred_batch_size owner.dynamic_batching[fn_name]['timeout'] = timeout + owner.dynamic_batching[fn_name]['flush_all'] = flush_all setattr(owner, name, self.fn) def __set_name__(self, owner, name): diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index 90126a82700f5..355e771c52fc7 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -629,11 +629,17 @@ def test_failure_propagation(): ) -@pytest.mark.repeat(10) -def test_exception_handling_in_dynamic_batch(): +@pytest.mark.parametrize( + 'flush_all', + [ + False, + True + ], +) +def test_exception_handling_in_dynamic_batch(flush_all): class SlowExecutorWithException(Executor): - @dynamic_batching(preferred_batch_size=3, timeout=1000) + @dynamic_batching(preferred_batch_size=3, timeout=5000, flush_all=flush_all) @requests(on='/foo') def foo(self, docs, **kwargs): for doc in docs: @@ -659,4 +665,50 @@ def foo(self, docs, **kwargs): if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR: num_failed_requests += 1 - assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing + if not flush_all: + assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing + else: + assert 1 <= num_failed_requests <= len(da) # 3 requests in the dynamic batch failing + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'flush_all', + [ + False, + True + ], +) +async def test_num_docs_processed_in_exec(flush_all): + class DynBatchProcessor(Executor): + + @dynamic_batching(preferred_batch_size=5, timeout=5000, flush_all=flush_all) + @requests(on='/foo') + def foo(self, docs, **kwargs): + for doc in docs: + doc.text = f"{len(docs)}" + + depl = Deployment(uses=DynBatchProcessor, protocol='http') + + with depl: + da = DocumentArray([Document(text='good') for _ in range(50)]) + cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True) + res = [] + async for r in cl.post( + on='/foo', + inputs=da, + request_size=7, + continue_on_error=True, + results_in_order=True, + ): + res.extend(r) + assert len(res) == 50 # 1 request per input + if not flush_all: + for d in res: + assert int(d.text) <= 5 + else: + larger_than_5 = 0 + for d in res: + if int(d.text) > 5: + larger_than_5 += 1 + assert int(d.text) >= 5 + assert larger_than_5 > 0 diff --git a/tests/unit/serve/dynamic_batching/test_batch_queue.py b/tests/unit/serve/dynamic_batching/test_batch_queue.py index 22758995d7270..8cf902cbf38f9 100644 --- a/tests/unit/serve/dynamic_batching/test_batch_queue.py +++ b/tests/unit/serve/dynamic_batching/test_batch_queue.py @@ -9,7 +9,8 @@ @pytest.mark.asyncio -async def test_batch_queue_timeout(): +@pytest.mark.parametrize('flush_all', [False, True]) +async def test_batch_queue_timeout(flush_all): async def foo(docs, **kwargs): await asyncio.sleep(0.1) return DocumentArray([Document(text='Done') for _ in docs]) @@ -20,6 +21,7 @@ async def foo(docs, **kwargs): response_docarray_cls=DocumentArray, preferred_batch_size=4, timeout=2000, + flush_all=flush_all, ) three_data_requests = [DataRequest() for _ in range(3)] @@ -59,7 +61,8 @@ async def process_request(req): @pytest.mark.asyncio -async def test_batch_queue_timeout_does_not_wait_previous_batch(): +@pytest.mark.parametrize('flush_all', [False, True]) +async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all): batches_lengths_computed = [] async def foo(docs, **kwargs): @@ -73,6 +76,7 @@ async def foo(docs, **kwargs): response_docarray_cls=DocumentArray, preferred_batch_size=5, timeout=3000, + flush_all=flush_all ) data_requests = [DataRequest() for _ in range(3)] @@ -93,19 +97,28 @@ async def process_request(req, sleep=0): init_time = time.time() tasks = [asyncio.create_task(process_request(req)) for req in data_requests] tasks.append(asyncio.create_task(process_request(extra_data_request, sleep=2))) - responses = await asyncio.gather(*tasks) + _ = await asyncio.gather(*tasks) time_spent = (time.time() - init_time) * 1000 - # TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately - # BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late) - assert time_spent >= 12000 - assert time_spent <= 12500 - assert batches_lengths_computed == [5, 1, 2] + + if flush_all is False: + # TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately + # BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late) + assert time_spent >= 12000 + assert time_spent <= 12500 + else: + assert time_spent >= 8000 + assert time_spent <= 8500 + if flush_all is False: + assert batches_lengths_computed == [5, 1, 2] + else: + assert batches_lengths_computed == [6, 2] await bq.close() @pytest.mark.asyncio -async def test_batch_queue_req_length_larger_than_preferred(): +@pytest.mark.parametrize('flush_all', [False, True]) +async def test_batch_queue_req_length_larger_than_preferred(flush_all): async def foo(docs, **kwargs): await asyncio.sleep(0.1) return DocumentArray([Document(text='Done') for _ in docs]) @@ -116,6 +129,7 @@ async def foo(docs, **kwargs): response_docarray_cls=DocumentArray, preferred_batch_size=4, timeout=2000, + flush_all=flush_all, ) data_requests = [DataRequest() for _ in range(3)] @@ -240,7 +254,8 @@ async def process_request(req): @pytest.mark.asyncio -async def test_exception_all(): +@pytest.mark.parametrize('flush_all', [False, True]) +async def test_exception_all(flush_all): async def foo(docs, **kwargs): raise AssertionError @@ -249,6 +264,7 @@ async def foo(docs, **kwargs): request_docarray_cls=DocumentArray, response_docarray_cls=DocumentArray, preferred_batch_size=2, + flush_all=flush_all, timeout=500, ) @@ -287,8 +303,9 @@ async def foo(docs, **kwargs): @pytest.mark.parametrize('num_requests', [61, 127, 100]) @pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100]) @pytest.mark.parametrize('timeout', [0.3, 500]) +@pytest.mark.parametrize('flush_all', [False, True]) @pytest.mark.asyncio -async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout): +async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all): import random async def foo(docs, **kwargs): @@ -301,6 +318,7 @@ async def foo(docs, **kwargs): request_docarray_cls=DocumentArray, response_docarray_cls=DocumentArray, preferred_batch_size=preferred_batch_size, + flush_all=flush_all, timeout=timeout, ) @@ -331,3 +349,58 @@ async def process_request(req): assert len(resp.docs) == length for j, d in enumerate(resp.docs): assert d.text == f'Text {j} from request {i} with len {length} Processed' + + +@pytest.mark.parametrize('num_requests', [61, 127, 100]) +@pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100]) +@pytest.mark.parametrize('timeout', [0.3, 500]) +@pytest.mark.parametrize('flush_all', [False, True]) +@pytest.mark.asyncio +async def test_length_processed_in_func(num_requests, preferred_batch_size, timeout, flush_all): + import random + + async def foo(docs, **kwargs): + if not flush_all: + assert len(docs) <= preferred_batch_size + else: + assert len(docs) >= preferred_batch_size + await asyncio.sleep(0.1) + for doc in docs: + doc.text += ' Processed' + + bq: BatchQueue = BatchQueue( + foo, + request_docarray_cls=DocumentArray, + response_docarray_cls=DocumentArray, + preferred_batch_size=preferred_batch_size, + flush_all=flush_all, + timeout=timeout, + ) + + data_requests = [DataRequest() for _ in range(num_requests)] + len_requests = [] + for i, req in enumerate(data_requests): + len_request = random.randint(preferred_batch_size, preferred_batch_size * 10) + len_requests.append(len_request) + req.data.docs = DocumentArray( + [ + Document(text=f'Text {j} from request {i} with len {len_request}') + for j in range(len_request) + ] + ) + + async def process_request(req): + q = await bq.push(req) + item = await q.get() + q.task_done() + return item + + tasks = [asyncio.create_task(process_request(req)) for req in data_requests] + items = await asyncio.gather(*tasks) + for i, item in enumerate(items): + assert item is None + + for i, (resp, length) in enumerate(zip(data_requests, len_requests)): + assert len(resp.docs) == length + for j, d in enumerate(resp.docs): + assert d.text == f'Text {j} from request {i} with len {length} Processed' diff --git a/tests/unit/serve/executors/test_executor.py b/tests/unit/serve/executors/test_executor.py index 3bb6f1769ceff..a6d902421ae83 100644 --- a/tests/unit/serve/executors/test_executor.py +++ b/tests/unit/serve/executors/test_executor.py @@ -614,15 +614,15 @@ class C(B): [ ( dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000), + dict(preferred_batch_size=4, timeout=5_000, flush_all=False), ), ( - dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True), ), ( dict(preferred_batch_size=4), - dict(preferred_batch_size=4, timeout=10_000), + dict(preferred_batch_size=4, timeout=10_000, flush_all=False), ), ], ) @@ -641,15 +641,15 @@ def foo(self, docs, **kwargs): [ ( dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000), + dict(preferred_batch_size=4, timeout=5_000, flush_all=False), ), ( - dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True), ), ( dict(preferred_batch_size=4), - dict(preferred_batch_size=4, timeout=10_000), + dict(preferred_batch_size=4, timeout=10_000, flush_all=False), ), ], )