Skip to content

Commit

Permalink
test: test batch queue flush all
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 18, 2024
1 parent 16a464d commit 958c5d2
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 24 deletions.
6 changes: 5 additions & 1 deletion jina/serve/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def dynamic_batching(
*,
preferred_batch_size: Optional[int] = None,
timeout: Optional[float] = 10_000,
flush_all: bool = False
):
"""
`@dynamic_batching` defines the dynamic batching behavior of an Executor.
Expand All @@ -426,11 +427,13 @@ def dynamic_batching(
:param func: the method to decorate
:param preferred_batch_size: target number of Documents in a batch. The batcher will collect requests until `preferred_batch_size` is reached,
or until `timeout` is reached. Therefore, the actual batch size can be smaller or larger than `preferred_batch_size`.
or until `timeout` is reached. Therefore, the actual batch size can be smaller or equal to `preferred_batch_size`, except if `flush_all` is set to True
:param timeout: maximum time in milliseconds to wait for a request to be assigned to a batch.
If the oldest request in the queue reaches a waiting time of `timeout`, the batch will be passed to the Executor,
even if it contains fewer than `preferred_batch_size` Documents.
Default is 10_000ms (10 seconds).
:param flush_all: Determines if once the batches is triggered by timeout or preferred_batch_size, the function will receive everything that the batcher has accumulated or not.
If this is true, `preferred_batch_size` is used as a trigger mechanism.
:return: decorated function
"""

Expand Down Expand Up @@ -476,6 +479,7 @@ def _inject_owner_attrs(self, owner, name):
'preferred_batch_size'
] = preferred_batch_size
owner.dynamic_batching[fn_name]['timeout'] = timeout
owner.dynamic_batching[fn_name]['flush_all'] = flush_all
setattr(owner, name, self.fn)

def __set_name__(self, owner, name):
Expand Down
60 changes: 56 additions & 4 deletions tests/integration/dynamic_batching/test_dynamic_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,11 +629,17 @@ def test_failure_propagation():
)


@pytest.mark.repeat(10)
def test_exception_handling_in_dynamic_batch():
@pytest.mark.parametrize(
'flush_all',
[
False,
True
],
)
def test_exception_handling_in_dynamic_batch(flush_all):
class SlowExecutorWithException(Executor):

@dynamic_batching(preferred_batch_size=3, timeout=1000)
@dynamic_batching(preferred_batch_size=3, timeout=5000, flush_all=flush_all)
@requests(on='/foo')
def foo(self, docs, **kwargs):
for doc in docs:
Expand All @@ -659,4 +665,50 @@ def foo(self, docs, **kwargs):
if r.header.status.code == jina_pb2.StatusProto.StatusCode.ERROR:
num_failed_requests += 1

assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing
if not flush_all:
assert 1 <= num_failed_requests <= 3 # 3 requests in the dynamic batch failing
else:
assert 1 <= num_failed_requests <= len(da) # 3 requests in the dynamic batch failing

@pytest.mark.asyncio
@pytest.mark.parametrize(
'flush_all',
[
False,
True
],
)
async def test_num_docs_processed_in_exec(flush_all):
class DynBatchProcessor(Executor):

@dynamic_batching(preferred_batch_size=5, timeout=5000, flush_all=flush_all)
@requests(on='/foo')
def foo(self, docs, **kwargs):
for doc in docs:
doc.text = f"{len(docs)}"

depl = Deployment(uses=DynBatchProcessor, protocol='http')

with depl:
da = DocumentArray([Document(text='good') for _ in range(50)])
cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True)
res = []
async for r in cl.post(
on='/foo',
inputs=da,
request_size=7,
continue_on_error=True,
results_in_order=True,
):
res.extend(r)
assert len(res) == 50 # 1 request per input
if not flush_all:
for d in res:
assert int(d.text) <= 5
else:
larger_than_5 = 0
for d in res:
if int(d.text) > 5:
larger_than_5 += 1
assert int(d.text) >= 5
assert larger_than_5 > 0
95 changes: 84 additions & 11 deletions tests/unit/serve/dynamic_batching/test_batch_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


@pytest.mark.asyncio
async def test_batch_queue_timeout():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_batch_queue_timeout(flush_all):
async def foo(docs, **kwargs):
await asyncio.sleep(0.1)
return DocumentArray([Document(text='Done') for _ in docs])
Expand All @@ -20,6 +21,7 @@ async def foo(docs, **kwargs):
response_docarray_cls=DocumentArray,
preferred_batch_size=4,
timeout=2000,
flush_all=flush_all,
)

three_data_requests = [DataRequest() for _ in range(3)]
Expand Down Expand Up @@ -59,7 +61,8 @@ async def process_request(req):


@pytest.mark.asyncio
async def test_batch_queue_timeout_does_not_wait_previous_batch():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_batch_queue_timeout_does_not_wait_previous_batch(flush_all):
batches_lengths_computed = []

async def foo(docs, **kwargs):
Expand All @@ -73,6 +76,7 @@ async def foo(docs, **kwargs):
response_docarray_cls=DocumentArray,
preferred_batch_size=5,
timeout=3000,
flush_all=flush_all
)

data_requests = [DataRequest() for _ in range(3)]
Expand All @@ -93,19 +97,28 @@ async def process_request(req, sleep=0):
init_time = time.time()
tasks = [asyncio.create_task(process_request(req)) for req in data_requests]
tasks.append(asyncio.create_task(process_request(extra_data_request, sleep=2)))
responses = await asyncio.gather(*tasks)
_ = await asyncio.gather(*tasks)
time_spent = (time.time() - init_time) * 1000
# TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately
# BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late)
assert time_spent >= 12000
assert time_spent <= 12500
assert batches_lengths_computed == [5, 1, 2]

