From bbb1678e48a240d87dff78283272288a62d72443 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 14 Aug 2024 22:47:05 +0800 Subject: [PATCH 1/2] dev(narugo): optimize batch download --- hfutils/operate/download.py | 131 ++++++++++++++++++---------------- test/operate/test_download.py | 13 ++-- 2 files changed, 77 insertions(+), 67 deletions(-) diff --git a/hfutils/operate/download.py b/hfutils/operate/download.py index 1103ad7d2a..e369cb580c 100644 --- a/hfutils/operate/download.py +++ b/hfutils/operate/download.py @@ -12,9 +12,30 @@ from ..utils import tqdm, TemporaryDirectory, hf_normpath +def _raw_download_file(td: str, local_file: str, repo_id: str, file_in_repo: str, + repo_type: RepoTypeTyping = 'dataset', revision: str = 'main', + hf_token: Optional[str] = None): + hf_client = get_hf_client(hf_token=hf_token) + relative_filename = os.path.join(*file_in_repo.split("/")) + temp_path = os.path.join(td, relative_filename) + try: + hf_client.hf_hub_download( + repo_id=repo_id, + repo_type=repo_type, + filename=hf_normpath(file_in_repo), + revision=revision, + local_dir=td, + ) + finally: + if os.path.exists(temp_path): + if os.path.dirname(local_file): + os.makedirs(os.path.dirname(local_file), exist_ok=True) + shutil.move(temp_path, local_file) + + def download_file_to_file(local_file: str, repo_id: str, file_in_repo: str, repo_type: RepoTypeTyping = 'dataset', revision: str = 'main', - resume_download: bool = True, hf_token: Optional[str] = None): + hf_token: Optional[str] = None): """ Download a file from a Hugging Face repository and save it to a local file. @@ -28,29 +49,19 @@ def download_file_to_file(local_file: str, repo_id: str, file_in_repo: str, :type repo_type: RepoTypeTyping :param revision: The revision of the repository (e.g., branch, tag, commit hash). :type revision: str - :param resume_download: Resume the existing download. - :type resume_download: bool :param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned. :type hf_token: str, optional """ - hf_client = get_hf_client(hf_token) - relative_filename = os.path.join(*file_in_repo.split("/")) with TemporaryDirectory() as td: - temp_path = os.path.join(td, relative_filename) - try: - hf_client.hf_hub_download( - repo_id=repo_id, - repo_type=repo_type, - filename=hf_normpath(file_in_repo), - revision=revision, - local_dir=td, - resume_download=resume_download, - ) - finally: - if os.path.exists(temp_path): - if os.path.dirname(local_file): - os.makedirs(os.path.dirname(local_file), exist_ok=True) - shutil.move(temp_path, local_file) + _raw_download_file( + td=td, + local_file=local_file, + repo_id=repo_id, + file_in_repo=file_in_repo, + repo_type=repo_type, + revision=revision, + hf_token=hf_token, + ) def download_archive_as_directory(local_directory: str, repo_id: str, file_in_repo: str, @@ -84,7 +95,7 @@ def download_directory_as_directory( local_directory: str, repo_id: str, dir_in_repo: str = '.', pattern: str = '**/*', repo_type: RepoTypeTyping = 'dataset', revision: str = 'main', silent: bool = False, ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET, - resume_download: bool = True, max_workers: int = 8, max_retries: int = 5, + max_workers: int = 8, max_retries: int = 5, soft_mode_when_check: bool = False, hf_token: Optional[str] = None ): """ @@ -110,8 +121,6 @@ def download_directory_as_directory( :type max_workers: int :param max_retries: Max retry times when downloading. Default is ``5``. :type max_retries: int - :param resume_download: Resume the existing download. - :type resume_download: bool :param soft_mode_when_check: Just check the size of the expected file when enabled. Default is False. :type soft_mode_when_check: bool :param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned. @@ -129,46 +138,46 @@ def download_directory_as_directory( progress = tqdm(files, silent=silent, desc=f'Downloading {dir_in_repo!r} ...') def _download_one_file(rel_file): - current_resume_download = resume_download - try: - dst_file = os.path.join(local_directory, rel_file) - file_in_repo = hf_normpath(f'{dir_in_repo}/{rel_file}') - if os.path.exists(dst_file) and is_local_file_ready( - repo_id=repo_id, - repo_type=repo_type, - local_file=dst_file, - file_in_repo=file_in_repo, - revision=revision, - hf_token=hf_token, - soft_mode=soft_mode_when_check, - ): - logging.info(f'Local file {rel_file} is ready, download skipped.') - else: - tries = 0 - while True: - try: - download_file_to_file( - local_file=dst_file, - repo_id=repo_id, - file_in_repo=file_in_repo, - repo_type=repo_type, - revision=revision, - resume_download=current_resume_download, - hf_token=hf_token, - ) - except requests.exceptions.RequestException as err: - if tries < max_retries: - tries += 1 - logging.warning(f'Download {rel_file!r} failed, retry ({tries}/{max_retries}) - {err!r}.') - current_resume_download = True + with TemporaryDirectory() as td: + try: + dst_file = os.path.join(local_directory, rel_file) + file_in_repo = hf_normpath(f'{dir_in_repo}/{rel_file}') + if os.path.exists(dst_file) and is_local_file_ready( + repo_id=repo_id, + repo_type=repo_type, + local_file=dst_file, + file_in_repo=file_in_repo, + revision=revision, + hf_token=hf_token, + soft_mode=soft_mode_when_check, + ): + logging.info(f'Local file {rel_file} is ready, download skipped.') + else: + tries = 0 + while True: + try: + _raw_download_file( + td=td, + local_file=dst_file, + repo_id=repo_id, + file_in_repo=file_in_repo, + repo_type=repo_type, + revision=revision, + hf_token=hf_token, + ) + except requests.exceptions.RequestException as err: + if tries < max_retries: + tries += 1 + logging.warning( + f'Download {rel_file!r} failed, retry ({tries}/{max_retries}) - {err!r}.') + else: + raise else: - raise - else: - break + break - progress.update() - except Exception as err: - logging.error(f'Unexpected error when downloading {rel_file!r} - {err!r}') + progress.update() + except Exception as err: + logging.exception(f'Unexpected error when downloading {rel_file!r} - {err!r}') tp = ThreadPoolExecutor(max_workers=max_workers) for file in files: diff --git a/test/operate/test_download.py b/test/operate/test_download.py index 7e4c4131dc..c7f7d97f46 100644 --- a/test/operate/test_download.py +++ b/test/operate/test_download.py @@ -4,6 +4,7 @@ from hbutils.testing import isolated_directory from hfutils.operate import download_file_to_file, download_archive_as_directory, download_directory_as_directory +from hfutils.operate.download import _raw_download_file from test.testings import get_testfile, file_compare, dir_compare @@ -37,9 +38,9 @@ def test_download_directory_as_directory(self): def _my_download(*args, **kwargs): nonlocal call_times call_times += 1 - return download_file_to_file(*args, **kwargs) + return _raw_download_file(*args, **kwargs) - with patch('hfutils.operate.download.download_file_to_file', _my_download), \ + with patch('hfutils.operate.download._raw_download_file', _my_download), \ isolated_directory(): download_directory_as_directory( 'download_dir', @@ -59,9 +60,9 @@ def test_download_directory_as_directory_partial(self): def _my_download(*args, **kwargs): nonlocal call_times call_times += 1 - return download_file_to_file(*args, **kwargs) + return _raw_download_file(*args, **kwargs) - with patch('hfutils.operate.download.download_file_to_file', _my_download), \ + with patch('hfutils.operate.download._raw_download_file', _my_download), \ isolated_directory({'download_dir': src_dir}): download_directory_as_directory( 'download_dir', @@ -80,9 +81,9 @@ def test_download_directory_as_directory_with_pattern(self): def _my_download(*args, **kwargs): nonlocal call_times call_times += 1 - return download_file_to_file(*args, **kwargs) + return _raw_download_file(*args, **kwargs) - with patch('hfutils.operate.download.download_file_to_file', _my_download), \ + with patch('hfutils.operate.download._raw_download_file', _my_download), \ isolated_directory(): download_directory_as_directory( 'download_dir', From ddc0e9d97a743f60b7f5ab9595bfb2fc031eca3b Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 14 Aug 2024 22:52:08 +0800 Subject: [PATCH 2/2] dev(narugo): add pydocs --- hfutils/operate/download.py | 44 +++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/hfutils/operate/download.py b/hfutils/operate/download.py index e369cb580c..b94f825ee5 100644 --- a/hfutils/operate/download.py +++ b/hfutils/operate/download.py @@ -1,3 +1,27 @@ +""" +This module provides functions for downloading files and directories from Hugging Face repositories. + +It includes utilities for downloading individual files, archives, and entire directories, +with support for concurrent downloads, retries, and progress tracking. + +The module interacts with the Hugging Face Hub API to fetch repository contents and +download files, handling various repository types and revisions. + +Key features: + +- Download individual files from Hugging Face repositories +- Download and extract archive files +- Download entire directories with pattern matching and ignore rules +- Concurrent downloads with configurable worker count +- Retry mechanism for failed downloads +- Progress tracking with tqdm +- Support for different repository types (dataset, model, space) +- Token-based authentication for accessing private repositories + +This module is particularly useful for managing and synchronizing local copies of +Hugging Face repository contents, especially when dealing with large datasets or models. +""" + import logging import os.path import shutil @@ -15,6 +39,26 @@ def _raw_download_file(td: str, local_file: str, repo_id: str, file_in_repo: str, repo_type: RepoTypeTyping = 'dataset', revision: str = 'main', hf_token: Optional[str] = None): + """ + Download a file from a Hugging Face repository to a temporary directory and then move it to the final location. + + This internal function handles the actual download process using the Hugging Face Hub client. + + :param td: Temporary directory path. + :type td: str + :param local_file: The final local file path where the downloaded file will be moved. + :type local_file: str + :param repo_id: The identifier of the repository. + :type repo_id: str + :param file_in_repo: The file path within the repository. + :type file_in_repo: str + :param repo_type: The type of the repository ('dataset', 'model', 'space'). + :type repo_type: RepoTypeTyping + :param revision: The revision of the repository (e.g., branch, tag, commit hash). + :type revision: str + :param hf_token: Hugging Face token for API client. + :type hf_token: str, optional + """ hf_client = get_hf_client(hf_token=hf_token) relative_filename = os.path.join(*file_in_repo.split("/")) temp_path = os.path.join(td, relative_filename)