diff --git a/jina/clients/base/__init__.py b/jina/clients/base/__init__.py index 3cfe9b49dd8a1..41ec147fbd74b 100644 --- a/jina/clients/base/__init__.py +++ b/jina/clients/base/__init__.py @@ -65,6 +65,7 @@ def __init__( async def close(self): """Closes the potential resources of the Client. + :return: Return whatever a close method may return """ return self.teardown_instrumentation() diff --git a/jina/clients/base/helper.py b/jina/clients/base/helper.py index f44f995a6cb94..f1b85dca5322f 100644 --- a/jina/clients/base/helper.py +++ b/jina/clients/base/helper.py @@ -154,6 +154,7 @@ class HTTPClientlet(AioHttpClientlet): async def send_message(self, url, request: 'Request'): """Sends a POST request to the server + :param url: the URL where to send the message :param request: request as dict :return: send post message """ @@ -196,6 +197,7 @@ async def send_message(self, url, request: 'Request'): async def send_streaming_message(self, url, doc: 'Document', on: str): """Sends a GET SSE request to the server + :param url: the URL where to send the message :param doc: Request Document :param on: Request endpoint :yields: responses @@ -218,6 +220,7 @@ async def send_streaming_message(self, url, doc: 'Document', on: str): async def send_dry_run(self, url, **kwargs): """Query the dry_run endpoint from Gateway + :param url: the URL where to send the message :param kwargs: keyword arguments to make sure compatible API with other clients :return: send get message """ @@ -264,8 +267,9 @@ async def __anext__(self): class WebsocketClientlet(AioHttpClientlet): """Websocket Client to be used with the streamer""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, url, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + self.url = url self.websocket = None self.response_iter = None diff --git a/jina/clients/base/http.py b/jina/clients/base/http.py index b43f6f7be78d8..3bf996c8bf009 100644 --- a/jina/clients/base/http.py +++ b/jina/clients/base/http.py @@ -23,6 +23,7 @@ class HTTPBaseClient(BaseClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._endpoints = [] + self._lock = asyncio.Lock() self.iolet = None async def close(self): @@ -76,23 +77,23 @@ async def _is_flow_ready(self, **kwargs) -> bool: proto = 'https' if self.args.tls else 'http' url = f'{proto}://{self.args.host}:{self.args.port}/dry_run' - if self.iolet is not None and self.args.reuse_session: - iolet = self.iolet - else: - iolet = HTTPClientlet( - logger=self.logger, - tracer_provider=self.tracer_provider, - **kwargs, - ) - - if self.args.reuse_session and self.iolet is None: - self.iolet = iolet - await self.iolet.__aenter__() - if not self.args.reuse_session: iolet = await stack.enter_async_context( - iolet + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) + await self.iolet.__aenter__() response = await iolet.send_dry_run(**kwargs) r_status = response.status @@ -112,20 +113,20 @@ async def _is_flow_ready(self, **kwargs) -> bool: return False async def _get_results( - self, - inputs: 'InputType', - on_done: 'CallbackFnType', - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, - max_attempts: int = 1, - initial_backoff: float = 0.5, - max_backoff: float = 0.1, - backoff_multiplier: float = 1.5, - results_in_order: bool = False, - prefetch: Optional[int] = None, - timeout: Optional[int] = None, - return_type: Type[DocumentArray] = DocumentArray, - **kwargs, + self, + inputs: 'InputType', + on_done: 'CallbackFnType', + on_error: Optional['CallbackFnType'] = None, + on_always: Optional['CallbackFnType'] = None, + max_attempts: int = 1, + initial_backoff: float = 0.5, + max_backoff: float = 0.1, + backoff_multiplier: float = 1.5, + results_in_order: bool = False, + prefetch: Optional[int] = None, + timeout: Optional[int] = None, + return_type: Type[DocumentArray] = DocumentArray, + **kwargs, ): """ :param inputs: the callable @@ -168,30 +169,26 @@ async def _get_results( else: url = f'{proto}://{self.args.host}:{self.args.port}/post' - if self.iolet is not None and self.args.reuse_session: - iolet = self.iolet - else: - iolet = HTTPClientlet( - logger=self.logger, - tracer_provider=self.tracer_provider, - max_attempts=max_attempts, - initial_backoff=initial_backoff, - max_backoff=max_backoff, - backoff_multiplier=backoff_multiplier, - timeout=timeout, - **kwargs, - ) - if self.args.reuse_session and self.iolet is None: - self.iolet = iolet - await self.iolet.__aenter__() - if not self.args.reuse_session: iolet = await stack.enter_async_context( - iolet + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) + await self.iolet.__aenter__() def _request_handler( - request: 'Request', **kwargs + request: 'Request', **kwargs ) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]': """ For HTTP Client, for each request in the iterator, we `send_message` using @@ -215,7 +212,7 @@ def _result_handler(result): **streamer_args, ) async for response in streamer.stream( - request_iterator=request_iterator, results_in_order=results_in_order + request_iterator=request_iterator, results_in_order=results_in_order ): r_status, r_str = response handle_response_status(r_status, r_str, url) @@ -256,13 +253,13 @@ def _result_handler(result): yield resp async def _get_streaming_results( - self, - on: str, - inputs: 'Document', - parameters: Optional[Dict] = None, - return_type: Type[Document] = Document, - timeout: Optional[int] = None, - **kwargs, + self, + on: str, + inputs: 'Document', + parameters: Optional[Dict] = None, + return_type: Type[Document] = Document, + timeout: Optional[int] = None, + **kwargs, ): proto = 'https' if self.args.tls else 'http' endpoint = on.strip('/') @@ -272,15 +269,27 @@ async def _get_streaming_results( url = f'{proto}://{self.args.host}:{self.args.port}/{endpoint}' else: url = f'{proto}://{self.args.host}:{self.args.port}/default' + async with AsyncExitStack() as stack: + if not self.args.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + **kwargs, + ) + ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + **kwargs, + ) + await self.iolet.__aenter__() - iolet = HTTPClientlet( - logger=self.logger, - tracer_provider=self.tracer_provider, - timeout=timeout, - **kwargs, - ) - - async with iolet: async for doc in iolet.send_streaming_message(url=url, doc=inputs, on=on): if not docarray_v2: yield Document.from_dict(json.loads(doc))