From 9d07268fc49255f14cfab725e36e383ae4e37df8 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Sat, 10 Aug 2024 00:20:36 +0800 Subject: [PATCH] dev(narugo): add retry session in entries --- hfutils/entry/download.py | 4 ++ hfutils/entry/index.py | 6 +- hfutils/entry/ls.py | 4 ++ hfutils/entry/ls_repo.py | 4 ++ hfutils/entry/upload.py | 4 ++ hfutils/entry/whoami.py | 4 ++ hfutils/utils/__init__.py | 1 + hfutils/utils/session.py | 115 ++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 +- 9 files changed, 143 insertions(+), 2 deletions(-) create mode 100644 hfutils/utils/session.py diff --git a/hfutils/entry/download.py b/hfutils/entry/download.py index 550dbedf28..b657310ddb 100644 --- a/hfutils/entry/download.py +++ b/hfutils/entry/download.py @@ -3,10 +3,12 @@ from typing import Optional import click +from huggingface_hub import configure_http_backend from .base import CONTEXT_SETTINGS, command_wrap, ClickErrorException from ..operate import download_file_to_file, download_archive_as_directory, download_directory_as_directory from ..operate.base import REPO_TYPES, RepoTypeTyping +from ..utils import get_requests_session class NoRemotePathAssignedWithDownload(ClickErrorException): @@ -84,6 +86,8 @@ def download( :param tmpdir: Use custom temporary Directory. :type tmpdir: str, optional """ + configure_http_backend(get_requests_session) + if tmpdir: os.environ['TMPDIR'] = tmpdir diff --git a/hfutils/entry/index.py b/hfutils/entry/index.py index 8b032009f1..87653e4570 100644 --- a/hfutils/entry/index.py +++ b/hfutils/entry/index.py @@ -18,13 +18,15 @@ import click from hbutils.string import plural_word +from huggingface_hub import configure_http_backend from .base import CONTEXT_SETTINGS from ..cache import delete_detached_cache from ..index import hf_tar_validate, tar_create_index from ..operate import get_hf_fs, download_file_to_file, upload_directory_as_directory from ..operate.base import REPO_TYPES, RepoTypeTyping, get_hf_client -from ..utils import tqdm, hf_fs_path, parse_hf_fs_path, TemporaryDirectory, hf_normpath, ColoredFormatter +from ..utils import tqdm, hf_fs_path, parse_hf_fs_path, TemporaryDirectory, hf_normpath, ColoredFormatter, \ + get_requests_session def _add_index_subcommand(cli: click.Group) -> click.Group: @@ -86,6 +88,8 @@ def index(repo_id: str, idx_repo_id: Optional[str], repo_type: RepoTypeTyping, r This function is typically invoked through the CLI interface, like: $ python script.py index -r my_repo -x my_index_repo -t dataset -R main --min_upload_interval 120 """ + configure_http_backend(get_requests_session) + logger = logging.getLogger() logger.setLevel(logging.INFO) console_handler = logging.StreamHandler() diff --git a/hfutils/entry/ls.py b/hfutils/entry/ls.py index 6c057bc30e..349d38f76f 100644 --- a/hfutils/entry/ls.py +++ b/hfutils/entry/ls.py @@ -5,10 +5,12 @@ import click import tzlocal +from huggingface_hub import configure_http_backend from huggingface_hub.hf_api import RepoFolder, RepoFile from .base import CONTEXT_SETTINGS from ..operate.base import REPO_TYPES, get_hf_client +from ..utils import get_requests_session mimetypes.add_type('image/webp', '.webp') @@ -124,6 +126,8 @@ def ls(repo_id: str, repo_type: str, dir_in_repo, revision: str, show_all: bool, :param show_detailed: Flag to indicate whether to show detailed file information. :type show_detailed: bool """ + configure_http_backend(get_requests_session) + hf_client = get_hf_client() items: List[ListItem] = [] for item in hf_client.list_repo_tree( diff --git a/hfutils/entry/ls_repo.py b/hfutils/entry/ls_repo.py index fc108ce00b..856859b614 100644 --- a/hfutils/entry/ls_repo.py +++ b/hfutils/entry/ls_repo.py @@ -2,10 +2,12 @@ from typing import Optional import click +from huggingface_hub import configure_http_backend from huggingface_hub.utils import LocalTokenNotFoundError from .base import CONTEXT_SETTINGS, ClickErrorException from ..operate.base import REPO_TYPES, get_hf_client +from ..utils import get_requests_session class NoLocalAuthentication(ClickErrorException): @@ -46,6 +48,8 @@ def ls(author: Optional[str], repo_type: str, pattern: str): :param pattern: Pattern of the repository names. :type pattern: str """ + configure_http_backend(get_requests_session) + hf_client = get_hf_client() if not author: try: diff --git a/hfutils/entry/upload.py b/hfutils/entry/upload.py index 8b299703cf..cd9582d491 100644 --- a/hfutils/entry/upload.py +++ b/hfutils/entry/upload.py @@ -2,10 +2,12 @@ from typing import Optional import click +from huggingface_hub import configure_http_backend from .base import CONTEXT_SETTINGS, command_wrap, ClickErrorException from ..operate import upload_file_to_file, upload_directory_as_archive, upload_directory_as_directory from ..operate.base import REPO_TYPES, RepoTypeTyping, get_hf_client +from ..utils import get_requests_session class NoRemotePathAssignedWithUpload(ClickErrorException): @@ -78,6 +80,8 @@ def upload(repo_id: str, repo_type: RepoTypeTyping, :param public: Set public repository when created. :type public: bool """ + configure_http_backend(get_requests_session) + if not file_in_repo and not archive_in_repo and not dir_in_repo: raise NoRemotePathAssignedWithUpload('No remote path in repository assigned.\n' 'One of the -f, -a, or -d option is required.') diff --git a/hfutils/entry/whoami.py b/hfutils/entry/whoami.py index c366ce08d4..478fc8f4cc 100644 --- a/hfutils/entry/whoami.py +++ b/hfutils/entry/whoami.py @@ -1,9 +1,11 @@ import click from hbutils.string import plural_word +from huggingface_hub import configure_http_backend from huggingface_hub.utils import LocalTokenNotFoundError from .base import CONTEXT_SETTINGS from ..operate.base import get_hf_client +from ..utils import get_requests_session def _add_whoami_subcommand(cli: click.Group) -> click.Group: @@ -28,6 +30,8 @@ def whoami(): This function retrieves the current user's identification from the Hugging Face Hub API and displays it. """ + configure_http_backend(get_requests_session) + hf_client = get_hf_client() try: info = hf_client.whoami() diff --git a/hfutils/utils/__init__.py b/hfutils/utils/__init__.py index 3e2899347f..e76a3d4bd8 100644 --- a/hfutils/utils/__init__.py +++ b/hfutils/utils/__init__.py @@ -3,6 +3,7 @@ from .logging import ColoredFormatter from .number import number_to_tag from .path import hf_normpath, hf_fs_path, parse_hf_fs_path, HfFileSystemPath +from .session import TimeoutHTTPAdapter, get_requests_session, get_random_ua from .temp import TemporaryDirectory from .tqdm_ import tqdm from .walk import walk_files diff --git a/hfutils/utils/session.py b/hfutils/utils/session.py new file mode 100644 index 0000000000..f6675eec66 --- /dev/null +++ b/hfutils/utils/session.py @@ -0,0 +1,115 @@ +""" +This module provides functionality for creating and managing HTTP sessions with customizable retry logic, +timeout settings, and user-agent rotation using random user-agent generation. It is designed to help with +robust web scraping and API consumption by handling common HTTP errors and timeouts gracefully. + +Main Features: + +- Automatic retries on specified HTTP response status codes. +- Configurable request timeout. +- Rotating user-agent for each session to mimic different browsers and operating systems. +- Optional SSL verification. +""" + +from functools import lru_cache +from typing import Optional, Dict + +import requests +from random_user_agent.params import SoftwareName, OperatingSystem +from random_user_agent.user_agent import UserAgent +from requests.adapters import HTTPAdapter, Retry + +DEFAULT_TIMEOUT = 15 # seconds + + +class TimeoutHTTPAdapter(HTTPAdapter): + """ + A custom HTTPAdapter that enforces a default timeout on all requests. + + :param args: Variable length argument list for HTTPAdapter. + :param kwargs: Arbitrary keyword arguments. 'timeout' can be specified to set a custom timeout. + """ + + def __init__(self, *args, **kwargs): + self.timeout = DEFAULT_TIMEOUT + if "timeout" in kwargs: + self.timeout = kwargs["timeout"] + del kwargs["timeout"] + super().__init__(*args, **kwargs) + + def send(self, request, **kwargs): + """ + Sends the Request object, applying the timeout setting. + + :param request: The Request object to send. + :type request: requests.PreparedRequest + :param kwargs: Keyword arguments that may contain 'timeout'. + :return: The response to the request. + """ + timeout = kwargs.get("timeout") + if timeout is None: + kwargs["timeout"] = self.timeout + return super().send(request, **kwargs) + + +def get_requests_session(max_retries: int = 5, timeout: int = DEFAULT_TIMEOUT, verify: bool = True, + headers: Optional[Dict[str, str]] = None, session: Optional[requests.Session] = None) \ + -> requests.Session: + """ + Creates a requests session with retry logic, timeout settings, and random user-agent headers. + + :param max_retries: Maximum number of retries on failed requests. + :type max_retries: int + :param timeout: Request timeout in seconds. + :type timeout: int + :param verify: Whether to verify SSL certificates. + :type verify: bool + :param headers: Additional headers to include in the requests. + :type headers: Optional[Dict[str, str]] + :param session: An existing requests.Session instance to use. + :type session: Optional[requests.Session] + :return: A configured requests.Session object. + :rtype: requests.Session + """ + session = session or requests.session() + retries = Retry( + total=max_retries, backoff_factor=1, + status_forcelist=[408, 413, 429, 500, 501, 502, 503, 504, 505, 506, 507, 509, 510, 511], + allowed_methods=["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"], + ) + adapter = TimeoutHTTPAdapter(max_retries=retries, timeout=timeout, pool_connections=32, pool_maxsize=32) + session.mount('http://', adapter) + session.mount('https://', adapter) + session.headers.update({ + "User-Agent": get_random_ua(), + **dict(headers or {}), + }) + if not verify: + session.verify = False + + return session + + +@lru_cache() +def _ua_pool(): + """ + Creates and caches a UserAgent rotator instance with a specified number of user agents. + + :return: A UserAgent rotator instance. + :rtype: UserAgent + """ + software_names = [SoftwareName.CHROME.value, SoftwareName.FIREFOX.value, SoftwareName.EDGE.value] + operating_systems = [OperatingSystem.WINDOWS.value, OperatingSystem.MACOS.value] + + user_agent_rotator = UserAgent(software_names=software_names, operating_systems=operating_systems, limit=1000) + return user_agent_rotator + + +def get_random_ua(): + """ + Retrieves a random user agent string from the cached UserAgent rotator. + + :return: A random user agent string. + :rtype: str + """ + return _ua_pool().get_random_user_agent() diff --git a/requirements.txt b/requirements.txt index 42c0a868a8..0d1fd6079a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ click>=7 tzlocal natsort urlobject -fsspec>=2024 \ No newline at end of file +fsspec>=2024 +random_user_agent \ No newline at end of file