diff --git a/jina/clients/base/__init__.py b/jina/clients/base/__init__.py index 7fe60b15f74e7..3cfe9b49dd8a1 100644 --- a/jina/clients/base/__init__.py +++ b/jina/clients/base/__init__.py @@ -29,9 +29,9 @@ class BaseClient(InstrumentationMixin, ABC): """ def __init__( - self, - args: Optional['argparse.Namespace'] = None, - **kwargs, + self, + args: Optional['argparse.Namespace'] = None, + **kwargs, ): if args and isinstance(args, argparse.Namespace): self.args = args @@ -63,6 +63,11 @@ def __init__( ) send_telemetry_event(event='start', obj_cls_name=self.__class__.__name__) + async def close(self): + """Closes the potential resources of the Client. + """ + return self.teardown_instrumentation() + def teardown_instrumentation(self): """Shut down the OpenTelemetry tracer and meter if available. This ensures that the daemon threads for exporting metrics data is properly cleaned up. @@ -118,7 +123,7 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None: raise BadClientInput from ex def _get_requests( - self, **kwargs + self, **kwargs ) -> Union[Iterator['Request'], AsyncIterator['Request']]: """ Get request in generator. @@ -177,13 +182,14 @@ def inputs(self, bytes_gen: 'InputType') -> None: @abc.abstractmethod async def _get_results( - self, - inputs: 'InputType', - on_done: 'CallbackFnType', - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, - **kwargs, - ): ... + self, + inputs: 'InputType', + on_done: 'CallbackFnType', + on_error: Optional['CallbackFnType'] = None, + on_always: Optional['CallbackFnType'] = None, + **kwargs, + ): + ... @abc.abstractmethod def _is_flow_ready(self, **kwargs) -> bool: diff --git a/jina/clients/base/helper.py b/jina/clients/base/helper.py index 620513ceec460..f44f995a6cb94 100644 --- a/jina/clients/base/helper.py +++ b/jina/clients/base/helper.py @@ -48,7 +48,6 @@ class AioHttpClientlet(ABC): def __init__( self, - url: str, logger: 'JinaLogger', max_attempts: int = 1, initial_backoff: float = 0.5, @@ -59,7 +58,6 @@ def __init__( ) -> None: """HTTP Client to be used with the streamer - :param url: url to send http/websocket request to :param logger: jina logger :param max_attempts: Number of sending attempts, including the original request. :param initial_backoff: The first retry will happen with a delay of random(0, initial_backoff) @@ -68,7 +66,6 @@ def __init__( :param tracer_provider: Optional tracer_provider that will be used to configure aiohttp tracing. :param kwargs: kwargs which will be forwarded to the `aiohttp.Session` instance. Used to pass headers to requests """ - self.url = url self.logger = logger self.msg_recv = 0 self.msg_sent = 0 @@ -154,7 +151,7 @@ class HTTPClientlet(AioHttpClientlet): UPDATE_EVENT_PREFIX = 14 # the update event has the following format: "event: update: {document_json}" - async def send_message(self, request: 'Request'): + async def send_message(self, url, request: 'Request'): """Sends a POST request to the server :param request: request as dict @@ -166,7 +163,7 @@ async def send_message(self, request: 'Request'): req_dict['target_executor'] = req_dict['header']['target_executor'] for attempt in range(1, self.max_attempts + 1): try: - request_kwargs = {'url': self.url} + request_kwargs = {'url': url} if not docarray_v2: request_kwargs['json'] = req_dict else: @@ -179,10 +176,10 @@ async def send_message(self, request: 'Request'): except aiohttp.ContentTypeError: r_str = await response.text() r_status = response.status - handle_response_status(response.status, r_str, self.url) + handle_response_status(response.status, r_str, url) return r_status, r_str except (ValueError, ConnectionError, BadClient, aiohttp.ClientError, aiohttp.ClientConnectionError) as err: - self.logger.debug(f'Got an error: {err} sending POST to {self.url} in attempt {attempt}/{self.max_attempts}') + self.logger.debug(f'Got an error: {err} sending POST to {url} in attempt {attempt}/{self.max_attempts}') await retry.wait_or_raise_err( attempt=attempt, err=err, @@ -193,10 +190,10 @@ async def send_message(self, request: 'Request'): ) except Exception as exc: self.logger.debug( - f'Got a non-retried error: {exc} sending POST to {self.url}') + f'Got a non-retried error: {exc} sending POST to {url}') raise exc - async def send_streaming_message(self, doc: 'Document', on: str): + async def send_streaming_message(self, url, doc: 'Document', on: str): """Sends a GET SSE request to the server :param doc: Request Document @@ -205,7 +202,7 @@ async def send_streaming_message(self, doc: 'Document', on: str): """ req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict() request_kwargs = { - 'url': self.url, + 'url': url, 'headers': {'Accept': 'text/event-stream'}, 'json': req_dict, } @@ -219,13 +216,13 @@ async def send_streaming_message(self, doc: 'Document', on: str): elif event.startswith(b'end'): pass - async def send_dry_run(self, **kwargs): + async def send_dry_run(self, url, **kwargs): """Query the dry_run endpoint from Gateway :param kwargs: keyword arguments to make sure compatible API with other clients :return: send get message """ return await self.session.get( - url=self.url, timeout=kwargs.get('timeout', None) + url=url, timeout=kwargs.get('timeout', None) ).__aenter__() async def recv_message(self): diff --git a/jina/clients/base/http.py b/jina/clients/base/http.py index 746bdf0e0acfd..b43f6f7be78d8 100644 --- a/jina/clients/base/http.py +++ b/jina/clients/base/http.py @@ -23,6 +23,12 @@ class HTTPBaseClient(BaseClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._endpoints = [] + self.iolet = None + + async def close(self): + await super().close() + if self.iolet is not None: + await self.iolet.__aexit__() async def _get_endpoints_from_openapi(self, **kwargs): def extract_paths_by_method(spec): @@ -69,14 +75,24 @@ async def _is_flow_ready(self, **kwargs) -> bool: try: proto = 'https' if self.args.tls else 'http' url = f'{proto}://{self.args.host}:{self.args.port}/dry_run' - iolet = await stack.enter_async_context( - HTTPClientlet( - url=url, + + if self.iolet is not None and self.args.reuse_session: + iolet = self.iolet + else: + iolet = HTTPClientlet( logger=self.logger, tracer_provider=self.tracer_provider, **kwargs, ) - ) + + if self.args.reuse_session and self.iolet is None: + self.iolet = iolet + await self.iolet.__aenter__() + + if not self.args.reuse_session: + iolet = await stack.enter_async_context( + iolet + ) response = await iolet.send_dry_run(**kwargs) r_status = response.status @@ -152,9 +168,10 @@ async def _get_results( else: url = f'{proto}://{self.args.host}:{self.args.port}/post' - iolet = await stack.enter_async_context( - HTTPClientlet( - url=url, + if self.iolet is not None and self.args.reuse_session: + iolet = self.iolet + else: + iolet = HTTPClientlet( logger=self.logger, tracer_provider=self.tracer_provider, max_attempts=max_attempts, @@ -164,7 +181,14 @@ async def _get_results( timeout=timeout, **kwargs, ) - ) + if self.args.reuse_session and self.iolet is None: + self.iolet = iolet + await self.iolet.__aenter__() + + if not self.args.reuse_session: + iolet = await stack.enter_async_context( + iolet + ) def _request_handler( request: 'Request', **kwargs @@ -176,7 +200,7 @@ def _request_handler( :param kwargs: kwargs :return: asyncio Task for sending message """ - return asyncio.ensure_future(iolet.send_message(request=request)), None + return asyncio.ensure_future(iolet.send_message(url=url, request=request)), None def _result_handler(result): return result @@ -250,7 +274,6 @@ async def _get_streaming_results( url = f'{proto}://{self.args.host}:{self.args.port}/default' iolet = HTTPClientlet( - url=url, logger=self.logger, tracer_provider=self.tracer_provider, timeout=timeout, @@ -258,7 +281,7 @@ async def _get_streaming_results( ) async with iolet: - async for doc in iolet.send_streaming_message(doc=inputs, on=on): + async for doc in iolet.send_streaming_message(url=url, doc=inputs, on=on): if not docarray_v2: yield Document.from_dict(json.loads(doc)) else: diff --git a/tests/unit/clients/test_helper.py b/tests/unit/clients/test_helper.py index 66ae0d9081f38..55182b066616e 100644 --- a/tests/unit/clients/test_helper.py +++ b/tests/unit/clients/test_helper.py @@ -33,11 +33,11 @@ async def test_http_clientlet(): port = random_port() with Flow(port=port, protocol='http').add(): async with HTTPClientlet( - url=f'http://localhost:{port}/post', logger=logger + logger=logger ) as iolet: request = _new_data_request('/', None, {'a': 'b'}) assert request.header.target_executor == '' - r_status, r_json = await iolet.send_message(request) + r_status, r_json = await iolet.send_message(url=f'http://localhost:{port}/post', request=request) response = DataRequest(r_json) assert response.header.exec_endpoint == '/' assert response.parameters == {'a': 'b'} @@ -50,11 +50,11 @@ async def test_http_clientlet_target(): port = random_port() with Flow(port=port, protocol='http').add(): async with HTTPClientlet( - url=f'http://localhost:{port}/post', logger=logger + logger=logger ) as iolet: request = _new_data_request('/', 'nothing', {'a': 'b'}) assert request.header.target_executor == 'nothing' - r = await iolet.send_message(request) + r = await iolet.send_message(url=f'http://localhost:{port}/post', request=request) r_status, r_json = r response = DataRequest(r_json) assert response.header.exec_endpoint == '/'