Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use body for streaming instead of params #6098

Merged
merged 35 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
94348f6
fix: add post endpoint for streaming
NarekA Oct 20, 2023
b645afb
test: stream complex docs
NarekA Oct 20, 2023
72037b5
docs: include docstring for issue test
NarekA Oct 20, 2023
99cdd7d
Merge branch 'master' into fix-streaming-6091
NarekA Oct 20, 2023
cc315c9
Merge branch 'master' into fix-streaming-6091
NarekA Oct 24, 2023
8bc5124
fix: use random port
NarekA Oct 24, 2023
804b665
fix: remove import
NarekA Oct 24, 2023
3ce6e1e
fix: simplify client code
NarekA Oct 24, 2023
a5376af
fix: use json field
NarekA Oct 24, 2023
025413f
fix: use to_dict for docarray v1
NarekA Oct 24, 2023
60c9b5e
docs: add docs about http get
NarekA Oct 24, 2023
03cc4c4
fix: typo in deployment
NarekA Oct 24, 2023
b89c01a
fix: don't use body.data
NarekA Oct 24, 2023
cf2f145
fix: do unpack body
NarekA Oct 24, 2023
fd30b0e
fix: docarray v2 cast model
NarekA Oct 24, 2023
f1f9a88
fix: change start time delay
NarekA Oct 24, 2023
3aa6534
Revert "fix: change start time delay"
NarekA Oct 25, 2023
57af50a
fix: use get only
NarekA Oct 25, 2023
5f27f0a
Merge remote-tracking branch 'origin/master' into fix-streaming-2-6091
NarekA Oct 25, 2023
a72a797
fix: delay test
NarekA Oct 25, 2023
ee77d13
fix: fix get and post endpoints
NarekA Oct 25, 2023
524aba4
fix: remove post
NarekA Oct 25, 2023
5cfb423
fix: remove endpoint tags
NarekA Oct 25, 2023
3d1b07f
Merge branch 'master' into fix-streaming-2-6091
NarekA Oct 25, 2023
ca3a707
test: use get with url params
NarekA Oct 26, 2023
92405db
docs: fix docs on streaming endpoints
NarekA Oct 26, 2023
2c626bf
test: increase tolerance
NarekA Oct 26, 2023
38a577b
fix: iteration over chunks
NarekA Oct 26, 2023
c4afd99
fix: gateway forwarding
NarekA Oct 27, 2023
c654c14
fix: pre-commit changes
NarekA Oct 27, 2023
bb90d14
fix: don't check for docarray_v2
NarekA Oct 27, 2023
08bfd6b
fix: remove type-hint
NarekA Oct 27, 2023
3eff072
fix: update test output
NarekA Oct 27, 2023
46af5fa
fix: adding payload
NarekA Oct 27, 2023
bf3cbe2
fix: use json payload
NarekA Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
'json': doc.dict() if docarray_v2 else doc.to_dict(),
}

async with self.session.post(**request_kwargs) as response:
async with self.session.get(**request_kwargs) as response:

Check warning on line 206 in jina/clients/base/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/helper.py#L206

