diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index 8ccf4cab76e36..d438746a4ab2a 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -393,7 +393,7 @@ def __init__( self._add_dynamic_batching(dynamic_batching) self._add_runtime_args(runtime_args) self.logger = JinaLogger(self.__class__.__name__, **vars(self.runtime_args)) - self._validate_sagemaker() + self._validate_csp() self._init_instrumentation(runtime_args) self._init_monitoring() self._init_workspace = workspace @@ -599,14 +599,14 @@ def _add_requests(self, _requests: Optional[Dict]): f'expect {typename(self)}.{func} to be a function, but receiving {typename(_func)}' ) - def _validate_sagemaker(self): - # sagemaker expects the POST /invocations endpoint to be defined. + def _validate_csp(self): + # csp (sagemaker/azure/gcp) expects the POST /invocations endpoint to be defined. # if it is not defined, we check if there is only one endpoint defined, # and if so, we use it as the POST /invocations endpoint, or raise an error if ( not hasattr(self, 'runtime_args') or not hasattr(self.runtime_args, 'provider') - or self.runtime_args.provider != ProviderType.SAGEMAKER.value + or self.runtime_args.provider not in (ProviderType.SAGEMAKER.value, ProviderType.GCP.value) ): return diff --git a/jina/serve/runtimes/asyncio.py b/jina/serve/runtimes/asyncio.py index 8d2fc8beeb8bc..e53f55a5c124d 100644 --- a/jina/serve/runtimes/asyncio.py +++ b/jina/serve/runtimes/asyncio.py @@ -206,6 +206,23 @@ def _get_server(self): cors=getattr(self.args, 'cors', None), is_cancel=self.is_cancel, ) + elif ( + hasattr(self.args, 'provider') + and self.args.provider == ProviderType.GCP + ): + from jina.serve.runtimes.servers.http import GCPHTTPServer + + return GCPHTTPServer( + name=self.args.name, + runtime_args=self.args, + req_handler_cls=self.req_handler_cls, + proxy=getattr(self.args, 'proxy', None), + uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None), + ssl_keyfile=getattr(self.args, 'ssl_keyfile', None), + ssl_certfile=getattr(self.args, 'ssl_certfile', None), + cors=getattr(self.args, 'cors', None), + is_cancel=self.is_cancel, + ) elif not hasattr(self.args, 'protocol') or ( len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.GRPC ): diff --git a/jina/serve/runtimes/worker/http_gcp_app.py b/jina/serve/runtimes/worker/http_gcp_app.py index a2f403525679b..37af7897a029c 100644 --- a/jina/serve/runtimes/worker/http_gcp_app.py +++ b/jina/serve/runtimes/worker/http_gcp_app.py @@ -41,7 +41,7 @@ def get_fastapi_app( from jina.serve.runtimes.gateway.models import _to_camel_case if not docarray_v2: - logger.warning('Only docarray v2 is supported with Sagemaker. ') + logger.warning('Only docarray v2 is supported with GCP. ') return class Header(BaseModel): @@ -129,7 +129,6 @@ async def process(body) -> output_model: raise HTTPException(status_code=499, detail=status.description) else: return {"predictions": resp.docs} - return output_model(predictions=resp.docs) @app.api_route(**app_kwargs) async def post(request: Request): @@ -175,7 +174,7 @@ async def post(request: Request): from jina.serve.runtimes.gateway.health_model import JinaHealthModel - # `/ping` route is required by AWS Sagemaker + # `/ping` route is required by GCP @app.get( path='/ping', summary='Get the health of Jina Executor service', diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 57f26a6767da2..f5c74445bf291 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -326,7 +326,7 @@ def _init_monitoring( if metrics_registry: with ImportExtensions( required=True, - help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', + help_text='You need to install the `prometheus_client` to use the monitoring functionality of jina', ): from prometheus_client import Counter, Summary diff --git a/tests/integration/docarray_v2/gcp/test_gcp.py b/tests/integration/docarray_v2/gcp/test_gcp.py index 2fc97a2adffda..9539e8317ea49 100644 --- a/tests/integration/docarray_v2/gcp/test_gcp.py +++ b/tests/integration/docarray_v2/gcp/test_gcp.py @@ -70,5 +70,25 @@ def test_provider_gcp_pod_inference(): assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json['predictions']) == 2 - print(resp_json) + +def test_provider_gcp_deployment_inference(): + with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): + dep_port = random_port() + with Deployment(uses='config.yml', provider='gcp', port=dep_port): + # Test the `GET /ping` endpoint (added by jina for gcp) + resp = requests.get(f'http://localhost:{dep_port}/ping') + assert resp.status_code == 200 + assert resp.json() == {} + + # Test the `POST /invocations` endpoint + # Note: this endpoint is not implemented in the sample executor + resp = requests.post( + f'http://localhost:{dep_port}/invocations', + json={ + 'instances': ["hello world", "good apple"] + }, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json['predictions']) == 2