diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index 06416159bf3fa..c821c5c1d0936 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -634,25 +634,26 @@ def _validate_sagemaker(self): and self.runtime_args.provider_endpoint ): endpoint_to_use = ('/' + self.runtime_args.provider_endpoint).lower() - if endpoint_to_use in list(self.requests.keys()): - self.logger.warning( - f'Using "{endpoint_to_use}" as "/invocations" route' - ) - self.requests['/invocations'] = self.requests[endpoint_to_use] - for k in remove_keys: - self.requests.pop(k) - return - - if len(self.requests) == 1: - route = list(self.requests.keys())[0] - self.logger.warning(f'Using "{route}" as "/invocations" route') - self.requests['/invocations'] = self.requests[route] + elif len(self.requests) == 1: + endpoint_to_use = list(self.requests.keys())[0] + else: + raise ValueError('Cannot identify the endpoint to use for "/invocations"') + + if endpoint_to_use in list(self.requests.keys()): + self.logger.warning(f'Using "{endpoint_to_use}" as "/invocations" route') + self.requests['/invocations'] = self.requests[endpoint_to_use] + if ( + getattr(self, 'dynamic_batching', {}).get(endpoint_to_use, None) + is not None + ): + self.dynamic_batching['/invocations'] = self.dynamic_batching[ + endpoint_to_use + ] + self.dynamic_batching.pop(endpoint_to_use) for k in remove_keys: self.requests.pop(k) return - raise ValueError('Cannot identify the endpoint to use for "/invocations"') - def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]): if _dynamic_batching: self.dynamic_batching = getattr(self, 'dynamic_batching', {}) diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 08d404fb686d1..af3786f2886d3 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -265,7 +265,16 @@ def _init_batchqueue_dict(self): # Endpoints allow specific configurations while functions allow configs to be applied to all endpoints of the function dbatch_endpoints = [] dbatch_functions = [] + request_models_map = self._executor._get_endpoint_models_dict() + for key, dbatch_config in self._executor.dynamic_batching.items(): + if request_models_map.get(key, {}).get('parameters', {}).get('model', None) is not None: + error_msg = f'Executor Dynamic Batching cannot be used for endpoint {key} because it depends on parameters.' + self.logger.error( + error_msg + ) + raise Exception(error_msg) + if key.startswith('/'): dbatch_endpoints.append((key, dbatch_config)) else: diff --git a/tests/integration/docarray_v2/csp/SampleExecutor/executor.py b/tests/integration/docarray_v2/csp/SampleExecutor/executor.py index 1e0b4afc129c2..e9a45c6757cc5 100644 --- a/tests/integration/docarray_v2/csp/SampleExecutor/executor.py +++ b/tests/integration/docarray_v2/csp/SampleExecutor/executor.py @@ -1,7 +1,7 @@ import numpy as np from docarray import BaseDoc, DocList from docarray.typing import NdArray -from pydantic import Field +from pydantic import Field, BaseModel from jina import Executor, requests @@ -19,6 +19,11 @@ class Config(BaseDoc.Config): json_encoders = {NdArray: lambda v: v.tolist()} +class Parameters(BaseModel): + emb_dim: int + + + class SampleExecutor(Executor): @requests(on="/encode") def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseModel]: @@ -32,3 +37,16 @@ def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseMode ) ) return DocList[EmbeddingResponseModel](ret) + + @requests(on="/encode_parameter") + def bar(self, docs: DocList[TextDoc], parameters: Parameters, **kwargs) -> DocList[EmbeddingResponseModel]: + ret = [] + for doc in docs: + ret.append( + EmbeddingResponseModel( + id=doc.id, + text=doc.text, + embeddings=np.random.random((1, parameters.emb_dim)), + ) + ) + return DocList[EmbeddingResponseModel](ret) diff --git a/tests/integration/docarray_v2/csp/test_sagemaker_embedding.py b/tests/integration/docarray_v2/csp/test_sagemaker_embedding.py index a2233f0789dbe..8ad4ad3861586 100644 --- a/tests/integration/docarray_v2/csp/test_sagemaker_embedding.py +++ b/tests/integration/docarray_v2/csp/test_sagemaker_embedding.py @@ -35,6 +35,8 @@ def test_provider_sagemaker_pod_inference(): os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"), '--provider', 'sagemaker', + "--provider-endpoint", + "encode", 'serve', # This is added by sagemaker ] ) @@ -60,6 +62,43 @@ def test_provider_sagemaker_pod_inference(): assert len(resp_json['data'][0]['embeddings'][0]) == 64 +def test_provider_sagemaker_pod_inference_parameters(): + args, _ = set_pod_parser().parse_known_args( + [ + '--uses', + os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"), + '--provider', + 'sagemaker', + "--provider-endpoint", + "encode_parameter", + 'serve', # This is added by sagemaker + ] + ) + with Pod(args): + # Test the `GET /ping` endpoint (added by jina for sagemaker) + resp = requests.get(f'http://localhost:{sagemaker_port}/ping') + assert resp.status_code == 200 + assert resp.json() == {} + for emb_dim in {32, 64, 128}: + + # Test the `POST /invocations` endpoint for inference + # Note: this endpoint is not implemented in the sample executor + resp = requests.post( + f'http://localhost:{sagemaker_port}/invocations', + json={ + 'data': [ + {'text': 'hello world'}, + ], + 'parameters': {'emb_dim': emb_dim} + }, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json['data']) == 1 + assert len(resp_json['data'][0]['embeddings'][0]) == emb_dim + + + @pytest.mark.parametrize( "filename", [ @@ -74,6 +113,8 @@ def test_provider_sagemaker_pod_batch_transform_valid(filename): os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"), '--provider', 'sagemaker', + "--provider-endpoint", + "encode", 'serve', # This is added by sagemaker ] ) @@ -114,6 +155,8 @@ def test_provider_sagemaker_pod_batch_transform_invalid(): os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"), '--provider', 'sagemaker', + "--provider-endpoint", + "encode", 'serve', # This is added by sagemaker ] ) @@ -145,6 +188,7 @@ def test_provider_sagemaker_deployment_inference(): with Deployment( uses=os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"), provider='sagemaker', + provider_endpoint='encode', port=dep_port, ): # Test the `GET /ping` endpoint (added by jina for sagemaker) @@ -171,7 +215,7 @@ def test_provider_sagemaker_deployment_inference(): def test_provider_sagemaker_deployment_inference_docker(replica_docker_image_built): dep_port = random_port() with Deployment( - uses='docker://sampler-executor', provider='sagemaker', port=dep_port + uses='docker://sampler-executor', provider='sagemaker', provider_endpoint='encode', port=dep_port ): # Test the `GET /ping` endpoint (added by jina for sagemaker) rsp = requests.get(f'http://localhost:{dep_port}/ping') @@ -200,6 +244,7 @@ def test_provider_sagemaker_deployment_batch(): with Deployment( uses=os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"), provider='sagemaker', + provider_endpoint='encode', port=dep_port, ): # Test the `POST /invocations` endpoint for batch-transform @@ -230,6 +275,24 @@ def test_provider_sagemaker_deployment_wrong_port(): os.path.dirname(__file__), "SampleExecutor", "config.yml" ), provider='sagemaker', + provider_endpoint='encode', port=8080, ): pass + + +def test_provider_sagemaker_deployment_wrong_dynamic_batching(): + # Sagemaker executor would start on 8080. + # If we use the same port for deployment, it should raise an error. + from jina.excepts import RuntimeFailToStart + + with pytest.raises(RuntimeFailToStart) as exc: + with Deployment( + uses=os.path.join( + os.path.dirname(__file__), "SampleExecutor", "config.yml" + ), + provider='sagemaker', + provider_endpoint='encode_parameter', + uses_dynamic_batching={'/encode_parameter': {'preferred_batch_size': 20, 'timeout': 50}}, + ): + pass