diff --git a/jina/clients/__init__.py b/jina/clients/__init__.py index 6484240d771dd..bf31c068f1cee 100644 --- a/jina/clients/__init__.py +++ b/jina/clients/__init__.py @@ -30,6 +30,7 @@ def Client( prefetch: Optional[int] = 1000, protocol: Optional[Union[str, List[str]]] = 'GRPC', proxy: Optional[bool] = False, + reuse_session: Optional[bool] = False, suppress_root_logging: Optional[bool] = False, tls: Optional[bool] = False, traces_exporter_host: Optional[str] = None, @@ -59,6 +60,7 @@ def Client( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. @@ -113,6 +115,7 @@ def Client(args: Optional['argparse.Namespace'] = None, **kwargs) -> Union[ Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. diff --git a/jina/clients/base/__init__.py b/jina/clients/base/__init__.py index 7fe60b15f74e7..41ec147fbd74b 100644 --- a/jina/clients/base/__init__.py +++ b/jina/clients/base/__init__.py @@ -29,9 +29,9 @@ class BaseClient(InstrumentationMixin, ABC): """ def __init__( - self, - args: Optional['argparse.Namespace'] = None, - **kwargs, + self, + args: Optional['argparse.Namespace'] = None, + **kwargs, ): if args and isinstance(args, argparse.Namespace): self.args = args @@ -63,6 +63,12 @@ def __init__( ) send_telemetry_event(event='start', obj_cls_name=self.__class__.__name__) + async def close(self): + """Closes the potential resources of the Client. + :return: Return whatever a close method may return + """ + return self.teardown_instrumentation() + def teardown_instrumentation(self): """Shut down the OpenTelemetry tracer and meter if available. This ensures that the daemon threads for exporting metrics data is properly cleaned up. @@ -118,7 +124,7 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None: raise BadClientInput from ex def _get_requests( - self, **kwargs + self, **kwargs ) -> Union[Iterator['Request'], AsyncIterator['Request']]: """ Get request in generator. @@ -177,13 +183,14 @@ def inputs(self, bytes_gen: 'InputType') -> None: @abc.abstractmethod async def _get_results( - self, - inputs: 'InputType', - on_done: 'CallbackFnType', - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, - **kwargs, - ): ... + self, + inputs: 'InputType', + on_done: 'CallbackFnType', + on_error: Optional['CallbackFnType'] = None, + on_always: Optional['CallbackFnType'] = None, + **kwargs, + ): + ... @abc.abstractmethod def _is_flow_ready(self, **kwargs) -> bool: diff --git a/jina/clients/base/helper.py b/jina/clients/base/helper.py index 620513ceec460..50f43ae69e264 100644 --- a/jina/clients/base/helper.py +++ b/jina/clients/base/helper.py @@ -48,7 +48,6 @@ class AioHttpClientlet(ABC): def __init__( self, - url: str, logger: 'JinaLogger', max_attempts: int = 1, initial_backoff: float = 0.5, @@ -59,7 +58,6 @@ def __init__( ) -> None: """HTTP Client to be used with the streamer - :param url: url to send http/websocket request to :param logger: jina logger :param max_attempts: Number of sending attempts, including the original request. :param initial_backoff: The first retry will happen with a delay of random(0, initial_backoff) @@ -68,7 +66,6 @@ def __init__( :param tracer_provider: Optional tracer_provider that will be used to configure aiohttp tracing. :param kwargs: kwargs which will be forwarded to the `aiohttp.Session` instance. Used to pass headers to requests """ - self.url = url self.logger = logger self.msg_recv = 0 self.msg_sent = 0 @@ -131,7 +128,6 @@ async def start(self): """ with ImportExtensions(required=True): import aiohttp - self.session = aiohttp.ClientSession( **self._session_kwargs, trace_configs=self._trace_config ) @@ -154,9 +150,10 @@ class HTTPClientlet(AioHttpClientlet): UPDATE_EVENT_PREFIX = 14 # the update event has the following format: "event: update: {document_json}" - async def send_message(self, request: 'Request'): + async def send_message(self, url, request: 'Request'): """Sends a POST request to the server + :param url: the URL where to send the message :param request: request as dict :return: send post message """ @@ -166,23 +163,24 @@ async def send_message(self, request: 'Request'): req_dict['target_executor'] = req_dict['header']['target_executor'] for attempt in range(1, self.max_attempts + 1): try: - request_kwargs = {'url': self.url} + request_kwargs = {'url': url} if not docarray_v2: request_kwargs['json'] = req_dict else: from docarray.base_doc.io.json import orjson_dumps request_kwargs['data'] = JinaJsonPayload(value=req_dict) + async with self.session.post(**request_kwargs) as response: try: r_str = await response.json() except aiohttp.ContentTypeError: r_str = await response.text() r_status = response.status - handle_response_status(response.status, r_str, self.url) - return r_status, r_str + handle_response_status(r_status, r_str, url) + return r_status, r_str except (ValueError, ConnectionError, BadClient, aiohttp.ClientError, aiohttp.ClientConnectionError) as err: - self.logger.debug(f'Got an error: {err} sending POST to {self.url} in attempt {attempt}/{self.max_attempts}') + self.logger.debug(f'Got an error: {err} sending POST to {url} in attempt {attempt}/{self.max_attempts}') await retry.wait_or_raise_err( attempt=attempt, err=err, @@ -193,19 +191,20 @@ async def send_message(self, request: 'Request'): ) except Exception as exc: self.logger.debug( - f'Got a non-retried error: {exc} sending POST to {self.url}') + f'Got a non-retried error: {exc} sending POST to {url}') raise exc - async def send_streaming_message(self, doc: 'Document', on: str): + async def send_streaming_message(self, url, doc: 'Document', on: str): """Sends a GET SSE request to the server + :param url: the URL where to send the message :param doc: Request Document :param on: Request endpoint :yields: responses """ req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict() request_kwargs = { - 'url': self.url, + 'url': url, 'headers': {'Accept': 'text/event-stream'}, 'json': req_dict, } @@ -219,13 +218,14 @@ async def send_streaming_message(self, doc: 'Document', on: str): elif event.startswith(b'end'): pass - async def send_dry_run(self, **kwargs): + async def send_dry_run(self, url, **kwargs): """Query the dry_run endpoint from Gateway + :param url: the URL where to send the message :param kwargs: keyword arguments to make sure compatible API with other clients :return: send get message """ return await self.session.get( - url=self.url, timeout=kwargs.get('timeout', None) + url=url, timeout=kwargs.get('timeout', None) ).__aenter__() async def recv_message(self): @@ -267,8 +267,9 @@ async def __anext__(self): class WebsocketClientlet(AioHttpClientlet): """Websocket Client to be used with the streamer""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, url, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + self.url = url self.websocket = None self.response_iter = None diff --git a/jina/clients/base/http.py b/jina/clients/base/http.py index 746bdf0e0acfd..eaac304695403 100644 --- a/jina/clients/base/http.py +++ b/jina/clients/base/http.py @@ -23,6 +23,18 @@ class HTTPBaseClient(BaseClient): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._endpoints = [] + self.reuse_session = False + self._lock = AsyncExitStack() + self.iolet = None + + async def close(self): + """Closes the potential resources of the Client. + :return: Return whatever a close method may return + """ + ret = super().close() + if self.iolet is not None: + await self.iolet.__aexit__(None, None, None) + return ret async def _get_endpoints_from_openapi(self, **kwargs): def extract_paths_by_method(spec): @@ -69,16 +81,27 @@ async def _is_flow_ready(self, **kwargs) -> bool: try: proto = 'https' if self.args.tls else 'http' url = f'{proto}://{self.args.host}:{self.args.port}/dry_run' - iolet = await stack.enter_async_context( - HTTPClientlet( - url=url, - logger=self.logger, - tracer_provider=self.tracer_provider, - **kwargs, + + if not self.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) ) - ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) + await self.iolet.__aenter__() + iolet = self.iolet - response = await iolet.send_dry_run(**kwargs) + response = await iolet.send_dry_run(url=url, **kwargs) r_status = response.status r_str = await response.json() @@ -96,20 +119,20 @@ async def _is_flow_ready(self, **kwargs) -> bool: return False async def _get_results( - self, - inputs: 'InputType', - on_done: 'CallbackFnType', - on_error: Optional['CallbackFnType'] = None, - on_always: Optional['CallbackFnType'] = None, - max_attempts: int = 1, - initial_backoff: float = 0.5, - max_backoff: float = 0.1, - backoff_multiplier: float = 1.5, - results_in_order: bool = False, - prefetch: Optional[int] = None, - timeout: Optional[int] = None, - return_type: Type[DocumentArray] = DocumentArray, - **kwargs, + self, + inputs: 'InputType', + on_done: 'CallbackFnType', + on_error: Optional['CallbackFnType'] = None, + on_always: Optional['CallbackFnType'] = None, + max_attempts: int = 1, + initial_backoff: float = 0.5, + max_backoff: float = 0.1, + backoff_multiplier: float = 1.5, + results_in_order: bool = False, + prefetch: Optional[int] = None, + timeout: Optional[int] = None, + return_type: Type[DocumentArray] = DocumentArray, + **kwargs, ): """ :param inputs: the callable @@ -152,22 +175,27 @@ async def _get_results( else: url = f'{proto}://{self.args.host}:{self.args.port}/post' - iolet = await stack.enter_async_context( - HTTPClientlet( - url=url, - logger=self.logger, - tracer_provider=self.tracer_provider, - max_attempts=max_attempts, - initial_backoff=initial_backoff, - max_backoff=max_backoff, - backoff_multiplier=backoff_multiplier, - timeout=timeout, - **kwargs, + if not self.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) ) - ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + **kwargs, + ) + self.iolet = await self.iolet.__aenter__() + iolet = self.iolet def _request_handler( - request: 'Request', **kwargs + request: 'Request', **kwargs ) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]': """ For HTTP Client, for each request in the iterator, we `send_message` using @@ -176,7 +204,7 @@ def _request_handler( :param kwargs: kwargs :return: asyncio Task for sending message """ - return asyncio.ensure_future(iolet.send_message(request=request)), None + return asyncio.ensure_future(iolet.send_message(url=url, request=request)), None def _result_handler(result): return result @@ -191,7 +219,7 @@ def _result_handler(result): **streamer_args, ) async for response in streamer.stream( - request_iterator=request_iterator, results_in_order=results_in_order + request_iterator=request_iterator, results_in_order=results_in_order ): r_status, r_str = response handle_response_status(r_status, r_str, url) @@ -232,13 +260,13 @@ def _result_handler(result): yield resp async def _get_streaming_results( - self, - on: str, - inputs: 'Document', - parameters: Optional[Dict] = None, - return_type: Type[Document] = Document, - timeout: Optional[int] = None, - **kwargs, + self, + on: str, + inputs: 'Document', + parameters: Optional[Dict] = None, + return_type: Type[Document] = Document, + timeout: Optional[int] = None, + **kwargs, ): proto = 'https' if self.args.tls else 'http' endpoint = on.strip('/') @@ -248,17 +276,28 @@ async def _get_streaming_results( url = f'{proto}://{self.args.host}:{self.args.port}/{endpoint}' else: url = f'{proto}://{self.args.host}:{self.args.port}/default' - - iolet = HTTPClientlet( - url=url, - logger=self.logger, - tracer_provider=self.tracer_provider, - timeout=timeout, - **kwargs, - ) - - async with iolet: - async for doc in iolet.send_streaming_message(doc=inputs, on=on): + async with AsyncExitStack() as stack: + if not self.reuse_session: + iolet = await stack.enter_async_context( + HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + **kwargs, + ) + ) + else: + async with self._lock: + if self.iolet is None: + self.iolet = HTTPClientlet( + logger=self.logger, + tracer_provider=self.tracer_provider, + timeout=timeout, + **kwargs, + ) + await self.iolet.__aenter__() + iolet = self.iolet + async for doc in iolet.send_streaming_message(url=url, doc=inputs, on=on): if not docarray_v2: yield Document.from_dict(json.loads(doc)) else: diff --git a/jina/clients/http.py b/jina/clients/http.py index e811d75f46ec5..2698e316d007a 100644 --- a/jina/clients/http.py +++ b/jina/clients/http.py @@ -9,6 +9,7 @@ PostMixin, ProfileMixin, ) +import asyncio class HTTPClient( @@ -80,3 +81,8 @@ async def async_inputs(): print(resp) """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._lock = asyncio.Lock() + self.reuse_session = self.args.reuse_session diff --git a/jina/orchestrate/flow/base.py b/jina/orchestrate/flow/base.py index 9aec82444a280..0b0a36d47b3c7 100644 --- a/jina/orchestrate/flow/base.py +++ b/jina/orchestrate/flow/base.py @@ -133,6 +133,7 @@ def __init__( prefetch: Optional[int] = 1000, protocol: Optional[Union[str, List[str]]] = 'GRPC', proxy: Optional[bool] = False, + reuse_session: Optional[bool] = False, suppress_root_logging: Optional[bool] = False, tls: Optional[bool] = False, traces_exporter_host: Optional[str] = None, @@ -155,6 +156,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. @@ -417,6 +419,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol between server and client. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy + :param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it :param suppress_root_logging: If set, then no root handlers will be suppressed from logging. :param tls: If set, connect to gateway using tls encryption :param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent. diff --git a/jina/parsers/client.py b/jina/parsers/client.py index 4f46ac390fe29..8f22375c8d24b 100644 --- a/jina/parsers/client.py +++ b/jina/parsers/client.py @@ -81,3 +81,9 @@ def mixin_client_features_parser(parser): default='default', help='The config name or the absolute path to the YAML config file of the logger used in this object.', ) + parser.add_argument( + '--reuse-session', + action='store_true', + default=False, + help='True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it', + ) diff --git a/jina_cli/autocomplete.py b/jina_cli/autocomplete.py index e3f85ff9fc5d3..c5dda85e72573 100644 --- a/jina_cli/autocomplete.py +++ b/jina_cli/autocomplete.py @@ -564,6 +564,7 @@ '--metrics-exporter-host', '--metrics-exporter-port', '--log-config', + '--reuse-session', '--protocol', '--grpc-channel-options', '--prefetch', diff --git a/tests/integration/concurrent_clients/test_concurrent_clients.py b/tests/integration/concurrent_clients/test_concurrent_clients.py index 546fe94da3314..898916ffe03cb 100644 --- a/tests/integration/concurrent_clients/test_concurrent_clients.py +++ b/tests/integration/concurrent_clients/test_concurrent_clients.py @@ -23,19 +23,23 @@ def ping(self, **kwargs): @pytest.mark.parametrize('prefetch', [1, 10]) @pytest.mark.parametrize('concurrent', [15]) @pytest.mark.parametrize('use_stream', [False, True]) +@pytest.mark.parametrize('reuse_session', [True, False]) def test_concurrent_clients( - concurrent, protocol, shards, polling, prefetch, reraise, use_stream + concurrent, protocol, shards, polling, prefetch, reraise, use_stream, reuse_session ): if not use_stream and protocol != 'grpc': return + if reuse_session and protocol != 'http': + return + def pong(peer_hash, queue, resp: Response): for d in resp.docs: queue.put((peer_hash, d.text)) def peer_client(port, protocol, peer_hash, queue): - c = Client(protocol=protocol, port=port) + c = Client(protocol=protocol, port=port, reuse_session=reuse_session) for _ in range(NUM_REQUESTS): c.post( '/ping', diff --git a/tests/integration/docarray_v2/test_singleton.py b/tests/integration/docarray_v2/test_singleton.py index e8cd663eb10d5..7405df29a4792 100644 --- a/tests/integration/docarray_v2/test_singleton.py +++ b/tests/integration/docarray_v2/test_singleton.py @@ -13,7 +13,10 @@ ) @pytest.mark.parametrize('return_type', ['batch', 'singleton']) @pytest.mark.parametrize('include_gateway', [True, False]) -def test_singleton_return(ctxt_manager, protocols, return_type, include_gateway): +@pytest.mark.parametrize('reuse_session', [True, False]) +def test_singleton_return(ctxt_manager, protocols, return_type, include_gateway, reuse_session): + if reuse_session and 'http' not in protocols: + return if 'websocket' in protocols and ctxt_manager != 'flow': return if not include_gateway and ctxt_manager == 'flow': @@ -63,7 +66,7 @@ def foo_single( with ctxt: for port, protocol in zip(ports, protocols): - c = Client(port=port, protocol=protocol) + c = Client(port=port, protocol=protocol, reuse_session=reuse_session) docs = c.post( on='/foo', inputs=MySingletonReturnInputDoc(text='hello', price=2), @@ -102,7 +105,10 @@ def foo_single( 'protocols', [['grpc'], ['http'], ['websocket'], ['grpc', 'http']] ) @pytest.mark.parametrize('return_type', ['batch', 'singleton']) -def test_singleton_return_async(ctxt_manager, protocols, return_type): +@pytest.mark.parametrize('reuse_session', [True, False]) +def test_singleton_return_async(ctxt_manager, protocols, return_type, reuse_session): + if reuse_session and 'http' not in protocols: + return if 'websocket' in protocols and ctxt_manager != 'flow': return @@ -149,7 +155,7 @@ async def foo_single( with ctxt: for port, protocol in zip(ports, protocols): - c = Client(port=port, protocol=protocol) + c = Client(port=port, protocol=protocol, reuse_session=reuse_session) docs = c.post( on='/foo', inputs=MySingletonReturnInputDoc(text='hello', price=2), diff --git a/tests/integration/streaming/test_streaming.py b/tests/integration/streaming/test_streaming.py index 5d2f6e4af848b..7f6675ec57e1f 100644 --- a/tests/integration/streaming/test_streaming.py +++ b/tests/integration/streaming/test_streaming.py @@ -23,7 +23,10 @@ async def non_gen_task(self, docs: DocumentArray, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize('protocol', ['http', 'grpc']) @pytest.mark.parametrize('include_gateway', [False, True]) -async def test_streaming_deployment(protocol, include_gateway): +@pytest.mark.parametrize('reuse_session', [False, True]) +async def test_streaming_deployment(protocol, include_gateway, reuse_session): + if reuse_session and protocol != 'http': + return port = random_port() docs = [] @@ -35,7 +38,7 @@ async def test_streaming_deployment(protocol, include_gateway): port=port, include_gateway=include_gateway, ): - client = Client(port=port, protocol=protocol, asyncio=True) + client = Client(port=port, protocol=protocol, asyncio=True, reuse_session=reuse_session) i = 0 async for doc in client.stream_doc( on='/hello', @@ -60,7 +63,10 @@ async def task(self, doc: Document, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize('protocol', ['http', 'grpc']) @pytest.mark.parametrize('include_gateway', [False, True]) -async def test_streaming_delay(protocol, include_gateway): +@pytest.mark.parametrize('reuse_session', [False, True]) +async def test_streaming_delay(protocol, include_gateway, reuse_session): + if reuse_session and protocol != 'http': + return from jina import Deployment port = random_port() @@ -72,7 +78,7 @@ async def test_streaming_delay(protocol, include_gateway): port=port, include_gateway=include_gateway, ): - client = Client(port=port, protocol=protocol, asyncio=True) + client = Client(port=port, protocol=protocol, asyncio=True, reuse_session=reuse_session) i = 0 start_time = time.time() async for doc in client.stream_doc( diff --git a/tests/unit/clients/python/test_client.py b/tests/unit/clients/python/test_client.py index 85d7371e52d31..addcb07dadffb 100644 --- a/tests/unit/clients/python/test_client.py +++ b/tests/unit/clients/python/test_client.py @@ -156,7 +156,6 @@ def test_all_sync_clients(protocol, mocker, use_stream): m3.assert_called_once() m4.assert_called() - @pytest.mark.slow @pytest.mark.parametrize('use_stream', [True, False]) def test_deployment_sync_client(mocker, use_stream): diff --git a/tests/unit/clients/test_helper.py b/tests/unit/clients/test_helper.py index 66ae0d9081f38..55182b066616e 100644 --- a/tests/unit/clients/test_helper.py +++ b/tests/unit/clients/test_helper.py @@ -33,11 +33,11 @@ async def test_http_clientlet(): port = random_port() with Flow(port=port, protocol='http').add(): async with HTTPClientlet( - url=f'http://localhost:{port}/post', logger=logger + logger=logger ) as iolet: request = _new_data_request('/', None, {'a': 'b'}) assert request.header.target_executor == '' - r_status, r_json = await iolet.send_message(request) + r_status, r_json = await iolet.send_message(url=f'http://localhost:{port}/post', request=request) response = DataRequest(r_json) assert response.header.exec_endpoint == '/' assert response.parameters == {'a': 'b'} @@ -50,11 +50,11 @@ async def test_http_clientlet_target(): port = random_port() with Flow(port=port, protocol='http').add(): async with HTTPClientlet( - url=f'http://localhost:{port}/post', logger=logger + logger=logger ) as iolet: request = _new_data_request('/', 'nothing', {'a': 'b'}) assert request.header.target_executor == 'nothing' - r = await iolet.send_message(request) + r = await iolet.send_message(url=f'http://localhost:{port}/post', request=request) r_status, r_json = r response = DataRequest(r_json) assert response.header.exec_endpoint == '/' diff --git a/tests/unit/orchestrate/flow/flow-async/test_asyncflow.py b/tests/unit/orchestrate/flow/flow-async/test_asyncflow.py index df98e4cc14214..00901322ee50b 100644 --- a/tests/unit/orchestrate/flow/flow-async/test_asyncflow.py +++ b/tests/unit/orchestrate/flow/flow-async/test_asyncflow.py @@ -41,11 +41,14 @@ def documents(start_index, end_index): 'return_responses, return_class', [(True, Request), (False, DocumentArray)] ) @pytest.mark.parametrize('use_stream', [False, True]) +@pytest.mark.parametrize('reuse_session', [False, True]) async def test_run_async_flow( - protocol, mocker, flow_cls, return_responses, return_class, use_stream + protocol, mocker, flow_cls, return_responses, return_class, use_stream, reuse_session ): + if reuse_session and protocol != 'http': + return r_val = mocker.Mock() - with flow_cls(protocol=protocol, asyncio=True).add() as f: + with flow_cls(protocol=protocol, asyncio=True, reuse_session=reuse_session).add() as f: async for r in f.index( from_ndarray(np.random.random([num_docs, 4])), on_done=r_val, @@ -155,8 +158,11 @@ async def test_run_async_flow_other_task_concurrent(protocol): @pytest.mark.parametrize('protocol', ['websocket', 'grpc', 'http']) @pytest.mark.parametrize('flow_cls', [Flow, AsyncFlow]) @pytest.mark.parametrize('use_stream', [False, True]) -async def test_return_results_async_flow(protocol, flow_cls, use_stream): - with flow_cls(protocol=protocol, asyncio=True).add() as f: +@pytest.mark.parametrize('reuse_session', [False, True]) +async def test_return_results_async_flow(protocol, flow_cls, use_stream, reuse_session): + if reuse_session and protocol != 'http': + return + with flow_cls(protocol=protocol, asyncio=True, reuse_session=reuse_session).add() as f: async for r in f.index( from_ndarray(np.random.random([10, 2])), stream=use_stream ): @@ -169,8 +175,9 @@ async def test_return_results_async_flow(protocol, flow_cls, use_stream): @pytest.mark.parametrize('flow_api', ['delete', 'index', 'update', 'search']) @pytest.mark.parametrize('flow_cls', [Flow, AsyncFlow]) @pytest.mark.parametrize('use_stream', [False, True]) -async def test_return_results_async_flow_crud(protocol, flow_api, flow_cls, use_stream): - with flow_cls(protocol=protocol, asyncio=True).add() as f: +@pytest.mark.parametrize('reuse_session', [False, True]) +async def test_return_results_async_flow_crud(protocol, flow_api, flow_cls, use_stream, reuse_session): + with flow_cls(protocol=protocol, asyncio=True, reuse_session=reuse_session).add() as f: async for r in getattr(f, flow_api)(documents(0, 10), stream=use_stream): assert isinstance(r, DocumentArray)