if flush_all is False:
# TIME TAKEN: 8000 for first batch of requests, plus 4000 for second batch that is fired inmediately
# BEFORE FIX in https://github.com/jina-ai/jina/pull/6071, this would take: 8000 + 3000 + 4000 (Timeout would start counting too late)
assert time_spent >= 12000
assert time_spent <= 12500
else:
assert time_spent >= 8000
assert time_spent <= 8500
if flush_all is False:
assert batches_lengths_computed == [5, 1, 2]
else:
assert batches_lengths_computed == [6, 2]

await bq.close()


@pytest.mark.asyncio
async def test_batch_queue_req_length_larger_than_preferred():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_batch_queue_req_length_larger_than_preferred(flush_all):
async def foo(docs, **kwargs):
await asyncio.sleep(0.1)
return DocumentArray([Document(text='Done') for _ in docs])
Expand All @@ -116,6 +129,7 @@ async def foo(docs, **kwargs):
response_docarray_cls=DocumentArray,
preferred_batch_size=4,
timeout=2000,
flush_all=flush_all,
)

data_requests = [DataRequest() for _ in range(3)]
Expand Down Expand Up @@ -240,7 +254,8 @@ async def process_request(req):


@pytest.mark.asyncio
async def test_exception_all():
@pytest.mark.parametrize('flush_all', [False, True])
async def test_exception_all(flush_all):
async def foo(docs, **kwargs):
raise AssertionError

Expand All @@ -249,6 +264,7 @@ async def foo(docs, **kwargs):
request_docarray_cls=DocumentArray,
response_docarray_cls=DocumentArray,
preferred_batch_size=2,
flush_all=flush_all,
timeout=500,
)

Expand Down Expand Up @@ -287,8 +303,9 @@ async def foo(docs, **kwargs):
@pytest.mark.parametrize('num_requests', [61, 127, 100])
@pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100])
@pytest.mark.parametrize('timeout', [0.3, 500])
@pytest.mark.parametrize('flush_all', [False, True])
@pytest.mark.asyncio
async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout):
async def test_return_proper_assignment(num_requests, preferred_batch_size, timeout, flush_all):
import random

async def foo(docs, **kwargs):
Expand All @@ -301,6 +318,7 @@ async def foo(docs, **kwargs):
request_docarray_cls=DocumentArray,
response_docarray_cls=DocumentArray,
preferred_batch_size=preferred_batch_size,
flush_all=flush_all,
timeout=timeout,
)

Expand Down Expand Up @@ -331,3 +349,58 @@ async def process_request(req):
assert len(resp.docs) == length
for j, d in enumerate(resp.docs):
assert d.text == f'Text {j} from request {i} with len {length} Processed'


@pytest.mark.parametrize('num_requests', [61, 127, 100])
@pytest.mark.parametrize('preferred_batch_size', [7, 27, 61, 73, 100])
@pytest.mark.parametrize('timeout', [0.3, 500])
@pytest.mark.parametrize('flush_all', [False, True])
@pytest.mark.asyncio
async def test_length_processed_in_func(num_requests, preferred_batch_size, timeout, flush_all):
import random

async def foo(docs, **kwargs):
if not flush_all:
assert len(docs) <= preferred_batch_size
else:
assert len(docs) >= preferred_batch_size
await asyncio.sleep(0.1)
for doc in docs:
doc.text += ' Processed'

bq: BatchQueue = BatchQueue(
foo,
request_docarray_cls=DocumentArray,
response_docarray_cls=DocumentArray,
preferred_batch_size=preferred_batch_size,
flush_all=flush_all,
timeout=timeout,
)

data_requests = [DataRequest() for _ in range(num_requests)]
len_requests = []
for i, req in enumerate(data_requests):
len_request = random.randint(preferred_batch_size, preferred_batch_size * 10)
len_requests.append(len_request)
req.data.docs = DocumentArray(
[
Document(text=f'Text {j} from request {i} with len {len_request}')
for j in range(len_request)
]
)

async def process_request(req):
q = await bq.push(req)
item = await q.get()
q.task_done()
return item

tasks = [asyncio.create_task(process_request(req)) for req in data_requests]
items = await asyncio.gather(*tasks)
for i, item in enumerate(items):
assert item is None

for i, (resp, length) in enumerate(zip(data_requests, len_requests)):
assert len(resp.docs) == length
for j, d in enumerate(resp.docs):
assert d.text == f'Text {j} from request {i} with len {length} Processed'
16 changes: 8 additions & 8 deletions tests/unit/serve/executors/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,15 +614,15 @@ class C(B):
[
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False),
),
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
),
(
dict(preferred_batch_size=4),
dict(preferred_batch_size=4, timeout=10_000),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False),
),
],
)
Expand All @@ -641,15 +641,15 @@ def foo(self, docs, **kwargs):
[
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=False),
),
(
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
dict(preferred_batch_size=4, timeout=5_000, flush_all=True),
),
(
dict(preferred_batch_size=4),
dict(preferred_batch_size=4, timeout=10_000),
dict(preferred_batch_size=4, timeout=10_000, flush_all=False),
),
],
)
Expand Down

0 comments on commit 958c5d2

Please sign in to comment.