diff --git a/docs/concepts/serving/gateway/customization.md b/docs/concepts/serving/gateway/customization.md index cd0e295b01da4..99629b8bc62fa 100644 --- a/docs/concepts/serving/gateway/customization.md +++ b/docs/concepts/serving/gateway/customization.md @@ -240,6 +240,13 @@ class MyGateway(FastAPIBaseGateway): return app ``` +```{hint} +:class: note +if you omit the `return_type` parameter, the gateway streamer can still fetch the Executor output schemas and dynamically construct a DocArray model for it. +Even though the dynamically created schema is very similar to original schema, some validation checks can still fail (for instance adding to a typed `DocList`). +It is recommended to always pass the `return_type` parameter +``` + ### Recovering Executor errors Exceptions raised by an `Executor` are captured in the server object which can be extracted by using the {meth}`jina.serve.streamer.stream()` method. The `stream` method @@ -266,6 +273,14 @@ async def get(text: str): return {'results': results, 'errors': [error.name for error in errors]} ``` + +```{hint} +:class: note +if you omit the `return_type` parameter, the gateway streamer can still fetch the Executor output schemas and dynamically construct a DocArray model for it. +Even though the dynamically created schema is very similar to original schema, some validation checks can still fail (for instance adding to a typed `DocList`). +It is recommended to always pass the `return_type` parameter +``` + (executor-streamer)= ## Calling an individual Executor diff --git a/jina/clients/base/http.py b/jina/clients/base/http.py index 1885f243acd40..44715ce732d6a 100644 --- a/jina/clients/base/http.py +++ b/jina/clients/base/http.py @@ -5,10 +5,7 @@ from jina._docarray import Document, DocumentArray, docarray_v2 from jina.clients.base import BaseClient -from jina.clients.base.helper import ( - HTTPClientlet, - handle_response_status, -) +from jina.clients.base.helper import HTTPClientlet, handle_response_status from jina.clients.helper import callback_exec from jina.importer import ImportExtensions from jina.logging.profile import ProgressBar @@ -165,12 +162,13 @@ async def _get_results( ) def _request_handler( - request: 'Request', + request: 'Request', **kwargs ) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]': """ For HTTP Client, for each request in the iterator, we `send_message` using http POST request and add it to the list of tasks which is awaited and yielded. :param request: current request in the iterator + :param kwargs: kwargs :return: asyncio Task for sending message """ return asyncio.ensure_future(iolet.send_message(request=request)), None @@ -203,11 +201,15 @@ def _result_handler(result): da = DocumentArray.from_dict(r_str['data']) else: from docarray import DocList + if issubclass(return_type, DocList): da = return_type( - [return_type.doc_type(**v) for v in r_str['data']]) + [return_type.doc_type(**v) for v in r_str['data']] + ) else: - da = DocList[return_type]([return_type(**v) for v in r_str['data']]) + da = DocList[return_type]( + [return_type(**v) for v in r_str['data']] + ) del r_str['data'] resp = DataRequest(r_str) diff --git a/jina/clients/base/unary_rpc.py b/jina/clients/base/unary_rpc.py index de55f53155035..6cb219706738c 100644 --- a/jina/clients/base/unary_rpc.py +++ b/jina/clients/base/unary_rpc.py @@ -67,7 +67,7 @@ async def unary_rpc_with_retry(self): stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(self.channel) def _request_handler( - request: 'Request', + request: 'Request', **kwargs ) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]': async def _with_retry(req: 'Request'): for attempt in range(1, self.max_attempts + 1): diff --git a/jina/clients/base/websocket.py b/jina/clients/base/websocket.py index ac6158fff609f..399a0b33a38ac 100644 --- a/jina/clients/base/websocket.py +++ b/jina/clients/base/websocket.py @@ -166,7 +166,7 @@ def _handle_end_of_iter(): asyncio.create_task(iolet.send_eoi()) def _request_handler( - request: 'Request', + request: 'Request', **kwargs ) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]': """ For each request in the iterator, we send the `Message` using `iolet.send_message()`. @@ -175,6 +175,7 @@ def _request_handler( Then add {: } to the request buffer. This empty future is used to track the `result` of this request during `receive`. :param request: current request in the iterator + :param kwargs: kwargs :return: asyncio Future for sending message """ future = get_or_reuse_loop().create_future() diff --git a/jina/serve/runtimes/gateway/async_request_response_handling.py b/jina/serve/runtimes/gateway/async_request_response_handling.py index e662caf3a007e..aa76202d45a0f 100644 --- a/jina/serve/runtimes/gateway/async_request_response_handling.py +++ b/jina/serve/runtimes/gateway/async_request_response_handling.py @@ -1,10 +1,10 @@ import asyncio import copy -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, AsyncGenerator +from typing import TYPE_CHECKING, AsyncGenerator, Callable, List, Optional, Tuple, Type import grpc.aio -from jina._docarray import DocumentArray +from jina._docarray import DocumentArray, docarray_v2 from jina.excepts import InternalNetworkError from jina.helper import GATEWAY_NAME from jina.logging.logger import JinaLogger @@ -13,7 +13,6 @@ from jina.serve.runtimes.helper import _is_param_for_specific_executor from jina.serve.runtimes.monitoring import MonitoringRequestMixin from jina.serve.runtimes.worker.request_handling import WorkerRequestHandler -from jina._docarray import docarray_v2 if TYPE_CHECKING: # pragma: no cover from asyncio import Future @@ -33,11 +32,11 @@ class AsyncRequestResponseHandler(MonitoringRequestMixin): """ def __init__( - self, - metrics_registry: Optional['CollectorRegistry'] = None, - meter: Optional['Meter'] = None, - runtime_name: Optional[str] = None, - logger: Optional[JinaLogger] = None, + self, + metrics_registry: Optional['CollectorRegistry'] = None, + meter: Optional['Meter'] = None, + runtime_name: Optional[str] = None, + logger: Optional[JinaLogger] = None, ): super().__init__(metrics_registry, meter, runtime_name) self._endpoint_discovery_finished = False @@ -45,7 +44,7 @@ def __init__( self.logger = logger or JinaLogger(self.__class__.__name__) def handle_request( - self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool' + self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool' ) -> Callable[['Request'], 'Tuple[Future, Optional[Future]]']: """ Function that handles the requests arriving to the gateway. This will be passed to the streamer. @@ -64,8 +63,8 @@ async def gather_endpoints(request_graph): err_code = err.code() if err_code == grpc.StatusCode.UNAVAILABLE: err._details = ( - err.details() - + f' |Gateway: Communication error while gathering endpoints with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.' + err.details() + + f' |Gateway: Communication error while gathering endpoints with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.' ) raise err else: @@ -75,7 +74,9 @@ async def gather_endpoints(request_graph): raise exc self._endpoint_discovery_finished = True - def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]': + def _handle_request( + request: 'Request', return_type: Type[DocumentArray] + ) -> 'Tuple[Future, Optional[Future]]': self._update_start_request_metrics(request) # important that the gateway needs to have an instance of the graph per request request_graph = copy.deepcopy(graph) @@ -100,11 +101,10 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]': request.header.target_executor = '' exec_endpoint = request.header.exec_endpoint gather_endpoints_task = None - if ( - not self._endpoint_discovery_finished - and not self._gathering_endpoints - ): - gather_endpoints_task = asyncio.create_task(gather_endpoints(request_graph)) + if not self._endpoint_discovery_finished and not self._gathering_endpoints: + gather_endpoints_task = asyncio.create_task( + gather_endpoints(request_graph) + ) init_task = None request_doc_ids = [] @@ -112,8 +112,8 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]': if graph.has_filter_conditions: if not docarray_v2: request_doc_ids = request.data.docs[ - :, 'id' - ] # used to maintain order of docs that are filtered by executors + :, 'id' + ] # used to maintain order of docs that are filtered by executors else: init_task = gather_endpoints_task from docarray import DocList @@ -137,7 +137,8 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]': request_input_parameters=request_input_parameters, request_input_has_specific_params=has_specific_params, copy_request_at_send=num_outgoing_nodes > 1 and has_specific_params, - init_task=init_task + init_task=init_task, + return_type=return_type, ) # Every origin node returns a set of tasks that are the ones corresponding to the leafs of each of their # subtrees that unwrap all the previous tasks. It starts like a chain of waiting for tasks from previous @@ -157,7 +158,7 @@ def sort_by_request_order(doc): response.data.docs = DocumentArray(sorted_docs) async def _process_results_at_end_gateway( - tasks: List[asyncio.Task], request_graph: TopologyGraph + tasks: List[asyncio.Task], request_graph: TopologyGraph ) -> asyncio.Future: try: partial_responses = await asyncio.gather(*tasks) @@ -209,7 +210,7 @@ async def _process_results_at_end_gateway( def handle_single_document_request( self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool' - ) -> Callable[['Request'], 'AsyncGenerator']: + ) -> Callable[['Request', Type[DocumentArray]], 'AsyncGenerator']: """ Function that handles the requests arriving to the gateway. This will be passed to the streamer. @@ -217,7 +218,10 @@ def handle_single_document_request( :param connection_pool: The connection pool to be used to send messages to specific nodes of the graph :return: Return a Function that given a Request will return a Future from where to extract the response """ - async def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]': + + async def _handle_request( + request: 'Request', return_type: Type[DocumentArray] = DocumentArray + ) -> 'Tuple[Future, Optional[Future]]': self._update_start_request_metrics(request) # important that the gateway needs to have an instance of the graph per request request_graph = copy.deepcopy(graph) @@ -229,10 +233,15 @@ async def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future] # reset it in case we send to an external gateway exec_endpoint = request.header.exec_endpoint - node = request_graph.all_nodes[0] # this assumes there is only one Executor behind this Gateway - async for resp in node.stream_single_doc(request=request, - connection_pool=connection_pool, - endpoint=exec_endpoint): + node = request_graph.all_nodes[ + 0 + ] # this assumes there is only one Executor behind this Gateway + async for resp in node.stream_single_doc( + request=request, + connection_pool=connection_pool, + endpoint=exec_endpoint, + return_type=return_type, + ): yield resp return _handle_request diff --git a/jina/serve/runtimes/gateway/graph/topology_graph.py b/jina/serve/runtimes/gateway/graph/topology_graph.py index b389e8e0f0774..942737149c9b7 100644 --- a/jina/serve/runtimes/gateway/graph/topology_graph.py +++ b/jina/serve/runtimes/gateway/graph/topology_graph.py @@ -3,11 +3,11 @@ import re from collections import defaultdict from datetime import datetime -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Type import grpc.aio -from jina._docarray import docarray_v2 +from jina._docarray import DocumentArray, docarray_v2 from jina.constants import __default_endpoint__ from jina.excepts import InternalNetworkError from jina.logging.logger import JinaLogger @@ -315,6 +315,7 @@ async def stream_single_doc( request: SingleDocumentRequest, connection_pool: GrpcConnectionPool, endpoint: Optional[str], + return_type: Type[DocumentArray] = DocumentArray, ): if docarray_v2: if self.endpoints and endpoint in self.endpoints: @@ -336,9 +337,20 @@ async def stream_single_doc( else: if docarray_v2: if self.endpoints and endpoint in self.endpoints: - resp.document_cls = self._pydantic_models_by_endpoint[ - endpoint - ]['output'] + from docarray.base_doc import AnyDoc + + # if return_type is not specified or if it is a default type, cast using retrieved + # schemas + if ( + not return_type + or not return_type.doc_type + or return_type.doc_type is AnyDoc + ): + resp.document_cls = self._pydantic_models_by_endpoint[ + endpoint + ]['output'] + else: + resp.document_array_cls = return_type yield resp async def _wait_previous_and_send( @@ -351,6 +363,7 @@ async def _wait_previous_and_send( request_input_parameters: Dict = {}, copy_request_at_send: bool = False, init_task: Optional[asyncio.Task] = None, + return_type: Type[DocumentArray] = None, ): # Check my condition and send request with the condition metadata = {} @@ -419,12 +432,30 @@ async def _wait_previous_and_send( resp, metadata = result if docarray_v2: - if self.endpoints and endpoint in self.endpoints: - resp.document_array_cls = DocList[ - self._pydantic_models_by_endpoint[endpoint][ - 'output' + if self.endpoints and ( + endpoint in self.endpoints + or __default_endpoint__ in self.endpoints + ): + from docarray.base_doc import AnyDoc + + # if return_type is not specified or if it is a default type, cast using retrieved + # schemas + if ( + not return_type + or not return_type.doc_type + or return_type.doc_type is AnyDoc + ): + pydantic_models = ( + self._pydantic_models_by_endpoint.get(endpoint) + or self._pydantic_models_by_endpoint.get( + __default_endpoint__ + ) + ) + resp.document_array_cls = DocList[ + pydantic_models['output'] ] - ] + else: + resp.document_array_cls = return_type if WorkerRequestHandler._KEY_RESULT in resp.parameters: # Accumulate results from each Node and then add them to the original @@ -570,6 +601,7 @@ def get_leaf_req_response_tasks( request_input_has_specific_params: bool = False, copy_request_at_send: bool = False, init_task: Optional[asyncio.Task] = None, + return_type: Type[DocumentArray] = DocumentArray, ) -> List[Tuple[bool, asyncio.Task]]: """ Gets all the tasks corresponding from all the subgraphs born from this node @@ -602,6 +634,7 @@ def get_leaf_req_response_tasks( When the caller of these methods await them, they will fire the logic of sending requests and responses from and to every deployment + :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. :return: Return a list of tuples, where tasks corresponding to the leafs of all the subgraphs born from this node are in each tuple. These tasks will be based on awaiting for the task from previous_node and sending a request to the corresponding node. The other member of the pair is a flag indicating if the task is to be awaited by the gateway or not. @@ -616,6 +649,7 @@ def get_leaf_req_response_tasks( request_input_parameters=request_input_parameters, copy_request_at_send=copy_request_at_send, init_task=init_task, + return_type=return_type, ) ) if self.leaf: # I am like a leaf @@ -635,6 +669,7 @@ def get_leaf_req_response_tasks( request_input_has_specific_params=request_input_has_specific_params, copy_request_at_send=num_outgoing_nodes > 1 and request_input_has_specific_params, + return_type=return_type, ) # We are interested in the last one, that will be the task that awaits all the previous hanging_tasks_tuples.extend(t) diff --git a/jina/serve/runtimes/gateway/streamer.py b/jina/serve/runtimes/gateway/streamer.py index 57acefddcdff3..d024f5a07bc06 100644 --- a/jina/serve/runtimes/gateway/streamer.py +++ b/jina/serve/runtimes/gateway/streamer.py @@ -10,11 +10,11 @@ Optional, Sequence, Tuple, + Type, Union, - Type ) -from jina._docarray import DocumentArray, Document +from jina._docarray import Document, DocumentArray, docarray_v2 from jina.excepts import ExecutorError from jina.logging.logger import JinaLogger from jina.proto import jina_pb2 @@ -26,7 +26,6 @@ from jina.serve.stream import RequestStreamer from jina.types.request import Request from jina.types.request.data import DataRequest, SingleDocumentRequest -from jina._docarray import docarray_v2 if docarray_v2: from docarray import DocList @@ -48,23 +47,23 @@ class GatewayStreamer: """ def __init__( - self, - graph_representation: Dict, - executor_addresses: Dict[str, Union[str, List[str]]], - graph_conditions: Dict = {}, - deployments_metadata: Dict[str, Dict[str, str]] = {}, - deployments_no_reduce: List[str] = [], - timeout_send: Optional[float] = None, - retries: int = 0, - compression: Optional[str] = None, - runtime_name: str = 'custom gateway', - prefetch: int = 0, - logger: Optional['JinaLogger'] = None, - metrics_registry: Optional['CollectorRegistry'] = None, - meter: Optional['Meter'] = None, - aio_tracing_client_interceptors: Optional[Sequence['ClientInterceptor']] = None, - tracing_client_interceptor: Optional['OpenTelemetryClientInterceptor'] = None, - grpc_channel_options: Optional[list] = None, + self, + graph_representation: Dict, + executor_addresses: Dict[str, Union[str, List[str]]], + graph_conditions: Dict = {}, + deployments_metadata: Dict[str, Dict[str, str]] = {}, + deployments_no_reduce: List[str] = [], + timeout_send: Optional[float] = None, + retries: int = 0, + compression: Optional[str] = None, + runtime_name: str = 'custom gateway', + prefetch: int = 0, + logger: Optional['JinaLogger'] = None, + metrics_registry: Optional['CollectorRegistry'] = None, + meter: Optional['Meter'] = None, + aio_tracing_client_interceptors: Optional[Sequence['ClientInterceptor']] = None, + tracing_client_interceptor: Optional['OpenTelemetryClientInterceptor'] = None, + grpc_channel_options: Optional[list] = None, ): """ :param graph_representation: A dictionary describing the topology of the Deployments. 2 special nodes are expected, the name `start-gateway` and `end-gateway` to @@ -116,8 +115,11 @@ def __init__( request_handler = AsyncRequestResponseHandler( metrics_registry, meter, runtime_name, logger ) - self._single_doc_request_handler = request_handler.handle_single_document_request(graph=self.topology_graph, - connection_pool=self._connection_pool) + self._single_doc_request_handler = ( + request_handler.handle_single_document_request( + graph=self.topology_graph, connection_pool=self._connection_pool + ) + ) self._streamer = RequestStreamer( request_handler=request_handler.handle_request( graph=self.topology_graph, connection_pool=self._connection_pool @@ -130,15 +132,15 @@ def __init__( self._streamer.Call = self._streamer.stream def _create_connection_pool( - self, - deployments_addresses, - compression, - metrics_registry, - meter, - logger, - aio_tracing_client_interceptors, - tracing_client_interceptor, - grpc_channel_options=None, + self, + deployments_addresses, + compression, + metrics_registry, + meter, + logger, + aio_tracing_client_interceptors, + tracing_client_interceptor, + grpc_channel_options=None, ): # add the connections needed connection_pool = GrpcConnectionPool( @@ -188,9 +190,11 @@ async def _get_endpoints_input_output_models(self, is_cancel): # The logic should be to get the response of all the endpoints protos schemas from all the nodes. Then do a # logic that for every endpoint fom every Executor computes what is the input and output schema seen by the # Flow. - self._endpoints_models_map = await self._streamer._get_endpoints_input_output_models(self.topology_graph, - self._connection_pool, - is_cancel) + self._endpoints_models_map = ( + await self._streamer._get_endpoints_input_output_models( + self.topology_graph, self._connection_pool, is_cancel + ) + ) def _validate_flow_docarray_compatibility(self): """ @@ -199,14 +203,15 @@ def _validate_flow_docarray_compatibility(self): self.topology_graph._validate_flow_docarray_compatibility() async def stream( - self, - docs: DocumentArray, - request_size: int = 100, - return_results: bool = False, - exec_endpoint: Optional[str] = None, - target_executor: Optional[str] = None, - parameters: Optional[Dict] = None, - results_in_order: bool = False, + self, + docs: DocumentArray, + request_size: int = 100, + return_results: bool = False, + exec_endpoint: Optional[str] = None, + target_executor: Optional[str] = None, + parameters: Optional[Dict] = None, + results_in_order: bool = False, + return_type: Type[DocumentArray] = DocumentArray, ) -> AsyncIterator[Tuple[Union[DocumentArray, 'Request'], 'ExecutorError']]: """ stream Documents and yield Documents or Responses and unpacked Executor error if any. @@ -218,16 +223,18 @@ async def stream( :param target_executor: A regex expression indicating the Executors that should receive the Request :param parameters: Parameters to be attached to the Requests :param results_in_order: return the results in the same order as the request_iterator + :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. :yield: tuple of Documents or Responses and unpacked error from Executors if any """ async for result in self.stream_docs( - docs=docs, - request_size=request_size, - return_results=True, # force return Responses - exec_endpoint=exec_endpoint, - target_executor=target_executor, - parameters=parameters, - results_in_order=results_in_order, + docs=docs, + request_size=request_size, + return_results=True, # force return Responses + exec_endpoint=exec_endpoint, + target_executor=target_executor, + parameters=parameters, + results_in_order=results_in_order, + return_type=return_type, ): error = None if jina_pb2.StatusProto.ERROR == result.status.code: @@ -244,13 +251,14 @@ async def stream( yield result.data.docs, error async def stream_doc( - self, - doc: 'Document', - return_results: bool = False, - exec_endpoint: Optional[str] = None, - target_executor: Optional[str] = None, - parameters: Optional[Dict] = None, - request_id: Optional[str] = None + self, + doc: 'Document', + return_results: bool = False, + exec_endpoint: Optional[str] = None, + target_executor: Optional[str] = None, + parameters: Optional[Dict] = None, + request_id: Optional[str] = None, + return_type: Type[DocumentArray] = DocumentArray, ) -> AsyncIterator[Tuple[Union[DocumentArray, 'Request'], 'ExecutorError']]: """ stream Documents and yield Documents or Responses and unpacked Executor error if any. @@ -261,6 +269,7 @@ async def stream_doc( :param target_executor: A regex expression indicating the Executors that should receive the Request :param parameters: Parameters to be attached to the Requests :param request_id: Request ID to add to the request streamed to Executor. Only applicable if request_size is equal or less to the length of the docs + :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. :yield: tuple of Documents or Responses and unpacked error from Executors if any """ req = SingleDocumentRequest() @@ -275,9 +284,7 @@ async def stream_doc( if parameters: req.parameters = parameters - async for result in self.rpc_stream_doc( - request=req, - ): + async for result in self.rpc_stream_doc(request=req, return_type=return_type): error = None if jina_pb2.StatusProto.ERROR == result.status.code: exception = result.status.exception @@ -293,15 +300,16 @@ async def stream_doc( yield result.data.doc, error async def stream_docs( - self, - docs: DocumentArray, - request_size: int = 100, - return_results: bool = False, - exec_endpoint: Optional[str] = None, - target_executor: Optional[str] = None, - parameters: Optional[Dict] = None, - results_in_order: bool = False, - request_id: Optional[str] = None, + self, + docs: DocumentArray, + request_size: int = 100, + return_results: bool = False, + exec_endpoint: Optional[str] = None, + target_executor: Optional[str] = None, + parameters: Optional[Dict] = None, + results_in_order: bool = False, + request_id: Optional[str] = None, + return_type: Type[DocumentArray] = DocumentArray, ): """ stream documents and stream responses back. @@ -314,6 +322,7 @@ async def stream_docs( :param parameters: Parameters to be attached to the Requests :param results_in_order: return the results in the same order as the request_iterator :param request_id: Request ID to add to the request streamed to Executor. Only applicable if request_size is equal or less to the length of the docs + :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. :yield: Yields DocumentArrays or Responses from the Executors """ request_id = request_id if len(docs) <= request_size else None @@ -334,10 +343,11 @@ def _req_generator(): yield req else: from docarray import BaseDoc + def batch(iterable, n=1): l = len(iterable) for ndx in range(0, l, n): - yield iterable[ndx:min(ndx + n, l)] + yield iterable[ndx : min(ndx + n, l)] if len(docs) > 0: for docs_batch in batch(docs, n=request_size): @@ -368,7 +378,9 @@ def batch(iterable, n=1): yield req async for resp in self.rpc_stream( - request_iterator=_req_generator(), results_in_order=results_in_order + request_iterator=_req_generator(), + results_in_order=results_in_order, + return_type=return_type, ): if return_results: yield resp @@ -385,7 +397,7 @@ async def close(self): Call = rpc_stream async def process_single_data( - self, request: DataRequest, context=None + self, request: DataRequest, context=None ) -> DataRequest: """Implements request and response handling of a single DataRequest :param request: DataRequest from Client @@ -454,14 +466,16 @@ def __init__(self, connection_pool: GrpcConnectionPool, executor_name: str) -> N self.executor_name = executor_name async def post( - self, - inputs: DocumentArray, - request_size: int = 100, - on: Optional[str] = None, - parameters: Optional[Dict] = None, - return_type: Type[DocumentArray] = DocumentArray, - **kwargs, + self, + inputs: DocumentArray, + request_size: int = 100, + on: Optional[str] = None, + parameters: Optional[Dict] = None, + return_type: Type[DocumentArray] = DocumentArray, + **kwargs, ): + if not parameters: + parameters = {} if not docarray_v2: reqs = [] for docs_batch in inputs.batch(batch_size=request_size, shuffle=False): @@ -473,10 +487,11 @@ async def post( reqs.append(req) else: from docarray import BaseDoc + def batch(iterable, n=1): l = len(iterable) for ndx in range(0, l, n): - yield iterable[ndx:min(ndx + n, l)] + yield iterable[ndx : min(ndx + n, l)] reqs = [] @@ -519,11 +534,11 @@ def batch(iterable, n=1): return docs async def stream_doc( - self, - inputs: 'Document', - on: Optional[str] = None, - parameters: Optional[Dict] = None, - **kwargs, + self, + inputs: 'Document', + on: Optional[str] = None, + parameters: Optional[Dict] = None, + **kwargs, ): req: SingleDocumentRequest = SingleDocumentRequest(inputs.to_protobuf()) req.header.exec_endpoint = on diff --git a/jina/serve/stream/__init__.py b/jina/serve/stream/__init__.py index fe88f625ddb35..21685388bd4aa 100644 --- a/jina/serve/stream/__init__.py +++ b/jina/serve/stream/__init__.py @@ -7,6 +7,7 @@ Iterator, Optional, Tuple, + Type, Union, ) @@ -17,6 +18,7 @@ __all__ = ['RequestStreamer'] +from jina._docarray import DocumentArray from jina.types.request.data import Response if TYPE_CHECKING: # pragma: no cover @@ -32,16 +34,16 @@ class _EndOfStreaming: pass def __init__( - self, - request_handler: Callable[ - ['Request'], Tuple[Awaitable['Request'], Optional[Awaitable['Request']]] - ], - result_handler: Callable[['Request'], Optional['Request']], - prefetch: int = 0, - iterate_sync_in_thread: bool = True, - end_of_iter_handler: Optional[Callable[[], None]] = None, - logger: Optional['JinaLogger'] = None, - **logger_kwargs, + self, + request_handler: Callable[ + ['Request'], Tuple[Awaitable['Request'], Optional[Awaitable['Request']]] + ], + result_handler: Callable[['Request'], Optional['Request']], + prefetch: int = 0, + iterate_sync_in_thread: bool = True, + end_of_iter_handler: Optional[Callable[[], None]] = None, + logger: Optional['JinaLogger'] = None, + **logger_kwargs, ): """ :param request_handler: The callable responsible for handling the request. It should handle a request as input and return a Future to be awaited @@ -61,7 +63,9 @@ def __init__( self._iterate_sync_in_thread = iterate_sync_in_thread self.total_num_floating_tasks_alive = 0 - async def _get_endpoints_input_output_models(self, topology_graph, connection_pool, is_cancel): + async def _get_endpoints_input_output_models( + self, topology_graph, connection_pool, is_cancel + ): """ Return a Dictionary with endpoints as keys and values as a dictionary of input and output schemas and names taken from the endpoints proto endpoint of Executors @@ -77,27 +81,34 @@ async def _get_endpoints_input_output_models(self, topology_graph, connection_po # create loop and get from topology_graph _endpoints_models_map = {} self.logger.debug(f'Get all endpoints from TopologyGraph') - endpoints = await topology_graph._get_all_endpoints(connection_pool, retry_forever=True, is_cancel=is_cancel) + endpoints = await topology_graph._get_all_endpoints( + connection_pool, retry_forever=True, is_cancel=is_cancel + ) self.logger.debug(f'Got all endpoints from TopologyGraph {endpoints}') if endpoints is not None: for endp in endpoints: for origin_node in topology_graph.origin_nodes: - leaf_input_output_model = origin_node._get_leaf_input_output_model(previous_input=None, - previous_output=None, - previous_is_generator=None, - previous_is_singleton_doc=None, - previous_parameters=None, - endpoint=endp) - if leaf_input_output_model is not None and len(leaf_input_output_model) > 0: + leaf_input_output_model = origin_node._get_leaf_input_output_model( + previous_input=None, + previous_output=None, + previous_is_generator=None, + previous_is_singleton_doc=None, + previous_parameters=None, + endpoint=endp, + ) + if ( + leaf_input_output_model is not None + and len(leaf_input_output_model) > 0 + ): _endpoints_models_map[endp] = leaf_input_output_model[0] return _endpoints_models_map async def stream_doc( - self, - request, - context=None, - *args, + self, + request, + context=None, + *args, ) -> AsyncIterator['Request']: """ stream requests from client iterator and stream responses back. @@ -115,7 +126,7 @@ async def stream_doc( yield response except InternalNetworkError as err: if ( - context is not None + context is not None ): # inside GrpcGateway we can handle the error directly here through the grpc context context.set_details(err.details()) context.set_code(err.code()) @@ -137,12 +148,13 @@ async def stream_doc( raise err async def stream( - self, - request_iterator, - context=None, - results_in_order: bool = False, - prefetch: Optional[int] = None, - *args, + self, + request_iterator, + context=None, + results_in_order: bool = False, + prefetch: Optional[int] = None, + return_type: Type[DocumentArray] = DocumentArray, + *args, ) -> AsyncIterator['Request']: """ stream requests from client iterator and stream responses back. @@ -151,6 +163,7 @@ async def stream( :param context: context of the grpc call :param results_in_order: return the results in the same order as the request_iterator :param prefetch: How many Requests are processed from the Client at the same time. If not provided then the prefetch value from the metadata will be utilized. + :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. :param args: positional arguments :yield: responses from Executors """ @@ -170,12 +183,13 @@ async def stream( request_iterator=request_iterator, results_in_order=results_in_order, prefetch=prefetch, + return_type=return_type, ) async for response in async_iter: yield response except InternalNetworkError as err: if ( - context is not None + context is not None ): # inside GrpcGateway we can handle the error directly here through the grpc context context.set_details(err.details()) context.set_code(err.code()) @@ -197,15 +211,17 @@ async def stream( raise err async def _stream_requests( - self, - request_iterator: Union[Iterator, AsyncIterator], - results_in_order: bool = False, - prefetch: Optional[int] = None, + self, + request_iterator: Union[Iterator, AsyncIterator], + results_in_order: bool = False, + prefetch: Optional[int] = None, + return_type: Type[DocumentArray] = DocumentArray, ) -> AsyncIterator: """Implements request and response handling without prefetching :param request_iterator: requests iterator from Client :param results_in_order: return the results in the same order as the request_iterator :param prefetch: How many Requests are processed from the Client at the same time. If not provided then the prefetch value from the class will be utilized. + :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. :yield: responses """ result_queue = asyncio.Queue() @@ -254,15 +270,15 @@ async def iterate_requests() -> None: """ num_reqs = 0 async for request in AsyncRequestsIterator( - iterator=request_iterator, - request_counter=requests_to_handle, - prefetch=prefetch or self._prefetch, - iterate_sync_in_thread=self._iterate_sync_in_thread, + iterator=request_iterator, + request_counter=requests_to_handle, + prefetch=prefetch or self._prefetch, + iterate_sync_in_thread=self._iterate_sync_in_thread, ): num_reqs += 1 requests_to_handle.count += 1 future_responses, future_hanging = self._request_handler( - request=request + request=request, return_type=return_type ) future_queue.put_nowait(future_responses) future_responses.add_done_callback(callback) @@ -284,8 +300,8 @@ async def iterate_requests() -> None: future_cancel = asyncio.ensure_future(end_future()) result_queue.put_nowait(future_cancel) if ( - all_floating_requests_awaited.is_set() - or empty_requests_iterator.is_set() + all_floating_requests_awaited.is_set() + or empty_requests_iterator.is_set() ): # It will be waiting for something that will never appear future_cancel = asyncio.ensure_future(end_future()) @@ -293,8 +309,8 @@ async def iterate_requests() -> None: async def handle_floating_responses(): while ( - not all_floating_requests_awaited.is_set() - and not empty_requests_iterator.is_set() + not all_floating_requests_awaited.is_set() + and not empty_requests_iterator.is_set() ): hanging_response = await floating_results_queue.get() res = hanging_response.result() @@ -347,7 +363,7 @@ async def wait_floating_requests_end(self): await asyncio.sleep(0) async def process_single_data( - self, request: DataRequest, context=None + self, request: DataRequest, context=None ) -> DataRequest: """Implements request and response handling of a single DataRequest :param request: DataRequest from Client diff --git a/tests/integration/docarray_v2/test_v2.py b/tests/integration/docarray_v2/test_v2.py index 34ac28524b706..e029b414785a5 100644 --- a/tests/integration/docarray_v2/test_v2.py +++ b/tests/integration/docarray_v2/test_v2.py @@ -838,8 +838,26 @@ async def get_executor(text: str): parameters=PARAMETERS, return_type=DocList[TextDoc], ) + assert resp.doc_type is TextDoc return {'result': [doc.text for doc in resp]} + @app.get('/endpoint_stream_docs') + async def get_endpoint_stream_docs(text: str): + docs = DocList[TextDoc]( + [ + TextDoc(text=f'stream {text}'), + TextDoc(text=f'stream {text}'.upper()), + ] + ) + async for resp in self.streamer.stream_docs( + docs, + parameters=PARAMETERS, + target_executor='executor1', + return_type=DocList[TextDoc], + ): + assert resp.doc_type is TextDoc + return {'result': [doc.text for doc in resp]} + @app.get('/endpoint_stream') async def get_endpoint_stream(text: str): docs = DocList[TextDoc]( @@ -848,9 +866,13 @@ async def get_endpoint_stream(text: str): TextDoc(text=f'stream {text}'.upper()), ] ) - async for resp in self.streamer.stream_docs( - docs, parameters=PARAMETERS, target_executor='executor1' + async for resp, _ in self.streamer.stream( + docs, + parameters=PARAMETERS, + target_executor='executor1', + return_type=DocList[TextDoc], ): + assert resp.doc_type is TextDoc return {'result': [doc.text for doc in resp]} return app @@ -880,6 +902,12 @@ def func( f'EXECUTOR MEOW Second(parameters={str(PARAMETERS)})', ] + r = requests.get(f'http://localhost:{flow.port}/endpoint_stream_docs?text=meow') + assert r.json()['result'] == [ + f'stream meow Second(parameters={str(PARAMETERS)})', + f'STREAM MEOW Second(parameters={str(PARAMETERS)})', + ] + r = requests.get(f'http://localhost:{flow.port}/endpoint_stream?text=meow') assert r.json()['result'] == [ f'stream meow Second(parameters={str(PARAMETERS)})', @@ -1136,7 +1164,9 @@ class OutputComplexDoc(BaseDoc): class MyComplexServeExec(Executor): @requests(on='/bar') - def bar(self, docs: DocList[InputComplexDoc], **kwargs) -> DocList[OutputComplexDoc]: + def bar( + self, docs: DocList[InputComplexDoc], **kwargs + ) -> DocList[OutputComplexDoc]: docs_return = DocList[OutputComplexDoc]( [ OutputComplexDoc( @@ -1160,9 +1190,13 @@ def bar(self, docs: DocList[InputComplexDoc], **kwargs) -> DocList[OutputComplex ports = [random_port() for _ in protocols] if ctxt_manager == 'flow': - ctxt = Flow(port=ports, protocol=protocols).add(replicas=replicas, uses=MyComplexServeExec) + ctxt = Flow(port=ports, protocol=protocols).add( + replicas=replicas, uses=MyComplexServeExec + ) else: - ctxt = Deployment(port=ports, protocol=protocols, replicas=replicas, uses=MyComplexServeExec) + ctxt = Deployment( + port=ports, protocol=protocols, replicas=replicas, uses=MyComplexServeExec + ) with ctxt: for port, protocol in zip(ports, protocols): c = Client(port=port, protocol=protocol) @@ -1484,22 +1518,27 @@ def foo( def test_doc_with_examples(ctxt_manager, include_gateway): if ctxt_manager == 'flow' and include_gateway: return - import string import random + import string random_example = ''.join(random.choices(string.ascii_letters, k=10)) random_description = ''.join(random.choices(string.ascii_letters, k=10)) from pydantic.fields import Field + class MyDocWithExample(BaseDoc): """This test should be in description""" + t: str = Field(examples=[random_example], description=random_description) + class Config: title: str = 'MyDocWithExampleTitle' schema_extra: Dict = {'extra_key': 'extra_value'} class MyExecDocWithExample(Executor): @requests - def foo(self, docs: DocList[MyDocWithExample], **kwargs) -> DocList[MyDocWithExample]: + def foo( + self, docs: DocList[MyDocWithExample], **kwargs + ) -> DocList[MyDocWithExample]: pass port = random_port() @@ -1507,10 +1546,16 @@ def foo(self, docs: DocList[MyDocWithExample], **kwargs) -> DocList[MyDocWithExa if ctxt_manager == 'flow': ctxt = Flow(protocol='http', port=port).add(uses=MyExecDocWithExample) else: - ctxt = Deployment(uses=MyExecDocWithExample, protocol='http', port=port, include_gateway=include_gateway) + ctxt = Deployment( + uses=MyExecDocWithExample, + protocol='http', + port=port, + include_gateway=include_gateway, + ) with ctxt: import requests as general_requests + resp = general_requests.get(f'http://localhost:{port}/openapi.json') resp_str = str(resp.json()) assert random_example in resp_str @@ -1527,13 +1572,18 @@ class MyRandomModel(BaseDoc): class MyInputModel(BaseDoc): b: Optional[MyRandomModel] = None - class MyFailingExecutor(Executor): @requests(on='/generate') - def generate(self, docs: DocList[MyInputModel], **kwargs) -> DocList[MyRandomModel]: + def generate( + self, docs: DocList[MyInputModel], **kwargs + ) -> DocList[MyRandomModel]: return DocList[MyRandomModel]([doc.b for doc in docs]) with Flow(protocol='http').add(uses=MyFailingExecutor) as f: input_doc = MyRandomModel(a='hello world') - res = f.post(on='/generate', inputs=[MyInputModel(b=MyRandomModel(a='hey'))], return_type=DocList[MyRandomModel]) + res = f.post( + on='/generate', + inputs=[MyInputModel(b=MyRandomModel(a='hey'))], + return_type=DocList[MyRandomModel], + ) assert res[0].a == 'hey' diff --git a/tests/unit/serve/stream/test_stream.py b/tests/unit/serve/stream/test_stream.py index 0d32b1f2e502e..6ef4b3f4f5d88 100644 --- a/tests/unit/serve/stream/test_stream.py +++ b/tests/unit/serve/stream/test_stream.py @@ -2,8 +2,8 @@ import random import pytest - from docarray import Document, DocumentArray + from jina.helper import Namespace, random_identity from jina.serve.stream import RequestStreamer from jina.types.request.data import DataRequest @@ -24,10 +24,10 @@ def __init__(self, num_requests, prefetch, iterate_sync_in_thread): result_handler=self.result_handle_fn, end_of_iter_handler=self.end_of_iter_fn, prefetch=getattr(args, 'prefetch', 0), - iterate_sync_in_thread=iterate_sync_in_thread + iterate_sync_in_thread=iterate_sync_in_thread, ) - def request_handler_fn(self, request): + def request_handler_fn(self, request, **kwargs): self.requests_handled.append(request) async def task(): @@ -82,7 +82,9 @@ async def test_request_streamer( prefetch, num_requests, async_iterator, results_in_order, iterate_sync_in_thread ): - test_streamer = RequestStreamerWrapper(num_requests, prefetch, iterate_sync_in_thread) + test_streamer = RequestStreamerWrapper( + num_requests, prefetch, iterate_sync_in_thread + ) streamer = test_streamer.streamer it = ( @@ -113,7 +115,9 @@ async def test_request_streamer( @pytest.mark.asyncio @pytest.mark.parametrize('num_requests', [1, 5, 13]) @pytest.mark.parametrize('iterate_sync_in_thread', [False, True]) -async def test_request_streamer_process_single_data(monkeypatch, num_requests, iterate_sync_in_thread): +async def test_request_streamer_process_single_data( + monkeypatch, num_requests, iterate_sync_in_thread +): test_streamer = RequestStreamerWrapper(num_requests, 0, iterate_sync_in_thread) streamer = test_streamer.streamer