Skip to content

Commit

Permalink
dev(narugo): add pattern for directory downloading
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Jul 28, 2024
1 parent 74c0565 commit 4bf6540
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 45 deletions.
108 changes: 74 additions & 34 deletions hfutils/operate/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import fnmatch
import logging
import os
import re
from functools import lru_cache
from typing import Literal, List, Optional
from typing import Literal, List, Optional, Union, Iterator

from huggingface_hub import HfApi, HfFileSystem
from huggingface_hub.hf_api import RepoFolder, RepoFile
from huggingface_hub.utils import HfHubHTTPError

from hfutils.utils import parse_hf_fs_path
from ..utils import parse_hf_fs_path, hf_fs_path, tqdm, hf_normpath

RepoTypeTyping = Literal['dataset', 'model', 'space']
REPO_TYPES = ['dataset', 'model', 'space']
Expand Down Expand Up @@ -76,10 +80,56 @@ def _is_file_ignored(file_segments: List[str], ignore_patterns: List[str]) -> bo
return False


def list_files_in_repository(repo_id: str, repo_type: RepoTypeTyping = 'dataset',
subdir: str = '', revision: str = 'main',
ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET,
hf_token: Optional[str] = None) -> List[str]:
def list_all_with_pattern(
repo_id: str, pattern: str = '**/*', repo_type: RepoTypeTyping = 'dataset',
revision: str = 'main', startup_batch: int = 500, batch_factor: float = 0.8,
hf_token: Optional[str] = None, silent: bool = False
) -> Iterator[Union[RepoFile, RepoFolder]]:
hf_fs = get_hf_fs(hf_token=hf_token)
hf_client = get_hf_client(hf_token=hf_token)

try:
paths = [
parse_hf_fs_path(path).filename
for path in hf_fs.glob(hf_fs_path(
repo_id=repo_id,
repo_type=repo_type,
filename=pattern,
revision=revision,
))
]
except FileNotFoundError:
return

offset, batch_size = 0, startup_batch
progress = tqdm(total=len(paths), desc='Paths Info', silent=silent)
while offset < len(paths):
batch_paths = paths[offset:offset + batch_size]
try:
all_items = hf_client.get_paths_info(
repo_id=repo_id,
repo_type=repo_type,
paths=batch_paths,
revision=revision,
)
except HfHubHTTPError as err:
if err.response.status_code == 413:
new_batch_size = max(1, int(round(batch_size * batch_factor)))
logging.info(f'Reducing batch size {batch_size} --> {new_batch_size} ...')
batch_size = new_batch_size
continue
raise

Check warning on line 121 in hfutils/operate/base.py

View check run for this annotation

Codecov / codecov/patch

hfutils/operate/base.py#L115-L121

Added lines #L115 - L121 were not covered by tests
else:
progress.update(len(all_items))
offset += len(all_items)
yield from all_items


def list_files_in_repository(
repo_id: str, repo_type: RepoTypeTyping = 'dataset',
subdir: str = '', pattern: str = '**/*', revision: str = 'main',
ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET,
hf_token: Optional[str] = None, silent: bool = False) -> List[str]:
"""
List files in a Hugging Face repository based on the given parameters.
Expand All @@ -89,6 +139,8 @@ def list_files_in_repository(repo_id: str, repo_type: RepoTypeTyping = 'dataset'
:type repo_type: RepoTypeTyping
:param subdir: The subdirectory to list files from.
:type subdir: str
:param pattern: Wildcard pattern of the target files.
:type pattern: str
:param revision: The revision of the repository (e.g., branch, tag, commit hash).
:type revision: str
:param ignore_patterns: List of file patterns to ignore.
Expand All @@ -101,35 +153,23 @@ def list_files_in_repository(repo_id: str, repo_type: RepoTypeTyping = 'dataset'
"""
if ignore_patterns is _IGNORE_PATTERN_UNSET:
ignore_patterns = _DEFAULT_IGNORE_PATTERNS
hf_fs = get_hf_fs(hf_token)
if repo_type == 'model':
repo_root_path = repo_id
elif repo_type == 'dataset':
repo_root_path = f'datasets/{repo_id}'
elif repo_type == 'space':
repo_root_path = f'spaces/{repo_id}'
else:
raise ValueError(f'Invalid repo_type - {repo_type!r}.')
if subdir and subdir != '.':
repo_root_path = f'{repo_root_path}/{subdir}'

