Skip to content

Commit

Permalink
fix: fix clientlets
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Sep 15, 2024
1 parent 361f08c commit 8136939
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 89 deletions.
1 change: 1 addition & 0 deletions jina/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(

async def close(self):
"""Closes the potential resources of the Client.
:return: Return whatever a close method may return
"""
return self.teardown_instrumentation()

Check warning on line 70 in jina/clients/base/__init__.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/__init__.py#L70

Added line #L70 was not covered by tests

Expand Down
12 changes: 8 additions & 4 deletions jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ async def start(self):
"""
with ImportExtensions(required=True):
import aiohttp

self.session = aiohttp.ClientSession(
**self._session_kwargs, trace_configs=self._trace_config
)
Expand All @@ -154,6 +153,7 @@ class HTTPClientlet(AioHttpClientlet):
async def send_message(self, url, request: 'Request'):

Check warning on line 153 in jina/clients/base/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/helper.py#L153

Added line #L153 was not covered by tests
"""Sends a POST request to the server
:param url: the URL where to send the message
:param request: request as dict
:return: send post message
"""
Expand All @@ -170,14 +170,15 @@ async def send_message(self, url, request: 'Request'):
from docarray.base_doc.io.json import orjson_dumps

request_kwargs['data'] = JinaJsonPayload(value=req_dict)

async with self.session.post(**request_kwargs) as response:
try:
r_str = await response.json()
except aiohttp.ContentTypeError:
r_str = await response.text()
r_status = response.status
handle_response_status(response.status, r_str, url)
return r_status, r_str
handle_response_status(r_status, r_str, url)
return r_status, r_str

Check warning on line 181 in jina/clients/base/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/helper.py#L180-L181

Added lines #L180 - L181 were not covered by tests
except (ValueError, ConnectionError, BadClient, aiohttp.ClientError, aiohttp.ClientConnectionError) as err:
self.logger.debug(f'Got an error: {err} sending POST to {url} in attempt {attempt}/{self.max_attempts}')

Check warning on line 183 in jina/clients/base/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/helper.py#L183

Added line #L183 was not covered by tests
await retry.wait_or_raise_err(
Expand All @@ -196,6 +197,7 @@ async def send_message(self, url, request: 'Request'):
async def send_streaming_message(self, url, doc: 'Document', on: str):

Check warning on line 197 in jina/clients/base/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/helper.py#L197

Added line #L197 was not covered by tests
"""Sends a GET SSE request to the server
:param url: the URL where to send the message
:param doc: Request Document
:param on: Request endpoint
:yields: responses
Expand All @@ -218,6 +220,7 @@ async def send_streaming_message(self, url, doc: 'Document', on: str):

async def send_dry_run(self, url, **kwargs):

Check warning on line 221 in jina/clients/base/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/helper.py#L221

Added line #L221 was not covered by tests
"""Query the dry_run endpoint from Gateway
:param url: the URL where to send the message
:param kwargs: keyword arguments to make sure compatible API with other clients
:return: send get message
"""
Expand Down Expand Up @@ -264,8 +267,9 @@ async def __anext__(self):
class WebsocketClientlet(AioHttpClientlet):
"""Websocket Client to be used with the streamer"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, url, *args, **kwargs) -> None:

Check warning on line 270 in jina/clients/base/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/helper.py#L270

Added line #L270 was not covered by tests
super().__init__(*args, **kwargs)
self.url = url

Check warning on line 272 in jina/clients/base/helper.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/helper.py#L272

Added line #L272 was not covered by tests
self.websocket = None
self.response_iter = None

Expand Down
152 changes: 84 additions & 68 deletions jina/clients/base/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ class HTTPBaseClient(BaseClient):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._endpoints = []
self.reuse_session = False
self._lock = AsyncExitStack()
self.iolet = None

Check warning on line 28 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L26-L28

Added lines #L26 - L28 were not covered by tests

async def close(self):

Check warning on line 30 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L30

Added line #L30 was not covered by tests
await super().close()
"""Closes the potential resources of the Client.
:return: Return whatever a close method may return
"""
ret = super().close()
if self.iolet is not None:
await self.iolet.__aexit__()
return ret

Check warning on line 37 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L34-L37

Added lines #L34 - L37 were not covered by tests

async def _get_endpoints_from_openapi(self, **kwargs):
def extract_paths_by_method(spec):
Expand Down Expand Up @@ -76,25 +82,26 @@ async def _is_flow_ready(self, **kwargs) -> bool:
proto = 'https' if self.args.tls else 'http'
url = f'{proto}://{self.args.host}:{self.args.port}/dry_run'

if self.iolet is not None and self.args.reuse_session:
iolet = self.iolet
else:
iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)

if self.args.reuse_session and self.iolet is None:
self.iolet = iolet
await self.iolet.__aenter__()

if not self.args.reuse_session:
if not self.reuse_session:
iolet = await stack.enter_async_context(

Check warning on line 86 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L85-L86

Added lines #L85 - L86 were not covered by tests
iolet
HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
)
else:
async with self._lock:
if self.iolet is None:
self.iolet = HTTPClientlet(

Check warning on line 96 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L94-L96

Added lines #L94 - L96 were not covered by tests
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
await self.iolet.__aenter__()
iolet = self.iolet

Check warning on line 102 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L101-L102

Added lines #L101 - L102 were not covered by tests

response = await iolet.send_dry_run(**kwargs)
response = await iolet.send_dry_run(url=url, **kwargs)

Check warning on line 104 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L104

Added line #L104 was not covered by tests
r_status = response.status

r_str = await response.json()
Expand All @@ -112,20 +119,20 @@ async def _is_flow_ready(self, **kwargs) -> bool:
return False

async def _get_results(
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
max_attempts: int = 1,
initial_backoff: float = 0.5,
max_backoff: float = 0.1,
backoff_multiplier: float = 1.5,
results_in_order: bool = False,
prefetch: Optional[int] = None,
timeout: Optional[int] = None,
return_type: Type[DocumentArray] = DocumentArray,
**kwargs,
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
max_attempts: int = 1,
initial_backoff: float = 0.5,
max_backoff: float = 0.1,
backoff_multiplier: float = 1.5,
results_in_order: bool = False,
prefetch: Optional[int] = None,
timeout: Optional[int] = None,
return_type: Type[DocumentArray] = DocumentArray,
**kwargs,
):
"""
:param inputs: the callable
Expand Down Expand Up @@ -168,30 +175,27 @@ async def _get_results(
else:
url = f'{proto}://{self.args.host}:{self.args.port}/post'

if self.iolet is not None and self.args.reuse_session:
iolet = self.iolet
else:
iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
max_attempts=max_attempts,
initial_backoff=initial_backoff,
max_backoff=max_backoff,
backoff_multiplier=backoff_multiplier,
timeout=timeout,
**kwargs,
)
if self.args.reuse_session and self.iolet is None:
self.iolet = iolet
await self.iolet.__aenter__()

if not self.args.reuse_session:
if not self.reuse_session:
iolet = await stack.enter_async_context(

Check warning on line 179 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L178-L179

Added lines #L178 - L179 were not covered by tests
iolet
HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
)
else:
async with self._lock:
if self.iolet is None:
self.iolet = HTTPClientlet(

Check warning on line 189 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L187-L189

Added lines #L187 - L189 were not covered by tests
logger=self.logger,
tracer_provider=self.tracer_provider,
**kwargs,
)
self.iolet = await self.iolet.__aenter__()
iolet = self.iolet

Check warning on line 195 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L194-L195

Added lines #L194 - L195 were not covered by tests

def _request_handler(
request: 'Request', **kwargs
request: 'Request', **kwargs
) -> 'Tuple[asyncio.Future, Optional[asyncio.Future]]':
"""
For HTTP Client, for each request in the iterator, we `send_message` using
Expand All @@ -215,7 +219,7 @@ def _result_handler(result):
**streamer_args,
)
async for response in streamer.stream(
request_iterator=request_iterator, results_in_order=results_in_order
request_iterator=request_iterator, results_in_order=results_in_order
):
r_status, r_str = response
handle_response_status(r_status, r_str, url)
Expand Down Expand Up @@ -256,13 +260,13 @@ def _result_handler(result):
yield resp

async def _get_streaming_results(
self,
on: str,
inputs: 'Document',
parameters: Optional[Dict] = None,
return_type: Type[Document] = Document,
timeout: Optional[int] = None,
**kwargs,
self,
on: str,
inputs: 'Document',
parameters: Optional[Dict] = None,
return_type: Type[Document] = Document,
timeout: Optional[int] = None,
**kwargs,
):
proto = 'https' if self.args.tls else 'http'
endpoint = on.strip('/')
Expand All @@ -272,15 +276,27 @@ async def _get_streaming_results(
url = f'{proto}://{self.args.host}:{self.args.port}/{endpoint}'
else:
url = f'{proto}://{self.args.host}:{self.args.port}/default'

iolet = HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
timeout=timeout,
**kwargs,
)

async with iolet:
async with AsyncExitStack() as stack:
if not self.reuse_session:
iolet = await stack.enter_async_context(

Check warning on line 281 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L279-L281

Added lines #L279 - L281 were not covered by tests
HTTPClientlet(
logger=self.logger,
tracer_provider=self.tracer_provider,
timeout=timeout,
**kwargs,
)
)
else:
async with self._lock:
if self.iolet is None:
self.iolet = HTTPClientlet(

Check warning on line 292 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L290-L292

Added lines #L290 - L292 were not covered by tests
logger=self.logger,
tracer_provider=self.tracer_provider,
timeout=timeout,
**kwargs,
)
await self.iolet.__aenter__()
iolet = self.iolet
async for doc in iolet.send_streaming_message(url=url, doc=inputs, on=on):

Check warning on line 300 in jina/clients/base/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/base/http.py#L298-L300

Added lines #L298 - L300 were not covered by tests
if not docarray_v2:
yield Document.from_dict(json.loads(doc))
Expand Down
6 changes: 6 additions & 0 deletions jina/clients/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
PostMixin,
ProfileMixin,
)
import asyncio

Check warning on line 12 in jina/clients/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/http.py#L12

Added line #L12 was not covered by tests


class HTTPClient(
Expand Down Expand Up @@ -80,3 +81,8 @@ async def async_inputs():
print(resp)
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._lock = asyncio.Lock()
self.reuse_session = self.args.reuse_session

Check warning on line 88 in jina/clients/http.py

View check run for this annotation

Codecov / codecov/patch

jina/clients/http.py#L85-L88

Added lines #L85 - L88 were not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,23 @@ def ping(self, **kwargs):
@pytest.mark.parametrize('prefetch', [1, 10])
@pytest.mark.parametrize('concurrent', [15])
@pytest.mark.parametrize('use_stream', [False, True])
@pytest.mark.parametrize('reuse_session', [True, False])
def test_concurrent_clients(
concurrent, protocol, shards, polling, prefetch, reraise, use_stream
concurrent, protocol, shards, polling, prefetch, reraise, use_stream, reuse_session
):

if not use_stream and protocol != 'grpc':
return

if reuse_session and protocol != 'http':
return

def pong(peer_hash, queue, resp: Response):
for d in resp.docs:
queue.put((peer_hash, d.text))

def peer_client(port, protocol, peer_hash, queue):
c = Client(protocol=protocol, port=port)
c = Client(protocol=protocol, port=port, reuse_session=reuse_session)
for _ in range(NUM_REQUESTS):
c.post(
'/ping',
Expand Down
14 changes: 10 additions & 4 deletions tests/integration/docarray_v2/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
)
@pytest.mark.parametrize('return_type', ['batch', 'singleton'])
@pytest.mark.parametrize('include_gateway', [True, False])
def test_singleton_return(ctxt_manager, protocols, return_type, include_gateway):
@pytest.mark.parametrize('reuse_session', [True, False])
def test_singleton_return(ctxt_manager, protocols, return_type, include_gateway, reuse_session):
if reuse_session and 'http' not in protocols:
return
if 'websocket' in protocols and ctxt_manager != 'flow':
return
if not include_gateway and ctxt_manager == 'flow':
Expand Down Expand Up @@ -63,7 +66,7 @@ def foo_single(

with ctxt:
for port, protocol in zip(ports, protocols):
c = Client(port=port, protocol=protocol)
c = Client(port=port, protocol=protocol, reuse_session=reuse_session)
docs = c.post(
on='/foo',
inputs=MySingletonReturnInputDoc(text='hello', price=2),
Expand Down Expand Up @@ -102,7 +105,10 @@ def foo_single(
'protocols', [['grpc'], ['http'], ['websocket'], ['grpc', 'http']]
)
@pytest.mark.parametrize('return_type', ['batch', 'singleton'])
def test_singleton_return_async(ctxt_manager, protocols, return_type):
@pytest.mark.parametrize('reuse_session', [True, False])
def test_singleton_return_async(ctxt_manager, protocols, return_type, reuse_session):
if reuse_session and 'http' not in protocols:
return
if 'websocket' in protocols and ctxt_manager != 'flow':
return

Expand Down Expand Up @@ -149,7 +155,7 @@ async def foo_single(

with ctxt:
for port, protocol in zip(ports, protocols):
c = Client(port=port, protocol=protocol)
c = Client(port=port, protocol=protocol, reuse_session=reuse_session)
docs = c.post(
on='/foo',
inputs=MySingletonReturnInputDoc(text='hello', price=2),
Expand Down
Loading

0 comments on commit 8136939

Please sign in to comment.