From c817617fdf18edc62e579623fdd319becc40af5d Mon Sep 17 00:00:00 2001 From: nareka Date: Fri, 20 Oct 2023 14:14:43 -0700 Subject: [PATCH] fix: add post endpoint for streaming --- jina/clients/base/helper.py | 10 +++---- .../gateway/http_fastapi_app_docarrayv2.py | 27 +++++++++++++++++-- .../serve/runtimes/worker/http_fastapi_app.py | 26 +++++++++++++++--- 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/jina/clients/base/helper.py b/jina/clients/base/helper.py index 78ebd7f8dd37c..2da159a6a64a9 100644 --- a/jina/clients/base/helper.py +++ b/jina/clients/base/helper.py @@ -197,19 +197,15 @@ async def send_streaming_message(self, doc: 'Document', on: str): :param on: Request endpoint :yields: responses """ - if docarray_v2: - req_dict = doc.dict() - else: - req_dict = doc.to_dict() + req_dict = doc.json() request_kwargs = { 'url': self.url, 'headers': {'Accept': 'text/event-stream'}, } - req_dict = {key: value for key, value in req_dict.items() if value is not None} - request_kwargs['params'] = req_dict + request_kwargs['data'] = req_dict - async with self.session.get(**request_kwargs) as response: + async with self.session.post(**request_kwargs) as response: async for chunk in response.content.iter_any(): events = chunk.split(b'event: ')[1:] for event in events: diff --git a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py index 517365303769f..3a1fb1c115bcc 100644 --- a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py +++ b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py @@ -241,7 +241,7 @@ async def post(body: input_model, response: Response): ) return result - def add_streaming_get_route( + def add_streaming_routes( endpoint_path, input_doc_model=None, ): @@ -258,6 +258,29 @@ async def streaming_get(request: Request): async def event_generator(): async for doc, error in streamer.stream_doc( doc=input_doc_model(**query_params), exec_endpoint=endpoint_path + ): + if error: + raise HTTPException(status_code=499, detail=str(error)) + yield { + 'event': 'update', + 'data': doc.dict() + } + yield { + 'event': 'end' + } + + return EventSourceResponse(event_generator()) + + @app.api_route( + path=f'/{endpoint_path.strip("/")}', + methods=['POST'], + summary=f'Streaming Endpoint {endpoint_path}', + ) + async def streaming_post(body: input_doc_model, request: Request): + + async def event_generator(): + async for doc, error in streamer.stream_doc( + doc=input_doc_model(**body.data), exec_endpoint=endpoint_path ): if error: raise HTTPException(status_code=499, detail=str(error)) @@ -293,7 +316,7 @@ async def event_generator(): ) if is_generator: - add_streaming_get_route( + add_streaming_routes( endpoint, input_doc_model=input_doc_model, ) diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index fb3470425fcc4..7e35b8e1c0b37 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -122,7 +122,7 @@ async def post(body: input_model, response: Response): ret = output_model(data=docs_response, parameters=resp.parameters) return ret - def add_streaming_get_route( + def add_streaming_routes( endpoint_path, input_doc_model=None, ): @@ -143,7 +143,27 @@ async def streaming_get(request: Request): req.data.docs = DocumentArray([Document.from_dict(query_params)]) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_model]([input_doc_model(**query_params)]) + req.data.docs = DocList[input_doc_model]( + [input_doc_model(**query_params)] + ) + event_generator = _gen_dict_documents(await caller(req)) + return EventSourceResponse(event_generator) + + @app.api_route( + path=f'/{endpoint_path.strip("/")}', + methods=['POST'], + summary=f'Streaming Endpoint {endpoint_path}', + ) + async def streaming_post(body: input_doc_model, request: Request): + req = DataRequest() + req.header.exec_endpoint = endpoint_path + if not docarray_v2: + from docarray import Document + + req.data.docs = DocumentArray([body]) + else: + req.document_array_cls = DocList[input_doc_model] + req.data.docs = DocList[input_doc_model]([body]) event_generator = _gen_dict_documents(await caller(req)) return EventSourceResponse(event_generator) @@ -176,7 +196,7 @@ async def streaming_get(request: Request): ) if is_generator: - add_streaming_get_route( + add_streaming_routes( endpoint, input_doc_model=input_doc_model, )