From 19d8a42f99ece11d2fcd267b5966677309cb830d Mon Sep 17 00:00:00 2001 From: David Lougheed Date: Thu, 15 Aug 2024 13:10:23 -0400 Subject: [PATCH] feat(drs): optional session_kwargs for aiohttp.ClientSession --- bento_lib/drs/utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/bento_lib/drs/utils.py b/bento_lib/drs/utils.py index 7f8db6e..1b2a93a 100644 --- a/bento_lib/drs/utils.py +++ b/bento_lib/drs/utils.py @@ -47,8 +47,7 @@ def decode_drs_uri(drs_uri: str, internal_drs_base_url: Optional[str] = None) -> parsed_drs_uri = urlparse(drs_uri) if parsed_drs_uri.scheme != "drs": - print(f"[Bento Lib] Invalid scheme: '{parsed_drs_uri.scheme}'", - file=sys.stderr, flush=True) + print(f"[Bento Lib] Invalid scheme: '{parsed_drs_uri.scheme}'", file=sys.stderr, flush=True) raise DrsInvalidScheme(f"Encountered invalid DRS scheme: {parsed_drs_uri.scheme}") drs_base_path = internal_drs_base_url.rstrip("/") if internal_drs_base_url else f"https://{parsed_drs_uri.netloc}" @@ -82,12 +81,17 @@ def fetch_drs_record_by_uri(drs_uri: str, internal_drs_base_url: Optional[str] = return drs_res.json() -async def fetch_drs_record_by_uri_async(drs_uri: str, internal_drs_base_url: Optional[str] = None) -> Optional[dict]: +async def fetch_drs_record_by_uri_async( + drs_uri: str, + internal_drs_base_url: Optional[str] = None, + session_kwargs: dict | None = None, +) -> Optional[dict]: """ Given a URI in the format drs:///, decodes it into an HTTP URL and asynchronously fetches the object metadata. :param drs_uri: The URI of the object to fetch. :param internal_drs_base_url: An optional override hard-coded DRS base URL to use, for container networking etc. + :param session_kwargs: Optional dictionary of parameters to pass to the aiohttp.ClientSession constructor. :return: The fetched DRS object metadata. """ @@ -97,7 +101,7 @@ async def fetch_drs_record_by_uri_async(drs_uri: str, internal_drs_base_url: Opt print(f"[Bento Lib] Attempting to fetch {decoded_object_uri}", flush=True) params = {"internal_path": "true"} if internal_drs_base_url else {} - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(**(session_kwargs or {})) as session: async with session.get(decoded_object_uri, params=params) as drs_res: if drs_res.status != 200: print(f"[Bento Lib] Could not fetch: '{decoded_object_uri}'", file=sys.stderr, flush=True)