diff --git a/docs/source/api_doc/operate/base.rst b/docs/source/api_doc/operate/base.rst index e1a69a3463..3d3664d0a8 100644 --- a/docs/source/api_doc/operate/base.rst +++ b/docs/source/api_doc/operate/base.rst @@ -20,6 +20,13 @@ get_hf_fs +list_all_with_pattern +---------------------------------------- + +.. autofunction:: list_all_with_pattern + + + list_files_in_repository ---------------------------------------- diff --git a/hfutils/entry/download.py b/hfutils/entry/download.py index d174596490..550dbedf28 100644 --- a/hfutils/entry/download.py +++ b/hfutils/entry/download.py @@ -47,12 +47,17 @@ def _add_download_subcommand(cli: click.Group) -> click.Group: help='Max threads to download.', show_default=True) @click.option('-p', '--password', 'password', type=str, default=None, help='Password for the archive file. Only applied when -a is used.', show_default=True) + @click.option('-w', '--wildcard', 'wildcard', type=str, default=None, + help='Wildcard for files to download. Only applied when -d is used.', show_default=True) @click.option('--tmpdir', 'tmpdir', type=str, default=None, help='Use custom temporary Directory.', show_default=True) @command_wrap() - def download(repo_id: str, repo_type: RepoTypeTyping, - file_in_repo: Optional[str], archive_in_repo: Optional[str], dir_in_repo: Optional[str], - output_path: str, revision: str, max_workers: int, password: Optional[str], tmpdir: Optional[str]): + def download( + repo_id: str, repo_type: RepoTypeTyping, + file_in_repo: Optional[str], archive_in_repo: Optional[str], dir_in_repo: Optional[str], + output_path: str, revision: str, max_workers: int, + password: Optional[str], wildcard: Optional[str], tmpdir: Optional[str] + ): """ Download data from HuggingFace repositories. @@ -74,6 +79,8 @@ def download(repo_id: str, repo_type: RepoTypeTyping, :type max_workers: int :param password: Password for the archive file. Only applied when -a is used. :type password: str, optional + :param wildcard: Wildcard for files to download. Only applied when -d is used. + :type password: str, optional :param tmpdir: Use custom temporary Directory. :type tmpdir: str, optional """ @@ -114,6 +121,7 @@ def download(repo_id: str, repo_type: RepoTypeTyping, local_directory=output_path, repo_id=repo_id, dir_in_repo=dir_in_repo, + pattern=wildcard or '**/*', repo_type=repo_type, revision=revision, silent=False, diff --git a/hfutils/operate/__init__.py b/hfutils/operate/__init__.py index 207e512c84..d2142812ab 100644 --- a/hfutils/operate/__init__.py +++ b/hfutils/operate/__init__.py @@ -1,4 +1,4 @@ -from .base import get_hf_client, get_hf_fs, list_files_in_repository +from .base import get_hf_client, get_hf_fs, list_all_with_pattern, list_files_in_repository from .download import download_file_to_file, download_archive_as_directory, download_directory_as_directory from .upload import upload_file_to_file, upload_directory_as_archive, upload_directory_as_directory from .validate import is_local_file_ready diff --git a/hfutils/operate/base.py b/hfutils/operate/base.py index 85a9ae49c5..6744808e39 100644 --- a/hfutils/operate/base.py +++ b/hfutils/operate/base.py @@ -1,11 +1,31 @@ +""" +This module provides utilities for interacting with the Hugging Face Hub API and filesystem. +It includes functions for retrieving API clients, listing files in repositories, and handling +file patterns and ignore rules. + +The module offers the following main functionalities: + +1. Retrieving Hugging Face API tokens and clients +2. Accessing the Hugging Face filesystem +3. Listing files in Hugging Face repositories with pattern matching and ignore rules +4. Parsing and normalizing Hugging Face filesystem paths + +These utilities are designed to simplify working with Hugging Face repositories, especially +when dealing with datasets, models, and spaces. +""" + import fnmatch +import logging import os +import re from functools import lru_cache -from typing import Literal, List, Optional +from typing import Literal, List, Optional, Union, Iterator from huggingface_hub import HfApi, HfFileSystem +from huggingface_hub.hf_api import RepoFolder, RepoFile +from huggingface_hub.utils import HfHubHTTPError -from hfutils.utils import parse_hf_fs_path +from ..utils import parse_hf_fs_path, hf_fs_path, tqdm, hf_normpath RepoTypeTyping = Literal['dataset', 'model', 'space'] REPO_TYPES = ['dataset', 'model', 'space'] @@ -16,7 +36,10 @@ def _get_hf_token() -> Optional[str]: """ Retrieve the Hugging Face token from the environment variable. - :return: The Hugging Face token. + This function checks for the 'HF_TOKEN' environment variable and returns its value. + It is cached to avoid repeated environment variable lookups. + + :return: The Hugging Face token if set, otherwise None. :rtype: Optional[str] """ return os.environ.get('HF_TOKEN') @@ -27,11 +50,20 @@ def get_hf_client(hf_token: Optional[str] = None) -> HfApi: """ Get the Hugging Face API client. - :param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned. - :type hf_token: str, optional + This function returns an instance of the Hugging Face API client. If a token is not provided, + it attempts to use the token from the environment variable. + + :param hf_token: Hugging Face token for API client. If not provided, uses the 'HF_TOKEN' environment variable. + :type hf_token: Optional[str] - :return: The Hugging Face API client. + :return: An instance of the Hugging Face API client. :rtype: HfApi + + :example: + + >>> client = get_hf_client() + >>> # Use client to interact with Hugging Face API + >>> client.list_repos(organization="huggingface") """ return HfApi(token=hf_token or _get_hf_token()) @@ -41,11 +73,21 @@ def get_hf_fs(hf_token: Optional[str] = None) -> HfFileSystem: """ Get the Hugging Face file system. - :param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned. - :type hf_token: str, optional + This function returns an instance of the Hugging Face file system. If a token is not provided, + it attempts to use the token from the environment variable. The file system is configured + not to use listings cache to ensure fresh results. + + :param hf_token: Hugging Face token for API client. If not provided, uses the 'HF_TOKEN' environment variable. + :type hf_token: Optional[str] - :return: The Hugging Face file system. + :return: An instance of the Hugging Face file system. :rtype: HfFileSystem + + :example: + + >>> fs = get_hf_fs() + >>> # Use fs to interact with Hugging Face file system + >>> fs.ls("dataset/example") """ # use_listings_cache=False is necessary # or the result of glob and ls will be cached, the unittest will down @@ -60,6 +102,9 @@ def _is_file_ignored(file_segments: List[str], ignore_patterns: List[str]) -> bo """ Check if a file should be ignored based on the given ignore patterns. + This function checks each segment of the file path against the provided ignore patterns. + If any segment matches any of the patterns, the file is considered ignored. + :param file_segments: The segments of the file path. :type file_segments: List[str] :param ignore_patterns: List of file patterns to ignore. @@ -67,6 +112,13 @@ def _is_file_ignored(file_segments: List[str], ignore_patterns: List[str]) -> bo :return: True if the file should be ignored, False otherwise. :rtype: bool + + :example: + + >>> _is_file_ignored(['folder', 'file.txt'], ['.git*', '*.log']) + False + >>> _is_file_ignored(['folder', '.gitignore'], ['.git*', '*.log']) + True """ for segment in file_segments: for pattern in ignore_patterns: @@ -76,60 +128,140 @@ def _is_file_ignored(file_segments: List[str], ignore_patterns: List[str]) -> bo return False -def list_files_in_repository(repo_id: str, repo_type: RepoTypeTyping = 'dataset', - subdir: str = '', revision: str = 'main', - ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET, - hf_token: Optional[str] = None) -> List[str]: +def list_all_with_pattern( + repo_id: str, pattern: str = '**/*', repo_type: RepoTypeTyping = 'dataset', + revision: str = 'main', startup_batch: int = 500, batch_factor: float = 0.8, + hf_token: Optional[str] = None, silent: bool = False +) -> Iterator[Union[RepoFile, RepoFolder]]: + """ + List all files and folders in a Hugging Face repository matching a given pattern. + + This function retrieves information about files and folders in a repository that match + the specified pattern. It uses batching to handle large repositories efficiently. + + :param repo_id: The identifier of the repository. + :type repo_id: str + :param pattern: Wildcard pattern to match files and folders. Default is `**/*` (all files and folders). + :type pattern: str + :param repo_type: The type of the repository ('dataset', 'model', 'space'). Default is 'dataset'. + :type repo_type: RepoTypeTyping + :param revision: The revision of the repository (e.g., branch, tag, commit hash). Default is 'main'. + :type revision: str + :param startup_batch: Initial batch size for retrieving path information. Default is 500. + :type startup_batch: int + :param batch_factor: Factor to reduce batch size if a request fails. Default is 0.8. + :type batch_factor: float + :param hf_token: Hugging Face token for API client. If not provided, uses the 'HF_TOKEN' environment variable. + :type hf_token: Optional[str] + :param silent: If True, suppresses progress bar. Default is False. + :type silent: bool + + :return: An iterator of RepoFile and RepoFolder objects matching the pattern. + :rtype: Iterator[Union[RepoFile, RepoFolder]] + + :raises HfHubHTTPError: If there's an error in the API request that's not related to batch size. + + :example: + + >>> for item in list_all_with_pattern("username/repo", pattern="*.txt"): + ... print(item.path) + """ + hf_fs = get_hf_fs(hf_token=hf_token) + hf_client = get_hf_client(hf_token=hf_token) + + try: + paths = [ + parse_hf_fs_path(path).filename + for path in hf_fs.glob(hf_fs_path( + repo_id=repo_id, + repo_type=repo_type, + filename=pattern, + revision=revision, + )) + ] + except FileNotFoundError: + return + + offset, batch_size = 0, startup_batch + progress = tqdm(total=len(paths), desc='Paths Info', silent=silent) + while offset < len(paths): + batch_paths = paths[offset:offset + batch_size] + try: + all_items = hf_client.get_paths_info( + repo_id=repo_id, + repo_type=repo_type, + paths=batch_paths, + revision=revision, + ) + except HfHubHTTPError as err: + if err.response.status_code == 413: + new_batch_size = max(1, int(round(batch_size * batch_factor))) + logging.warning(f'Reducing batch size {batch_size} --> {new_batch_size} ...') + batch_size = new_batch_size + continue + raise + else: + progress.update(len(all_items)) + offset += len(all_items) + yield from all_items + + +def list_files_in_repository( + repo_id: str, repo_type: RepoTypeTyping = 'dataset', + subdir: str = '', pattern: str = '**/*', revision: str = 'main', + ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET, + hf_token: Optional[str] = None, silent: bool = False) -> List[str]: """ List files in a Hugging Face repository based on the given parameters. + This function retrieves a list of file paths in a specified repository that match + the given pattern and are not ignored by the ignore patterns. + :param repo_id: The identifier of the repository. :type repo_id: str - :param repo_type: The type of the repository ('dataset', 'model', 'space'). + :param repo_type: The type of the repository ('dataset', 'model', 'space'). Default is 'dataset'. :type repo_type: RepoTypeTyping - :param subdir: The subdirectory to list files from. + :param subdir: The subdirectory to list files from. Default is an empty string (root directory). :type subdir: str - :param revision: The revision of the repository (e.g., branch, tag, commit hash). + :param pattern: Wildcard pattern of the target files. Default is `**/*` (all files). + :type pattern: str + :param revision: The revision of the repository (e.g., branch, tag, commit hash). Default is 'main'. :type revision: str - :param ignore_patterns: List of file patterns to ignore. + :param ignore_patterns: List of file patterns to ignore. If not set, uses default ignore patterns. :type ignore_patterns: List[str] - :param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned. - :type hf_token: str, optional + :param hf_token: Hugging Face token for API client. If not provided, uses the 'HF_TOKEN' environment variable. + :type hf_token: Optional[str] + :param silent: If True, suppresses progress bar. Default is False. + :type silent: bool - :return: A list of file paths. + :return: A list of file paths that match the criteria. :rtype: List[str] + + :example: + + >>> files = list_files_in_repository("username/repo", pattern="*.txt", ignore_patterns=[".git*", "*.log"]) + >>> print(files) + ['file1.txt', 'folder/file2.txt'] """ if ignore_patterns is _IGNORE_PATTERN_UNSET: ignore_patterns = _DEFAULT_IGNORE_PATTERNS - hf_fs = get_hf_fs(hf_token) - if repo_type == 'model': - repo_root_path = repo_id - elif repo_type == 'dataset': - repo_root_path = f'datasets/{repo_id}' - elif repo_type == 'space': - repo_root_path = f'spaces/{repo_id}' - else: - raise ValueError(f'Invalid repo_type - {repo_type!r}.') - if subdir and subdir != '.': - repo_root_path = f'{repo_root_path}/{subdir}' - try: - _exist_files = [ - parse_hf_fs_path(file).filename - for file in hf_fs.glob(f'{repo_root_path}/**', revision=revision) - ] - except FileNotFoundError: - return [] if subdir and subdir != '.': - _exist_files = [os.path.relpath(file, subdir) for file in _exist_files] - - _exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1]) - retval = [] - for i, (file, segments) in enumerate(_exist_ps): - if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]: - continue - if file != '.': + pattern = f'{subdir}/{pattern}' + + result = [] + for item in list_all_with_pattern( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + pattern=pattern, + hf_token=hf_token, + silent=silent, + ): + if isinstance(item, RepoFile): + path = hf_normpath(os.path.relpath(item.path, start=subdir or '.')) + segments = list(filter(bool, re.split(r'[\\/]+', path))) if not _is_file_ignored(segments, ignore_patterns): - retval.append('/'.join(segments)) + result.append(path) - return retval + return result diff --git a/hfutils/operate/download.py b/hfutils/operate/download.py index 46cdfc9657..d24ba94207 100644 --- a/hfutils/operate/download.py +++ b/hfutils/operate/download.py @@ -80,11 +80,13 @@ def download_archive_as_directory(local_directory: str, repo_id: str, file_in_re archive_unpack(archive_file, local_directory, password=password) -def download_directory_as_directory(local_directory: str, repo_id: str, dir_in_repo: str = '.', - repo_type: RepoTypeTyping = 'dataset', revision: str = 'main', - silent: bool = False, ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET, - resume_download: bool = True, max_workers: int = 8, max_retries: int = 5, - hf_token: Optional[str] = None): +def download_directory_as_directory( + local_directory: str, repo_id: str, dir_in_repo: str = '.', pattern: str = '**/*', + repo_type: RepoTypeTyping = 'dataset', revision: str = 'main', + silent: bool = False, ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET, + resume_download: bool = True, max_workers: int = 8, max_retries: int = 5, + hf_token: Optional[str] = None +): """ Download all files in a directory from a Hugging Face repository to a local directory. @@ -94,6 +96,8 @@ def download_directory_as_directory(local_directory: str, repo_id: str, dir_in_r :type repo_id: str :param dir_in_repo: The directory path within the repository. :type dir_in_repo: str + :param pattern: Patterns for filtering. + :type pattern: str :param repo_type: The type of the repository ('dataset', 'model', 'space'). :type repo_type: RepoTypeTyping :param revision: The revision of the repository (e.g., branch, tag, commit hash). @@ -111,7 +115,15 @@ def download_directory_as_directory(local_directory: str, repo_id: str, dir_in_r :param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned. :type hf_token: str, optional """ - files = list_files_in_repository(repo_id, repo_type, dir_in_repo, revision, ignore_patterns, hf_token=hf_token) + files = list_files_in_repository( + repo_id=repo_id, + repo_type=repo_type, + subdir=dir_in_repo, + pattern=pattern, + revision=revision, + ignore_patterns=ignore_patterns, + hf_token=hf_token, + ) progress = tqdm(files, silent=silent, desc=f'Downloading {dir_in_repo!r} ...') def _download_one_file(rel_file): diff --git a/requirements.txt b/requirements.txt index f72b7ece0f..42c0a868a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ requests click>=7 tzlocal natsort -urlobject \ No newline at end of file +urlobject +fsspec>=2024 \ No newline at end of file diff --git a/test/operate/test_base.py b/test/operate/test_base.py index 2b4388df70..1597947d09 100644 --- a/test/operate/test_base.py +++ b/test/operate/test_base.py @@ -1,6 +1,7 @@ import pytest +from natsort import natsorted -from hfutils.operate import list_files_in_repository +from hfutils.operate import list_files_in_repository, list_all_with_pattern should_not_exists = [ '.gitignore', @@ -109,9 +110,42 @@ def test_list_files_in_repository_space(self): assert (set(should_exists) & set(files)) == set(should_exists) assert not (set(should_not_exists) & set(files)) - def test_list_files_in_repository_failed(self): - with pytest.raises(ValueError): - list_files_in_repository('deepghs/highres_datasets', repo_type='fff') + def test_list_files_in_repository_large(self): + files = list_files_in_repository('deepghs/danbooru_newest', repo_type='dataset', pattern='**/*.tar') + files = natsorted(files) + assert files == [ + f'images/0{i:03d}.tar' + for i in range(1000) + ] def test_list_files_in_repository_repo_not_exist(self): assert list_files_in_repository('deepghs/highres_datasets', repo_type='model') == [] + + def test_list_all_with_pattern(self): + vs = natsorted([ + item.path for item in + list_all_with_pattern( + 'deepghs/danbooru_newest', + repo_type='dataset', + pattern='images/*', + ) + ]) + assert vs == natsorted([ + *[f'images/0{i:03d}.tar' for i in range(1000)], + *[f'images/0{i:03d}.json' for i in range(1000)], + ]) + + def test_list_all_with_pattern_with_large_startup(self): + vs = natsorted([ + item.path for item in + list_all_with_pattern( + 'deepghs/danbooru_newest', + repo_type='dataset', + pattern='images/*', + startup_batch=1500, + ) + ]) + assert vs == natsorted([ + *[f'images/0{i:03d}.tar' for i in range(1000)], + *[f'images/0{i:03d}.json' for i in range(1000)], + ]) diff --git a/test/operate/test_download.py b/test/operate/test_download.py index 1ceed608de..7e4c4131dc 100644 --- a/test/operate/test_download.py +++ b/test/operate/test_download.py @@ -71,3 +71,25 @@ def _my_download(*args, **kwargs): dir_compare(target_dir, 'download_dir') assert call_times == 6 + + def test_download_directory_as_directory_with_pattern(self): + target_dir = get_testfile('skin_mashu_pattern') + + call_times = 0 + + def _my_download(*args, **kwargs): + nonlocal call_times + call_times += 1 + return download_file_to_file(*args, **kwargs) + + with patch('hfutils.operate.download.download_file_to_file', _my_download), \ + isolated_directory(): + download_directory_as_directory( + 'download_dir', + repo_id='deepghs/game_character_skins', + dir_in_repo='fgo/1', + pattern='第*.png', + ) + dir_compare(target_dir, 'download_dir') + + assert call_times == 4 diff --git "a/test/testfile/skin_mashu_pattern/\347\254\2541\351\230\266\346\256\265.png" "b/test/testfile/skin_mashu_pattern/\347\254\2541\351\230\266\346\256\265.png" new file mode 100644 index 0000000000..468941a4dc Binary files /dev/null and "b/test/testfile/skin_mashu_pattern/\347\254\2541\351\230\266\346\256\265.png" differ diff --git "a/test/testfile/skin_mashu_pattern/\347\254\2542\351\230\266\346\256\265.png" "b/test/testfile/skin_mashu_pattern/\347\254\2542\351\230\266\346\256\265.png" new file mode 100644 index 0000000000..2f0397b7f9 Binary files /dev/null and "b/test/testfile/skin_mashu_pattern/\347\254\2542\351\230\266\346\256\265.png" differ diff --git "a/test/testfile/skin_mashu_pattern/\347\254\2543\351\230\266\346\256\265.png" "b/test/testfile/skin_mashu_pattern/\347\254\2543\351\230\266\346\256\265.png" new file mode 100644 index 0000000000..875d3e2f59 Binary files /dev/null and "b/test/testfile/skin_mashu_pattern/\347\254\2543\351\230\266\346\256\265.png" differ diff --git "a/test/testfile/skin_mashu_pattern/\347\254\2544\351\230\266\346\256\265.png" "b/test/testfile/skin_mashu_pattern/\347\254\2544\351\230\266\346\256\265.png" new file mode 100644 index 0000000000..79bed1704a Binary files /dev/null and "b/test/testfile/skin_mashu_pattern/\347\254\2544\351\230\266\346\256\265.png" differ