diff --git a/docs/concepts/serving/executor/add-endpoints.md b/docs/concepts/serving/executor/add-endpoints.md index e539659fb2bcb..cf66b12cf3fa3 100644 --- a/docs/concepts/serving/executor/add-endpoints.md +++ b/docs/concepts/serving/executor/add-endpoints.md @@ -381,6 +381,11 @@ Streaming endpoints receive one Document as input and yields one Document at a t :class: note Streaming endpoints are only supported for HTTP and gRPC protocols and for Deployment and Flow with one single Executor. + +For HTTP deployment streaming executors generate both a GET and POST endpoint. +The GET endpoint support documents with string, integer, or float fields only, +whereas, POST requests support all docarrays. +The Jina client uses the POST endpoints. ``` A streaming endpoint has the following signature: diff --git a/jina/clients/base/helper.py b/jina/clients/base/helper.py index 78ebd7f8dd37c..d3b76dc014e72 100644 --- a/jina/clients/base/helper.py +++ b/jina/clients/base/helper.py @@ -197,19 +197,13 @@ async def send_streaming_message(self, doc: 'Document', on: str): :param on: Request endpoint :yields: responses """ - if docarray_v2: - req_dict = doc.dict() - else: - req_dict = doc.to_dict() - request_kwargs = { 'url': self.url, 'headers': {'Accept': 'text/event-stream'}, + 'json': doc.dict() if docarray_v2 else doc.to_dict(), } - req_dict = {key: value for key, value in req_dict.items() if value is not None} - request_kwargs['params'] = req_dict - async with self.session.get(**request_kwargs) as response: + async with self.session.post(**request_kwargs) as response: async for chunk in response.content.iter_any(): events = chunk.split(b'event: ')[1:] for event in events: diff --git a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py index 517365303769f..6feb721ab0794 100644 --- a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py +++ b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py @@ -241,7 +241,7 @@ async def post(body: input_model, response: Response): ) return result - def add_streaming_get_route( + def add_streaming_routes( endpoint_path, input_doc_model=None, ): @@ -258,6 +258,29 @@ async def streaming_get(request: Request): async def event_generator(): async for doc, error in streamer.stream_doc( doc=input_doc_model(**query_params), exec_endpoint=endpoint_path + ): + if error: + raise HTTPException(status_code=499, detail=str(error)) + yield { + 'event': 'update', + 'data': doc.dict() + } + yield { + 'event': 'end' + } + + return EventSourceResponse(event_generator()) + + @app.api_route( + path=f'/{endpoint_path.strip("/")}', + methods=['POST'], + summary=f'Streaming Endpoint {endpoint_path}', + ) + async def streaming_post(body: dict): + + async def event_generator(): + async for doc, error in streamer.stream_doc( + doc=input_doc_model.parse_obj(body), exec_endpoint=endpoint_path ): if error: raise HTTPException(status_code=499, detail=str(error)) @@ -293,7 +316,7 @@ async def event_generator(): ) if is_generator: - add_streaming_get_route( + add_streaming_routes( endpoint, input_doc_model=input_doc_model, ) diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index fb3470425fcc4..0407bb87d8760 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -122,7 +122,7 @@ async def post(body: input_model, response: Response): ret = output_model(data=docs_response, parameters=resp.parameters) return ret - def add_streaming_get_route( + def add_streaming_routes( endpoint_path, input_doc_model=None, ): @@ -138,12 +138,28 @@ async def streaming_get(request: Request): req = DataRequest() req.header.exec_endpoint = endpoint_path if not docarray_v2: - from docarray import Document - req.data.docs = DocumentArray([Document.from_dict(query_params)]) else: req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_model]([input_doc_model(**query_params)]) + req.data.docs = DocList[input_doc_model]( + [input_doc_model(**query_params)] + ) + event_generator = _gen_dict_documents(await caller(req)) + return EventSourceResponse(event_generator) + + @app.api_route( + path=f'/{endpoint_path.strip("/")}', + methods=['POST'], + summary=f'Streaming Endpoint {endpoint_path}', + ) + async def streaming_post(body: input_doc_model, request: Request): + req = DataRequest() + req.header.exec_endpoint = endpoint_path + if not docarray_v2: + req.data.docs = DocumentArray([body]) + else: + req.document_array_cls = DocList[input_doc_model] + req.data.docs = DocList[input_doc_model]([body]) event_generator = _gen_dict_documents(await caller(req)) return EventSourceResponse(event_generator) @@ -176,7 +192,7 @@ async def streaming_get(request: Request): ) if is_generator: - add_streaming_get_route( + add_streaming_routes( endpoint, input_doc_model=input_doc_model, ) diff --git a/tests/integration/docarray_v2/test_issues.py b/tests/integration/docarray_v2/test_issues.py index 0584078a50909..a7757b7516e2c 100644 --- a/tests/integration/docarray_v2/test_issues.py +++ b/tests/integration/docarray_v2/test_issues.py @@ -1,8 +1,10 @@ -from typing import List, Optional +from typing import List, Optional, Dict +import pytest from docarray import BaseDoc, DocList +from pydantic import Field -from jina import Executor, Flow, requests +from jina import Executor, Flow, requests, Deployment, Client class Nested2Doc(BaseDoc): @@ -92,3 +94,50 @@ def foo(self, docs: DocList[A], **kwargs) -> DocList[A]: f = Flow().add(uses=MyIssue6084Exec).add(uses=MyIssue6084Exec) with f: pass + + +@pytest.mark.asyncio +async def test_issue_6090(): + """Tests if streaming works with pydantic models with complex fields which are not + str, int, or float. + """ + + class NestedFieldSchema(BaseDoc): + name: str = "test_name" + dict_field: Dict = Field(default_factory=dict) + + class InputWithComplexFields(BaseDoc): + text: str = "test" + nested_field: NestedFieldSchema = Field(default_factory=NestedFieldSchema) + dict_field: Dict = Field(default_factory=dict) + bool_field: bool = False + + class MyExecutor(Executor): + @requests(on="/stream") + async def stream( + self, doc: InputWithComplexFields, parameters: Optional[Dict] = None, **kwargs + ) -> InputWithComplexFields: + for i in range(4): + yield InputWithComplexFields(text=f"hello world {doc.text} {i}") + + docs = [] + protocol = "http" + with Deployment(uses=MyExecutor, protocol=protocol) as dep: + client = Client(port=dep.port, protocol=protocol, asyncio=True) + example_doc = InputWithComplexFields(text="my input text") + async for doc in client.stream_doc( + on="/stream", + inputs=example_doc, + input_type=InputWithComplexFields, + return_type=InputWithComplexFields, + ): + docs.append(doc) + + assert [d.text for d in docs] == [ + "hello world my input text 0", + "hello world my input text 1", + "hello world my input text 2", + "hello world my input text 3", + ] + assert docs[0].nested_field.name == "test_name" + diff --git a/tests/integration/docarray_v2/test_streaming.py b/tests/integration/docarray_v2/test_streaming.py index 457391b747278..25befe58d814a 100644 --- a/tests/integration/docarray_v2/test_streaming.py +++ b/tests/integration/docarray_v2/test_streaming.py @@ -143,17 +143,19 @@ async def test_streaming_delay(protocol, include_gateway): ): client = Client(port=port, protocol=protocol, asyncio=True) i = 0 - start_time = time.time() - async for doc in client.stream_doc( + stream = client.stream_doc( on='/hello', inputs=MyDocument(text='hello world', number=i), return_type=MyDocument, - ): + ) + start_time = None + async for doc in stream: + start_time = start_time or time.time() assert doc.text == f'hello world {i}' i += 1 - + delay = time.time() - start_time # 0.5 seconds between each request + 0.5 seconds tolerance interval - assert time.time() - start_time < (0.5 * i) + 0.5 + assert delay < (0.5 * i), f'Expected delay to be less than {0.5 * i}, got {delay} on iteration {i}' @pytest.mark.asyncio