From 338ac3f3ae40d21278fd79fe11b710df646acdff Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Mon, 23 Sep 2024 16:19:06 +0200 Subject: [PATCH] feat: use dynamic batching param (#6203) --- extra-requirements.txt | 2 +- jina/clients/base/__init__.py | 3 +- jina/serve/executors/decorators.py | 3 ++ jina/serve/runtimes/worker/batch_queue.py | 5 +- .../serve/runtimes/worker/request_handling.py | 9 ++-- .../dynamic_batching/test_dynamic_batching.py | 49 +++++++++++++++++-- tests/k8s/conftest.py | 19 +++---- tests/k8s/test_k8s_deployment.py | 1 - tests/unit/serve/executors/test_executor.py | 12 ++--- 9 files changed, 75 insertions(+), 28 deletions(-) diff --git a/extra-requirements.txt b/extra-requirements.txt index 2c8de1b058d8d..025ccc10625f7 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -71,7 +71,7 @@ mock: test requests-mock: test pytest-custom_exit_code: test black==24.3.0: test -kubernetes>=18.20.0: test +kubernetes>=18.20.0,<31.0.0: test pytest-kind==22.11.1: test pytest-lazy-fixture: test torch: cicd diff --git a/jina/clients/base/__init__.py b/jina/clients/base/__init__.py index 41ec147fbd74b..51845502f49a9 100644 --- a/jina/clients/base/__init__.py +++ b/jina/clients/base/__init__.py @@ -48,6 +48,7 @@ def __init__( os.unsetenv('http_proxy') os.unsetenv('https_proxy') self._inputs = None + self._inputs_length = None self._setup_instrumentation( name=( self.args.name @@ -144,8 +145,6 @@ def _get_requests( else: total_docs = None - self._inputs_length = None - if total_docs: self._inputs_length = max(1, total_docs / _kwargs['request_size']) diff --git a/jina/serve/executors/decorators.py b/jina/serve/executors/decorators.py index b9072929cbed7..49fb6f4e17681 100644 --- a/jina/serve/executors/decorators.py +++ b/jina/serve/executors/decorators.py @@ -419,6 +419,7 @@ def dynamic_batching( flush_all: bool = False, custom_metric: Optional[Callable[['DocumentArray'], Union[float, int]]] = None, use_custom_metric: bool = False, + use_dynamic_batching: bool = True, ): """ `@dynamic_batching` defines the dynamic batching behavior of an Executor. @@ -438,6 +439,7 @@ def dynamic_batching( If this is true, `preferred_batch_size` is used as a trigger mechanism. :param custom_metric: Potential lambda function to measure the "weight" of each request. :param use_custom_metric: Determines if we need to use the `custom_metric` to determine preferred_batch_size. + :param use_dynamic_batching: Determines if we should apply dynamic batching for this method. :return: decorated function """ @@ -486,6 +488,7 @@ def _inject_owner_attrs(self, owner, name): owner.dynamic_batching[fn_name]['flush_all'] = flush_all owner.dynamic_batching[fn_name]['use_custom_metric'] = use_custom_metric owner.dynamic_batching[fn_name]['custom_metric'] = custom_metric + owner.dynamic_batching[fn_name]['use_dynamic_batching'] = use_dynamic_batching setattr(owner, name, self.fn) def __set_name__(self, owner, name): diff --git a/jina/serve/runtimes/worker/batch_queue.py b/jina/serve/runtimes/worker/batch_queue.py index 56ba81e61e2a7..ac63f2d2c2dae 100644 --- a/jina/serve/runtimes/worker/batch_queue.py +++ b/jina/serve/runtimes/worker/batch_queue.py @@ -29,6 +29,7 @@ def __init__( timeout: int = 10_000, custom_metric: Optional[Callable[['DocumentArray'], Union[int, float]]] = None, use_custom_metric: bool = False, + **kwargs, ) -> None: # To keep old user behavior, we use data lock when flush_all is true and no allow_concurrent self.func = func @@ -285,7 +286,8 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option sum_from_previous_first_req_idx = 0 for docs_inner_batch, req_idxs in batch( big_doc_in_batch, requests_idxs_in_batch, - self._preferred_batch_size if not self._flush_all else None, docs_metrics_in_batch if self._custom_metric is not None else None + self._preferred_batch_size if not self._flush_all else None, + docs_metrics_in_batch if self._custom_metric is not None else None ): involved_requests_min_indx = req_idxs[0] involved_requests_max_indx = req_idxs[-1] @@ -360,7 +362,6 @@ def batch(iterable_1, iterable_2, n: Optional[int] = 1, iterable_metrics: Option requests_completed_in_batch, ) - async def close(self): """Closes the batch queue by flushing pending requests.""" if not self._is_closed: diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 456c94a7bdf41..b6edd7cddc090 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -275,10 +275,11 @@ def _init_batchqueue_dict(self): ) raise Exception(error_msg) - if key.startswith('/'): - dbatch_endpoints.append((key, dbatch_config)) - else: - dbatch_functions.append((key, dbatch_config)) + if dbatch_config.get("use_dynamic_batching", True): + if key.startswith('/'): + dbatch_endpoints.append((key, dbatch_config)) + else: + dbatch_functions.append((key, dbatch_config)) # Specific endpoint configs take precedence over function configs for endpoint, dbatch_config in dbatch_endpoints: diff --git a/tests/integration/dynamic_batching/test_dynamic_batching.py b/tests/integration/dynamic_batching/test_dynamic_batching.py index f7940289d6154..8f08d364899a4 100644 --- a/tests/integration/dynamic_batching/test_dynamic_batching.py +++ b/tests/integration/dynamic_batching/test_dynamic_batching.py @@ -706,8 +706,8 @@ def foo(self, docs, **kwargs): @pytest.mark.asyncio -@pytest.mark.parametrize('use_custom_metric', [True]) -@pytest.mark.parametrize('flush_all', [True]) +@pytest.mark.parametrize('use_custom_metric', [True, False]) +@pytest.mark.parametrize('flush_all', [True, False]) async def test_dynamic_batching_custom_metric(use_custom_metric, flush_all): class DynCustomBatchProcessor(Executor): @@ -719,7 +719,9 @@ def foo(self, docs, **kwargs): for doc in docs: doc.text = f"{total_len}" - depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric, "flush_all": flush_all}}) + depl = Deployment(uses=DynCustomBatchProcessor, uses_dynamic_batching={ + 'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_custom_metric": use_custom_metric, + "flush_all": flush_all}}) da = DocumentArray([Document(text='aaaaa') for i in range(50)]) with depl: cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True) @@ -733,3 +735,44 @@ def foo(self, docs, **kwargs): ): res.extend(r) assert len(res) == 50 # 1 request per input + + +@pytest.mark.asyncio +@pytest.mark.parametrize('use_dynamic_batching', [True, False]) +async def test_use_dynamic_batching(use_dynamic_batching): + class UseDynBatchProcessor(Executor): + + @dynamic_batching(preferred_batch_size=10) + @requests(on='/foo') + def foo(self, docs, **kwargs): + print(f'len docs {len(docs)}') + for doc in docs: + doc.text = f"{len(docs)}" + + depl = Deployment(uses=UseDynBatchProcessor, uses_dynamic_batching={ + 'foo': {"preferred_batch_size": 10, "timeout": 2000, "use_dynamic_batching": use_dynamic_batching, + "flush_all": False}}) + da = DocumentArray([Document(text='aaaaa') for _ in range(50)]) + with depl: + cl = Client(protocol=depl.protocol, port=depl.port, asyncio=True) + res = [] + async for r in cl.post( + on='/foo', + inputs=da, + request_size=1, + continue_on_error=True, + results_in_order=True, + ): + res.extend(r) + assert len(res) == 50 # 1 request per input + for doc in res: + num_10 = 0 + if doc.text == "10": + num_10 += 1 + if not use_dynamic_batching: + assert doc.text == "1" + + if use_dynamic_batching: + assert num_10 > 0 + else: + assert num_10 == 0 diff --git a/tests/k8s/conftest.py b/tests/k8s/conftest.py index 886cd7e4de473..80f9bed5283c1 100644 --- a/tests/k8s/conftest.py +++ b/tests/k8s/conftest.py @@ -30,14 +30,14 @@ def __init__(self, kind_cluster: KindCluster, logger: JinaLogger) -> None: self._loaded_images = set() def _linkerd_install_cmd( - self, kind_cluster: KindCluster, cmd, tool_name: str + self, kind_cluster: KindCluster, cmd, tool_name: str ) -> None: self._log.info(f'Installing {tool_name} to Cluster...') kube_out = subprocess.check_output( (str(kind_cluster.kubectl_path), 'version'), env=os.environ, ) - self._log.info(f'kuberbetes versions: {kube_out}') + self._log.info(f'kubernetes versions: {kube_out}') # since we need to pipe to commands and the linkerd output can bee too long # there is a risk of deadlock and hanging tests: https://docs.python.org/3/library/subprocess.html#popen-objects @@ -86,7 +86,7 @@ def _install_linkerd(self, kind_cluster: KindCluster) -> None: print(f'linkerd check yields {out.decode() if out else "nothing"}') except subprocess.CalledProcessError as e: print( - f'linkerd check failed with error code { e.returncode } and output { e.output }, and stderr { e.stderr }' + f'linkerd check failed with error code {e.returncode} and output {e.output}, and stderr {e.stderr}' ) raise @@ -125,8 +125,9 @@ def install_linkerd_smi(self) -> None: print(f'linkerd check yields {out.decode() if out else "nothing"}') except subprocess.CalledProcessError as e: print( - f'linkerd check failed with error code { e.returncode } and output { e.output }' + f'linkerd check failed with error code {e.returncode} and output {e.output}, and stderr {e.stderr}' ) + raise def _set_kube_config(self): self._log.info(f'Setting KUBECONFIG to {self._kube_config_path}') @@ -134,7 +135,7 @@ def _set_kube_config(self): load_cluster_config() def load_docker_images( - self, images: List[str], image_tag_map: Dict[str, str] + self, images: List[str], image_tag_map: Dict[str, str] ) -> None: for image in images: full_image_name = image + ':' + image_tag_map[image] @@ -213,9 +214,9 @@ def load_cluster_config() -> None: @pytest.fixture def docker_images( - request: FixtureRequest, - image_name_tag_map: Dict[str, str], - k8s_cluster: KindClusterWrapper, + request: FixtureRequest, + image_name_tag_map: Dict[str, str], + k8s_cluster: KindClusterWrapper, ) -> List[str]: image_names: List[str] = request.param k8s_cluster.load_docker_images(image_names, image_name_tag_map) @@ -227,7 +228,7 @@ def docker_images( @contextlib.contextmanager def shell_portforward( - kubectl_path, pod_or_service, port1, port2, namespace, waiting: float = 1 + kubectl_path, pod_or_service, port1, port2, namespace, waiting: float = 1 ): try: proc = subprocess.Popen( diff --git a/tests/k8s/test_k8s_deployment.py b/tests/k8s/test_k8s_deployment.py index 2f1fd9691fc94..1ab58d0accccc 100644 --- a/tests/k8s/test_k8s_deployment.py +++ b/tests/k8s/test_k8s_deployment.py @@ -8,7 +8,6 @@ from jina.serve.runtimes.servers import BaseServer from jina import Deployment, Client -from jina.helper import random_port from tests.k8s.conftest import shell_portforward cluster.KIND_VERSION = 'v0.11.1' diff --git a/tests/unit/serve/executors/test_executor.py b/tests/unit/serve/executors/test_executor.py index 344ebcaab7254..5c71b18a9f8e9 100644 --- a/tests/unit/serve/executors/test_executor.py +++ b/tests/unit/serve/executors/test_executor.py @@ -614,15 +614,15 @@ class C(B): [ ( dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None), + dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True), ), ( dict(preferred_batch_size=4, timeout=5_000, flush_all=True), - dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True), ), ( dict(preferred_batch_size=4), - dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None), + dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True), ), ], ) @@ -641,15 +641,15 @@ def foo(self, docs, **kwargs): [ ( dict(preferred_batch_size=4, timeout=5_000), - dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None), + dict(preferred_batch_size=4, timeout=5_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True), ), ( dict(preferred_batch_size=4, timeout=5_000, flush_all=True), - dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None), + dict(preferred_batch_size=4, timeout=5_000, flush_all=True, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True), ), ( dict(preferred_batch_size=4), - dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None), + dict(preferred_batch_size=4, timeout=10_000, flush_all=False, use_custom_metric=False, custom_metric=None, use_dynamic_batching=True), ), ], )