diff --git a/jina/constants.py b/jina/constants.py index 1ccc5d18ed43a..9866bc31f236f 100644 --- a/jina/constants.py +++ b/jina/constants.py @@ -1,7 +1,7 @@ +import datetime as _datetime import os as _os import sys as _sys from pathlib import Path as _Path -import datetime as _datetime __windows__ = _sys.platform == 'win32' __uptime__ = _datetime.datetime.now().isoformat() @@ -53,6 +53,7 @@ __args_executor_func__ = { 'docs', 'parameters', + 'headers', 'docs_matrix', } __args_executor_init__ = {'metas', 'requests', 'runtime_args'} diff --git a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py index c4153ec3480fc..7dc513e7fd855 100644 --- a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py +++ b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py @@ -187,8 +187,10 @@ def add_post_route( ) app_kwargs['response_class'] = DocArrayResponse + from fastapi import Request + @app.api_route(**app_kwargs) - async def post(body: input_model, response: Response): + async def post(body: input_model, response: Response, request: Request): target_executor = None req_id = None if body.header is not None: @@ -208,6 +210,7 @@ async def post(body: input_model, response: Response): docs, exec_endpoint=endpoint_path, parameters=body.parameters, + headers=request.headers, target_executor=target_executor, request_id=req_id, return_results=True, diff --git a/jina/serve/runtimes/gateway/streamer.py b/jina/serve/runtimes/gateway/streamer.py index 959438254fed9..3f26237806b61 100644 --- a/jina/serve/runtimes/gateway/streamer.py +++ b/jina/serve/runtimes/gateway/streamer.py @@ -6,6 +6,7 @@ AsyncIterator, Dict, List, + Mapping, Optional, Sequence, Tuple, @@ -209,6 +210,7 @@ async def stream( exec_endpoint: Optional[str] = None, target_executor: Optional[str] = None, parameters: Optional[Dict] = None, + headers: Optional[Mapping[str, str]] = None, results_in_order: bool = False, return_type: Type[DocumentArray] = DocumentArray, ) -> AsyncIterator[Tuple[Union[DocumentArray, 'Request'], 'ExecutorError']]: @@ -221,6 +223,7 @@ async def stream( :param exec_endpoint: The Executor endpoint to which to send the Documents :param target_executor: A regex expression indicating the Executors that should receive the Request :param parameters: Parameters to be attached to the Requests + :param headers: Http request headers :param results_in_order: return the results in the same order as the request_iterator :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. :yield: tuple of Documents or Responses and unpacked error from Executors if any @@ -232,6 +235,7 @@ async def stream( exec_endpoint=exec_endpoint, target_executor=target_executor, parameters=parameters, + headers=headers, results_in_order=results_in_order, return_type=return_type, ): @@ -256,6 +260,7 @@ async def stream_doc( exec_endpoint: Optional[str] = None, target_executor: Optional[str] = None, parameters: Optional[Dict] = None, + headers: Optional[Mapping[str, str]] = None, request_id: Optional[str] = None, return_type: Type[DocumentArray] = DocumentArray, ) -> AsyncIterator[Tuple[Union[DocumentArray, 'Request'], 'ExecutorError']]: @@ -267,6 +272,7 @@ async def stream_doc( :param exec_endpoint: The Executor endpoint to which to send the Documents :param target_executor: A regex expression indicating the Executors that should receive the Request :param parameters: Parameters to be attached to the Requests + :param headers: Http request headers :param request_id: Request ID to add to the request streamed to Executor. Only applicable if request_size is equal or less to the length of the docs :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. :yield: tuple of Documents or Responses and unpacked error from Executors if any @@ -282,6 +288,8 @@ async def stream_doc( req.header.target_executor = target_executor if parameters: req.parameters = parameters + if headers: + req.headers = headers async for result in self.rpc_stream_doc(request=req, return_type=return_type): error = None @@ -306,6 +314,7 @@ async def stream_docs( exec_endpoint: Optional[str] = None, target_executor: Optional[str] = None, parameters: Optional[Dict] = None, + headers: Optional[Mapping[str, str]] = None, results_in_order: bool = False, request_id: Optional[str] = None, return_type: Type[DocumentArray] = DocumentArray, @@ -319,6 +328,7 @@ async def stream_docs( :param exec_endpoint: The Executor endpoint to which to send the Documents :param target_executor: A regex expression indicating the Executors that should receive the Request :param parameters: Parameters to be attached to the Requests + :param headers: Http request headers :param results_in_order: return the results in the same order as the request_iterator :param request_id: Request ID to add to the request streamed to Executor. Only applicable if request_size is equal or less to the length of the docs :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. @@ -339,6 +349,8 @@ def _req_generator(): req.header.target_executor = target_executor if parameters: req.parameters = parameters + if headers: + req.headers = headers yield req else: from docarray import BaseDoc @@ -361,6 +373,8 @@ def batch(iterable, n=1): req.header.target_executor = target_executor if parameters: req.parameters = parameters + if headers: + req.headers = headers yield req else: req = DataRequest() @@ -374,6 +388,8 @@ def batch(iterable, n=1): req.header.target_executor = target_executor if parameters: req.parameters = parameters + if headers: + req.headers = headers yield req async for resp in self.rpc_stream( @@ -438,6 +454,7 @@ async def post( request_size: int = 100, on: Optional[str] = None, parameters: Optional[Dict] = None, + headers: Optional[Mapping[str, str]] = None, return_type: Type[DocumentArray] = DocumentArray, **kwargs, ): @@ -505,6 +522,7 @@ async def stream_doc( inputs: 'Document', on: Optional[str] = None, parameters: Optional[Dict] = None, + headers: Optional[Mapping[str, str]] = None, **kwargs, ): req: SingleDocumentRequest = SingleDocumentRequest(inputs.to_protobuf()) diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index 47006dd4be329..301ad7cfdbc1c 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -86,8 +86,10 @@ def add_post_route( app_kwargs['response_class'] = DocArrayResponse + from fastapi import Request + @app.api_route(**app_kwargs) - async def post(body: input_model, response: Response): + async def post(body: input_model, response: Response, request: Request): req = DataRequest() if body.header is not None: @@ -95,6 +97,7 @@ async def post(body: input_model, response: Response): if body.parameters is not None: req.parameters = body.parameters + req.headers = request.headers req.header.exec_endpoint = endpoint_path data = body.data if isinstance(data, list): @@ -149,6 +152,7 @@ async def streaming_get(request: Request = None, body: input_doc_model = None): body = Document.from_pydantic_model(body) req = DataRequest() req.header.exec_endpoint = endpoint_path + req.headers = request.headers if not docarray_v2: req.data.docs = DocumentArray([body]) else: diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 0849aaebb388d..5bad985d876d2 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -51,16 +51,16 @@ class WorkerRequestHandler: _KEY_RESULT = '__results__' def __init__( - self, - args: 'argparse.Namespace', - logger: 'JinaLogger', - metrics_registry: Optional['CollectorRegistry'] = None, - tracer_provider: Optional['trace.TracerProvider'] = None, - meter_provider: Optional['metrics.MeterProvider'] = None, - meter=None, - tracer=None, - deployment_name: str = '', - **kwargs, + self, + args: 'argparse.Namespace', + logger: 'JinaLogger', + metrics_registry: Optional['CollectorRegistry'] = None, + tracer_provider: Optional['trace.TracerProvider'] = None, + meter_provider: Optional['metrics.MeterProvider'] = None, + meter=None, + tracer=None, + deployment_name: str = '', + **kwargs, ): """Initialize private parameters and execute private loading functions. @@ -83,8 +83,8 @@ def __init__( self._is_closed = False if self.metrics_registry: with ImportExtensions( - required=True, - help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', + required=True, + help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Counter, Summary @@ -178,6 +178,7 @@ def call_handle(request): ] return self.process_single_data(request, None, is_generator=is_generator) + app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs ) @@ -228,9 +229,9 @@ async def _hot_reload(self): watched_files.add(extra_python_file) with ImportExtensions( - required=True, - logger=self.logger, - help_text='''hot reload requires watchfiles dependency to be installed. You can do `pip install + required=True, + logger=self.logger, + help_text='''hot reload requires watchfiles dependency to be installed. You can do `pip install watchfiles''', ): from watchfiles import awatch @@ -297,14 +298,14 @@ def _init_batchqueue_dict(self): } def _init_monitoring( - self, - metrics_registry: Optional['CollectorRegistry'] = None, - meter: Optional['metrics.Meter'] = None, + self, + metrics_registry: Optional['CollectorRegistry'] = None, + meter: Optional['metrics.Meter'] = None, ): if metrics_registry: with ImportExtensions( - required=True, - help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', + required=True, + help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Counter, Summary @@ -360,10 +361,10 @@ def _init_monitoring( self._sent_response_size_histogram = None def _load_executor( - self, - metrics_registry: Optional['CollectorRegistry'] = None, - tracer_provider: Optional['trace.TracerProvider'] = None, - meter_provider: Optional['metrics.MeterProvider'] = None, + self, + metrics_registry: Optional['CollectorRegistry'] = None, + tracer_provider: Optional['trace.TracerProvider'] = None, + meter_provider: Optional['metrics.MeterProvider'] = None, ): """ Load the executor to this runtime, specified by ``uses`` CLI argument. @@ -577,8 +578,8 @@ def _setup_req_doc_array_cls(self, requests, exec_endpoint, is_response=False): req.document_array_cls = DocumentArray else: if ( - not endpoint_info.is_generator - and not endpoint_info.is_singleton_doc + not endpoint_info.is_generator + and not endpoint_info.is_singleton_doc ): req.document_array_cls = ( endpoint_info.request_schema @@ -595,9 +596,9 @@ def _setup_req_doc_array_cls(self, requests, exec_endpoint, is_response=False): pass def _setup_requests( - self, - requests: List['DataRequest'], - exec_endpoint: str, + self, + requests: List['DataRequest'], + exec_endpoint: str, ): """Execute a request using the executor. @@ -613,7 +614,7 @@ def _setup_requests( return requests, params async def handle_generator( - self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None + self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None ) -> Generator: """Prepares and executes a request for generator endpoints. @@ -642,13 +643,14 @@ async def handle_generator( req_endpoint=exec_endpoint, doc=doc, parameters=params, + headers=requests[0].headers, docs_matrix=docs_matrix, docs_map=docs_map, tracing_context=tracing_context, ) async def handle( - self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None + self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None ) -> DataRequest: """Initialize private parameters and execute private loading functions. @@ -679,8 +681,12 @@ async def handle( if param_key not in self._batchqueue_instances[exec_endpoint]: self._batchqueue_instances[exec_endpoint][param_key] = BatchQueue( functools.partial(self._executor.__acall__, exec_endpoint), - request_docarray_cls=self._executor.requests[exec_endpoint].request_schema, - response_docarray_cls=self._executor.requests[exec_endpoint].response_schema, + request_docarray_cls=self._executor.requests[ + exec_endpoint + ].request_schema, + response_docarray_cls=self._executor.requests[ + exec_endpoint + ].response_schema, output_array_type=self.args.output_array_type, params=params, **self._batchqueue_config[exec_endpoint], @@ -702,6 +708,7 @@ async def handle( req_endpoint=exec_endpoint, docs=docs, parameters=params, + headers=requests[0].headers, docs_matrix=docs_matrix, docs_map=docs_map, tracing_context=tracing_context, @@ -722,7 +729,7 @@ async def handle( @staticmethod def replace_docs( - request: List['DataRequest'], docs: 'DocumentArray', ndarray_type: str = None + request: List['DataRequest'], docs: 'DocumentArray', ndarray_type: str = None ) -> None: """Replaces the docs in a message with new Documents. @@ -770,7 +777,7 @@ async def close(self): @staticmethod def _get_docs_matrix_from_request( - requests: List['DataRequest'], + requests: List['DataRequest'], ) -> Tuple[Optional[List['DocumentArray']], Optional[Dict[str, 'DocumentArray']]]: """ Returns a docs matrix from a list of DataRequest objects. @@ -794,7 +801,7 @@ def _get_docs_matrix_from_request( @staticmethod def get_parameters_dict_from_request( - requests: List['DataRequest'], + requests: List['DataRequest'], ) -> 'Dict': """ Returns a parameters dict from a list of DataRequest objects. @@ -814,7 +821,7 @@ def get_parameters_dict_from_request( @staticmethod def get_docs_from_request( - requests: List['DataRequest'], + requests: List['DataRequest'], ) -> 'DocumentArray': """ Gets a field from the message @@ -894,7 +901,7 @@ def reduce_requests(requests: List['DataRequest']) -> 'DataRequest': # serving part async def process_single_data( - self, request: DataRequest, context, is_generator: bool = False + self, request: DataRequest, context, is_generator: bool = False ) -> DataRequest: """ Process the received requests and return the result as a new request @@ -908,7 +915,7 @@ async def process_single_data( return await self.process_data([request], context, is_generator=is_generator) async def stream_doc( - self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext' + self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext' ) -> SingleDocumentRequest: """ Process the received requests and return the result as a new request, used for streaming behavior, one doc IN, several out @@ -1034,7 +1041,7 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: return endpoints_proto def _extract_tracing_context( - self, metadata: 'grpc.aio.Metadata' + self, metadata: 'grpc.aio.Metadata' ) -> Optional['Context']: if self.tracer: from opentelemetry.propagate import extract @@ -1045,7 +1052,7 @@ def _extract_tracing_context( return None async def process_data( - self, requests: List[DataRequest], context, is_generator: bool = False + self, requests: List[DataRequest], context, is_generator: bool = False ) -> DataRequest: """ Process the received requests and return the result as a new request @@ -1057,7 +1064,7 @@ async def process_data( """ self.logger.debug('recv a process_data request') with MetricsTimer( - self._summary, self._receiving_request_seconds, self._metric_attributes + self._summary, self._receiving_request_seconds, self._metric_attributes ): try: if self.logger.debug_enabled: @@ -1113,8 +1120,8 @@ async def process_data( ) if ( - self.args.exit_on_exceptions - and type(ex).__name__ in self.args.exit_on_exceptions + self.args.exit_on_exceptions + and type(ex).__name__ in self.args.exit_on_exceptions ): self.logger.info('Exiting because of "--exit-on-exceptions".') raise RuntimeTerminated @@ -1138,7 +1145,7 @@ async def _status(self, empty, context) -> jina_pb2.JinaInfoProto: return info_proto async def stream( - self, request_iterator, context=None, *args, **kwargs + self, request_iterator, context=None, *args, **kwargs ) -> AsyncIterator['Request']: """ stream requests from client iterator and stream responses back. @@ -1156,8 +1163,8 @@ async def stream( Call = stream def _create_snapshot_status( - self, - snapshot_directory: str, + self, + snapshot_directory: str, ) -> 'jina_pb2.SnapshotStatusProto': _id = str(uuid.uuid4()) self.logger.debug(f'Generated snapshot id: {_id}') @@ -1170,7 +1177,7 @@ def _create_snapshot_status( ) def _create_restore_status( - self, + self, ) -> 'jina_pb2.SnapshotStatusProto': _id = str(uuid.uuid4()) self.logger.debug(f'Generated restore id: {_id}') @@ -1189,9 +1196,9 @@ async def snapshot(self, request, context) -> 'jina_pb2.SnapshotStatusProto': """ self.logger.debug('Calling snapshot') if ( - self._snapshot - and self._snapshot_thread - and self._snapshot_thread.is_alive() + self._snapshot + and self._snapshot_thread + and self._snapshot_thread.is_alive() ): raise RuntimeError( f'A snapshot with id {self._snapshot.id.value} is currently in progress. Cannot start another.' @@ -1209,7 +1216,7 @@ async def snapshot(self, request, context) -> 'jina_pb2.SnapshotStatusProto': return self._snapshot async def snapshot_status( - self, request: 'jina_pb2.SnapshotId', context + self, request: 'jina_pb2.SnapshotId', context ) -> 'jina_pb2.SnapshotStatusProto': """ method to start a snapshot process of the Executor @@ -1271,7 +1278,7 @@ async def restore(self, request: 'jina_pb2.RestoreSnapshotCommand', context): return self._restore async def restore_status( - self, request, context + self, request, context ) -> 'jina_pb2.RestoreSnapshotStatusProto': """ method to start a snapshot process of the Executor diff --git a/tests/integration/docarray_v2/test_request_headers.py b/tests/integration/docarray_v2/test_request_headers.py new file mode 100644 index 0000000000000..942ba6b218732 --- /dev/null +++ b/tests/integration/docarray_v2/test_request_headers.py @@ -0,0 +1,148 @@ +import logging +from typing import Dict, List, Literal, Optional + +import pytest +from docarray import BaseDoc, DocList + +from jina import Client, Deployment, Executor, requests +from jina.helper import random_port + + +class PortGetter: + def __init__(self): + self.ports = { + "http": { + True: random_port(), + False: random_port(), + }, + "grpc": { + True: random_port(), + False: random_port(), + }, + } + + def get_port(self, protocol: Literal["http", "grpc"], include_gateway: bool) -> int: + return self.ports[protocol][include_gateway] + + @property + def gateway_ports(self) -> List[int]: + return [self.ports["http"][True], self.ports["grpc"][True]] + + @property + def no_gateway_ports(self) -> List[int]: + return [self.ports["http"][False], self.ports["grpc"][False]] + + +@pytest.fixture(scope='module') +def port_getter() -> callable: + getter = PortGetter() + return getter + + +class DictDoc(BaseDoc): + data: dict + + +class HeaderExecutor(Executor): + @requests(on="/get-headers") + def post_endpoint( + self, + docs: DocList[DictDoc], + parameters: Optional[Dict] = None, + headers: Optional[Dict] = None, + **kwargs, + ) -> DocList[DictDoc]: + return DocList[DictDoc]([DictDoc(data=headers)]) + + @requests(on='/stream-headers') + async def stream_task( + self, doc: DictDoc, headers: Optional[dict] = None, **kwargs + ) -> DictDoc: + for k, v in sorted((headers or {}).items()): + yield DictDoc(data={k: v}) + + yield DictDoc(data={"DONE": "true"}) + + +@pytest.fixture(scope='module') +def deployment_no_gateway(port_getter: PortGetter) -> Deployment: + + with Deployment( + uses=HeaderExecutor, + protocol=["http", "grpc"], + port=port_getter.no_gateway_ports, + include_gateway=False, + ) as dep: + yield dep + + +@pytest.fixture(scope='module') +def deployment_gateway(port_getter: PortGetter) -> Deployment: + + with Deployment( + uses=HeaderExecutor, + protocol=["http", "grpc"], + port=port_getter.gateway_ports, + include_gateway=False, + ) as dep: + yield dep + + +@pytest.fixture(scope='module') +def deployments(deployment_gateway, deployment_no_gateway) -> Dict[bool, Deployment]: + return { + True: deployment_gateway, + False: deployment_no_gateway, + } + + +@pytest.mark.parametrize('include_gateway', [False, True]) +def test_headers_in_http_headers(include_gateway, port_getter: PortGetter, deployments): + port = port_getter.get_port("http", include_gateway) + data = { + "data": [{"text": "test"}], + "parameters": { + "parameter1": "value1", + }, + } + logging.info(f"Posting to {port}") + client = Client(port=port, protocol="http") + resp = client.post( + on=f'/get-headers', + inputs=DocList([DictDoc(data=data)]), + headers={ + "header1": "value1", + "header2": "value2", + }, + return_type=DocList[DictDoc], + ) + assert resp[0].data['header1'] == 'value1' + + +@pytest.mark.asyncio +@pytest.mark.parametrize('include_gateway', [False, True]) +async def test_headers_in_http_headers_streaming( + include_gateway, port_getter: PortGetter, deployments +): + client = Client( + port=port_getter.get_port("http", include_gateway), + protocol="http", + asyncio=True, + ) + data = {"data": [{"text": "test"}], "parameters": {"parameter1": "value1"}} + chunks = [] + + async for doc in client.stream_doc( + on=f'/stream-headers', + inputs=DictDoc(data=data), + headers={ + "header1": "value1", + "header2": "value2", + }, + return_type=DictDoc, + ): + chunks.append(doc) + assert len(chunks) > 2 + + assert DictDoc(data={'header1': 'value1'}) in chunks + assert DictDoc(data={'header2': 'value2'}) in chunks