diff --git a/jina/clients/base/grpc.py b/jina/clients/base/grpc.py index 204924a57f74d..6a4c0e9f6ae56 100644 --- a/jina/clients/base/grpc.py +++ b/jina/clients/base/grpc.py @@ -145,7 +145,7 @@ async def _get_results( async for ( response ) in stream_rpc.stream_rpc_with_retry(): - yield response + yield response, None else: unary_rpc = UnaryRpc( channel=channel, @@ -169,7 +169,7 @@ async def _get_results( **kwargs, ) async for response in unary_rpc.unary_rpc_with_retry(): - yield response + yield response, None except (grpc.aio.AioRpcError, InternalNetworkError) as err: await self._handle_error_and_metadata(err) except KeyboardInterrupt: diff --git a/jina/clients/base/http.py b/jina/clients/base/http.py index 653a98f051629..cc7a495cd3fdf 100644 --- a/jina/clients/base/http.py +++ b/jina/clients/base/http.py @@ -218,12 +218,13 @@ def _result_handler(result): del r_str['data'] resp = DataRequest(r_str) - if da is not None: - resp.data.docs = da + #if da is not None: + # resp.data.docs = da callback_exec( response=resp, logger=self.logger, + docs=da, on_error=on_error, on_done=on_done, on_always=on_always, @@ -231,7 +232,7 @@ def _result_handler(result): ) if self.show_progress: p_bar.update() - yield resp + yield resp, da async def _get_streaming_results( self, diff --git a/jina/clients/base/stream_rpc.py b/jina/clients/base/stream_rpc.py index 6ea2b9805e27a..ebfef9fdfd71b 100644 --- a/jina/clients/base/stream_rpc.py +++ b/jina/clients/base/stream_rpc.py @@ -58,6 +58,7 @@ async def stream_rpc_with_retry(self): callback_exec( response=resp, logger=self.logger, + docs=None, on_error=self.on_error, on_done=self.on_done, on_always=self.on_always, diff --git a/jina/clients/base/unary_rpc.py b/jina/clients/base/unary_rpc.py index 6cb219706738c..dc762bcee49ab 100644 --- a/jina/clients/base/unary_rpc.py +++ b/jina/clients/base/unary_rpc.py @@ -101,6 +101,7 @@ def _result_handler(resp): callback_exec( response=resp, logger=self.logger, + docs=None, on_error=self.on_error, on_done=self.on_done, on_always=self.on_always, diff --git a/jina/clients/base/websocket.py b/jina/clients/base/websocket.py index a8b868704bac0..806e517182446 100644 --- a/jina/clients/base/websocket.py +++ b/jina/clients/base/websocket.py @@ -209,6 +209,7 @@ def _request_handler( callback_exec( response=response, logger=self.logger, + docs=None, on_error=on_error, on_done=on_done, on_always=on_always, @@ -216,7 +217,7 @@ def _request_handler( ) if self.show_progress: p_bar.update() - yield response + yield response, None except Exception as ex: exception_raised = ex try: diff --git a/jina/clients/helper.py b/jina/clients/helper.py index 063837abc9240..5bdbacd17600f 100644 --- a/jina/clients/helper.py +++ b/jina/clients/helper.py @@ -58,6 +58,7 @@ def _arg_wrapper(*args, **kwargs): def callback_exec( response, logger: JinaLogger, + docs: Optional = None, on_done: Optional[Callable] = None, on_error: Optional[Callable] = None, on_always: Optional[Callable] = None, @@ -66,20 +67,28 @@ def callback_exec( """Execute the callback with the response. :param response: the response + :param logger: a logger instance + :param docs: the docs to attach lazily to response if needed :param on_done: the on_done callback :param on_error: the on_error callback :param on_always: the on_always callback :param continue_on_error: whether to continue on error - :param logger: a logger instance """ if response.header.status.code >= jina_pb2.StatusProto.ERROR: if on_error: + if docs is not None: + # response.data.docs is expensive and not always needed. + response.data.docs = docs _safe_callback(on_error, continue_on_error, logger)(response) elif continue_on_error: logger.error(f'Server error: {response.header}') else: raise BadServer(response.header) elif on_done and response.header.status.code == jina_pb2.StatusProto.SUCCESS: + if docs is not None: + response.data.docs = docs _safe_callback(on_done, continue_on_error, logger)(response) if on_always: + if docs is not None: + response.data.docs = docs _safe_callback(on_always, continue_on_error, logger)(response) diff --git a/jina/clients/mixin.py b/jina/clients/mixin.py index ec0c52049d200..024316df64e46 100644 --- a/jina/clients/mixin.py +++ b/jina/clients/mixin.py @@ -407,14 +407,17 @@ async def _get_results(*args, **kwargs): inferred_return_type = DocList[return_type] result = [] if return_responses else inferred_return_type([]) - async for resp in c._get_results(*args, **kwargs): + async for resp, da in c._get_results(*args, **kwargs): if return_results: resp.document_array_cls = inferred_return_type if return_responses: + if da is not None: + resp.data.docs = da result.append(resp) else: - result.extend(resp.data.docs) + result.extend(da if da is not None else resp.data.docs) + if return_results: if not return_responses and is_singleton and len(result) == 1: return result[0] @@ -508,7 +511,7 @@ async def post( parameters = _include_results_field_in_param(parameters) - async for result in c._get_results( + async for result, da in c._get_results( on=on, inputs=inputs, on_done=on_done, @@ -538,12 +541,14 @@ async def post( is_singleton = True result.document_array_cls = DocList[return_type] if not return_responses: - ret_docs = result.data.docs + ret_docs = da if da is not None else result.data.docs if is_singleton and len(ret_docs) == 1: yield ret_docs[0] else: yield ret_docs else: + if da is not None: + result.data.docs = da yield result async def stream_doc(