Skip to content

Commit

Permalink
test: test batch queue flush all
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 18, 2024
1 parent 16a464d commit f0fc5e4
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 12 deletions.
6 changes: 5 additions & 1 deletion jina/serve/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def dynamic_batching(
*,
preferred_batch_size: Optional[int] = None,
timeout: Optional[float] = 10_000,
flush_all: bool = False
):
"""
`@dynamic_batching` defines the dynamic batching behavior of an Executor.
Expand All @@ -426,11 +427,13 @@ def dynamic_batching(
:param func: the method to decorate
:param preferred_batch_size: target number of Documents in a batch. The batcher will collect requests until `preferred_batch_size` is reached,
or until `timeout` is reached. Therefore, the actual batch size can be smaller or larger than `preferred_batch_size`.
or until `timeout` is reached. Therefore, the actual batch size can be smaller or equal to `preferred_batch_size`, except if `flush_all` is set to True
:param timeout: maximum time in milliseconds to wait for a request to be assigned to a batch.
If the oldest request in the queue reaches a waiting time of `timeout`, the batch will be passed to the Executor,
even if it contains fewer than `preferred_batch_size` Documents.
Default is 10_000ms (10 seconds).
:param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not.
If this is true, `preferred_batch_size` is used as a trigger mechanism.
:return: decorated function
"""

Expand Down Expand Up @@ -476,6 +479,7 @@ def _inject_owner_attrs(self, owner, name):
'preferred_batch_size'
] = preferred_batch_size
owner.dynamic_batching[fn_name]['timeout'] = timeout
owner.dynamic_batching[fn_name]['flush_all'] = flush_all

Check warning on line 482 in jina/serve/executors/decorators.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/executors/decorators.py#L482

Added line #L482 was not covered by tests
setattr(owner, name, self.fn)

def __set_name__(self, owner, name):
Expand Down
53 changes: 49 additions & 4 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,11 +629,17 @@ def test_failure_propagation():
)


@pytest.mark.repeat(10)
def test_exception_handling_in_dynamic_batch():
@pytest.mark.parametrize(
'flush_all',
[
False,
True
],
)
def test_exception_handling_in_dynamic_batch(flush_all):
class SlowExecutorWithException(Executor):

@dynamic_batching(preferred_batch_size=3, timeout=1000)
@dynamic_batching(preferred_batch_size=3, timeout=5000, flush_all=flush_all)
@requests(on='/foo')
def foo(self, docs, **kwargs):
for doc in docs:
Expand All @@ -659,4 +665,43 @@ def foo(self, docs, **kwargs):
if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR:
num_failed_requests += 1

assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing
if not flush_all:
assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing
else:
assert 1 <= num_failed_requests <= len(da) # 3 requests in the dynamic batch failing


@pytest.mark.parametrize(
'flush_all',
[
False,
True
],
)
def test_num_docs_processed_in_exec(flush_all):
class DynBatchProcessor(Executor):

@dynamic_batching(preferred_batch_size=5, timeout=5000, flush_all=flush_all)
@requests(on='/foo')
def foo(self, docs, **kwargs):
for doc in docs:
doc.text += f" processed with {len(docs)} docs"

depl = Deployment(uses=DynBatchProcessor)

with depl:
da = DocumentArray([Document(text='good') for _ in range(50)])
res = depl.post(
on='/foo',
inputs=da,
request_size=1,
continue_on_error=True,
results_in_order=True,
)
assert len(res) == 50 # 1 request per input
if not flush_all:
for d in res:
assert 'with 5 docs' in d.text
else:
for d in res:
assert 'with 50 docs' in d.text
83 changes: 76 additions & 7 deletions tests/unit/serve/dynamic_batching/test_batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


@pytest.mark.asyncio
async def test_batch_queue_timeout():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_batch_queue_timeout(flush_all):
async def foo(docs, **kwargs):
await asyncio.sleep(0.1)
return DocumentArray([Document(text='Done') for _ in docs])
Expand All @@ -20,6 +21,7 @@ async def foo(docs, **kwargs):
response_docarray_cls=DocumentArray,
preferred_batch_size=4,
timeout=2000,
flush_all=flush_all,
)

three_data_requests = [DataRequest() for _ in range(3)]
Expand Down Expand Up @@ -59,7 +61,8 @@ async def process_request(req):


@pytest.mark.asyncio
async def test_batch_queue_timeout_does_not_wait_previous_batch():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all):
batches_lengths_computed = []

async def foo(docs, **kwargs):
Expand All @@ -73,6 +76,7 @@ async def foo(docs, **kwargs):
response_docarray_cls=DocumentArray,
preferred_batch_size=5,
timeout=3000,
flush_all=flush_all
)

data_requests = [DataRequest() for _ in range(3)]
Expand Down Expand Up @@ -105,7 +109,8 @@ async def process_request(req, sleep=0):


