diff --git a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py index bfc23dcbb726f..ac2c3a50ddee2 100644 --- a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py +++ b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py @@ -254,7 +254,11 @@ def add_streaming_routes( ) async def streaming_get(request: Request, body: input_doc_model = None): body = body or dict(request.query_params) - body = input_doc_model.parse_obj(body) if docarray_v2 else Document.from_dict(body) + body = ( + input_doc_model.parse_obj(body) + if docarray_v2 + else Document.from_dict(body) + ) async def event_generator(): async for doc, error in streamer.stream_doc( diff --git a/jina/serve/runtimes/gateway/request_handling.py b/jina/serve/runtimes/gateway/request_handling.py index fd8a4293d1edb..316efbaa9d8fb 100644 --- a/jina/serve/runtimes/gateway/request_handling.py +++ b/jina/serve/runtimes/gateway/request_handling.py @@ -175,14 +175,13 @@ def _http_fastapi_default_app( ) ) - async def _load_balance(self, request: 'aiohttp.web_request.Request'): import aiohttp from aiohttp import web + target_server = next(self.load_balancer_servers) target_url = f'{target_server}{request.path_qs}' - try: async with aiohttp.ClientSession() as session: @@ -192,8 +191,9 @@ async def _load_balance(self, request: 'aiohttp.web_request.Request'): if payload: request_kwargs['json'] = payload - - async with session.get(url=target_url, **request_kwargs) as response: + async with session.get( + url=target_url, **request_kwargs + ) as response: # Create a StreamResponse with the same headers and status as the target response stream_response = web.StreamResponse( status=response.status, diff --git a/tests/integration/streaming/test_streaming.py b/tests/integration/streaming/test_streaming.py index 3cc72566f3785..5d2f6e4af848b 100644 --- a/tests/integration/streaming/test_streaming.py +++ b/tests/integration/streaming/test_streaming.py @@ -38,7 +38,10 @@ async def test_streaming_deployment(protocol, include_gateway): client = Client(port=port, protocol=protocol, asyncio=True) i = 0 async for doc in client.stream_doc( - on='/hello', inputs=Document(text='hello world'), return_type=Document, input_type=Document + on='/hello', + inputs=Document(text='hello world'), + return_type=Document, + input_type=Document, ): docs.append(doc.text) i += 1