Added line #L206 was not covered by tests
async for chunk in response.content.iter_any():
events = chunk.split(b'event: ')[1:]
for event in events:
Expand Down
31 changes: 5 additions & 26 deletions jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,35 +252,14 @@
methods=['GET'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_get(request: Request):
query_params = dict(request.query_params)
async def streaming_get(request: Request, body: input_doc_model = None):
if not body:
Copy link
Contributor Author

@NarekA NarekA Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can and request.method == 'GET' to the condition here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may not need POST at all then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the current state of this PR, and it fixes everything. Can't tell if tests need to be re-run or if they're actually failing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed that issue, I think the current failures just need re-runs, can you take a look?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, found one that's not flakyness

query_params = dict(request.query_params)
body = input_doc_model.parse_obj(query_params)

Check warning on line 258 in jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py#L255-L258

Added lines #L255 - L258 were not covered by tests

async def event_generator():
async for doc, error in streamer.stream_doc(
doc=input_doc_model(**query_params), exec_endpoint=endpoint_path
):
if error:
raise HTTPException(status_code=499, detail=str(error))
yield {
'event': 'update',
'data': doc.dict()
}
yield {
'event': 'end'
}

return EventSourceResponse(event_generator())

@app.api_route(
path=f'/{endpoint_path.strip("/")}',
methods=['POST'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_post(body: dict):

async def event_generator():
async for doc, error in streamer.stream_doc(
doc=input_doc_model.parse_obj(body), exec_endpoint=endpoint_path
doc=body, exec_endpoint=endpoint_path
):
if error:
raise HTTPException(status_code=499, detail=str(error))
Expand Down
65 changes: 29 additions & 36 deletions jina/serve/runtimes/worker/http_fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

from jina import DocumentArray, Document
from jina import Document, DocumentArray
from jina._docarray import docarray_v2
from jina.importer import ImportExtensions
from jina.serve.networking.sse import EventSourceResponse
Expand All @@ -11,15 +11,15 @@
from jina.logging.logger import JinaLogger

if docarray_v2:
from docarray import DocList, BaseDoc
from docarray import BaseDoc, DocList

Check warning on line 14 in jina/serve/runtimes/worker/http_fastapi_app.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/http_fastapi_app.py#L14

Added line #L14 was not covered by tests


def get_fastapi_app(
request_models_map: Dict,
caller: Callable,
logger: 'JinaLogger',
cors: bool = False,
**kwargs,
request_models_map: Dict,
caller: Callable,
logger: 'JinaLogger',
cors: bool = False,
**kwargs,
):
"""
Get the app from FastAPI as the REST interface.
Expand All @@ -35,15 +35,18 @@
from fastapi import FastAPI, Response, HTTPException
import pydantic
from fastapi.middleware.cors import CORSMiddleware
import os

from pydantic import BaseModel, Field
from pydantic.config import BaseConfig, inherit_config

from jina.proto import jina_pb2
from jina.serve.runtimes.gateway.models import _to_camel_case
import os

class Header(BaseModel):
request_id: Optional[str] = Field(description='Request ID', example=os.urandom(16).hex())
request_id: Optional[str] = Field(
description='Request ID', example=os.urandom(16).hex()
)

class Config(BaseConfig):
alias_generator = _to_camel_case
Expand All @@ -66,11 +69,11 @@
logger.warning('CORS is enabled. This service is accessible from any website!')

def add_post_route(
endpoint_path,
input_model,
output_model,
input_doc_list_model=None,
output_doc_list_model=None,
endpoint_path,
input_model,
output_model,
input_doc_list_model=None,
output_doc_list_model=None,
):
app_kwargs = dict(
path=f'/{endpoint_path.strip("/")}',
Expand Down Expand Up @@ -123,8 +126,8 @@
return ret

def add_streaming_routes(
endpoint_path,
input_doc_model=None,
endpoint_path,
input_doc_model=None,
):
from fastapi import Request

Expand All @@ -133,26 +136,14 @@
methods=['GET'],
Copy link
Contributor Author

@NarekA NarekA Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following code works, but I did not allow get since it doesn't work with gateway deployments.

methods=['GET', 'POST'],

summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_get(request: Request):
query_params = dict(request.query_params)
req = DataRequest()
req.header.exec_endpoint = endpoint_path
if not docarray_v2:
req.data.docs = DocumentArray([Document.from_dict(query_params)])
else:
req.document_array_cls = DocList[input_doc_model]
req.data.docs = DocList[input_doc_model](
[input_doc_model(**query_params)]
async def streaming_get(request: Request, body: input_doc_model = None):
if not body:
query_params = dict(request.query_params)
body = (

Check warning on line 142 in jina/serve/runtimes/worker/http_fastapi_app.py

View check run for this annotation

Codecov / codecov/patch

jina/serve/runtimes/worker/http_fastapi_app.py#L139-L142

Added lines #L139 - L142 were not covered by tests
input_doc_model.parse_obj(query_params)
if docarray_v2
else Document.from_dict(query_params)
)
event_generator = _gen_dict_documents(await caller(req))
return EventSourceResponse(event_generator)

@app.api_route(
path=f'/{endpoint_path.strip("/")}',
methods=['POST'],
summary=f'Streaming Endpoint {endpoint_path}',
)
async def streaming_post(body: input_doc_model, request: Request):
req = DataRequest()
req.header.exec_endpoint = endpoint_path
if not docarray_v2:
Expand All @@ -169,7 +160,9 @@
output_doc_model = input_output_map['output']['model']
is_generator = input_output_map['is_generator']
parameters_model = input_output_map['parameters']['model'] or Optional[Dict]
default_parameters = ... if input_output_map['parameters']['model'] else None
default_parameters = (
... if input_output_map['parameters']['model'] else None
)

if docarray_v2:
_config = inherit_config(InnerConfig, BaseDoc.__config__)
Expand Down
20 changes: 11 additions & 9 deletions tests/integration/docarray_v2/test_issues.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import List, Optional, Dict
from typing import Dict, List, Optional

import pytest
from docarray import BaseDoc, DocList
from pydantic import Field

from jina import Executor, Flow, requests, Deployment, Client
from jina import Client, Deployment, Executor, Flow, requests


class Nested2Doc(BaseDoc):
Expand Down Expand Up @@ -78,6 +78,7 @@ def test_issue_6019_with_nested_list():
assert res[0].text == 'hello world'
assert res[0].nested[0].nested.value == 'test'


def test_issue_6084():
class EnvInfo(BaseDoc):
history: str = ''
Expand All @@ -86,7 +87,6 @@ class A(BaseDoc):
b: EnvInfo

class MyIssue6084Exec(Executor):

@requests
def foo(self, docs: DocList[A], **kwargs) -> DocList[A]:
pass
Expand Down Expand Up @@ -115,7 +115,10 @@ class InputWithComplexFields(BaseDoc):
class MyExecutor(Executor):
@requests(on="/stream")
async def stream(
self, doc: InputWithComplexFields, parameters: Optional[Dict] = None, **kwargs
self,
doc: InputWithComplexFields,
parameters: Optional[Dict] = None,
**kwargs,
) -> InputWithComplexFields:
for i in range(4):
yield InputWithComplexFields(text=f"hello world {doc.text} {i}")
Expand All @@ -134,10 +137,9 @@ async def stream(
docs.append(doc)

assert [d.text for d in docs] == [
"hello world my input text 0",
"hello world my input text 1",
"hello world my input text 2",
"hello world my input text 3",
'hello world test 0',
'hello world test 1',
'hello world test 2',
'hello world test 3',
]
assert docs[0].nested_field.name == "test_name"

12 changes: 5 additions & 7 deletions tests/integration/docarray_v2/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,17 @@ async def test_streaming_delay(protocol, include_gateway):
):
client = Client(port=port, protocol=protocol, asyncio=True)
i = 0
stream = client.stream_doc(
start_time = time.time()
async for doc in client.stream_doc(
on='/hello',
inputs=MyDocument(text='hello world', number=i),
return_type=MyDocument,
)
start_time = None
async for doc in stream:
start_time = start_time or time.time()
):
assert doc.text == f'hello world {i}'
i += 1
delay = time.time() - start_time

# 0.5 seconds between each request + 0.5 seconds tolerance interval
assert delay < (0.5 * i), f'Expected delay to be less than {0.5 * i}, got {delay} on iteration {i}'
assert time.time() - start_time < (0.5 * i) + 0.5
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the very least, we should merge this and fix the gateway.



@pytest.mark.asyncio
Expand Down
Loading