diff --git a/docs/source/api_doc/utils/index.rst b/docs/source/api_doc/utils/index.rst index bfa8088324..a4eb28e71e 100644 --- a/docs/source/api_doc/utils/index.rst +++ b/docs/source/api_doc/utils/index.rst @@ -13,6 +13,7 @@ hfutils.utils download number path + session tqdm_ walk diff --git a/docs/source/api_doc/utils/session.rst b/docs/source/api_doc/utils/session.rst new file mode 100644 index 0000000000..cd30ba5b5a --- /dev/null +++ b/docs/source/api_doc/utils/session.rst @@ -0,0 +1,31 @@ +hfutils.utils.session +================================= + +.. currentmodule:: hfutils.utils.session + +.. automodule:: hfutils.utils.session + + + +TimeoutHTTPAdapter +----------------------------------------------------- + +.. autoclass:: TimeoutHTTPAdapter + :members: __init__, send + + + +get_requests_session +----------------------------------------------------- + +.. autofunction:: get_requests_session + + + +get_random_ua +----------------------------------------------------- + +.. autofunction:: get_random_ua + + + diff --git a/hfutils/entry/download.py b/hfutils/entry/download.py index 550dbedf28..b657310ddb 100644 --- a/hfutils/entry/download.py +++ b/hfutils/entry/download.py @@ -3,10 +3,12 @@ from typing import Optional import click +from huggingface_hub import configure_http_backend from .base import CONTEXT_SETTINGS, command_wrap, ClickErrorException from ..operate import download_file_to_file, download_archive_as_directory, download_directory_as_directory from ..operate.base import REPO_TYPES, RepoTypeTyping +from ..utils import get_requests_session class NoRemotePathAssignedWithDownload(ClickErrorException): @@ -84,6 +86,8 @@ def download( :param tmpdir: Use custom temporary Directory. :type tmpdir: str, optional """ + configure_http_backend(get_requests_session) + if tmpdir: os.environ['TMPDIR'] = tmpdir diff --git a/hfutils/entry/index.py b/hfutils/entry/index.py index 8b032009f1..87653e4570 100644 --- a/hfutils/entry/index.py +++ b/hfutils/entry/index.py @@ -18,13 +18,15 @@ import click from hbutils.string import plural_word +from huggingface_hub import configure_http_backend from .base import CONTEXT_SETTINGS from ..cache import delete_detached_cache from ..index import hf_tar_validate, tar_create_index from ..operate import get_hf_fs, download_file_to_file, upload_directory_as_directory from ..operate.base import REPO_TYPES, RepoTypeTyping, get_hf_client -from ..utils import tqdm, hf_fs_path, parse_hf_fs_path, TemporaryDirectory, hf_normpath, ColoredFormatter +from ..utils import tqdm, hf_fs_path, parse_hf_fs_path, TemporaryDirectory, hf_normpath, ColoredFormatter, \ + get_requests_session def _add_index_subcommand(cli: click.Group) -> click.Group: @@ -86,6 +88,8 @@ def index(repo_id: str, idx_repo_id: Optional[str], repo_type: RepoTypeTyping, r This function is typically invoked through the CLI interface, like: $ python script.py index -r my_repo -x my_index_repo -t dataset -R main --min_upload_interval 120 """ + configure_http_backend(get_requests_session) + logger = logging.getLogger() logger.setLevel(logging.INFO) console_handler = logging.StreamHandler() diff --git a/hfutils/entry/ls.py b/hfutils/entry/ls.py index 6c057bc30e..349d38f76f 100644 --- a/hfutils/entry/ls.py +++ b/hfutils/entry/ls.py @@ -5,10 +5,12 @@ import click import tzlocal +from huggingface_hub import configure_http_backend from huggingface_hub.hf_api import RepoFolder, RepoFile from .base import CONTEXT_SETTINGS from ..operate.base import REPO_TYPES, get_hf_client +from ..utils import get_requests_session mimetypes.add_type('image/webp', '.webp') @@ -124,6 +126,8 @@ def ls(repo_id: str, repo_type: str, dir_in_repo, revision: str, show_all: bool, :param show_detailed: Flag to indicate whether to show detailed file information. :type show_detailed: bool """ + configure_http_backend(get_requests_session) + hf_client = get_hf_client() items: List[ListItem] = [] for item in hf_client.list_repo_tree( diff --git a/hfutils/entry/ls_repo.py b/hfutils/entry/ls_repo.py index fc108ce00b..856859b614 100644 --- a/hfutils/entry/ls_repo.py +++ b/hfutils/entry/ls_repo.py @@ -2,10 +2,12 @@ from typing import Optional import click +from huggingface_hub import configure_http_backend from huggingface_hub.utils import LocalTokenNotFoundError from .base import CONTEXT_SETTINGS, ClickErrorException from ..operate.base import REPO_TYPES, get_hf_client +from ..utils import get_requests_session class NoLocalAuthentication(ClickErrorException): @@ -46,6 +48,8 @@ def ls(author: Optional[str], repo_type: str, pattern: str): :param pattern: Pattern of the repository names. :type pattern: str """ + configure_http_backend(get_requests_session) + hf_client = get_hf_client() if not author: try: diff --git a/hfutils/entry/upload.py b/hfutils/entry/upload.py index 8b299703cf..cd9582d491 100644 --- a/hfutils/entry/upload.py +++ b/hfutils/entry/upload.py @@ -2,10 +2,12 @@ from typing import Optional import click +from huggingface_hub import configure_http_backend from .base import CONTEXT_SETTINGS, command_wrap, ClickErrorException from ..operate import upload_file_to_file, upload_directory_as_archive, upload_directory_as_directory from ..operate.base import REPO_TYPES, RepoTypeTyping, get_hf_client +from ..utils import get_requests_session class NoRemotePathAssignedWithUpload(ClickErrorException): @@ -78,6 +80,8 @@ def upload(repo_id: str, repo_type: RepoTypeTyping, :param public: Set public repository when created. :type public: bool """ + configure_http_backend(get_requests_session) + if not file_in_repo and not archive_in_repo and not dir_in_repo: raise NoRemotePathAssignedWithUpload('No remote path in repository assigned.\n' 'One of the -f, -a, or -d option is required.') diff --git a/hfutils/entry/whoami.py b/hfutils/entry/whoami.py index c366ce08d4..478fc8f4cc 100644 --- a/hfutils/entry/whoami.py +++ b/hfutils/entry/whoami.py @@ -1,9 +1,11 @@ import click from hbutils.string import plural_word +from huggingface_hub import configure_http_backend from huggingface_hub.utils import LocalTokenNotFoundError from .base import CONTEXT_SETTINGS from ..operate.base import get_hf_client +from ..utils import get_requests_session def _add_whoami_subcommand(cli: click.Group) -> click.Group: @@ -28,6 +30,8 @@ def whoami(): This function retrieves the current user's identification from the Hugging Face Hub API and displays it. """ + configure_http_backend(get_requests_session) + hf_client = get_hf_client() try: info = hf_client.whoami() diff --git a/hfutils/utils/__init__.py b/hfutils/utils/__init__.py index 3e2899347f..e76a3d4bd8 100644 --- a/hfutils/utils/__init__.py +++ b/hfutils/utils/__init__.py @@ -3,6 +3,7 @@ from .logging import ColoredFormatter from .number import number_to_tag from .path import hf_normpath, hf_fs_path, parse_hf_fs_path, HfFileSystemPath +from .session import TimeoutHTTPAdapter, get_requests_session, get_random_ua from .temp import TemporaryDirectory from .tqdm_ import tqdm from .walk import walk_files diff --git a/hfutils/utils/session.py b/hfutils/utils/session.py new file mode 100644 index 0000000000..34eca6d1d1 --- /dev/null +++ b/hfutils/utils/session.py @@ -0,0 +1,115 @@ +""" +This module provides functionality for creating and managing HTTP sessions with customizable retry logic, +timeout settings, and user-agent rotation using random user-agent generation. It is designed to help with +robust web scraping and API consumption by handling common HTTP errors and timeouts gracefully. + +Main Features: + +- Automatic retries on specified HTTP response status codes. +- Configurable request timeout. +- Rotating user-agent for each session to mimic different browsers and operating systems. +- Optional SSL verification. +""" + +from functools import lru_cache +from typing import Optional, Dict + +import requests +from random_user_agent.params import SoftwareName, OperatingSystem +from random_user_agent.user_agent import UserAgent +from requests.adapters import HTTPAdapter, Retry + +DEFAULT_TIMEOUT = 15 # seconds + + +class TimeoutHTTPAdapter(HTTPAdapter): + """ + A custom HTTPAdapter that enforces a default timeout on all requests. + + :param args: Variable length argument list for HTTPAdapter. + :param kwargs: Arbitrary keyword arguments. 'timeout' can be specified to set a custom timeout. + """ + + def __init__(self, *args, **kwargs): + self.timeout = DEFAULT_TIMEOUT + if "timeout" in kwargs: + self.timeout = kwargs["timeout"] + del kwargs["timeout"] + super().__init__(*args, **kwargs) + + def send(self, request, **kwargs): + """ + Sends the Request object, applying the timeout setting. + + :param request: The Request object to send. + :type request: requests.PreparedRequest + :param kwargs: Keyword arguments that may contain 'timeout'. + :return: The response to the request. + """ + timeout = kwargs.get("timeout") + if timeout is None: + kwargs["timeout"] = self.timeout + return super().send(request, **kwargs) + + +def get_requests_session(max_retries: int = 5, timeout: int = DEFAULT_TIMEOUT, verify: bool = True, + headers: Optional[Dict[str, str]] = None, session: Optional[requests.Session] = None) \ + -> requests.Session: + """ + Creates a requests session with retry logic, timeout settings, and random user-agent headers. + + :param max_retries: Maximum number of retries on failed requests. + :type max_retries: int + :param timeout: Request timeout in seconds. + :type timeout: int + :param verify: Whether to verify SSL certificates. + :type verify: bool + :param headers: Additional headers to include in the requests. + :type headers: Optional[Dict[str, str]] + :param session: An existing requests.Session instance to use. + :type session: Optional[requests.Session] + :return: A configured requests.Session object. + :rtype: requests.Session + """ + session = session or requests.session() + retries = Retry( + total=max_retries, backoff_factor=1, + status_forcelist=[408, 429, 500, 501, 502, 503, 504, 505, 506, 507, 509, 510, 511], + allowed_methods=["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"], + ) + adapter = TimeoutHTTPAdapter(max_retries=retries, timeout=timeout, pool_connections=32, pool_maxsize=32) + session.mount('http://', adapter) + session.mount('https://', adapter) + session.headers.update({ + "User-Agent": get_random_ua(), + **dict(headers or {}), + }) + if not verify: + session.verify = False + + return session + + +@lru_cache() +def _ua_pool(): + """ + Creates and caches a UserAgent rotator instance with a specified number of user agents. + + :return: A UserAgent rotator instance. + :rtype: UserAgent + """ + software_names = [SoftwareName.CHROME.value, SoftwareName.FIREFOX.value, SoftwareName.EDGE.value] + operating_systems = [OperatingSystem.WINDOWS.value, OperatingSystem.MACOS.value] + + user_agent_rotator = UserAgent(software_names=software_names, operating_systems=operating_systems, limit=1000) + return user_agent_rotator + + +def get_random_ua(): + """ + Retrieves a random user agent string from the cached UserAgent rotator. + + :return: A random user agent string. + :rtype: str + """ + return _ua_pool().get_random_user_agent() diff --git a/requirements.txt b/requirements.txt index 42c0a868a8..0d1fd6079a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ click>=7 tzlocal natsort urlobject -fsspec>=2024 \ No newline at end of file +fsspec>=2024 +random_user_agent \ No newline at end of file diff --git a/test/utils/test_session.py b/test/utils/test_session.py new file mode 100644 index 0000000000..79efd8cb56 --- /dev/null +++ b/test/utils/test_session.py @@ -0,0 +1,82 @@ +from unittest.mock import patch, Mock + +import pytest +import requests +from huggingface_hub import hf_hub_url +from requests.adapters import HTTPAdapter + +from hfutils.utils.session import TimeoutHTTPAdapter, get_requests_session, get_random_ua + + +@pytest.fixture +def mock_requests_session(): + with patch('requests.session') as mock_session: + yield mock_session.return_value + + +@pytest.fixture +def mock_ua_pool(): + with patch('hfutils.utils.session._ua_pool') as mock_pool: + mock_pool.return_value.get_random_user_agent.return_value = 'MockUserAgent' + yield mock_pool + + +@pytest.fixture() +def example_url(): + return hf_hub_url( + repo_id='deepghs/danbooru_newest', + repo_type='dataset', + filename='README.md' + ) + + +@pytest.mark.unittest +class TestUtilsSession: + def test_timeout_http_adapter_init(self, ): + adapter = TimeoutHTTPAdapter() + assert adapter.timeout == 15 + + adapter = TimeoutHTTPAdapter(timeout=30) + assert adapter.timeout == 30 + + def test_timeout_http_adapter_send(self, ): + adapter = TimeoutHTTPAdapter(timeout=10) + mock_request = Mock() + mock_kwargs = {} + + with patch.object(HTTPAdapter, 'send') as mock_send: + adapter.send(mock_request, **mock_kwargs) + mock_send.assert_called_once_with(mock_request, timeout=10) + + mock_kwargs = {'timeout': 20} + with patch.object(HTTPAdapter, 'send') as mock_send: + adapter.send(mock_request, **mock_kwargs) + mock_send.assert_called_once_with(mock_request, timeout=20) + + def test_get_requests_session(self, mock_ua_pool): + session = get_requests_session() + assert isinstance(session, requests.Session) + assert 'User-Agent' in session.headers + assert session.headers['User-Agent'] == 'MockUserAgent' + + custom_headers = {'Custom-Header': 'Value'} + session = get_requests_session(headers=custom_headers) + assert 'Custom-Header' in session.headers + assert session.headers['Custom-Header'] == 'Value' + + session = get_requests_session(verify=False) + assert session.verify is False + + existing_session = requests.Session() + session = get_requests_session(session=existing_session) + assert session is existing_session + + def test_get_requests_session_with_custom_params(self): + session = get_requests_session(max_retries=3, timeout=30) + assert isinstance(session, requests.Session) + # You might want to add more assertions here to check if the custom parameters are applied correctly + + def test_get_random_ua(self, mock_ua_pool): + ua = get_random_ua() + assert ua == 'MockUserAgent' + mock_ua_pool.return_value.get_random_user_agent.assert_called_once()