@pytest.mark.asyncio
async def test_batch_queue_req_length_larger_than_preferred():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_batch_queue_req_length_larger_than_preferred(flush_all):
async def foo(docs, **kwargs):
await asyncio.sleep(0.1)
return DocumentArray([Document(text='Done') for _ in docs])
Expand All @@ -116,6 +121,7 @@ async def foo(docs, **kwargs):
response_docarray_cls=DocumentArray,
preferred_batch_size=4,
timeout=2000,
flush_all=flush_all,
)

data_requests = [DataRequest() for _ in range(3)]
Expand All @@ -142,7 +148,8 @@ async def process_request(req):


@pytest.mark.asyncio
async def test_exception():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_exception(flush_all):
BAD_REQUEST_IDX = [2, 6]

async def foo(docs, **kwargs):
Expand All @@ -159,6 +166,7 @@ async def foo(docs, **kwargs):
response_docarray_cls=DocumentArray,
preferred_batch_size=1,
timeout=500,
flush_all=flush_all,
)

data_requests = [DataRequest() for _ in range(35)]
Expand Down Expand Up @@ -188,7 +196,8 @@ async def process_request(req):


@pytest.mark.asyncio
async def test_exception_more_complex():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_exception_more_complex(flush_all):
TRIGGER_BAD_REQUEST_IDX = [2, 6]
EXPECTED_BAD_REQUESTS = [2, 3, 6, 7]

Expand All @@ -208,6 +217,7 @@ async def foo(docs, **kwargs):
request_docarray_cls=DocumentArray,
response_docarray_cls=DocumentArray,
preferred_batch_size=2,
flush_all=flush_all,
timeout=500,
)

Expand Down Expand Up @@ -240,7 +250,8 @@ async def process_request(req):


@pytest.mark.asyncio
async def test_exception_all():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_exception_all(flush_all):
async def foo(docs, **kwargs):
raise AssertionError

Expand All @@ -249,6 +260,7 @@ async def foo(docs, **kwargs):
request_docarray_cls=DocumentArray,
response_docarray_cls=DocumentArray,
preferred_batch_size=2,
flush_all=flush_all,
timeout=500,
)

Expand Down Expand Up @@ -287,8 +299,9 @@ async def foo(docs, **kwargs):
@pytest.mark.parametrize('num_requests', [61, 127, 100])
@pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100])
@pytest.mark.parametrize('timeout', [0.3, 500])
@pytest.mark.parametrize('flush_all', [False, True])
@pytest.mark.asyncio
async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout):
async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all):
import random

async def foo(docs, **kwargs):
Expand All @@ -301,6 +314,7 @@ async def foo(docs, **kwargs):
request_docarray_cls=DocumentArray,
response_docarray_cls=DocumentArray,
preferred_batch_size=preferred_batch_size,
flush_all=flush_all,
timeout=timeout,
)

Expand Down Expand Up @@ -331,3 +345,58 @@ async def process_request(req):
assert len(resp.docs) == length
for j, d in enumerate(resp.docs):
assert d.text == f'Text {j} from request {i} with len {length} Processed'


@pytest.mark.parametrize('num_requests', [61, 127, 100])
@pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100])
@pytest.mark.parametrize('timeout', [0.3, 500])
@pytest.mark.parametrize('flush_all', [False, True])
@pytest.mark.asyncio
async def test_length_processed_in_func(num_requests, preferred_batch_size, timeout, flush_all):
import random

async def foo(docs, **kwargs):
if not flush_all:
assert len(docs) <= preferred_batch_size
else:
assert len(docs) >= preferred_batch_size
await asyncio.sleep(0.1)
for doc in docs:
doc.text += ' Processed'

bq: BatchQueue = BatchQueue(
foo,
request_docarray_cls=DocumentArray,
response_docarray_cls=DocumentArray,
preferred_batch_size=preferred_batch_size,
flush_all=flush_all,
timeout=timeout,
)

data_requests = [DataRequest() for _ in range(num_requests)]
len_requests = []
for i, req in enumerate(data_requests):
len_request = random.randint(preferred_batch_size, preferred_batch_size * 10)
len_requests.append(len_request)
req.data.docs = DocumentArray(
[
Document(text=f'Text {j} from request {i} with len {len_request}')
for j in range(len_request)
]
)

async def process_request(req):
q = await bq.push(req)
item = await q.get()
q.task_done()
return item

tasks = [asyncio.create_task(process_request(req)) for req in data_requests]
items = await asyncio.gather(*tasks)
for i, item in enumerate(items):
assert item is None

for i, (resp, length) in enumerate(zip(data_requests, len_requests)):
assert len(resp.docs) == length
for j, d in enumerate(resp.docs):
assert d.text == f'Text {j} from request {i} with len {length} Processed'

0 comments on commit f0fc5e4

Please sign in to comment.