Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid expensive setting #6181

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jina/clients/base/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def _get_results(
async for (
response
) in stream_rpc.stream_rpc_with_retry():
yield response
yield response, None
else:
unary_rpc = UnaryRpc(
channel=channel,
Expand All @@ -169,7 +169,7 @@ async def _get_results(
**kwargs,
)
async for response in unary_rpc.unary_rpc_with_retry():
yield response
yield response, None
except (grpc.aio.AioRpcError, InternalNetworkError) as err:
await self._handle_error_and_metadata(err)
except KeyboardInterrupt:
Expand Down
7 changes: 4 additions & 3 deletions jina/clients/base/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,20 +218,21 @@ def _result_handler(result):
del r_str['data']

resp = DataRequest(r_str)
if da is not None:
resp.data.docs = da
#if da is not None:
# resp.data.docs = da

callback_exec(
response=resp,
logger=self.logger,
docs=da,
on_error=on_error,
on_done=on_done,
on_always=on_always,
continue_on_error=self.continue_on_error,
)
if self.show_progress:
p_bar.update()
yield resp
yield resp, da

async def _get_streaming_results(
self,
Expand Down
1 change: 1 addition & 0 deletions jina/clients/base/stream_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ async def stream_rpc_with_retry(self):
callback_exec(
response=resp,
logger=self.logger,
docs=None,
on_error=self.on_error,
on_done=self.on_done,
on_always=self.on_always,
Expand Down
1 change: 1 addition & 0 deletions jina/clients/base/unary_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _result_handler(resp):
callback_exec(
response=resp,
logger=self.logger,
docs=None,
on_error=self.on_error,
on_done=self.on_done,
on_always=self.on_always,
Expand Down
3 changes: 2 additions & 1 deletion jina/clients/base/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,15 @@ def _request_handler(
callback_exec(
response=response,
logger=self.logger,
docs=None,
on_error=on_error,
on_done=on_done,
on_always=on_always,
continue_on_error=self.continue_on_error,
)
if self.show_progress:
p_bar.update()
yield response
yield response, None
except Exception as ex:
exception_raised = ex
try:
Expand Down
11 changes: 10 additions & 1 deletion jina/clients/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _arg_wrapper(*args, **kwargs):
def callback_exec(
response,
logger: JinaLogger,
docs: Optional = None,
on_done: Optional[Callable] = None,
on_error: Optional[Callable] = None,
on_always: Optional[Callable] = None,
Expand All @@ -66,20 +67,28 @@ def callback_exec(
"""Execute the callback with the response.

:param response: the response
:param logger: a logger instance
:param docs: the docs to attach lazily to response if needed
:param on_done: the on_done callback
:param on_error: the on_error callback
:param on_always: the on_always callback
:param continue_on_error: whether to continue on error
:param logger: a logger instance
"""
if response.header.status.code >= jina_pb2.StatusProto.ERROR:
if on_error:
if docs is not None:
# response.data.docs is expensive and not always needed.
response.data.docs = docs
_safe_callback(on_error, continue_on_error, logger)(response)
elif continue_on_error:
logger.error(f'Server error: {response.header}')
else:
raise BadServer(response.header)
elif on_done and response.header.status.code == jina_pb2.StatusProto.SUCCESS:
if docs is not None:
response.data.docs = docs
_safe_callback(on_done, continue_on_error, logger)(response)
if on_always:
if docs is not None:
response.data.docs = docs
_safe_callback(on_always, continue_on_error, logger)(response)
13 changes: 9 additions & 4 deletions jina/clients/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,14 +407,17 @@ async def _get_results(*args, **kwargs):
inferred_return_type = DocList[return_type]
result = [] if return_responses else inferred_return_type([])

async for resp in c._get_results(*args, **kwargs):
async for resp, da in c._get_results(*args, **kwargs):

if return_results:
resp.document_array_cls = inferred_return_type
if return_responses:
if da is not None:
resp.data.docs = da
result.append(resp)
else:
result.extend(resp.data.docs)
result.extend(da if da is not None else resp.data.docs)

if return_results:
if not return_responses and is_singleton and len(result) == 1:
return result[0]
Expand Down Expand Up @@ -508,7 +511,7 @@ async def post(

parameters = _include_results_field_in_param(parameters)

async for result in c._get_results(
async for result, da in c._get_results(
on=on,
inputs=inputs,
on_done=on_done,
Expand Down Expand Up @@ -538,12 +541,14 @@ async def post(
is_singleton = True
result.document_array_cls = DocList[return_type]
if not return_responses:
ret_docs = result.data.docs
ret_docs = da if da is not None else result.data.docs
if is_singleton and len(ret_docs) == 1:
yield ret_docs[0]
else:
yield ret_docs
else:
if da is not None:
result.data.docs = da
yield result

async def stream_doc(
Expand Down
Loading