Skip to content

Commit

Permalink
feat: add return_type parameter to gateway streamer methods to allow … (
Browse files Browse the repository at this point in the history
  • Loading branch information
alaeddine-13 authored Sep 14, 2023
1 parent 6d62bce commit b256f9b
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 191 deletions.
15 changes: 15 additions & 0 deletions docs/concepts/serving/gateway/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,13 @@ class MyGateway(FastAPIBaseGateway):
return app
```

```{hint}
:class: note
if you omit the `return_type` parameter, the gateway streamer can still fetch the Executor output schemas and dynamically construct a DocArray model for it.
Even though the dynamically created schema is very similar to original schema, some validation checks can still fail (for instance adding to a typed `DocList`).
It is recommended to always pass the `return_type` parameter
```

### Recovering Executor errors

Exceptions raised by an `Executor` are captured in the server object which can be extracted by using the {meth}`jina.serve.streamer.stream()` method. The `stream` method
Expand All @@ -266,6 +273,14 @@ async def get(text: str):
return {'results': results, 'errors': [error.name for error in errors]}
```


```{hint}
:class: note
if you omit the `return_type` parameter, the gateway streamer can still fetch the Executor output schemas and dynamically construct a DocArray model for it.
Even though the dynamically created schema is very similar to original schema, some validation checks can still fail (for instance adding to a typed `DocList`).
It is recommended to always pass the `return_type` parameter
```

(executor-streamer)=
## Calling an individual Executor

Expand Down
16 changes: 9 additions & 7 deletions jina/clients/base/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@

from jina._docarray import Document, DocumentArray, docarray_v2
from jina.clients.base import BaseClient
from jina.clients.base.helper import (
HTTPClientlet,
handle_response_status,
)
from jina.clients.base.helper import HTTPClientlet, handle_response_status
from jina.clients.helper import callback_exec
from jina.importer import ImportExtensions
from jina.logging.profile import ProgressBar
Expand Down Expand Up @@ -165,12 +162,13 @@ async def _get_results(
)

def _request_handler(
request: 'Request',
request: 'Request', **kwargs
) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]':
"""
For HTTP Client, for each request in the iterator, we `send_message` using
http POST request and add it to the list of tasks which is awaited and yielded.
:param request: current request in the iterator
:param kwargs: kwargs
:return: asyncio Task for sending message
"""
return asyncio.ensure_future(iolet.send_message(request=request)), None
Expand Down Expand Up @@ -203,11 +201,15 @@ def _result_handler(result):
da = DocumentArray.from_dict(r_str['data'])
else:
from docarray import DocList

if issubclass(return_type, DocList):
da = return_type(
[return_type.doc_type(**v) for v in r_str['data']])
[return_type.doc_type(**v) for v in r_str['data']]
)
else:
da = DocList[return_type]([return_type(**v) for v in r_str['data']])
da = DocList[return_type](
[return_type(**v) for v in r_str['data']]
)
del r_str['data']

resp = DataRequest(r_str)
Expand Down
2 changes: 1 addition & 1 deletion jina/clients/base/unary_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def unary_rpc_with_retry(self):
stub = jina_pb2_grpc.JinaSingleDataRequestRPCStub(self.channel)

def _request_handler(
request: 'Request',
request: 'Request', **kwargs
) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]':
async def _with_retry(req: 'Request'):
for attempt in range(1, self.max_attempts + 1):
Expand Down
3 changes: 2 additions & 1 deletion jina/clients/base/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _handle_end_of_iter():
asyncio.create_task(iolet.send_eoi())

def _request_handler(
request: 'Request',
request: 'Request', **kwargs
) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]':
"""
For each request in the iterator, we send the `Message` using `iolet.send_message()`.
Expand All @@ -175,6 +175,7 @@ def _request_handler(
Then add {<request-id>: <an-empty-future>} to the request buffer.
This empty future is used to track the `result` of this request during `receive`.
:param request: current request in the iterator
:param kwargs: kwargs
:return: asyncio Future for sending message
"""
future = get_or_reuse_loop().create_future()
Expand Down
63 changes: 36 additions & 27 deletions jina/serve/runtimes/gateway/async_request_response_handling.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import copy
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, AsyncGenerator
from typing import TYPE_CHECKING, AsyncGenerator, Callable, List, Optional, Tuple, Type

import grpc.aio

from jina._docarray import DocumentArray
from jina._docarray import DocumentArray, docarray_v2
from jina.excepts import InternalNetworkError
from jina.helper import GATEWAY_NAME
from jina.logging.logger import JinaLogger
Expand All @@ -13,7 +13,6 @@
from jina.serve.runtimes.helper import _is_param_for_specific_executor
from jina.serve.runtimes.monitoring import MonitoringRequestMixin
from jina.serve.runtimes.worker.request_handling import WorkerRequestHandler
from jina._docarray import docarray_v2

