Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev(narugo): add retry session in entries #37

Merged
merged 4 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_doc/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ hfutils.utils
download
number
path
session
tqdm_
walk

31 changes: 31 additions & 0 deletions docs/source/api_doc/utils/session.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
hfutils.utils.session
=================================

.. currentmodule:: hfutils.utils.session

.. automodule:: hfutils.utils.session



TimeoutHTTPAdapter
-----------------------------------------------------

.. autoclass:: TimeoutHTTPAdapter
:members: __init__, send



get_requests_session
-----------------------------------------------------

.. autofunction:: get_requests_session



get_random_ua
-----------------------------------------------------

.. autofunction:: get_random_ua



4 changes: 4 additions & 0 deletions hfutils/entry/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import Optional

import click
from huggingface_hub import configure_http_backend

from .base import CONTEXT_SETTINGS, command_wrap, ClickErrorException
from ..operate import download_file_to_file, download_archive_as_directory, download_directory_as_directory
from ..operate.base import REPO_TYPES, RepoTypeTyping
from ..utils import get_requests_session


class NoRemotePathAssignedWithDownload(ClickErrorException):
Expand Down Expand Up @@ -84,6 +86,8 @@ def download(
:param tmpdir: Use custom temporary Directory.
:type tmpdir: str, optional
"""
configure_http_backend(get_requests_session)

if tmpdir:
os.environ['TMPDIR'] = tmpdir

Expand Down
6 changes: 5 additions & 1 deletion hfutils/entry/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

import click
from hbutils.string import plural_word
from huggingface_hub import configure_http_backend

from .base import CONTEXT_SETTINGS
from ..cache import delete_detached_cache
from ..index import hf_tar_validate, tar_create_index
from ..operate import get_hf_fs, download_file_to_file, upload_directory_as_directory
from ..operate.base import REPO_TYPES, RepoTypeTyping, get_hf_client
from ..utils import tqdm, hf_fs_path, parse_hf_fs_path, TemporaryDirectory, hf_normpath, ColoredFormatter
from ..utils import tqdm, hf_fs_path, parse_hf_fs_path, TemporaryDirectory, hf_normpath, ColoredFormatter, \
get_requests_session


def _add_index_subcommand(cli: click.Group) -> click.Group:
Expand Down Expand Up @@ -86,6 +88,8 @@
This function is typically invoked through the CLI interface, like:
$ python script.py index -r my_repo -x my_index_repo -t dataset -R main --min_upload_interval 120
"""
configure_http_backend(get_requests_session)

Check warning on line 91 in hfutils/entry/index.py

View check run for this annotation

Codecov / codecov/patch

hfutils/entry/index.py#L91

Added line #L91 was not covered by tests

logger = logging.getLogger()
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
Expand Down
4 changes: 4 additions & 0 deletions hfutils/entry/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

import click
import tzlocal
from huggingface_hub import configure_http_backend
from huggingface_hub.hf_api import RepoFolder, RepoFile

from .base import CONTEXT_SETTINGS
from ..operate.base import REPO_TYPES, get_hf_client
from ..utils import get_requests_session

mimetypes.add_type('image/webp', '.webp')

Expand Down Expand Up @@ -124,6 +126,8 @@ def ls(repo_id: str, repo_type: str, dir_in_repo, revision: str, show_all: bool,
:param show_detailed: Flag to indicate whether to show detailed file information.
:type show_detailed: bool
"""
configure_http_backend(get_requests_session)

hf_client = get_hf_client()
items: List[ListItem] = []
for item in hf_client.list_repo_tree(
Expand Down
4 changes: 4 additions & 0 deletions hfutils/entry/ls_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import Optional

import click
from huggingface_hub import configure_http_backend
from huggingface_hub.utils import LocalTokenNotFoundError

from .base import CONTEXT_SETTINGS, ClickErrorException
from ..operate.base import REPO_TYPES, get_hf_client
from ..utils import get_requests_session


class NoLocalAuthentication(ClickErrorException):
Expand Down Expand Up @@ -46,6 +48,8 @@ def ls(author: Optional[str], repo_type: str, pattern: str):
:param pattern: Pattern of the repository names.
:type pattern: str
"""
configure_http_backend(get_requests_session)

hf_client = get_hf_client()
if not author:
try:
Expand Down
4 changes: 4 additions & 0 deletions hfutils/entry/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from typing import Optional

import click
from huggingface_hub import configure_http_backend

from .base import CONTEXT_SETTINGS, command_wrap, ClickErrorException
from ..operate import upload_file_to_file, upload_directory_as_archive, upload_directory_as_directory
from ..operate.base import REPO_TYPES, RepoTypeTyping, get_hf_client
from ..utils import get_requests_session


class NoRemotePathAssignedWithUpload(ClickErrorException):
Expand Down Expand Up @@ -78,6 +80,8 @@
:param public: Set public repository when created.
:type public: bool
"""
configure_http_backend(get_requests_session)

Check warning on line 83 in hfutils/entry/upload.py

View check run for this annotation

Codecov / codecov/patch

hfutils/entry/upload.py#L83

Added line #L83 was not covered by tests

if not file_in_repo and not archive_in_repo and not dir_in_repo:
raise NoRemotePathAssignedWithUpload('No remote path in repository assigned.\n'
'One of the -f, -a, or -d option is required.')
Expand Down
4 changes: 4 additions & 0 deletions hfutils/entry/whoami.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import click
from hbutils.string import plural_word
from huggingface_hub import configure_http_backend
from huggingface_hub.utils import LocalTokenNotFoundError

from .base import CONTEXT_SETTINGS
from ..operate.base import get_hf_client
from ..utils import get_requests_session


def _add_whoami_subcommand(cli: click.Group) -> click.Group:
Expand All @@ -28,6 +30,8 @@ def whoami():
This function retrieves the current user's identification from the Hugging Face Hub API and displays it.

"""
configure_http_backend(get_requests_session)

hf_client = get_hf_client()
try:
info = hf_client.whoami()
Expand Down
1 change: 1 addition & 0 deletions hfutils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .logging import ColoredFormatter
from .number import number_to_tag
from .path import hf_normpath, hf_fs_path, parse_hf_fs_path, HfFileSystemPath
from .session import TimeoutHTTPAdapter, get_requests_session, get_random_ua
from .temp import TemporaryDirectory
from .tqdm_ import tqdm
from .walk import walk_files
115 changes: 115 additions & 0 deletions hfutils/utils/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
This module provides functionality for creating and managing HTTP sessions with customizable retry logic,
timeout settings, and user-agent rotation using random user-agent generation. It is designed to help with
robust web scraping and API consumption by handling common HTTP errors and timeouts gracefully.

Main Features:

- Automatic retries on specified HTTP response status codes.
- Configurable request timeout.
- Rotating user-agent for each session to mimic different browsers and operating systems.
- Optional SSL verification.
"""

from functools import lru_cache
from typing import Optional, Dict

import requests
from random_user_agent.params import SoftwareName, OperatingSystem
from random_user_agent.user_agent import UserAgent
from requests.adapters import HTTPAdapter, Retry

DEFAULT_TIMEOUT = 15 # seconds


class TimeoutHTTPAdapter(HTTPAdapter):
"""
A custom HTTPAdapter that enforces a default timeout on all requests.

:param args: Variable length argument list for HTTPAdapter.
:param kwargs: Arbitrary keyword arguments. 'timeout' can be specified to set a custom timeout.
"""

def __init__(self, *args, **kwargs):
self.timeout = DEFAULT_TIMEOUT
if "timeout" in kwargs:
self.timeout = kwargs["timeout"]
del kwargs["timeout"]
super().__init__(*args, **kwargs)

def send(self, request, **kwargs):
"""
Sends the Request object, applying the timeout setting.

:param request: The Request object to send.
:type request: requests.PreparedRequest
:param kwargs: Keyword arguments that may contain 'timeout'.
:return: The response to the request.
"""
timeout = kwargs.get("timeout")
if timeout is None:
kwargs["timeout"] = self.timeout
return super().send(request, **kwargs)


def get_requests_session(max_retries: int = 5, timeout: int = DEFAULT_TIMEOUT, verify: bool = True,
headers: Optional[Dict[str, str]] = None, session: Optional[requests.Session] = None) \
-> requests.Session:
"""
Creates a requests session with retry logic, timeout settings, and random user-agent headers.

:param max_retries: Maximum number of retries on failed requests.
:type max_retries: int
:param timeout: Request timeout in seconds.
:type timeout: int
:param verify: Whether to verify SSL certificates.
:type verify: bool
:param headers: Additional headers to include in the requests.
:type headers: Optional[Dict[str, str]]
:param session: An existing requests.Session instance to use.
:type session: Optional[requests.Session]
:return: A configured requests.Session object.
:rtype: requests.Session
"""
session = session or requests.session()
retries = Retry(
total=max_retries, backoff_factor=1,
status_forcelist=[408, 429, 500, 501, 502, 503, 504, 505, 506, 507, 509, 510, 511],
allowed_methods=["HEAD", "GET", "POST", "PUT", "DELETE", "OPTIONS", "TRACE"],
)
adapter = TimeoutHTTPAdapter(max_retries=retries, timeout=timeout, pool_connections=32, pool_maxsize=32)
session.mount('http://', adapter)
session.mount('https://', adapter)
session.headers.update({
"User-Agent": get_random_ua(),
**dict(headers or {}),
})
if not verify:
session.verify = False

return session


@lru_cache()
def _ua_pool():
"""
Creates and caches a UserAgent rotator instance with a specified number of user agents.

:return: A UserAgent rotator instance.
:rtype: UserAgent
"""
software_names = [SoftwareName.CHROME.value, SoftwareName.FIREFOX.value, SoftwareName.EDGE.value]
operating_systems = [OperatingSystem.WINDOWS.value, OperatingSystem.MACOS.value]

user_agent_rotator = UserAgent(software_names=software_names, operating_systems=operating_systems, limit=1000)
return user_agent_rotator


def get_random_ua():
"""
Retrieves a random user agent string from the cached UserAgent rotator.

:return: A random user agent string.
:rtype: str
"""
return _ua_pool().get_random_user_agent()
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ click>=7
tzlocal
natsort
urlobject
fsspec>=2024
fsspec>=2024
random_user_agent
82 changes: 82 additions & 0 deletions test/utils/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from unittest.mock import patch, Mock

import pytest
import requests
from huggingface_hub import hf_hub_url
from requests.adapters import HTTPAdapter

from hfutils.utils.session import TimeoutHTTPAdapter, get_requests_session, get_random_ua


@pytest.fixture
def mock_requests_session():
with patch('requests.session') as mock_session:
yield mock_session.return_value


@pytest.fixture
def mock_ua_pool():
with patch('hfutils.utils.session._ua_pool') as mock_pool:
mock_pool.return_value.get_random_user_agent.return_value = 'MockUserAgent'
yield mock_pool


@pytest.fixture()
def example_url():
return hf_hub_url(
repo_id='deepghs/danbooru_newest',
repo_type='dataset',
filename='README.md'
)


@pytest.mark.unittest
class TestUtilsSession:
def test_timeout_http_adapter_init(self, ):
adapter = TimeoutHTTPAdapter()
assert adapter.timeout == 15

adapter = TimeoutHTTPAdapter(timeout=30)
assert adapter.timeout == 30

def test_timeout_http_adapter_send(self, ):
adapter = TimeoutHTTPAdapter(timeout=10)
mock_request = Mock()
mock_kwargs = {}

with patch.object(HTTPAdapter, 'send') as mock_send:
adapter.send(mock_request, **mock_kwargs)
mock_send.assert_called_once_with(mock_request, timeout=10)

mock_kwargs = {'timeout': 20}
with patch.object(HTTPAdapter, 'send') as mock_send:
adapter.send(mock_request, **mock_kwargs)
mock_send.assert_called_once_with(mock_request, timeout=20)

def test_get_requests_session(self, mock_ua_pool):
session = get_requests_session()
assert isinstance(session, requests.Session)
assert 'User-Agent' in session.headers
assert session.headers['User-Agent'] == 'MockUserAgent'

custom_headers = {'Custom-Header': 'Value'}
session = get_requests_session(headers=custom_headers)
assert 'Custom-Header' in session.headers
assert session.headers['Custom-Header'] == 'Value'

session = get_requests_session(verify=False)
assert session.verify is False

existing_session = requests.Session()
session = get_requests_session(session=existing_session)
assert session is existing_session

def test_get_requests_session_with_custom_params(self):
session = get_requests_session(max_retries=3, timeout=30)
assert isinstance(session, requests.Session)
# You might want to add more assertions here to check if the custom parameters are applied correctly

def test_get_random_ua(self, mock_ua_pool):
ua = get_random_ua()
assert ua == 'MockUserAgent'
mock_ua_pool.return_value.get_random_user_agent.assert_called_once()
Loading