try:
_exist_files = [
parse_hf_fs_path(file).filename
for file in hf_fs.glob(f'{repo_root_path}/**', revision=revision)
]
except FileNotFoundError:
return []
if subdir and subdir != '.':
_exist_files = [os.path.relpath(file, subdir) for file in _exist_files]

_exist_ps = sorted([(file, file.split('/')) for file in _exist_files], key=lambda x: x[1])
retval = []
for i, (file, segments) in enumerate(_exist_ps):
if i < len(_exist_ps) - 1 and segments == _exist_ps[i + 1][1][:len(segments)]:
continue
if file != '.':
pattern = f'{subdir}/{pattern}'

result = []
for item in list_all_with_pattern(
repo_id=repo_id,
repo_type=repo_type,
revision=revision,
pattern=pattern,
hf_token=hf_token,
silent=silent,
):
if isinstance(item, RepoFile):
path = hf_normpath(os.path.relpath(item.path, start=subdir or '.'))
segments = list(filter(bool, re.split(r'[\\/]+', path)))
if not _is_file_ignored(segments, ignore_patterns):
retval.append('/'.join(segments))
result.append(path)

return retval
return result
23 changes: 17 additions & 6 deletions hfutils/operate/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ def download_archive_as_directory(local_directory: str, repo_id: str, file_in_re
archive_unpack(archive_file, local_directory, password=password)


def download_directory_as_directory(local_directory: str, repo_id: str, dir_in_repo: str = '.',
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
silent: bool = False, ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET,
resume_download: bool = True, max_workers: int = 8, max_retries: int = 5,
hf_token: Optional[str] = None):
def download_directory_as_directory(
local_directory: str, repo_id: str, dir_in_repo: str = '.', pattern: str = '**/*',
repo_type: RepoTypeTyping = 'dataset', revision: str = 'main',
silent: bool = False, ignore_patterns: List[str] = _IGNORE_PATTERN_UNSET,
resume_download: bool = True, max_workers: int = 8, max_retries: int = 5,
hf_token: Optional[str] = None
):
"""
Download all files in a directory from a Hugging Face repository to a local directory.
Expand All @@ -94,6 +96,8 @@ def download_directory_as_directory(local_directory: str, repo_id: str, dir_in_r
:type repo_id: str
:param dir_in_repo: The directory path within the repository.
:type dir_in_repo: str
:param pattern: Patterns for filtering.
:type pattern: str
:param repo_type: The type of the repository ('dataset', 'model', 'space').
:type repo_type: RepoTypeTyping
:param revision: The revision of the repository (e.g., branch, tag, commit hash).
Expand All @@ -111,7 +115,14 @@ def download_directory_as_directory(local_directory: str, repo_id: str, dir_in_r
:param hf_token: Huggingface token for API client, use ``HF_TOKEN`` variable if not assigned.
:type hf_token: str, optional
"""
files = list_files_in_repository(repo_id, repo_type, dir_in_repo, revision, ignore_patterns, hf_token=hf_token)
files = list_files_in_repository(
repo_id=repo_id,
repo_type=repo_type,
subdir=dir_in_repo,
revision=revision,
ignore_patterns=ignore_patterns,
hf_token=hf_token,
)
progress = tqdm(files, silent=silent, desc=f'Downloading {dir_in_repo!r} ...')

def _download_one_file(rel_file):
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ requests
click>=7
tzlocal
natsort
urlobject
urlobject
fsspec>=2024
4 changes: 0 additions & 4 deletions test/operate/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,5 @@ def test_list_files_in_repository_space(self):
assert (set(should_exists) & set(files)) == set(should_exists)
assert not (set(should_not_exists) & set(files))

def test_list_files_in_repository_failed(self):
with pytest.raises(ValueError):
list_files_in_repository('deepghs/highres_datasets', repo_type='fff')

def test_list_files_in_repository_repo_not_exist(self):
assert list_files_in_repository('deepghs/highres_datasets', repo_type='model') == []

0 comments on commit 4bf6540

Please sign in to comment.