if TYPE_CHECKING: # pragma: no cover
from asyncio import Future
Expand All @@ -33,19 +32,19 @@ class AsyncRequestResponseHandler(MonitoringRequestMixin):
"""

def __init__(
self,
metrics_registry: Optional['CollectorRegistry'] = None,
meter: Optional['Meter'] = None,
runtime_name: Optional[str] = None,
logger: Optional[JinaLogger] = None,
self,
metrics_registry: Optional['CollectorRegistry'] = None,
meter: Optional['Meter'] = None,
runtime_name: Optional[str] = None,
logger: Optional[JinaLogger] = None,
):
super().__init__(metrics_registry, meter, runtime_name)
self._endpoint_discovery_finished = False
self._gathering_endpoints = False
self.logger = logger or JinaLogger(self.__class__.__name__)

def handle_request(
self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool'
self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool'
) -> Callable[['Request'], 'Tuple[Future, Optional[Future]]']:
"""
Function that handles the requests arriving to the gateway. This will be passed to the streamer.
Expand All @@ -64,8 +63,8 @@ async def gather_endpoints(request_graph):
err_code = err.code()
if err_code == grpc.StatusCode.UNAVAILABLE:
err._details = (
err.details()
+ f' |Gateway: Communication error while gathering endpoints with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.'
err.details()
+ f' |Gateway: Communication error while gathering endpoints with deployment at address(es) {err.dest_addr}. Head or worker(s) may be down.'
)
raise err
else:
Expand All @@ -75,7 +74,9 @@ async def gather_endpoints(request_graph):
raise exc
self._endpoint_discovery_finished = True

def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':
def _handle_request(
request: 'Request', return_type: Type[DocumentArray]
) -> 'Tuple[Future, Optional[Future]]':
self._update_start_request_metrics(request)
# important that the gateway needs to have an instance of the graph per request
request_graph = copy.deepcopy(graph)
Expand All @@ -100,20 +101,19 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':
request.header.target_executor = ''
exec_endpoint = request.header.exec_endpoint
gather_endpoints_task = None
if (
not self._endpoint_discovery_finished
and not self._gathering_endpoints
):
gather_endpoints_task = asyncio.create_task(gather_endpoints(request_graph))
if not self._endpoint_discovery_finished and not self._gathering_endpoints:
gather_endpoints_task = asyncio.create_task(
gather_endpoints(request_graph)
)

init_task = None
request_doc_ids = []

if graph.has_filter_conditions:
if not docarray_v2:
request_doc_ids = request.data.docs[
:, 'id'
] # used to maintain order of docs that are filtered by executors
:, 'id'
] # used to maintain order of docs that are filtered by executors
else:
init_task = gather_endpoints_task
from docarray import DocList
Expand All @@ -137,7 +137,8 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':
request_input_parameters=request_input_parameters,
request_input_has_specific_params=has_specific_params,
copy_request_at_send=num_outgoing_nodes > 1 and has_specific_params,
init_task=init_task
init_task=init_task,
return_type=return_type,
)
# Every origin node returns a set of tasks that are the ones corresponding to the leafs of each of their
# subtrees that unwrap all the previous tasks. It starts like a chain of waiting for tasks from previous
Expand All @@ -157,7 +158,7 @@ def sort_by_request_order(doc):
response.data.docs = DocumentArray(sorted_docs)

async def _process_results_at_end_gateway(
tasks: List[asyncio.Task], request_graph: TopologyGraph
tasks: List[asyncio.Task], request_graph: TopologyGraph
) -> asyncio.Future:
try:
partial_responses = await asyncio.gather(*tasks)
Expand Down Expand Up @@ -209,15 +210,18 @@ async def _process_results_at_end_gateway(

def handle_single_document_request(
self, graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool'
) -> Callable[['Request'], 'AsyncGenerator']:
) -> Callable[['Request', Type[DocumentArray]], 'AsyncGenerator']:
"""
Function that handles the requests arriving to the gateway. This will be passed to the streamer.
:param graph: The TopologyGraph of the Flow.
:param connection_pool: The connection pool to be used to send messages to specific nodes of the graph
:return: Return a Function that given a Request will return a Future from where to extract the response
"""
async def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':

async def _handle_request(
request: 'Request', return_type: Type[DocumentArray] = DocumentArray
) -> 'Tuple[Future, Optional[Future]]':
self._update_start_request_metrics(request)
# important that the gateway needs to have an instance of the graph per request
request_graph = copy.deepcopy(graph)
Expand All @@ -229,10 +233,15 @@ async def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]
# reset it in case we send to an external gateway
exec_endpoint = request.header.exec_endpoint

