Skip to content

Commit

Permalink
fix: fix clientlets
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Sep 13, 2024
1 parent 361f08c commit ea610bd
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 64 deletions.
1 change: 1 addition & 0 deletions jina/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(

async def close(self):
"""Closes the potential resources of the Client.
:return: Return whatever a close method may return
"""
return self.teardown_instrumentation()

Check warning on line 70 in jina/clients/base/__init__.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/__init__.py#L70

Added line #L70 was not covered by tests

Expand Down
6 changes: 5 additions & 1 deletion jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class HTTPClientlet(AioHttpClientlet):
async def send_message(self, url, request: 'Request'):
"""Sends a POST request to the server
:param url: the URL where to send the message
:param request: request as dict
:return: send post message
"""
Expand Down Expand Up @@ -196,6 +197,7 @@ async def send_message(self, url, request: 'Request'):
async def send_streaming_message(self, url, doc: 'Document', on: str):
"""Sends a GET SSE request to the server
:param url: the URL where to send the message
:param doc: Request Document
:param on: Request endpoint
:yields: responses
Expand All @@ -218,6 +220,7 @@ async def send_streaming_message(self, url, doc: 'Document', on: str):

async def send_dry_run(self, url, **kwargs):
"""Query the dry_run endpoint from Gateway
:param url: the URL where to send the message
:param kwargs: keyword arguments to make sure compatible API with other clients
:return: send get message
"""
Expand Down Expand Up @@ -264,8 +267,9 @@ async def __anext__(self):
class WebsocketClientlet(AioHttpClientlet):
"""Websocket Client to be used with the streamer"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, url, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.url = url
self.websocket = None
self.response_iter = None

Expand Down
135 changes: 72 additions & 63 deletions jina/clients/base/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class HTTPBaseClient(BaseClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._endpoints = []
self._lock = asyncio.Lock()
self.iolet = None

async def close(self):
Expand Down Expand Up @@ -76,23 +77,23 @@ async def _is_flow_ready(self, **kwargs) -> bool:
proto = 'https' if self.args.tls else 'http'
url = f'{proto}://{self.args.host}:{self.args.port}/dry_run'

if self.iolet is not None and self.args.reuse_session:
iolet = self.iolet
else:
iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)

if self.args.reuse_session and self.iolet is None:
self.iolet = iolet
await self.iolet.__aenter__()

if not self.args.reuse_session:
iolet = await stack.enter_async_context(

Check warning on line 81 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L80-L81

Added lines #L80 - L81 were not covered by tests
iolet
HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
)
else:
async with self._lock:
if self.iolet is None:
self.iolet = HTTPClientlet(

Check warning on line 91 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L89-L91

Added lines #L89 - L91 were not covered by tests
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
await self.iolet.__aenter__()

Check warning on line 96 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L96

Added line #L96 was not covered by tests

response = await iolet.send_dry_run(**kwargs)
r_status = response.status
Expand All @@ -112,20 +113,20 @@ async def _is_flow_ready(self, **kwargs) -> bool:
return False

async def _get_results(
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
max_attempts: int = 1,
initial_backoff: float = 0.5,
max_backoff: float = 0.1,
backoff_multiplier: float = 1.5,
results_in_order: bool = False,
prefetch: Optional[int] = None,
timeout: Optional[int] = None,
return_type: Type[DocumentArray] = DocumentArray,
**kwargs,
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
max_attempts: int = 1,
initial_backoff: float = 0.5,
max_backoff: float = 0.1,
backoff_multiplier: float = 1.5,
results_in_order: bool = False,
prefetch: Optional[int] = None,
timeout: Optional[int] = None,
return_type: Type[DocumentArray] = DocumentArray,
**kwargs,
):
"""
:param inputs: the callable
Expand Down Expand Up @@ -168,30 +169,26 @@ async def _get_results(
else:
url = f'{proto}://{self.args.host}:{self.args.port}/post'

if self.iolet is not None and self.args.reuse_session:
iolet = self.iolet
else:
iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
max_attempts=max_attempts,
initial_backoff=initial_backoff,
max_backoff=max_backoff,
backoff_multiplier=backoff_multiplier,
timeout=timeout,
**kwargs,
)
if self.args.reuse_session and self.iolet is None:
self.iolet = iolet
await self.iolet.__aenter__()

if not self.args.reuse_session:
iolet = await stack.enter_async_context(
iolet
HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
)
else:
async with self._lock:
if self.iolet is None:
self.iolet = HTTPClientlet(

Check warning on line 183 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L181-L183

Added lines #L181 - L183 were not covered by tests
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
await self.iolet.__aenter__()

Check warning on line 188 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L188

Added line #L188 was not covered by tests

def _request_handler(
request: 'Request', **kwargs
request: 'Request', **kwargs
) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]':
"""
For HTTP Client, for each request in the iterator, we `send_message` using
Expand All @@ -215,7 +212,7 @@ def _result_handler(result):
**streamer_args,
)
async for response in streamer.stream(
request_iterator=request_iterator, results_in_order=results_in_order
request_iterator=request_iterator, results_in_order=results_in_order
):
r_status, r_str = response
handle_response_status(r_status, r_str, url)
Expand Down Expand Up @@ -256,13 +253,13 @@ def _result_handler(result):
yield resp

async def _get_streaming_results(
self,
on: str,
inputs: 'Document',
parameters: Optional[Dict] = None,
return_type: Type[Document] = Document,
timeout: Optional[int] = None,
**kwargs,
self,
on: str,
inputs: 'Document',
parameters: Optional[Dict] = None,
return_type: Type[Document] = Document,
timeout: Optional[int] = None,
**kwargs,
):
proto = 'https' if self.args.tls else 'http'
endpoint = on.strip('/')
Expand All @@ -272,15 +269,27 @@ async def _get_streaming_results(
url = f'{proto}://{self.args.host}:{self.args.port}/{endpoint}'
else:
url = f'{proto}://{self.args.host}:{self.args.port}/default'
async with AsyncExitStack() as stack:
if not self.args.reuse_session:
iolet = await stack.enter_async_context(

Check warning on line 274 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L272-L274

Added lines #L272 - L274 were not covered by tests
HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
timeout=timeout,
**kwargs,
)
)
else:
async with self._lock:
if self.iolet is None:
self.iolet = HTTPClientlet(

Check warning on line 285 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L283-L285

Added lines #L283 - L285 were not covered by tests
logger=self.logger,
tracer_provider=self.tracer_provider,
timeout=timeout,
**kwargs,
)
await self.iolet.__aenter__()

Check warning on line 291 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L291

Added line #L291 was not covered by tests

iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
timeout=timeout,
**kwargs,
)

async with iolet:
async for doc in iolet.send_streaming_message(url=url, doc=inputs, on=on):

Check warning on line 293 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L293

Added line #L293 was not covered by tests
if not docarray_v2:
yield Document.from_dict(json.loads(doc))
Expand Down

0 comments on commit ea610bd

Please sign in to comment.