From d1450b07017abb8dd35c9fa7a4e348b8eee28447 Mon Sep 17 00:00:00 2001 From: nareka Date: Thu, 7 Dec 2023 14:21:49 -0800 Subject: [PATCH] feat: pass request-headers as metadata --- jina/__init__.py | 4 +- .../gateway/http_fastapi_app_docarrayv2.py | 8 +- .../serve/runtimes/worker/http_fastapi_app.py | 3 +- .../docarray_v2/test_metadata_headers.py | 150 ++++++++++++++++++ 4 files changed, 157 insertions(+), 8 deletions(-) create mode 100644 tests/integration/docarray_v2/test_metadata_headers.py diff --git a/jina/__init__.py b/jina/__init__.py index 778e7767ecd35..9e8793a8e68c3 100644 --- a/jina/__init__.py +++ b/jina/__init__.py @@ -35,7 +35,7 @@ def _ignore_google_warnings(): 'ignore', category=DeprecationWarning, message='Deprecated call to `pkg_resources.declare_namespace(\'google\')`.', - append=True + append=True, ) @@ -81,7 +81,7 @@ def _ignore_google_warnings(): # do not change this line manually # this is managed by proto/build-proto.sh and updated on every execution -__proto_version__ = '0.1.27' +__proto_version__ = '0.1.28' try: __docarray_version__ = _docarray.__version__ diff --git a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py index 91a915073b5fa..50c2e5c3c35ab 100644 --- a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py +++ b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py @@ -210,11 +210,7 @@ async def post(body: input_model, response: Response, request: Request): docs, exec_endpoint=endpoint_path, parameters=body.parameters, - metadata=dict( - request.headers or { - "no_headers": "true" - } - ), + metadata=dict(request.headers or {"no_headers": "true"}), target_executor=target_executor, request_id=req_id, return_results=True, @@ -252,6 +248,8 @@ def add_streaming_routes( endpoint_path, input_doc_model=None, ): + from fastapi import Request + @app.api_route( path=f'/{endpoint_path.strip("/")}', methods=['GET'], diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index be5ee41412cac..f2078c44f3e2c 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -97,7 +97,7 @@ async def post(body: input_model, response: Response, request: Request): if body.parameters is not None: req.parameters = body.parameters - req.metadata = dict(request.headers or {"no_headers": "true"}) + req.metadata = dict(request.headers or {}) req.header.exec_endpoint = endpoint_path data = body.data if isinstance(data, list): @@ -152,6 +152,7 @@ async def streaming_get(request: Request = None, body: input_doc_model = None): body = Document.from_pydantic_model(body) req = DataRequest() req.header.exec_endpoint = endpoint_path + req.metadata = dict(request.headers or {}) if not docarray_v2: req.data.docs = DocumentArray([body]) else: diff --git a/tests/integration/docarray_v2/test_metadata_headers.py b/tests/integration/docarray_v2/test_metadata_headers.py new file mode 100644 index 0000000000000..cdac87885e326 --- /dev/null +++ b/tests/integration/docarray_v2/test_metadata_headers.py @@ -0,0 +1,150 @@ +import logging +from typing import Dict, List, Literal, Optional + +import pytest +from docarray import BaseDoc, DocList + +from jina import Client, Deployment, Executor, requests +from jina.helper import random_port + + +class PortGetter: + def __init__(self): + self.ports = { + "http": { + True: random_port(), + False: random_port(), + }, + "grpc": { + True: random_port(), + False: random_port(), + }, + } + + def get_port(self, protocol: Literal["http", "grpc"], include_gateway: bool) -> int: + return self.ports[protocol][include_gateway] + + @property + def gateway_ports(self) -> List[int]: + return [self.ports["http"][True], self.ports["grpc"][True]] + + @property + def no_gateway_ports(self) -> List[int]: + return [self.ports["http"][False], self.ports["grpc"][False]] + + +@pytest.fixture(scope='module') +def port_getter() -> callable: + getter = PortGetter() + return getter + + +class DictDoc(BaseDoc): + data: dict + + +class MetadataExecutor(Executor): + @requests(on="/get-metadata-headers") + def post_endpoint( + self, + docs: DocList[DictDoc], + parameters: Optional[Dict] = None, + metadata: Optional[Dict] = None, + **kwargs, + ) -> DocList[DictDoc]: + return DocList[DictDoc]([DictDoc(data=metadata)]) + + @requests(on='/stream-metadata-headers') + async def stream_task( + self, doc: DictDoc, metadata: Optional[dict] = None, **kwargs + ) -> DictDoc: + for k, v in sorted((metadata or {}).items()): + yield DictDoc(data={k: v}) + + yield DictDoc(data={"DONE": "true"}) + + +@pytest.fixture(scope='module') +def deployment_no_gateway(port_getter: PortGetter) -> Deployment: + + with Deployment( + uses=MetadataExecutor, + protocol=["http", "grpc"], + port=port_getter.no_gateway_ports, + include_gateway=False, + ) as dep: + yield dep + + +@pytest.fixture(scope='module') +def deployment_gateway(port_getter: PortGetter) -> Deployment: + + with Deployment( + uses=MetadataExecutor, + protocol=["http", "grpc"], + port=port_getter.gateway_ports, + include_gateway=False, + ) as dep: + yield dep + + +@pytest.fixture(scope='module') +def deployments(deployment_gateway, deployment_no_gateway) -> Dict[bool, Deployment]: + return { + True: deployment_gateway, + False: deployment_no_gateway, + } + + +@pytest.mark.parametrize('include_gateway', [False, True]) +def test_headers_in_http_metadata( + include_gateway, port_getter: PortGetter, deployments +): + port = port_getter.get_port("http", include_gateway) + data = { + "data": [{"text": "test"}], + "parameters": { + "parameter1": "value1", + }, + } + logging.info(f"Posting to {port}") + client = Client(port=port, protocol="http") + resp = client.post( + on=f'/get-metadata-headers', + inputs=DocList([DictDoc(data=data)]), + headers={ + "header1": "value1", + "header2": "value2", + }, + return_type=DocList[DictDoc], + ) + assert resp[0].data['header1'] == 'value1' + + +@pytest.mark.asyncio +@pytest.mark.parametrize('include_gateway', [False, True]) +async def test_headers_in_http_metadata_streaming( + include_gateway, port_getter: PortGetter, deployments +): + client = Client( + port=port_getter.get_port("http", include_gateway), + protocol="http", + asyncio=True, + ) + data = {"data": [{"text": "test"}], "parameters": {"parameter1": "value1"}} + chunks = [] + + async for doc in client.stream_doc( + on=f'/stream-metadata-headers', + inputs=DictDoc(data=data), + headers={ + "header1": "value1", + "header2": "value2", + }, + return_type=DictDoc, + ): + chunks.append(doc) + assert len(chunks) > 2 + + assert DictDoc(data={'header1': 'value1'}) in chunks + assert DictDoc(data={'header2': 'value2'}) in chunks