node = request_graph.all_nodes[0] # this assumes there is only one Executor behind this Gateway
async for resp in node.stream_single_doc(request=request,
connection_pool=connection_pool,
endpoint=exec_endpoint):
node = request_graph.all_nodes[
0
] # this assumes there is only one Executor behind this Gateway
async for resp in node.stream_single_doc(
request=request,
connection_pool=connection_pool,
endpoint=exec_endpoint,
return_type=return_type,
):
yield resp

return _handle_request
Expand Down
55 changes: 45 additions & 10 deletions jina/serve/runtimes/gateway/graph/topology_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import re
from collections import defaultdict
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Type

import grpc.aio

from jina._docarray import docarray_v2
from jina._docarray import DocumentArray, docarray_v2
from jina.constants import __default_endpoint__
from jina.excepts import InternalNetworkError
from jina.logging.logger import JinaLogger
Expand Down Expand Up @@ -315,6 +315,7 @@ async def stream_single_doc(
request: SingleDocumentRequest,
connection_pool: GrpcConnectionPool,
endpoint: Optional[str],
return_type: Type[DocumentArray] = DocumentArray,
):
if docarray_v2:
if self.endpoints and endpoint in self.endpoints:
Expand All @@ -336,9 +337,20 @@ async def stream_single_doc(
else:
if docarray_v2:
if self.endpoints and endpoint in self.endpoints:
resp.document_cls = self._pydantic_models_by_endpoint[
endpoint
]['output']
from docarray.base_doc import AnyDoc

# if return_type is not specified or if it is a default type, cast using retrieved
# schemas
if (
not return_type
or not return_type.doc_type
or return_type.doc_type is AnyDoc
):
resp.document_cls = self._pydantic_models_by_endpoint[
endpoint
]['output']
else:
resp.document_array_cls = return_type
yield resp

async def _wait_previous_and_send(
Expand All @@ -351,6 +363,7 @@ async def _wait_previous_and_send(
request_input_parameters: Dict = {},
copy_request_at_send: bool = False,
init_task: Optional[asyncio.Task] = None,
return_type: Type[DocumentArray] = None,
):
# Check my condition and send request with the condition
metadata = {}
Expand Down Expand Up @@ -419,12 +432,30 @@ async def _wait_previous_and_send(
resp, metadata = result

if docarray_v2:
if self.endpoints and endpoint in self.endpoints:
resp.document_array_cls = DocList[
self._pydantic_models_by_endpoint[endpoint][
'output'
if self.endpoints and (
endpoint in self.endpoints
or __default_endpoint__ in self.endpoints
):
from docarray.base_doc import AnyDoc

# if return_type is not specified or if it is a default type, cast using retrieved
# schemas
if (
not return_type
or not return_type.doc_type
or return_type.doc_type is AnyDoc
):
pydantic_models = (
self._pydantic_models_by_endpoint.get(endpoint)
or self._pydantic_models_by_endpoint.get(
__default_endpoint__
)
)
resp.document_array_cls = DocList[
pydantic_models['output']
]
]
else:
resp.document_array_cls = return_type

if WorkerRequestHandler._KEY_RESULT in resp.parameters:
# Accumulate results from each Node and then add them to the original
Expand Down Expand Up @@ -570,6 +601,7 @@ def get_leaf_req_response_tasks(
request_input_has_specific_params: bool = False,
copy_request_at_send: bool = False,
init_task: Optional[asyncio.Task] = None,
return_type: Type[DocumentArray] = DocumentArray,
) -> List[Tuple[bool, asyncio.Task]]:
"""
Gets all the tasks corresponding from all the subgraphs born from this node
Expand Down Expand Up @@ -602,6 +634,7 @@ def get_leaf_req_response_tasks(
When the caller of these methods await them, they will fire the logic of sending requests and responses from and to every deployment
:param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`.
:return: Return a list of tuples, where tasks corresponding to the leafs of all the subgraphs born from this node are in each tuple.
These tasks will be based on awaiting for the task from previous_node and sending a request to the corresponding node. The other member of the pair
is a flag indicating if the task is to be awaited by the gateway or not.
Expand All @@ -616,6 +649,7 @@ def get_leaf_req_response_tasks(
request_input_parameters=request_input_parameters,
copy_request_at_send=copy_request_at_send,
init_task=init_task,
return_type=return_type,
)
)
if self.leaf: # I am like a leaf
Expand All @@ -635,6 +669,7 @@ def get_leaf_req_response_tasks(
request_input_has_specific_params=request_input_has_specific_params,
copy_request_at_send=num_outgoing_nodes > 1
and request_input_has_specific_params,
return_type=return_type,
)
# We are interested in the last one, that will be the task that awaits all the previous
hanging_tasks_tuples.extend(t)
Expand Down
Loading

0 comments on commit b256f9b

Please sign in to comment.