From 16a464df7549ea1794fbd007c36729f55c37c2bc Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Thu, 18 Jul 2024 13:05:00 +0200 Subject: [PATCH] feat: add flush all option --- jina/serve/runtimes/worker/batch_queue.py | 9 +++++++-- jina/serve/runtimes/worker/http_fastapi_app.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/jina/serve/runtimes/worker/batch_queue.py b/jina/serve/runtimes/worker/batch_queue.py index 530f5f58d3a81..8f7e0d283b413 100644 --- a/jina/serve/runtimes/worker/batch_queue.py +++ b/jina/serve/runtimes/worker/batch_queue.py @@ -23,6 +23,7 @@ def __init__( response_docarray_cls, output_array_type: Optional[str] = None, params: Optional[Dict] = None, + flush_all: bool = False, preferred_batch_size: int = 4, timeout: int = 10_000, ) -> None: @@ -35,6 +36,7 @@ def __init__( self.params = params self._request_docarray_cls = request_docarray_cls self._response_docarray_cls = response_docarray_cls + self._flush_all = flush_all self._preferred_batch_size: int = preferred_batch_size self._timeout: int = timeout self._reset() @@ -205,7 +207,10 @@ async def _assign_results( return num_assigned_docs - def batch(iterable_1, iterable_2, n=1): + def batch(iterable_1, iterable_2, n:Optional[int] = 1): + if n is None: + yield iterable_1, iterable_2 + return items = len(iterable_1) for ndx in range(0, items, n): yield iterable_1[ndx : min(ndx + n, items)], iterable_2[ @@ -229,7 +234,7 @@ def batch(iterable_1, iterable_2, n=1): non_assigned_to_response_request_idxs = [] sum_from_previous_first_req_idx = 0 for docs_inner_batch, req_idxs in batch( - self._big_doc, self._request_idxs, self._preferred_batch_size + self._big_doc, self._request_idxs, self._preferred_batch_size if not self._flush_all else None ): involved_requests_min_indx = req_idxs[0] involved_requests_max_indx = req_idxs[-1] diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index fd4aeabf8c79c..b45b94f7c62cf 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -7,6 +7,7 @@ from jina.serve.networking.sse import EventSourceResponse from jina.types.request.data import DataRequest + if TYPE_CHECKING: from jina.logging.logger import JinaLogger @@ -88,7 +89,6 @@ def add_post_route( @app.api_route(**app_kwargs) async def post(body: input_model, response: Response): - req = DataRequest() if body.header is not None: req.header.request_id = body.header.request_id @@ -122,7 +122,9 @@ async def post(body: input_model, response: Response): docs_response = resp.docs.to_dict() else: docs_response = resp.docs + ret = output_model(data=docs_response, parameters=resp.parameters) + return ret def add_streaming_routes(