Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: reuse session #6196

Merged
merged 4 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions jina/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def Client(
prefetch: Optional[int] = 1000,
protocol: Optional[Union[str, List[str]]] = 'GRPC',
proxy: Optional[bool] = False,
reuse_session: Optional[bool] = False,
suppress_root_logging: Optional[bool] = False,
tls: Optional[bool] = False,
traces_exporter_host: Optional[str] = None,
Expand Down Expand Up @@ -59,6 +60,7 @@ def Client(
Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default)
:param protocol: Communication protocol between server and client.
:param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy
:param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it
:param suppress_root_logging: If set, then no root handlers will be suppressed from logging.
:param tls: If set, connect to gateway using tls encryption
:param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent.
Expand Down Expand Up @@ -113,6 +115,7 @@ def Client(args: Optional['argparse.Namespace'] = None, **kwargs) -> Union[
Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default)
:param protocol: Communication protocol between server and client.
:param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy
:param reuse_session: True if HTTPClient should reuse ClientSession. If true, user will be responsible to close it
:param suppress_root_logging: If set, then no root handlers will be suppressed from logging.
:param tls: If set, connect to gateway using tls encryption
:param traces_exporter_host: If tracing is enabled, this hostname will be used to configure the trace exporter agent.
Expand Down
29 changes: 18 additions & 11 deletions jina/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class BaseClient(InstrumentationMixin, ABC):
"""

def __init__(
self,
args: Optional['argparse.Namespace'] = None,
**kwargs,
self,
args: Optional['argparse.Namespace'] = None,
**kwargs,
):
if args and isinstance(args, argparse.Namespace):
self.args = args
Expand Down Expand Up @@ -63,6 +63,12 @@ def __init__(
)
send_telemetry_event(event='start', obj_cls_name=self.__class__.__name__)

async def close(self):
"""Closes the potential resources of the Client.
:return: Return whatever a close method may return
"""
return self.teardown_instrumentation()

def teardown_instrumentation(self):
"""Shut down the OpenTelemetry tracer and meter if available. This ensures that the daemon threads for
exporting metrics data is properly cleaned up.
Expand Down Expand Up @@ -118,7 +124,7 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None:
raise BadClientInput from ex

def _get_requests(
self, **kwargs
self, **kwargs
) -> Union[Iterator['Request'], AsyncIterator['Request']]:
"""
Get request in generator.
Expand Down Expand Up @@ -177,13 +183,14 @@ def inputs(self, bytes_gen: 'InputType') -> None:

@abc.abstractmethod
async def _get_results(
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
**kwargs,
): ...
self,
inputs: 'InputType',
on_done: 'CallbackFnType',
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
**kwargs,
):
...

@abc.abstractmethod
def _is_flow_ready(self, **kwargs) -> bool:
Expand Down
31 changes: 16 additions & 15 deletions jina/clients/base/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class AioHttpClientlet(ABC):

def __init__(
self,
url: str,
logger: 'JinaLogger',
max_attempts: int = 1,
initial_backoff: float = 0.5,
Expand All @@ -59,7 +58,6 @@ def __init__(
) -> None:
"""HTTP Client to be used with the streamer

:param url: url to send http/websocket request to
:param logger: jina logger
:param max_attempts: Number of sending attempts, including the original request.
:param initial_backoff: The first retry will happen with a delay of random(0, initial_backoff)
Expand All @@ -68,7 +66,6 @@ def __init__(
:param tracer_provider: Optional tracer_provider that will be used to configure aiohttp tracing.
:param kwargs: kwargs which will be forwarded to the `aiohttp.Session` instance. Used to pass headers to requests
"""
self.url = url
self.logger = logger
self.msg_recv = 0
self.msg_sent = 0
Expand Down Expand Up @@ -131,7 +128,6 @@ async def start(self):
"""
with ImportExtensions(required=True):
import aiohttp

self.session = aiohttp.ClientSession(
**self._session_kwargs, trace_configs=self._trace_config
)
Expand All @@ -154,9 +150,10 @@ class HTTPClientlet(AioHttpClientlet):

UPDATE_EVENT_PREFIX = 14 # the update event has the following format: "event: update: {document_json}"

async def send_message(self, request: 'Request'):
async def send_message(self, url, request: 'Request'):
"""Sends a POST request to the server

:param url: the URL where to send the message
:param request: request as dict
:return: send post message
"""
Expand All @@ -166,23 +163,24 @@ async def send_message(self, request: 'Request'):
req_dict['target_executor'] = req_dict['header']['target_executor']
for attempt in range(1, self.max_attempts + 1):
try:
request_kwargs = {'url': self.url}
request_kwargs = {'url': url}
if not docarray_v2:
request_kwargs['json'] = req_dict
else:
from docarray.base_doc.io.json import orjson_dumps

request_kwargs['data'] = JinaJsonPayload(value=req_dict)

async with self.session.post(**request_kwargs) as response:
try:
r_str = await response.json()
except aiohttp.ContentTypeError:
r_str = await response.text()
r_status = response.status
handle_response_status(response.status, r_str, self.url)
return r_status, r_str
handle_response_status(r_status, r_str, url)
return r_status, r_str
except (ValueError, ConnectionError, BadClient, aiohttp.ClientError, aiohttp.ClientConnectionError) as err:
self.logger.debug(f'Got an error: {err} sending POST to {self.url} in attempt {attempt}/{self.max_attempts}')
self.logger.debug(f'Got an error: {err} sending POST to {url} in attempt {attempt}/{self.max_attempts}')
await retry.wait_or_raise_err(
attempt=attempt,
err=err,
Expand All @@ -193,19 +191,20 @@ async def send_message(self, request: 'Request'):
)
except Exception as exc:
self.logger.debug(
f'Got a non-retried error: {exc} sending POST to {self.url}')
f'Got a non-retried error: {exc} sending POST to {url}')
raise exc

async def send_streaming_message(self, doc: 'Document', on: str):
async def send_streaming_message(self, url, doc: 'Document', on: str):
"""Sends a GET SSE request to the server

:param url: the URL where to send the message
:param doc: Request Document
:param on: Request endpoint
:yields: responses
"""
req_dict = doc.to_dict() if hasattr(doc, "to_dict") else doc.dict()
request_kwargs = {
'url': self.url,
'url': url,
'headers': {'Accept': 'text/event-stream'},
'json': req_dict,
}
Expand All @@ -219,13 +218,14 @@ async def send_streaming_message(self, doc: 'Document', on: str):
elif event.startswith(b'end'):
pass

async def send_dry_run(self, **kwargs):
async def send_dry_run(self, url, **kwargs):
"""Query the dry_run endpoint from Gateway
:param url: the URL where to send the message
:param kwargs: keyword arguments to make sure compatible API with other clients
:return: send get message
"""
return await self.session.get(
url=self.url, timeout=kwargs.get('timeout', None)
url=url, timeout=kwargs.get('timeout', None)
).__aenter__()

async def recv_message(self):
Expand Down Expand Up @@ -267,8 +267,9 @@ async def __anext__(self):
class WebsocketClientlet(AioHttpClientlet):
"""Websocket Client to be used with the streamer"""

def __init__(self, *args, **kwargs) -> None:
def __init__(self, url, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.url = url
self.websocket = None
self.response_iter = None

Expand Down
Loading
Loading