diff --git a/jina/clients/base/__init__.py b/jina/clients/base/__init__.py index 3cfe9b49dd8a1..41ec147fbd74b 100644 --- a/jina/clients/base/__init__.py +++ b/jina/clients/base/__init__.py @@ -65,6 +65,7 @@ def __init__( async def close(self): """Closes the potential resources of the Client. + :return: Return whatever a close method may return """ return self.teardown_instrumentation() diff --git a/jina/clients/base/helper.py b/jina/clients/base/helper.py index f44f995a6cb94..50f43ae69e264 100644 --- a/jina/clients/base/helper.py +++ b/jina/clients/base/helper.py @@ -128,7 +128,6 @@ async def start(self): """ with ImportExtensions(required=True): import aiohttp - self.session = aiohttp.ClientSession( **self._session_kwargs, trace_configs=self._trace_config ) @@ -154,6 +153,7 @@ class HTTPClientlet(AioHttpClientlet): async def send_message(self, url, request: 'Request'): """Sends a POST request to the server + :param url: the URL where to send the message :param request: request as dict :return: send post message """ @@ -170,14 +170,15 @@ async def send_message(self, url, request: 'Request'): from docarray.base_doc.io.json import orjson_dumps request_kwargs['data'] = JinaJsonPayload(value=req_dict) + async with self.session.post(**request_kwargs) as response: try: r_str = await response.json() except aiohttp.ContentTypeError: r_str = await response.text() r_status = response.status - handle_response_status(response.status, r_str, url) - return r_status, r_str + handle_response_status(r_status, r_str, url) + return r_status, r_str except (ValueError, ConnectionError, BadClient, aiohttp.ClientError, aiohttp.ClientConnectionError) as err: self.logger.debug(f'Got an error: {err} sending POST to {url} in attempt {attempt}/{self.max_attempts}') await retry.wait_or_raise_err( @@ -196,6 +197,7 @@ async def send_message(self, url, request: 'Request'): async def send_streaming_message(self, url, doc: 'Document', on: str): """Sends a GET SSE request to the server + :param url: the URL where to send the message :param doc: Request Document :param on: Request endpoint :yields: responses @@ -218,6 +220,7 @@ async def send_streaming_message(self, url, doc: 'Document', on: str): async def send_dry_run(self, url, **kwargs): """Query the dry_run endpoint from Gateway + :param url: the URL where to send the message :param kwargs: keyword arguments to make sure compatible API with other clients :return: send get message """ @@ -264,8 +267,9 @@ async def __anext__(self): class WebsocketClientlet(AioHttpClientlet): """Websocket Client to be used with the streamer""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, url, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + self.url = url self.websocket = None self.response_iter = None diff --git a/jina/clients/base/http.py b/jina/clients/base/http.py index b43f6f7be78d8..eaac304695403 100644 --- a/jina/clients/base/http.py +++ b/jina/clients/base/http.py @@ -23,12 +23,18 @@ class HTTPBaseClient(BaseClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._endpoints = [] + self.reuse_session = False + self._lock = AsyncExitStack() self.iolet = None async def close(self): - await super().close() + """Closes the potential resources of the Client. + :return: Return whatever a close method may return + """ + ret = super().close() if self.iolet is not None: - await self.iolet.__aexit__() + await self.iolet.__aexit__(None, None, None) + return ret async def _get_endpoints_from_openapi(self, **kwargs): def extract_paths_by_method(spec): @@ -76,25 +82,26 @@ async def _is_flow_ready(self, **kwargs) -> bool: proto = 'https' if self.args.tls else 'http' url = f'{proto}://{self.args.host}:{self.args.port}/dry_run' - if self.iolet is not None and self.args.reuse_session: - iolet = self.iolet - else: - iolet = HTTPClientlet( - logger=self.logger, - tracer_provider=self.tracer_provider, - **kwargs, - ) - - if self.args.reuse_session and self.iolet is None: - self.iolet = iolet - await self.iolet.__aenter__() - - if not self.args.reuse_session: + if not self.reuse_session: iolet = await stack.enter_async_context( - iolet + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) + await self.iolet.__aenter__() + iolet = self.iolet - response = await iolet.send_dry_run(**kwargs) + response = await iolet.send_dry_run(url=url, **kwargs) r_status = response.status r_str = await response.json() @@ -112,20 +119,20 @@ async def _is_flow_ready(self, **kwargs) -> bool: return False async def _get_results( - self, - inputs: 'InputType', - on_done: 'CallbackFnType', - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, - max_attempts: int = 1, - initial_backoff: float = 0.5, - max_backoff: float = 0.1, - backoff_multiplier: float = 1.5, - results_in_order: bool = False, - prefetch: Optional[int] = None, - timeout: Optional[int] = None, - return_type: Type[DocumentArray] = DocumentArray, - **kwargs, + self, + inputs: 'InputType', + on_done: 'CallbackFnType', + on_error: Optional['CallbackFnType'] = None, + on_always: Optional['CallbackFnType'] = None, + max_attempts: int = 1, + initial_backoff: float = 0.5, + max_backoff: float = 0.1, + backoff_multiplier: float = 1.5, + results_in_order: bool = False, + prefetch: Optional[int] = None, + timeout: Optional[int] = None, + return_type: Type[DocumentArray] = DocumentArray, + **kwargs, ): """ :param inputs: the callable @@ -168,30 +175,27 @@ async def _get_results( else: url = f'{proto}://{self.args.host}:{self.args.port}/post' - if self.iolet is not None and self.args.reuse_session: - iolet = self.iolet - else: - iolet = HTTPClientlet( - logger=self.logger, - tracer_provider=self.tracer_provider, - max_attempts=max_attempts, - initial_backoff=initial_backoff, - max_backoff=max_backoff, - backoff_multiplier=backoff_multiplier, - timeout=timeout, - **kwargs, - ) - if self.args.reuse_session and self.iolet is None: - self.iolet = iolet - await self.iolet.__aenter__() - - if not self.args.reuse_session: + if not self.reuse_session: iolet = await stack.enter_async_context( - iolet + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) + self.iolet = await self.iolet.__aenter__() + iolet = self.iolet def _request_handler( - request: 'Request', **kwargs + request: 'Request', **kwargs ) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]': """ For HTTP Client, for each request in the iterator, we `send_message` using @@ -215,7 +219,7 @@ def _result_handler(result): **streamer_args, ) async for response in streamer.stream( - request_iterator=request_iterator, results_in_order=results_in_order + request_iterator=request_iterator, results_in_order=results_in_order ): r_status, r_str = response handle_response_status(r_status, r_str, url) @@ -256,13 +260,13 @@ def _result_handler(result): yield resp async def _get_streaming_results( - self, - on: str, - inputs: 'Document', - parameters: Optional[Dict] = None, - return_type: Type[Document] = Document, - timeout: Optional[int] = None, - **kwargs, + self, + on: str, + inputs: 'Document', + parameters: Optional[Dict] = None, + return_type: Type[Document] = Document, + timeout: Optional[int] = None, + **kwargs, ): proto = 'https' if self.args.tls else 'http' endpoint = on.strip('/') @@ -272,15 +276,27 @@ async def _get_streaming_results( url = f'{proto}://{self.args.host}:{self.args.port}/{endpoint}' else: url = f'{proto}://{self.args.host}:{self.args.port}/default' - - iolet = HTTPClientlet( - logger=self.logger, - tracer_provider=self.tracer_provider, - timeout=timeout, - **kwargs, - ) - - async with iolet: + async with AsyncExitStack() as stack: + if not self.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + **kwargs, + ) + ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + **kwargs, + ) + await self.iolet.__aenter__() + iolet = self.iolet async for doc in iolet.send_streaming_message(url=url, doc=inputs, on=on): if not docarray_v2: yield Document.from_dict(json.loads(doc)) diff --git a/jina/clients/http.py b/jina/clients/http.py index e811d75f46ec5..2698e316d007a 100644 --- a/jina/clients/http.py +++ b/jina/clients/http.py @@ -9,6 +9,7 @@ PostMixin, ProfileMixin, ) +import asyncio class HTTPClient( @@ -80,3 +81,8 @@ async def async_inputs(): print(resp) """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._lock = asyncio.Lock() + self.reuse_session = self.args.reuse_session diff --git a/tests/integration/concurrent_clients/test_concurrent_clients.py b/tests/integration/concurrent_clients/test_concurrent_clients.py index 546fe94da3314..898916ffe03cb 100644 --- a/tests/integration/concurrent_clients/test_concurrent_clients.py +++ b/tests/integration/concurrent_clients/test_concurrent_clients.py @@ -23,19 +23,23 @@ def ping(self, **kwargs): @pytest.mark.parametrize('prefetch', [1, 10]) @pytest.mark.parametrize('concurrent', [15]) @pytest.mark.parametrize('use_stream', [False, True]) +@pytest.mark.parametrize('reuse_session', [True, False]) def test_concurrent_clients( - concurrent, protocol, shards, polling, prefetch, reraise, use_stream + concurrent, protocol, shards, polling, prefetch, reraise, use_stream, reuse_session ): if not use_stream and protocol != 'grpc': return + if reuse_session and protocol != 'http': + return + def pong(peer_hash, queue, resp: Response): for d in resp.docs: queue.put((peer_hash, d.text)) def peer_client(port, protocol, peer_hash, queue): - c = Client(protocol=protocol, port=port) + c = Client(protocol=protocol, port=port, reuse_session=reuse_session) for _ in range(NUM_REQUESTS): c.post( '/ping', diff --git a/tests/integration/docarray_v2/test_singleton.py b/tests/integration/docarray_v2/test_singleton.py index e8cd663eb10d5..7405df29a4792 100644 --- a/tests/integration/docarray_v2/test_singleton.py +++ b/tests/integration/docarray_v2/test_singleton.py @@ -13,7 +13,10 @@ ) @pytest.mark.parametrize('return_type', ['batch', 'singleton']) @pytest.mark.parametrize('include_gateway', [True, False]) -def test_singleton_return(ctxt_manager, protocols, return_type, include_gateway): +@pytest.mark.parametrize('reuse_session', [True, False]) +def test_singleton_return(ctxt_manager, protocols, return_type, include_gateway, reuse_session): + if reuse_session and 'http' not in protocols: + return if 'websocket' in protocols and ctxt_manager != 'flow': return if not include_gateway and ctxt_manager == 'flow': @@ -63,7 +66,7 @@ def foo_single( with ctxt: for port, protocol in zip(ports, protocols): - c = Client(port=port, protocol=protocol) + c = Client(port=port, protocol=protocol, reuse_session=reuse_session) docs = c.post( on='/foo', inputs=MySingletonReturnInputDoc(text='hello', price=2), @@ -102,7 +105,10 @@ def foo_single( 'protocols', [['grpc'], ['http'], ['websocket'], ['grpc', 'http']] ) @pytest.mark.parametrize('return_type', ['batch', 'singleton']) -def test_singleton_return_async(ctxt_manager, protocols, return_type): +@pytest.mark.parametrize('reuse_session', [True, False]) +def test_singleton_return_async(ctxt_manager, protocols, return_type, reuse_session): + if reuse_session and 'http' not in protocols: + return if 'websocket' in protocols and ctxt_manager != 'flow': return @@ -149,7 +155,7 @@ async def foo_single( with ctxt: for port, protocol in zip(ports, protocols): - c = Client(port=port, protocol=protocol) + c = Client(port=port, protocol=protocol, reuse_session=reuse_session) docs = c.post( on='/foo', inputs=MySingletonReturnInputDoc(text='hello', price=2), diff --git a/tests/integration/streaming/test_streaming.py b/tests/integration/streaming/test_streaming.py index 5d2f6e4af848b..7f6675ec57e1f 100644 --- a/tests/integration/streaming/test_streaming.py +++ b/tests/integration/streaming/test_streaming.py @@ -23,7 +23,10 @@ async def non_gen_task(self, docs: DocumentArray, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize('protocol', ['http', 'grpc']) @pytest.mark.parametrize('include_gateway', [False, True]) -async def test_streaming_deployment(protocol, include_gateway): +@pytest.mark.parametrize('reuse_session', [False, True]) +async def test_streaming_deployment(protocol, include_gateway, reuse_session): + if reuse_session and protocol != 'http': + return port = random_port() docs = [] @@ -35,7 +38,7 @@ async def test_streaming_deployment(protocol, include_gateway): port=port, include_gateway=include_gateway, ): - client = Client(port=port, protocol=protocol, asyncio=True) + client = Client(port=port, protocol=protocol, asyncio=True, reuse_session=reuse_session) i = 0 async for doc in client.stream_doc( on='/hello', @@ -60,7 +63,10 @@ async def task(self, doc: Document, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize('protocol', ['http', 'grpc']) @pytest.mark.parametrize('include_gateway', [False, True]) -async def test_streaming_delay(protocol, include_gateway): +@pytest.mark.parametrize('reuse_session', [False, True]) +async def test_streaming_delay(protocol, include_gateway, reuse_session): + if reuse_session and protocol != 'http': + return from jina import Deployment port = random_port() @@ -72,7 +78,7 @@ async def test_streaming_delay(protocol, include_gateway): port=port, include_gateway=include_gateway, ): - client = Client(port=port, protocol=protocol, asyncio=True) + client = Client(port=port, protocol=protocol, asyncio=True, reuse_session=reuse_session) i = 0 start_time = time.time() async for doc in client.stream_doc( diff --git a/tests/unit/clients/python/test_client.py b/tests/unit/clients/python/test_client.py index 85d7371e52d31..addcb07dadffb 100644 --- a/tests/unit/clients/python/test_client.py +++ b/tests/unit/clients/python/test_client.py @@ -156,7 +156,6 @@ def test_all_sync_clients(protocol, mocker, use_stream): m3.assert_called_once() m4.assert_called() - @pytest.mark.slow @pytest.mark.parametrize('use_stream', [True, False]) def test_deployment_sync_client(mocker, use_stream): diff --git a/tests/unit/orchestrate/flow/flow-async/test_asyncflow.py b/tests/unit/orchestrate/flow/flow-async/test_asyncflow.py index df98e4cc14214..00901322ee50b 100644 --- a/tests/unit/orchestrate/flow/flow-async/test_asyncflow.py +++ b/tests/unit/orchestrate/flow/flow-async/test_asyncflow.py @@ -41,11 +41,14 @@ def documents(start_index, end_index): 'return_responses, return_class', [(True, Request), (False, DocumentArray)] ) @pytest.mark.parametrize('use_stream', [False, True]) +@pytest.mark.parametrize('reuse_session', [False, True]) async def test_run_async_flow( - protocol, mocker, flow_cls, return_responses, return_class, use_stream + protocol, mocker, flow_cls, return_responses, return_class, use_stream, reuse_session ): + if reuse_session and protocol != 'http': + return r_val = mocker.Mock() - with flow_cls(protocol=protocol, asyncio=True).add() as f: + with flow_cls(protocol=protocol, asyncio=True, reuse_session=reuse_session).add() as f: async for r in f.index( from_ndarray(np.random.random([num_docs, 4])), on_done=r_val, @@ -155,8 +158,11 @@ async def test_run_async_flow_other_task_concurrent(protocol): @pytest.mark.parametrize('protocol', ['websocket', 'grpc', 'http']) @pytest.mark.parametrize('flow_cls', [Flow, AsyncFlow]) @pytest.mark.parametrize('use_stream', [False, True]) -async def test_return_results_async_flow(protocol, flow_cls, use_stream): - with flow_cls(protocol=protocol, asyncio=True).add() as f: +@pytest.mark.parametrize('reuse_session', [False, True]) +async def test_return_results_async_flow(protocol, flow_cls, use_stream, reuse_session): + if reuse_session and protocol != 'http': + return + with flow_cls(protocol=protocol, asyncio=True, reuse_session=reuse_session).add() as f: async for r in f.index( from_ndarray(np.random.random([10, 2])), stream=use_stream ): @@ -169,8 +175,9 @@ async def test_return_results_async_flow(protocol, flow_cls, use_stream): @pytest.mark.parametrize('flow_api', ['delete', 'index', 'update', 'search']) @pytest.mark.parametrize('flow_cls', [Flow, AsyncFlow]) @pytest.mark.parametrize('use_stream', [False, True]) -async def test_return_results_async_flow_crud(protocol, flow_api, flow_cls, use_stream): - with flow_cls(protocol=protocol, asyncio=True).add() as f: +@pytest.mark.parametrize('reuse_session', [False, True]) +async def test_return_results_async_flow_crud(protocol, flow_api, flow_cls, use_stream, reuse_session): + with flow_cls(protocol=protocol, asyncio=True, reuse_session=reuse_session).add() as f: async for r in getattr(f, flow_api)(documents(0, 10), stream=use_stream): assert isinstance(r, DocumentArray)