From a83e7ecb1578c40a86c9f08965c7ffdc556b107b Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 19 Aug 2024 16:11:23 +0800 Subject: [PATCH 01/17] dev(narugo): add archive file check --- hfutils/entry/cli.py | 2 + hfutils/entry/tree.py | 39 ++++ hfutils/utils/__init__.py | 1 + hfutils/utils/archive.py | 121 ++++++++++++ test/utils/test_archive.py | 379 +++++++++++++++++++++++++++++++++++++ 5 files changed, 542 insertions(+) create mode 100644 hfutils/entry/tree.py create mode 100644 hfutils/utils/archive.py create mode 100644 test/utils/test_archive.py diff --git a/hfutils/entry/cli.py b/hfutils/entry/cli.py index 8194f4c06b..b7d31bef53 100644 --- a/hfutils/entry/cli.py +++ b/hfutils/entry/cli.py @@ -5,6 +5,7 @@ from .ls import _add_ls_subcommand from .ls_repo import _add_ls_repo_subcommand from .rollback import _add_rollback_subcommand +from .tree import _add_tree_subcommand from .upload import _add_upload_subcommand from .whoami import _add_whoami_subcommand @@ -17,6 +18,7 @@ _add_index_subcommand, _add_rollback_subcommand, _add_clone_subcommand, + _add_tree_subcommand, ] cli = hfutilcli diff --git a/hfutils/entry/tree.py b/hfutils/entry/tree.py new file mode 100644 index 0000000000..9e2eba5cfd --- /dev/null +++ b/hfutils/entry/tree.py @@ -0,0 +1,39 @@ +import mimetypes + +import click +from huggingface_hub import configure_http_backend +from hbutils.string import format_tree +from .base import CONTEXT_SETTINGS +from ..operate.base import REPO_TYPES, get_hf_client, list_files_in_repository, RepoTypeTyping +from ..utils import get_requests_session + + +def _add_tree_subcommand(cli: click.Group) -> click.Group: + @cli.command('tree', help='List files from HuggingFace repository.\n\n' + 'Set environment $HF_TOKEN to use your own access token.', + context_settings=CONTEXT_SETTINGS) + @click.option('-r', '--repository', 'repo_id', type=str, required=True, + help='Repository to download from.') + @click.option('-t', '--type', 'repo_type', type=click.Choice(REPO_TYPES), default='dataset', + help='Type of the HuggingFace repository.', show_default=True) + @click.option('-d', '--directory', 'dir_in_repo', type=str, default=None, + help='Directory in repository to download the full directory tree.') + @click.option('-R', '--revision', 'revision', type=str, default='main', + help='Revision of repository.', show_default=True) + @click.option('-a', '--all', 'show_all', is_flag=True, type=bool, default=False, + help='Show all files, including hidden files.', show_default=True) + def tree(repo_id: str, repo_type: RepoTypeTyping, dir_in_repo, revision: str, show_all: bool): + configure_http_backend(get_requests_session) + + hf_client = get_hf_client() + + list_files_in_repository( + repo_id=repo_id, + repo_type=repo_type, + subdir='.', + revision=revision, + ignore_patterns=[], + + ) + + return cli diff --git a/hfutils/utils/__init__.py b/hfutils/utils/__init__.py index e76a3d4bd8..9ae89b5fed 100644 --- a/hfutils/utils/__init__.py +++ b/hfutils/utils/__init__.py @@ -1,3 +1,4 @@ +from .archive import is_archive_or_compressed from .binary import is_binary_file from .download import download_file from .logging import ColoredFormatter diff --git a/hfutils/utils/archive.py b/hfutils/utils/archive.py new file mode 100644 index 0000000000..a95d1f685e --- /dev/null +++ b/hfutils/utils/archive.py @@ -0,0 +1,121 @@ +import os.path +import re +from typing import Union + +_ARCHIVE_KNOWN_EXTS = { + '.tar', # Tape Archive + '.rar', # Roshal Archive + '.rar5', # RAR version 5 + '.zip', # ZIP archive + '.7z', # 7-Zip archive + '.7zip', # Alternative extension for 7-Zip + '.gz', # Gzip compressed file + '.bz2', # Bzip2 compressed file + '.xz', # XZ compressed file + '.ace', # ACE archive + '.lz', # Lzip compressed file + '.lzma', # LZMA compressed file + '.z', # Compress (Unix) file + '.cab', # Microsoft Cabinet file + '.arj', # ARJ archive + '.iso', # ISO disk image + '.lzh', # LZH archive + '.sit', # StuffIt archive + '.sitx', # StuffIt X archive + '.sea', # Self-Extracting Archive + '.alz', # ALZip archive + '.egg', # Python Egg + '.whl', # Python Wheel + '.deb', # Debian package + '.rpm', # Red Hat Package Manager + '.pkg', # macOS Installer Package + '.dmg', # Apple Disk Image + '.msi', # Microsoft Installer + '.tgz', # Gzipped tar archive + '.tbz2', # Bzip2 compressed tar archive + '.lzw', # LZW compressed file + '.rz', # RZip compressed file + '.lzo', # Lempel-Ziv-Oberhumer compressed file + '.zst', # Zstandard compressed file + '.tar.gz', # Gzipped tar archive + '.tar.bz2', # Bzip2 compressed tar archive + '.tar.xz', # XZ compressed tar archive + '.tar.lz', # Lzip compressed tar archive + '.tar.lzma', # LZMA compressed tar archive + '.zipx', # Extended ZIP archive + '.arc', # ARC archive + '.ark', # ARC archive (alternative extension) + '.lha', # LHA archive + '.zoo', # ZOO archive + '.gca', # GCA archive + '.uc2', # UC2 archive + '.uha', # UHarc archive + '.war', # Web Application Archive + '.ear', # Enterprise Application aRchive + '.sar', # SAR archive + '.jar', # Java Archive + '.apk', # Android Package Kit + '.xpi', # XPInstall (Mozilla browser extension) + '.snap', # Snap package (Ubuntu) + '.appimage', # AppImage package + '.squashfs', # Squashfs filesystem + '.cpio', # CPIO archive + '.shar', # Shell archive + '.lbr', # LBR archive + '.mar', # Mozilla Archive + '.sbx', # Sandbox file + '.qcow2', # QEMU Copy On Write 2 + '.vdi', # VirtualBox Disk Image + '.vhd', # Virtual Hard Disk + '.vmdk', # Virtual Machine Disk + '.ova', # Open Virtual Appliance + '.xar', # eXtensible ARchive + '.mpq', # MoPaQ archive (Blizzard games) +} + +# Additional generic patterns +_EXTERNAL_PATTERNS = [ + r'\.tar(?:\.(?:gz|bz2|xz|lz|lzma|Z))?$', # Tar and compressed tar archives + r'\.t\.(?:gz|bz2|xz|lz|lzma|Z)$', # Compressed files with .t.XX extension + r'\.(?:zip|z)$', # ZIP archives and Z compressed files + r'\.rar$', # RAR archives + r'\.7z$', # 7-Zip archives + r'\.(?:ar|a)$', # AR archives + r'\.(?:lz|lzma?)$', # LZ and LZMA compressed files + r'\.gz(?:ip)?$', # Gzip compressed files + r'\.bz(?:ip)?2?$', # Bzip and Bzip2 compressed files + r'\.(?:xz|lzh|lha)$', # XZ, LZH, and LHA archives + r'\.(?:iso|img|dmg|pkg|msi)$', # Disk images and installers + r'\.(?:deb|rpm|apk|ipa)$', # Package formats +] + +_ARCHIVE_SPLIT_PATTERNS = [ + r'^(.+\.(?:zip|rar|7z|tar|tar\.gz|tar\.bz2|tar\.xz))\.\d+$', # Split archives (e.g., .zip.001) + r'^(.+\.t)\.[a-z0-9]{2,4}$', # Split archives with .t.XX extension + r'^(.+\.part\d+)\.rar$', # Split RAR archives + r'^(.+)\.r\d{2}$', # Old-style split RAR archives +] + + +def is_archive_or_compressed(filename: Union[str, os.PathLike]) -> bool: + if not isinstance(filename, (str, os.PathLike)): + raise TypeError(f'Unknown file name type - {filename!r}') + filename = os.path.basename(os.path.normcase(str(filename))) + + # Check for known extensions + for ext in _ARCHIVE_KNOWN_EXTS: + if filename.lower().endswith(ext): + return True + + # Check for split archives + for pattern in _ARCHIVE_SPLIT_PATTERNS: + match = re.match(pattern, filename.lower()) + if match: + return True + + # Check for external patterns + for pattern in _EXTERNAL_PATTERNS: + if re.search(pattern, filename.lower()): + return True + + return False diff --git a/test/utils/test_archive.py b/test/utils/test_archive.py new file mode 100644 index 0000000000..a570281a1f --- /dev/null +++ b/test/utils/test_archive.py @@ -0,0 +1,379 @@ +import os +from pathlib import Path + +import pytest + +from hfutils.utils.archive import is_archive_or_compressed + + +@pytest.mark.unittest +class TestUtilsArchive: + @pytest.mark.parametrize("filename, expected", [ + ("file.tar", True), + ("file.rar", True), + ("file.zip", True), + ("file.7z", True), + ("file.gz", True), + ("file.bz2", True), + ("file.xz", True), + ("file.ace", True), + ("file.lz", True), + ("file.lzma", True), + ("file.z", True), + ("file.cab", True), + ("file.arj", True), + ("file.iso", True), + ("file.lzh", True), + ("file.sit", True), + ("file.sitx", True), + ("file.sea", True), + ("file.alz", True), + ("file.egg", True), + ("file.whl", True), + ("file.deb", True), + ("file.rpm", True), + ("file.pkg", True), + ("file.dmg", True), + ("file.msi", True), + ("file.tgz", True), + ("file.tbz2", True), + ("file.lzw", True), + ("file.rz", True), + ("file.lzo", True), + ("file.zst", True), + + ("file.tar.gz", True), + ("file.tar.bz2", True), + ("file.tar.xz", True), + ("file.tar.lz", True), + ("file.tar.lzma", True), + + ("file.zipx", True), + ("file.arc", True), + ("file.ark", True), + ("file.lha", True), + ("file.zoo", True), + ("file.gca", True), + ("file.uc2", True), + ("file.uha", True), + ("file.war", True), + ("file.ear", True), + ("file.sar", True), + ("file.jar", True), + ("file.apk", True), + ("file.xpi", True), + ("file.snap", True), + ("file.appimage", True), + ("file.squashfs", True), + ("file.cpio", True), + ("file.shar", True), + ("file.lbr", True), + ("file.mar", True), + ("file.sbx", True), + ("file.qcow2", True), + ("file.vdi", True), + ("file.vhd", True), + ("file.vmdk", True), + ("file.ova", True), + ("file.xar", True), + ("file.mpq", True), + + ("file.tar.gz", True), + ("file.t.lz", True), + ("file.zipx", True), + ("file.rar5", True), + ("file.7zip", True), + ("file.arj", True), + ("file.lzma", True), + ("file.gzip", True), + ("file.bzip2", True), + ("file.xz", True), + ("file.lzh", True), + ("file.iso", True), + ("file.img", True), + ("file.dmg", True), + ("file.pkg", True), + ("file.msi", True), + ("file.deb", True), + ("file.rpm", True), + ("file.apk", True), + ("file.ipa", True), + + ("文件.zip", True), + ("ファイル.tar.gz", True), + ("파일.rar", True), + ("файл.7z", True), + ("αρχείο.iso", True), + ("फ़ाइल.deb", True), + ("ملف.rpm", True), + ("文件.exe", False), + ("ファイル.txt", False), + ("파일.doc", False), + + ("/home/user/file.zip", True), + ("C:\\Users\\User\\file.rar", True), + ("../relative/path/file.tar.gz", True), + + ("file.txt", False), + ("file.exe", False), + ("file.doc", False), + ("file", False), + ("", False), + ]) + def test_is_archive_or_compressed(self, filename, expected): + assert is_archive_or_compressed(filename) == expected + + @pytest.mark.parametrize("filename", [ + "FILE.ZIP", "File.Tar.Gz", "ARCHIVE.RAR", "package.DEB" + ]) + def test_case_insensitive(self, filename): + assert is_archive_or_compressed(filename) + + def test_no_extension(self, ): + assert not is_archive_or_compressed("file_without_extension") + + def test_hidden_file(self, ): + assert is_archive_or_compressed(".hidden_archive.zip") + assert not is_archive_or_compressed(".hidden_file") + + def test_empty_filename(self, ): + assert not is_archive_or_compressed("") + + def test_only_extension(self, ): + assert is_archive_or_compressed(".zip") + assert not is_archive_or_compressed(".txt") + + @pytest.mark.parametrize("filename, expected", [ + ("file.tar.gz", True), + ("archive.tar.bz2", True), + ("data.tar.xz", True), + ("backup.tar.lz", True), + ("compressed.tar.lzma", True), + ("old_archive.tar.Z", True), + ("legacy_file.tar.lzo", True), + ("file.backup.tar.gz", True), + ("archive.old.tar.bz2", True), + ("文档.tar.gz", True), + ("アーカイブ.tar.bz2", True), + ("압축파일.tar.xz", True), + ("архив.tar.lzma", True), + ("αρχείο.tar.lz", True), + ("File.TAR.GZ", True), + ("ARCHIVE.Tar.Bz2", True), + ("file.doc.pdf", False), + ("archive.zip.txt", False), + ("file.targz", False), + ("archive.tarbz2", False), + ("/home/user.name/file.tar.gz", True), + ("C:\\Users\\user.name\\archive.tar.bz2", True), + (".hidden_archive.tar.xz", True), + (".tar.gz", True), + (".tar.bz2", True), + ]) + def test_compound_extensions(self, filename, expected): + assert is_archive_or_compressed(filename) == expected + + def test_compound_extension_edge_cases(self): + assert not is_archive_or_compressed("file.tar.") + assert is_archive_or_compressed("file..tar.gz") + assert is_archive_or_compressed(".tar.gz") + assert is_archive_or_compressed("tar.gz") + + @pytest.mark.parametrize("filename", [ + "archive.tar", "data.tar.gz", "file.tar.bz2", "backup.tar.xz" + ]) + def test_tar_pattern(self, filename): + assert is_archive_or_compressed(filename) + + @pytest.mark.parametrize("filename", [ + "archive.t.gz", "data.t.bz", "file.t.xz", "backup.t.lz" + ]) + def test_t_pattern(self, filename): + assert is_archive_or_compressed(filename) + + @pytest.mark.parametrize("filename", [ + "archive.zip", "data.z" + ]) + def test_zip_pattern(self, filename): + assert is_archive_or_compressed(filename) + + def test_rar_pattern(self, ): + assert is_archive_or_compressed("archive.rar") + + @pytest.mark.parametrize("filename", [ + "archive.r00", "data.r01", "file.r99" + ]) + def test_r_pattern(self, filename): + assert is_archive_or_compressed(filename) + + def test_7z_pattern(self, ): + assert is_archive_or_compressed("archive.7z") + + @pytest.mark.parametrize("filename", [ + "archive.ar", "data.a" + ]) + def test_ar_pattern(self, filename): + assert is_archive_or_compressed(filename) + + @pytest.mark.parametrize("filename", [ + "archive.lz", "data.lzm", "file.lzma" + ]) + def test_lz_pattern(self, filename): + assert is_archive_or_compressed(filename) + + @pytest.mark.parametrize("filename", [ + "archive.gz", "data.gzip" + ]) + def test_gz_pattern(self, filename): + assert is_archive_or_compressed(filename) + + @pytest.mark.parametrize("filename", [ + "archive.bz", "data.bz2", "file.bzip", "backup.bzip2" + ]) + def test_bz_pattern(self, filename): + assert is_archive_or_compressed(filename) + + @pytest.mark.parametrize("filename", [ + "archive.xz", "data.lzh", "file.lha" + ]) + def test_xz_lzh_pattern(self, filename): + assert is_archive_or_compressed(filename) + + @pytest.mark.parametrize("filename", [ + "image.iso", "disk.img", "installer.dmg", "package.pkg", "setup.msi" + ]) + def test_disk_image_pattern(self, filename): + assert is_archive_or_compressed(filename) + + @pytest.mark.parametrize("filename", [ + "package.deb", "software.rpm", "app.apk", "application.ipa" + ]) + def test_package_pattern(self, filename): + assert is_archive_or_compressed(filename) + + # Test cases for non-matching filenames + @pytest.mark.parametrize("filename", [ + "document.txt", "image.png", "script.py", "webpage.html", + "archive.tar.txt", "file.zipper", "data.rarr", "backup.7zz", + "program.exe", "library.dll", "sheet.xlsx", "presentation.pptx" + ]) + def test_non_matching_patterns(self, filename): + assert not is_archive_or_compressed(filename) + + # Test case sensitivity + @pytest.mark.parametrize("filename", [ + "ARCHIVE.TAR", "DATA.ZIP", "FILE.RAR", "BACKUP.7Z", + "Image.ISO", "Package.DEB", "Software.RPM", "App.APK" + ]) + def test_case_insensitivity(self, filename): + assert is_archive_or_compressed(filename) + + @pytest.mark.parametrize("filename, expected", [ + # Common archive formats + ("file.tar", True), + ("archive.zip", True), + ("data.rar", True), + ("backup.7z", True), + + # Compressed tar formats + ("file.tar.gz", True), + ("archive.tar.bz2", True), + ("data.tar.xz", True), + ("backup.tgz", True), + ("file.tbz2", True), + + # Other compression formats + ("document.gz", True), + ("file.bz2", True), + ("data.xz", True), + ("archive.lz", True), + ("file.lzma", True), + ("data.Z", True), + + # Less common but valid formats + ("file.cab", True), + ("archive.iso", True), + ("data.dmg", True), + ("backup.vhd", True), + + # Split archives + ("large_file.zip.001", True), + ("big_archive.7z.001", True), + ("huge_data.tar.gz.001", True), + + # Non-archive/compressed files + ("document.txt", False), + ("image.png", False), + ("script.py", False), + ("data.csv", False), + + # Edge cases + ("archive.tar.gz.txt", False), # Archive extension but .txt at the end + (".htaccess", False), # Hidden file, not an archive + ("file_without_extension", False), + ("archive.tar.123", True), # Assuming this is considered valid + ("data.001", False), # Just a number extension, not necessarily an archive + ]) + def test_is_archive_or_compressed(self, filename, expected): + assert is_archive_or_compressed(filename) == expected + + # Test case insensitivity + @pytest.mark.parametrize("filename", [ + "ARCHIVE.ZIP", + "File.TaR.Gz", + "DATA.RAR", + ]) + def test_is_archive_or_compressed_case_insensitive(self, filename): + assert is_archive_or_compressed(filename) + + # Test with path-like filenames + @pytest.mark.parametrize("filename", [ + "/home/user/documents/archive.zip", + "C:\\Users\\Documents\\backup.tar.gz", + "../data/file.rar", + ]) + def test_is_archive_or_compressed_with_paths(self, filename): + assert is_archive_or_compressed(filename) + + # Test invalid inputs + @pytest.mark.parametrize("invalid_input", [ + None, + 123, + [], + {}, + ]) + def test_is_archive_or_compressed_invalid_input(self, invalid_input): + with pytest.raises(TypeError): # Replace with specific exception if known + is_archive_or_compressed(invalid_input) + + @pytest.mark.parametrize("filename, expected", [ + (Path("/home/user/file.zip"), True), + (Path("/home/user/file.tar.gz"), True), + (Path("/home/user/file.txt"), False), + (Path("/home/user/file.7z"), True), + (Path("/home/user/file.rar"), True), + (Path("/home/user/file.iso"), True), + (Path("/home/user/file.doc"), False), + (Path("/home/user/file.zip.001"), True), + (Path("/home/user/file.part1.rar"), True), + (Path("/home/user/file.r00"), True), + (Path("/home/user/file.001"), False), + ]) + def test_is_archive_or_compressed_with_pathlib(self, filename, expected): + assert is_archive_or_compressed(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + (os.path.join("home", "user", "file.zip"), True), + (os.path.join("home", "user", "file.tar.gz"), True), + (os.path.join("home", "user", "file.txt"), False), + (os.path.join("home", "user", "file.7z"), True), + (os.path.join("home", "user", "file.rar"), True), + (os.path.join("home", "user", "file.iso"), True), + (os.path.join("home", "user", "file.doc"), False), + (os.path.join("home", "user", "file.zip.001"), True), + (os.path.join("home", "user", "file.part1.rar"), True), + (os.path.join("home", "user", "file.r00"), True), + (os.path.join("home", "user", "file.001"), False), + ]) + def test_is_archive_or_compressed_with_os_path(self, filename, expected): + assert is_archive_or_compressed(filename) == expected From b0d01a1a3b53dcf3421488bc55070a1baf342a2b Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 19 Aug 2024 16:22:05 +0800 Subject: [PATCH 02/17] dev(narugo): add docs for function is_archive_or_compressed --- docs/source/api_doc/utils/archive.rst | 15 +++++++++++ docs/source/api_doc/utils/index.rst | 1 + hfutils/utils/archive.py | 36 +++++++++++++++++++++++++++ 3 files changed, 52 insertions(+) create mode 100644 docs/source/api_doc/utils/archive.rst diff --git a/docs/source/api_doc/utils/archive.rst b/docs/source/api_doc/utils/archive.rst new file mode 100644 index 0000000000..ba13fe37e5 --- /dev/null +++ b/docs/source/api_doc/utils/archive.rst @@ -0,0 +1,15 @@ +hfutils.utils.archive +================================= + +.. currentmodule:: hfutils.utils.archive + +.. automodule:: hfutils.utils.archive + + +is_archive_or_compressed +--------------------------- + +.. autofunction:: is_archive_or_compressed + + + diff --git a/docs/source/api_doc/utils/index.rst b/docs/source/api_doc/utils/index.rst index a4eb28e71e..cb0aef8cc5 100644 --- a/docs/source/api_doc/utils/index.rst +++ b/docs/source/api_doc/utils/index.rst @@ -9,6 +9,7 @@ hfutils.utils .. toctree:: :maxdepth: 3 + archive binary download number diff --git a/hfutils/utils/archive.py b/hfutils/utils/archive.py index a95d1f685e..2ae0d63104 100644 --- a/hfutils/utils/archive.py +++ b/hfutils/utils/archive.py @@ -1,3 +1,15 @@ +""" +This module provides functionality for identifying archive and compressed files based on their filenames. + +It includes a comprehensive list of known archive and compressed file extensions, as well as patterns for +identifying split archives and other generic compressed file formats. The main +function :func:`is_archive_or_compressed` can be used to determine if a given filename +represents an archive or compressed file. + +The module is useful for file handling operations where it's necessary to distinguish between regular files +and archives or compressed files. +""" + import os.path import re from typing import Union @@ -98,6 +110,30 @@ def is_archive_or_compressed(filename: Union[str, os.PathLike]) -> bool: + """ + Determine if the given filename represents an archive or compressed file. + + This function checks the filename against a list of known archive and compressed file extensions, + as well as patterns for split archives and other generic compressed file formats. + + :param filename: The name of the file to check. Can be a string or a path-like object. + :type filename: Union[str, os.PathLike] + + :return: True if the filename represents an archive or compressed file, False otherwise. + :rtype: bool + + :raises TypeError: If the filename is not a string or path-like object. + + Usage: + >>> is_archive_or_compressed('example.zip') + True + >>> is_archive_or_compressed('document.txt') + False + >>> is_archive_or_compressed('archive.tar.gz') + True + >>> is_archive_or_compressed('split_archive.zip.001') + True + """ if not isinstance(filename, (str, os.PathLike)): raise TypeError(f'Unknown file name type - {filename!r}') filename = os.path.basename(os.path.normcase(str(filename))) From d9f360055814fdba712e8030b07a7c45a4105747 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 19 Aug 2024 16:31:18 +0800 Subject: [PATCH 03/17] dev(narugo): update ls code --- hfutils/entry/ls.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hfutils/entry/ls.py b/hfutils/entry/ls.py index 349d38f76f..d4f9ea82f2 100644 --- a/hfutils/entry/ls.py +++ b/hfutils/entry/ls.py @@ -10,7 +10,7 @@ from .base import CONTEXT_SETTINGS from ..operate.base import REPO_TYPES, get_hf_client -from ..utils import get_requests_session +from ..utils import get_requests_session, is_archive_or_compressed mimetypes.add_type('image/webp', '.webp') @@ -71,8 +71,10 @@ def __init__(self, item: Union[RepoFolder, RepoFile], base_dir: str): mimetype, _ = mimetypes.guess_type(item.path) _, ext = os.path.splitext(item.path) self.type = ListItemType.FILE - if ext in {'.ckpt', '.pt', '.safetensors', '.onnx', '.model', '.h5', '.mlmodel', - '.ftz', '.pb', '.pth', '.tflite'}: + if is_archive_or_compressed(item.path): + self.type = ListItemType.ARCHIVE + elif ext in {'.ckpt', '.pt', '.safetensors', '.onnx', '.model', '.h5', '.mlmodel', + '.ftz', '.pb', '.pth', '.tflite'}: self.type = ListItemType.MODEL elif ext in {'.json', '.csv', '.tsv', '.arrow', '.bin', '.msgpack', '.npy', '.npz', '.parquet', '.pickle', '.pkl', '.wasm'}: @@ -80,8 +82,6 @@ def __init__(self, item: Union[RepoFolder, RepoFile], base_dir: str): elif mimetype: if 'image' in mimetype: self.type = ListItemType.IMAGE - elif mimetype.split('/', maxsplit=1)[-1] in {'zip', 'rar', 'x-tar', 'x-7z-compressed'}: - self.type = ListItemType.ARCHIVE def _add_ls_subcommand(cli: click.Group) -> click.Group: From 035f320496e638e2cb171059609e1a85c178d5b0 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 19 Aug 2024 17:09:37 +0800 Subject: [PATCH 04/17] dev(narugo): add is_model_file --- hfutils/entry/ls.py | 59 +--------- hfutils/utils/__init__.py | 1 + hfutils/utils/model.py | 88 ++++++++++++++ hfutils/utils/type_.py | 68 +++++++++++ test/utils/test_model.py | 234 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 394 insertions(+), 56 deletions(-) create mode 100644 hfutils/utils/model.py create mode 100644 hfutils/utils/type_.py create mode 100644 test/utils/test_model.py diff --git a/hfutils/entry/ls.py b/hfutils/entry/ls.py index d4f9ea82f2..eae9bb95df 100644 --- a/hfutils/entry/ls.py +++ b/hfutils/entry/ls.py @@ -1,6 +1,5 @@ -import mimetypes import os.path -from enum import Enum, unique +import os.path from typing import Union, List import click @@ -10,46 +9,7 @@ from .base import CONTEXT_SETTINGS from ..operate.base import REPO_TYPES, get_hf_client -from ..utils import get_requests_session, is_archive_or_compressed - -mimetypes.add_type('image/webp', '.webp') - - -@unique -class ListItemType(Enum): - """ - Enum class representing different types of list items. - """ - - FILE = 0x1 - FOLDER = 0x2 - IMAGE = 0x3 - ARCHIVE = 0x4 - MODEL = 0x5 - DATA = 0x6 - - @property - def render_color(self): - """ - Get the render color based on the item type. - - :return: The render color for the item type. - :rtype: str - """ - if self == self.FILE: - return None - elif self == self.FOLDER: - return 'blue' - elif self == self.IMAGE: - return 'magenta' - elif self == self.ARCHIVE: - return 'red' - elif self == self.MODEL: - return 'green' - elif self == self.DATA: - return 'yellow' - else: - raise ValueError(f'Unknown type - {self!r}') # pragma: no cover +from ..utils import get_requests_session, ListItemType, get_file_type class ListItem: @@ -68,20 +28,7 @@ def __init__(self, item: Union[RepoFolder, RepoFile], base_dir: str): if isinstance(item, RepoFolder): self.type = ListItemType.FOLDER else: - mimetype, _ = mimetypes.guess_type(item.path) - _, ext = os.path.splitext(item.path) - self.type = ListItemType.FILE - if is_archive_or_compressed(item.path): - self.type = ListItemType.ARCHIVE - elif ext in {'.ckpt', '.pt', '.safetensors', '.onnx', '.model', '.h5', '.mlmodel', - '.ftz', '.pb', '.pth', '.tflite'}: - self.type = ListItemType.MODEL - elif ext in {'.json', '.csv', '.tsv', '.arrow', '.bin', '.msgpack', '.npy', '.npz', - '.parquet', '.pickle', '.pkl', '.wasm'}: - self.type = ListItemType.DATA - elif mimetype: - if 'image' in mimetype: - self.type = ListItemType.IMAGE + self.type = get_file_type(item.path) def _add_ls_subcommand(cli: click.Group) -> click.Group: diff --git a/hfutils/utils/__init__.py b/hfutils/utils/__init__.py index 9ae89b5fed..d544ee8fbe 100644 --- a/hfutils/utils/__init__.py +++ b/hfutils/utils/__init__.py @@ -7,4 +7,5 @@ from .session import TimeoutHTTPAdapter, get_requests_session, get_random_ua from .temp import TemporaryDirectory from .tqdm_ import tqdm +from .type_ import ListItemType, get_file_type from .walk import walk_files diff --git a/hfutils/utils/model.py b/hfutils/utils/model.py new file mode 100644 index 0000000000..ca28bf0cd2 --- /dev/null +++ b/hfutils/utils/model.py @@ -0,0 +1,88 @@ +import os +import re +from typing import Union + +_MODEL_EXTS = { + '.ckpt', # Checkpoint file + '.pt', # PyTorch model file + '.pth', # PyTorch model file (alternative extension) + '.safetensors', # SafeTensors model file + '.onnx', # Open Neural Network Exchange model file + '.model', # Generic model file + '.h5', # Hierarchical Data Format version 5 + '.hdf5', # Hierarchical Data Format version 5 (alternative extension) + '.mlmodel', # Core ML model file + '.ftz', # FastText model file + '.pb', # Protocol Buffer file (often used for TensorFlow models) + '.tflite', # TensorFlow Lite model file + '.pkl', # Pickle file (often used for scikit-learn models) + '.joblib', # Joblib file (often used for scikit-learn models) + '.bin', # Binary file (generic) + '.meta', # Meta file (often associated with TensorFlow checkpoints) + '.params', # Parameters file (often used in MXNet) + '.pdparams', # PaddlePaddle parameters file + '.pdmodel', # PaddlePaddle model file + '.ot', # OpenVINO model file + '.nnet', # Neural network file + '.dnn', # Deep neural network file + '.mar', # MXNet Archive + '.tf', # TensorFlow SavedModel file + '.keras', # Keras model file + '.weights', # Weights file (generic) + '.pmml', # Predictive Model Markup Language file + '.gguf', # GGUF (GPT-Generated Unified Format) file + '.ggml', # GGML (GPT-Generated Model Language) file + '.q4_0', # 4-bit quantized model (type 0) + '.q4_1', # 4-bit quantized model (type 1) + '.q5_0', # 5-bit quantized model (type 0) + '.q5_1', # 5-bit quantized model (type 1) + '.q8_0', # 8-bit quantized model + '.qnt', # Quantized model (generic) + '.int8', # 8-bit integer quantized model + '.fp16', # 16-bit floating point model + '.bk', # Backup file (often used for model checkpoints) + '.engine', # TensorRT engine file + '.plan', # TensorRT plan file + '.trt', # TensorRT model file + '.torchscript', # TorchScript model file + '.pdiparams', # PaddlePaddle inference parameters file + '.pdopt', # PaddlePaddle optimizer file + '.nb', # Neural network binary file + '.mnn', # MNN (Mobile Neural Network) model file + '.ncnn', # NCNN model file + '.om', # CANN (Compute Architecture for Neural Networks) offline model + '.rknn', # Rockchip Neural Network model file + '.xmodel', # Vitis AI model file + '.kmodel', # Kendryte model file +} + +_MODEL_SHARD_PATTERNS = [ + r'.*-\d{5}-of-\d{5}', # Pattern for sharded files like "model-00001-of-00005" + r'.*\.bin\.\d+', # Pattern for binary shards like "model.bin.1" + r'.*\.part\.\d+', # Pattern for part files like "model.part.0" + r'.*_part_\d+', # Alternative pattern for part files like "model_part_0" + r'.*-shard\d+', # Pattern for shard files like "model-shard1" +] + +_HF_MODEL_PATTERNS = [ + r'pytorch_model.*\.bin', # Hugging Face PyTorch model file + r'tf_model.*\.h5', # Hugging Face TensorFlow model file + r'model.*\.ckpt', # Hugging Face checkpoint file + r'flax_model.*\.msgpack', # Hugging Face Flax model file + r'.*\.safetensors', # SafeTensors file (often used in Hugging Face models) +] + + +def is_model_file(filename: Union[str, os.PathLike]) -> bool: + if not isinstance(filename, (str, os.PathLike)): + raise TypeError(f'Unknown file name type - {filename!r}') + filename = os.path.basename(os.path.normcase(str(filename))) + + if any(filename.lower().endswith(ext) for ext in _MODEL_EXTS): + return True + if any(re.match(pattern, filename.lower()) for pattern in _MODEL_SHARD_PATTERNS): + return True + if any(re.match(pattern, filename.lower()) for pattern in _HF_MODEL_PATTERNS): + return True + + return False diff --git a/hfutils/utils/type_.py b/hfutils/utils/type_.py new file mode 100644 index 0000000000..62e04e3328 --- /dev/null +++ b/hfutils/utils/type_.py @@ -0,0 +1,68 @@ +import mimetypes +import os +from enum import Enum, unique +from typing import Union + +from .archive import is_archive_or_compressed + +mimetypes.add_type('image/webp', '.webp') + + +@unique +class ListItemType(Enum): + """ + Enum class representing different types of list items. + """ + + FILE = 0x1 + FOLDER = 0x2 + IMAGE = 0x3 + ARCHIVE = 0x4 + MODEL = 0x5 + DATA = 0x6 + + @property + def render_color(self): + """ + Get the render color based on the item type. + + :return: The render color for the item type. + :rtype: str + """ + if self == self.FILE: + return None + elif self == self.FOLDER: + return 'blue' + elif self == self.IMAGE: + return 'magenta' + elif self == self.ARCHIVE: + return 'red' + elif self == self.MODEL: + return 'green' + elif self == self.DATA: + return 'yellow' + else: + raise ValueError(f'Unknown type - {self!r}') # pragma: no cover + + +def get_file_type(filename: Union[str, os.PathLike]) -> ListItemType: + if not isinstance(filename, (str, os.PathLike)): + raise TypeError(f'Unknown file name type - {filename!r}') + filename = os.path.basename(os.path.normcase(str(filename))) + + mimetype, _ = mimetypes.guess_type(filename) + _, ext = os.path.splitext(filename) + type_ = ListItemType.FILE + if is_archive_or_compressed(filename): + type_ = ListItemType.ARCHIVE + elif ext in {'.ckpt', '.pt', '.safetensors', '.onnx', '.model', '.h5', '.mlmodel', + '.ftz', '.pb', '.pth', '.tflite'}: + type_ = ListItemType.MODEL + elif ext in {'.json', '.csv', '.tsv', '.arrow', '.bin', '.msgpack', '.npy', '.npz', + '.parquet', '.pickle', '.pkl', '.wasm'}: + type_ = ListItemType.DATA + elif mimetype: + if 'image' in mimetype: + type_ = ListItemType.IMAGE + + return type_ diff --git a/test/utils/test_model.py b/test/utils/test_model.py new file mode 100644 index 0000000000..820b0c8eb6 --- /dev/null +++ b/test/utils/test_model.py @@ -0,0 +1,234 @@ +from pathlib import Path + +import pytest + +from hfutils.utils.model import is_model_file + + +@pytest.mark.unittest +class TestUtilsModels: + @pytest.mark.parametrize("filename, expected", [ + ("model.ckpt", True), + ("model.pt", True), + ("model.pth", True), + ("model.safetensors", True), + ("model.onnx", True), + ("model.h5", True), + ("model.tflite", True), + ("model.pkl", True), + ("model.bin", True), + ("model.params", True), + ("model.pdparams", True), + ("model.keras", True), + ("model.weights", True), + ("model.pmml", True), + ("model.gguf", True), + ("model.ggml", True), + ("model.q4_0", True), + ("model.q4_1", True), + ("model.q5_0", True), + ("model.q5_1", True), + ("model.q8_0", True), + ("model.qnt", True), + ("model.int8", True), + ("model.fp16", True), + ("model.bk", True), + ("model.engine", True), + ("model.plan", True), + ("model.trt", True), + ("model.torchscript", True), + ("model.pdiparams", True), + ("model.pdopt", True), + ("model.nb", True), + ("model.mnn", True), + ("model.ncnn", True), + ("model.om", True), + ("model.rknn", True), + ("model.xmodel", True), + ("model.kmodel", True), + ("data.txt", False), + ("image.jpg", False), + ("archive.zip", False), + ]) + def test_common_model_extensions(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("model-00001-of-00002", True), + ("model.bin.1", True), + ("model.part.0", True), + ("model_part_0", True), + ("model-shard1", True), + ("model_01", False), + ("model-part", False), + ]) + def test_shard_patterns(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("pytorch_model.bin", True), + ("tf_model.h5", True), + ("model.ckpt", True), + ("flax_model.msgpack", True), + ("model.safetensors", True), + ("tokenizer.json", False), + ("config.json", False), + ]) + def test_huggingface_patterns(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("model_quant.gguf", True), + ("model_quantized.q4_0", True), + ("model_int8.q5_1", True), + ("model_fp16.q8_0", True), + ("model_quantized.qnt", True), + ("model_int8.int8", True), + ("model_fp16.fp16", True), + ("model_float32.pth", True), + ]) + def test_quantized_patterns(self, filename, expected): + assert is_model_file(filename) == expected + + def test_pathlib_input(self): + assert is_model_file(Path("/home/user/model.pt")) + assert not is_model_file(Path("/home/user/data.txt")) + + def test_invalid_input(self): + with pytest.raises(TypeError): + is_model_file(123) + + def test_empty_string(self): + assert not is_model_file("") + + def test_none_input(self): + with pytest.raises(TypeError): + is_model_file(None) + + @pytest.mark.parametrize("filename, expected", [ + ("/home/user/MODEL.PT", True), + ("/home/user/Model.Safetensors", True), + ("/home/user/PYTORCH_MODEL.BIN", True), + ("/home/user/model_QUANT.bin", True), + ]) + def test_case_insensitivity(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("model.ckpt", True), + ("model.pt", True), + ("model.pth", True), + ("model.safetensors", True), + ("model.onnx", True), + ("model.h5", True), + ("model.tflite", True), + ("model.pkl", True), + ("model.bin", True), + ("model.params", True), + ("model.pdparams", True), + ("model.keras", True), + ("model.weights", True), + ("model.pmml", True), + ("data.txt", False), + ("image.jpg", False), + ("archive.zip", False), + ]) + def test_common_extensions(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("model-00001-of-00002", True), + ("model.bin.1", True), + ("model.part.0", True), + ("model_part_0", True), + ("model-shard1", True), + ("model_01", False), + ("model-part", False), + ]) + def test_shard_patterns(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("pytorch_model.bin", True), + ("tf_model.h5", True), + ("model.ckpt", True), + ("flax_model.msgpack", True), + ("model.safetensors", True), + ]) + def test_huggingface_patterns(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("/home/user/model.pt", True), + ("C:\\Users\\User\\model.ckpt", True), + ("/Users/user/Documents/model.safetensors", True), + ("\\\\server\\share\\model.onnx", True), + ]) + def test_different_os_paths(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("模型.pt", True), + ("モデル.ckpt", True), + ("модель.safetensors", True), + ("mødel.onnx", True), + ]) + def test_non_ascii_filenames(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("MODEL.PT", True), + ("Model.Safetensors", True), + ("PYTORCH_MODEL.BIN", True), + ("model_QUANT.bin", True), + ]) + def test_case_insensitivity(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("model.tar.gz", False), + ("model.zip", False), + ("model.npy", False), + ("model.npz", False), + ("model.json", False), + ("model.yaml", False), + ("model.xml", False), + ]) + def test_non_model_files(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("model.ckpt.1", True), + ("model-00001.safetensors", True), + ("pytorch_model-00001-of-00002.bin", True), + ("model.ckpt.data-00000-of-00001", True), + ]) + def test_complex_model_filenames(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("model.engine", True), + ("model.plan", True), + ("model.trt", True), + ("model.torchscript", True), + ("model.pdiparams", True), + ("model.pdopt", True), + ("model.nb", True), + ("model.mnn", True), + ("model.ncnn", True), + ("model.om", True), + ("model.rknn", True), + ("model.xmodel", True), + ("model.kmodel", True), + ]) + def test_additional_model_formats(self, filename, expected): + assert is_model_file(filename) == expected + + @pytest.mark.parametrize("filename, expected", [ + ("model-q4_0.gguf", True), + ("model.q5_1-00001-of-00002", True), + ("pytorch_model-q8_0.bin", True), + ("model.ckpt.int8.data-00000-of-00001", True), + ]) + def test_complex_quantized_filenames(self, filename, expected): + assert is_model_file(filename) == expected From 51d30af02360a2124a6d55aae0d34315eccf1a84 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 19 Aug 2024 17:49:24 +0800 Subject: [PATCH 05/17] dev(narugo): add model files --- hfutils/utils/type_.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hfutils/utils/type_.py b/hfutils/utils/type_.py index 62e04e3328..6d53bd0e0f 100644 --- a/hfutils/utils/type_.py +++ b/hfutils/utils/type_.py @@ -4,6 +4,7 @@ from typing import Union from .archive import is_archive_or_compressed +from .model import is_model_file mimetypes.add_type('image/webp', '.webp') @@ -55,8 +56,7 @@ def get_file_type(filename: Union[str, os.PathLike]) -> ListItemType: type_ = ListItemType.FILE if is_archive_or_compressed(filename): type_ = ListItemType.ARCHIVE - elif ext in {'.ckpt', '.pt', '.safetensors', '.onnx', '.model', '.h5', '.mlmodel', - '.ftz', '.pb', '.pth', '.tflite'}: + elif is_model_file(filename): type_ = ListItemType.MODEL elif ext in {'.json', '.csv', '.tsv', '.arrow', '.bin', '.msgpack', '.npy', '.npz', '.parquet', '.pickle', '.pkl', '.wasm'}: From 4089a1ba75e2f5edc897333efb243f28c05e9ea3 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 19 Aug 2024 19:43:05 +0800 Subject: [PATCH 06/17] dev(narugo): add data file check --- hfutils/utils/__init__.py | 1 + hfutils/utils/data.py | 53 ++++++++++++++ test/utils/test_data.py | 144 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 198 insertions(+) create mode 100644 hfutils/utils/data.py create mode 100644 test/utils/test_data.py diff --git a/hfutils/utils/__init__.py b/hfutils/utils/__init__.py index d544ee8fbe..7979f30e4d 100644 --- a/hfutils/utils/__init__.py +++ b/hfutils/utils/__init__.py @@ -1,5 +1,6 @@ from .archive import is_archive_or_compressed from .binary import is_binary_file +from .data import is_data_file from .download import download_file from .logging import ColoredFormatter from .number import number_to_tag diff --git a/hfutils/utils/data.py b/hfutils/utils/data.py new file mode 100644 index 0000000000..cf36bd9e79 --- /dev/null +++ b/hfutils/utils/data.py @@ -0,0 +1,53 @@ +import os +from typing import Union + +_DATA_EXTS = { + '.json', # JavaScript Object Notation + '.csv', # Comma-Separated Values + '.tsv', # Tab-Separated Values + '.arrow', # Apache Arrow file format + '.feather', # Feather file format (fast, language-agnostic columnar format) + '.parquet', # Apache Parquet file format + '.avro', # Apache Avro file format + '.orc', # Optimized Row Columnar file format + '.npy', # NumPy array file + '.npz', # NumPy compressed archive file + '.hdf5', # Hierarchical Data Format version 5 + '.h5', # Alternative extension for HDF5 + '.mat', # MATLAB file format + '.sav', # SPSS data file + '.dta', # Stata data file + '.sas7bdat', # SAS data file + '.xpt', # SAS transport file + '.xlsx', # Microsoft Excel Open XML Spreadsheet + '.xls', # Microsoft Excel Binary File Format + '.ods', # OpenDocument Spreadsheet + '.db', # Generic database file + '.sqlite', # SQLite database file + '.mdb', # Microsoft Access database file + '.accdb', # Microsoft Access database file (newer version) + '.dbf', # dBase database file + '.ftr', # Feather file format (alternative extension) + '.geojson', # GeoJSON file (for geographical data) + '.shp', # Shapefile (for geographical data) + '.kml', # Keyhole Markup Language (for geographical data) + '.gpx', # GPS Exchange Format + '.nc', # NetCDF (Network Common Data Form) file + '.grib', # GRIdded Binary or General Regularly-distributed Information in Binary form + '.hdf', # Hierarchical Data Format (older version) + '.zarr', # Zarr array storage format + '.bin', # Generic Binary File + '.pickle', # Pickle dumped file + '.pkl', # Shortcut of .pickle + '.wasm', # WASM +} + + +def is_data_file(filename: Union[str, os.PathLike]) -> bool: + if not isinstance(filename, (str, os.PathLike)): + raise TypeError(f'Unknown file name type - {filename!r}') + + # Normalize the filename and get the extension + filename = os.path.basename(os.path.normcase(str(filename))) + _, ext = os.path.splitext(filename.lower()) + return ext in _DATA_EXTS diff --git a/test/utils/test_data.py b/test/utils/test_data.py new file mode 100644 index 0000000000..a05c101a87 --- /dev/null +++ b/test/utils/test_data.py @@ -0,0 +1,144 @@ +import os + +import pytest + +from hfutils.utils import is_data_file + + +@pytest.mark.unittest +class TestUtilsData: + @pytest.mark.parametrize("filename, expected", [ + ("data.json", True), + ("file.csv", True), + ("document.tsv", True), + ("data.arrow", True), + ("file.feather", True), + ("data.parquet", True), + ("file.avro", True), + ("data.orc", True), + ("array.npy", True), + ("compressed.npz", True), + ("data.hdf5", True), + ("file.h5", True), + ("matlab_data.mat", True), + ("spss_file.sav", True), + ("stata_data.dta", True), + ("sas_data.sas7bdat", True), + ("sas_transport.xpt", True), + ("excel_file.xlsx", True), + ("old_excel.xls", True), + ("open_document.ods", True), + ("database.db", True), + ("sqlite_db.sqlite", True), + ("access_db.mdb", True), + ("new_access.accdb", True), + ("dbase_file.dbf", True), + ("feather_data.ftr", True), + ("geo_data.geojson", True), + ("shape_file.shp", True), + ("keyhole_markup.kml", True), + ("gps_data.gpx", True), + ("netcdf_file.nc", True), + ("gridded_data.grib", True), + ("hierarchical_data.hdf", True), + ("zarr_data.zarr", True), + ("binary_data.bin", True), + ("pickled_data.pickle", True), + ("short_pickle.pkl", True), + ("webassembly.wasm", True), + ("text_file.txt", False), + ("image.png", False), + ("script.py", False), + ("DATA.JSON", True), # Test case insensitivity + ("/path/to/data.csv", True), # Test with path + ("file_without_extension", False), + ]) + def test_is_data_file(self, filename, expected): + assert is_data_file(filename) == expected + + def test_is_data_file_with_pathlike(self): + path = os.path.join("some", "path", "data.csv") + assert is_data_file(os.fspath(path)) + + def test_is_data_file_with_invalid_type(self): + with pytest.raises(TypeError): + is_data_file(123) + + def test_is_data_file_with_empty_string(self): + assert not is_data_file("") + + @pytest.mark.parametrize("filename", [ + "file.json", "file.CSV", "FILE.JSON", "DATA.CSV", + "/absolute/path/to/data.json", + "relative/path/to/data.csv", + r"C:\Windows\Path\To\data.tsv", + ]) + def test_is_data_file_case_and_path_variations(self, filename): + assert is_data_file(filename) + + @pytest.mark.parametrize("path", [ + "data/file.csv", + "data\\file.csv", + "/tmp/data.json", + "C:\\Users\\User\\data.json", + "~/documents/data.parquet", + "..\\..\\data.arrow", + "./data/file.feather", + ]) + def test_is_data_file_different_path_styles(self, path): + assert is_data_file(path) + + @pytest.mark.parametrize("filename", [ + "数据.csv", + "données.json", + "データ.parquet", + "данные.arrow", + "αρχείο.feather", + "파일.npy", + "ファイル.npz", + "ملف.hdf5", + ]) + def test_is_data_file_non_ascii_filenames(self, filename): + assert is_data_file(filename) + + @pytest.mark.parametrize("path", [ + "/用户/数据/file.csv", + "/utilisateur/données/file.json", + "/ユーザー/データ/file.parquet", + "/пользователь/данные/file.arrow", + "/χρήστης/αρχείο/file.feather", + "/사용자/파일/file.npy", + "/ユーザー/ファイル/file.npz", + "/المستخدم/ملف/file.hdf5", + ]) + def test_is_data_file_non_ascii_paths(self, path): + assert is_data_file(path) + + def test_is_data_file_windows_paths(self): + assert is_data_file(r"C:\Users\用户\Documents\data.csv") + assert is_data_file(r"\\server\share\データ.json") + + def test_is_data_file_macos_paths(self): + assert is_data_file("/Users/ユーザー/Documents/data.parquet") + assert is_data_file("/Volumes/External/données.arrow") + + def test_is_data_file_linux_paths(self): + assert is_data_file("/home/пользователь/documents/data.feather") + assert is_data_file("/mnt/external/αρχείο.npy") + + def test_is_data_file_with_os_path_objects(self): + paths = [ + os.path.join("data", "file.csv"), + os.path.join("用户", "数据", "file.json"), + os.path.join("ユーザー", "データ", "file.parquet"), + ] + for path in paths: + assert is_data_file(os.fspath(path)) + + @pytest.mark.parametrize("path", [ + "file:///C:/Users/User/data.csv", + "https://example.com/data.json", + "ftp://ftp.example.com/data.parquet", + ]) + def test_is_data_file_with_urls(self, path): + assert is_data_file(path) From 0d161a959c6211a3aaa658957a9db5a43dcb9b43 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 19 Aug 2024 21:07:04 +0800 Subject: [PATCH 07/17] dev(narugo): add docs for is_data_file and is_model_file --- docs/source/api_doc/utils/data.rst | 15 +++++++++ docs/source/api_doc/utils/index.rst | 2 ++ docs/source/api_doc/utils/model.rst | 15 +++++++++ hfutils/utils/data.py | 34 +++++++++++++++++++ hfutils/utils/model.py | 52 +++++++++++++++++++++++++++++ 5 files changed, 118 insertions(+) create mode 100644 docs/source/api_doc/utils/data.rst create mode 100644 docs/source/api_doc/utils/model.rst diff --git a/docs/source/api_doc/utils/data.rst b/docs/source/api_doc/utils/data.rst new file mode 100644 index 0000000000..e1a70e98cc --- /dev/null +++ b/docs/source/api_doc/utils/data.rst @@ -0,0 +1,15 @@ +hfutils.utils.data +================================= + +.. currentmodule:: hfutils.utils.data + +.. automodule:: hfutils.utils.data + + +is_data_file +--------------------------- + +.. autofunction:: is_data_file + + + diff --git a/docs/source/api_doc/utils/index.rst b/docs/source/api_doc/utils/index.rst index cb0aef8cc5..088c55876a 100644 --- a/docs/source/api_doc/utils/index.rst +++ b/docs/source/api_doc/utils/index.rst @@ -11,7 +11,9 @@ hfutils.utils archive binary + data download + model number path session diff --git a/docs/source/api_doc/utils/model.rst b/docs/source/api_doc/utils/model.rst new file mode 100644 index 0000000000..df369a25a1 --- /dev/null +++ b/docs/source/api_doc/utils/model.rst @@ -0,0 +1,15 @@ +hfutils.utils.model +================================= + +.. currentmodule:: hfutils.utils.model + +.. automodule:: hfutils.utils.model + + +is_model_file +--------------------------- + +.. autofunction:: is_model_file + + + diff --git a/hfutils/utils/data.py b/hfutils/utils/data.py index cf36bd9e79..ec7e57339b 100644 --- a/hfutils/utils/data.py +++ b/hfutils/utils/data.py @@ -1,3 +1,11 @@ +""" +This module provides functionality for identifying data files based on their file extensions. + +It includes a comprehensive set of data file extensions and a function to check if a given +filename corresponds to a known data file format. This can be useful in various data processing +and file handling scenarios where it's necessary to distinguish data files from other types of files. +""" + import os from typing import Union @@ -44,6 +52,32 @@ def is_data_file(filename: Union[str, os.PathLike]) -> bool: + """ + Determine if a given filename corresponds to a known data file format. + + This function checks if the file extension of the provided filename matches + any of the known data file extensions defined in the `_DATA_EXTS` set. + + :param filename: The name of the file to check. Can be a string or a path-like object. + :type filename: Union[str, os.PathLike] + + :return: True if the file extension matches a known data file format, False otherwise. + :rtype: bool + + :raises TypeError: If the provided filename is not a string or path-like object. + + Usage: + >>> is_data_file('data.csv') + True + >>> is_data_file('script.py') + False + >>> is_data_file(Path('/path/to/data.json')) + True + + .. note:: + The function is case-insensitive and works with both file names and full paths. + It normalizes the filename and extracts only the extension for comparison. + """ if not isinstance(filename, (str, os.PathLike)): raise TypeError(f'Unknown file name type - {filename!r}') diff --git a/hfutils/utils/model.py b/hfutils/utils/model.py index ca28bf0cd2..5d60312b20 100644 --- a/hfutils/utils/model.py +++ b/hfutils/utils/model.py @@ -1,3 +1,28 @@ +""" +This module provides functionality for identifying model files based on their extensions and naming patterns. + +It includes a comprehensive list of model file extensions, patterns for sharded model files, and specific patterns +for Hugging Face model files. The main function, :func:`is_model_file`, determines whether a given filename corresponds +to a model file based on these predefined patterns and extensions. + +This module can be useful in various scenarios, such as: + +- Automated model file detection in directories +- Validation of uploaded files in machine learning platforms +- Preprocessing steps in model loading pipelines + +Usage: + .. code:: python + + from model_file_identifier import is_model_file + + filename = "model.pt" + if is_model_file(filename): + print(f"{filename} is a model file") + else: + print(f"{filename} is not a model file") +""" + import os import re from typing import Union @@ -74,6 +99,33 @@ def is_model_file(filename: Union[str, os.PathLike]) -> bool: + """ + Determine if a given filename corresponds to a model file. + + This function checks if the provided filename matches any of the known model file extensions + or patterns, including sharded model files and Hugging Face specific patterns. + + :param filename: The name of the file to check. Can be a full path or just the filename. + :type filename: Union[str, os.PathLike] + + :return: True if the filename corresponds to a model file, False otherwise. + :rtype: bool + + :raises TypeError: If the filename is not a string or os.PathLike object. + + Usage: + >>> is_model_file("model.pt") + True + >>> is_model_file("data.csv") + False + >>> is_model_file("model-00001-of-00005") + True + >>> is_model_file("pytorch_model.bin") + True + + .. note:: + This function is case-insensitive and works with both file names and full paths. + """ if not isinstance(filename, (str, os.PathLike)): raise TypeError(f'Unknown file name type - {filename!r}') filename = os.path.basename(os.path.normcase(str(filename))) From fa15af679c069fbc1c6d175896b08afbfe45cf1b Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 19 Aug 2024 21:26:05 +0800 Subject: [PATCH 08/17] dev(narugo): add unittest for get_file_type --- hfutils/utils/type_.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/hfutils/utils/type_.py b/hfutils/utils/type_.py index 6d53bd0e0f..4adef6ae45 100644 --- a/hfutils/utils/type_.py +++ b/hfutils/utils/type_.py @@ -4,6 +4,7 @@ from typing import Union from .archive import is_archive_or_compressed +from .data import is_data_file from .model import is_model_file mimetypes.add_type('image/webp', '.webp') @@ -30,17 +31,17 @@ def render_color(self): :return: The render color for the item type. :rtype: str """ - if self == self.FILE: + if self == ListItemType.FILE: return None - elif self == self.FOLDER: + elif self == ListItemType.FOLDER: return 'blue' - elif self == self.IMAGE: + elif self == ListItemType.IMAGE: return 'magenta' - elif self == self.ARCHIVE: + elif self == ListItemType.ARCHIVE: return 'red' - elif self == self.MODEL: + elif self == ListItemType.MODEL: return 'green' - elif self == self.DATA: + elif self == ListItemType.DATA: return 'yellow' else: raise ValueError(f'Unknown type - {self!r}') # pragma: no cover @@ -52,14 +53,12 @@ def get_file_type(filename: Union[str, os.PathLike]) -> ListItemType: filename = os.path.basename(os.path.normcase(str(filename))) mimetype, _ = mimetypes.guess_type(filename) - _, ext = os.path.splitext(filename) type_ = ListItemType.FILE if is_archive_or_compressed(filename): type_ = ListItemType.ARCHIVE elif is_model_file(filename): type_ = ListItemType.MODEL - elif ext in {'.json', '.csv', '.tsv', '.arrow', '.bin', '.msgpack', '.npy', '.npz', - '.parquet', '.pickle', '.pkl', '.wasm'}: + elif is_data_file(filename): type_ = ListItemType.DATA elif mimetype: if 'image' in mimetype: From 1357652df768ca1ae2d78e4e11fdf3468da9e91d Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 19 Aug 2024 21:26:10 +0800 Subject: [PATCH 09/17] dev(narugo): add unittest for get_file_type --- test/utils/test_type.py | 90 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 test/utils/test_type.py diff --git a/test/utils/test_type.py b/test/utils/test_type.py new file mode 100644 index 0000000000..80b9059c20 --- /dev/null +++ b/test/utils/test_type.py @@ -0,0 +1,90 @@ +import os +from enum import Enum +from unittest.mock import patch + +import pytest + +from hfutils.utils import ListItemType, get_file_type + + +@pytest.mark.unittest +class TestUtilsType: + def test_list_item_type_enum(self): + assert isinstance(ListItemType.FILE, Enum) + assert ListItemType.FILE.value == 0x1 + assert ListItemType.FOLDER.value == 0x2 + assert ListItemType.IMAGE.value == 0x3 + assert ListItemType.ARCHIVE.value == 0x4 + assert ListItemType.MODEL.value == 0x5 + assert ListItemType.DATA.value == 0x6 + + @pytest.mark.parametrize("item_type, expected_color", [ + (ListItemType.FILE, None), + (ListItemType.FOLDER, 'blue'), + (ListItemType.IMAGE, 'magenta'), + (ListItemType.ARCHIVE, 'red'), + (ListItemType.MODEL, 'green'), + (ListItemType.DATA, 'yellow'), + ]) + def test_render_color(self, item_type, expected_color): + assert item_type.render_color == expected_color + + def test_render_color_unknown_type(self): + class UnknownType(Enum): + UNKNOWN = 0x7 + + with pytest.raises(ValueError, match='Unknown type'): + ListItemType.render_color.fget(UnknownType.UNKNOWN) + + @pytest.mark.parametrize("filename, expected_type", [ + ('file.txt', ListItemType.FILE), + ('image.jpg', ListItemType.IMAGE), + ('archive.zip', ListItemType.ARCHIVE), + ('model.pkl', ListItemType.MODEL), + ('data.csv', ListItemType.DATA), + ('folder', ListItemType.FILE), # Assuming folders are not detected by filename + ]) + def test_get_file_type(self, filename, expected_type): + with patch('mimetypes.guess_type') as mock_guess_type, \ + patch('hfutils.utils.type_.is_archive_or_compressed') as mock_is_archive, \ + patch('hfutils.utils.type_.is_model_file') as mock_is_model, \ + patch('hfutils.utils.type_.is_data_file') as mock_is_data: + mock_guess_type.return_value = (None, None) + mock_is_archive.return_value = expected_type == ListItemType.ARCHIVE + mock_is_model.return_value = expected_type == ListItemType.MODEL + mock_is_data.return_value = expected_type == ListItemType.DATA + + if expected_type == ListItemType.IMAGE: + mock_guess_type.return_value = ('image/jpeg', None) + + assert get_file_type(filename) == expected_type + + def test_get_file_type_with_path(self): + with patch('mimetypes.guess_type') as mock_guess_type: + mock_guess_type.return_value = (None, None) + assert get_file_type('/path/to/file.txt') == ListItemType.FILE + + def test_get_file_type_invalid_input(self): + with pytest.raises(TypeError, match='Unknown file name type'): + get_file_type(123) + + def test_get_file_type_empty_string(self): + with patch('mimetypes.guess_type') as mock_guess_type: + mock_guess_type.return_value = (None, None) + assert get_file_type('') == ListItemType.FILE + + @pytest.mark.parametrize("filename", [ + 'file.txt', 'FILE.TXT', 'FiLe.TxT', + '/path/to/file.txt', + r'C:\path\to\file.txt', + ]) + def test_get_file_type_case_insensitive(self, filename): + with patch('mimetypes.guess_type') as mock_guess_type: + mock_guess_type.return_value = (None, None) + assert get_file_type(filename) == ListItemType.FILE + + def test_get_file_type_with_pathlike_object(self): + with patch('mimetypes.guess_type') as mock_guess_type: + mock_guess_type.return_value = (None, None) + path = os.path.join('path', 'to', 'file.txt') + assert get_file_type(os.fspath(path)) == ListItemType.FILE From 24d729a6bccd420bb060116ae9e2cea20ad726a7 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Mon, 19 Aug 2024 21:30:51 +0800 Subject: [PATCH 10/17] dev(narugo): add pydocs for get_file_type --- docs/source/api_doc/utils/index.rst | 1 + docs/source/api_doc/utils/tqdm_.rst | 2 +- docs/source/api_doc/utils/type_.rst | 23 ++++++++++++ hfutils/utils/type_.py | 56 ++++++++++++++++++++++++++++- 4 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 docs/source/api_doc/utils/type_.rst diff --git a/docs/source/api_doc/utils/index.rst b/docs/source/api_doc/utils/index.rst index 088c55876a..0d2259b338 100644 --- a/docs/source/api_doc/utils/index.rst +++ b/docs/source/api_doc/utils/index.rst @@ -18,5 +18,6 @@ hfutils.utils path session tqdm_ + type_ walk diff --git a/docs/source/api_doc/utils/tqdm_.rst b/docs/source/api_doc/utils/tqdm_.rst index 633a1e6dd9..e790361ca3 100644 --- a/docs/source/api_doc/utils/tqdm_.rst +++ b/docs/source/api_doc/utils/tqdm_.rst @@ -1,4 +1,4 @@ -hfutils.utils.tqdm_ +hfutils.utils.tqdm\_ ================================= .. currentmodule:: hfutils.utils.tqdm_ diff --git a/docs/source/api_doc/utils/type_.rst b/docs/source/api_doc/utils/type_.rst new file mode 100644 index 0000000000..1f7b6d2c70 --- /dev/null +++ b/docs/source/api_doc/utils/type_.rst @@ -0,0 +1,23 @@ +hfutils.utils.type\_ +================================= + +.. currentmodule:: hfutils.utils.type_ + +.. automodule:: hfutils.utils.type_ + + +ListItemType +-------------------------- + +.. autoenum:: ListItemType + :members: render_color + + + +get_file_type +--------------------------- + +.. autofunction:: get_file_type + + + diff --git a/hfutils/utils/type_.py b/hfutils/utils/type_.py index 4adef6ae45..a56a83e9e4 100644 --- a/hfutils/utils/type_.py +++ b/hfutils/utils/type_.py @@ -1,3 +1,13 @@ +""" +This module provides functionality for determining file types and managing list item types. + +It includes an enumeration class for different types of list items, and a function to determine +the type of a given file. The module also adds support for the WebP image format. + +The module uses the ``mimetypes`` library for MIME type detection and imports custom functions +for identifying archive, model, and data files. +""" + import mimetypes import os from enum import Enum, unique @@ -14,6 +24,16 @@ class ListItemType(Enum): """ Enum class representing different types of list items. + + This enumeration defines various file and folder types that can be encountered + in a file system or list of items. Each type is associated with a unique integer value. + + Usage: + >>> item_type = ListItemType.FILE + >>> print(item_type) + ListItemType.FILE + >>> print(item_type.value) + 1 """ FILE = 0x1 @@ -28,8 +48,18 @@ def render_color(self): """ Get the render color based on the item type. + This property returns a color string associated with each item type, + which can be used for rendering or display purposes. + :return: The render color for the item type. - :rtype: str + :rtype: str or None + + :raises ValueError: If an unknown item type is encountered. + + Usage: + >>> item_type = ListItemType.FOLDER + >>> print(item_type.render_color) + blue """ if self == ListItemType.FILE: return None @@ -48,6 +78,30 @@ def render_color(self): def get_file_type(filename: Union[str, os.PathLike]) -> ListItemType: + """ + Determine the type of a given file. + + This function analyzes the provided filename and returns the corresponding ListItemType. + It uses various methods to determine the file type, including checking for archives, + model files, data files, and image files based on MIME types. + + :param filename: The name or path of the file to analyze. + :type filename: Union[str, os.PathLike] + + :return: The determined ListItemType for the given file. + :rtype: ListItemType + + :raises TypeError: If the provided filename is not a string or PathLike object. + + Usage: + >>> file_type = get_file_type('image.jpg') + >>> print(file_type) + ListItemType.IMAGE + + >>> file_type = get_file_type('data.csv') + >>> print(file_type) + ListItemType.DATA + """ if not isinstance(filename, (str, os.PathLike)): raise TypeError(f'Unknown file name type - {filename!r}') filename = os.path.basename(os.path.normcase(str(filename))) From 301196266acc7a30e1ed9057bbc19098d49f4fb0 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 20 Aug 2024 00:35:33 +0800 Subject: [PATCH 11/17] dev(narugo): add tree cli --- docs/source/api_doc/entry/ls.rst | 4 +- docs/source/api_doc/utils/type_.rst | 4 +- hfutils/entry/ls.py | 12 ++-- hfutils/entry/tree.py | 94 ++++++++++++++++++++++++++--- hfutils/utils/__init__.py | 2 +- hfutils/utils/type_.py | 42 ++++++------- test/utils/test_type.py | 58 +++++++++--------- 7 files changed, 145 insertions(+), 71 deletions(-) diff --git a/docs/source/api_doc/entry/ls.rst b/docs/source/api_doc/entry/ls.rst index c57f0e0c33..91ebeacb53 100644 --- a/docs/source/api_doc/entry/ls.rst +++ b/docs/source/api_doc/entry/ls.rst @@ -6,10 +6,10 @@ hfutils.entry.ls .. automodule:: hfutils.entry.ls -ListItemType +FileItemType ------------------------- -.. autoenum:: ListItemType +.. autoenum:: FileItemType diff --git a/docs/source/api_doc/utils/type_.rst b/docs/source/api_doc/utils/type_.rst index 1f7b6d2c70..f7678013a1 100644 --- a/docs/source/api_doc/utils/type_.rst +++ b/docs/source/api_doc/utils/type_.rst @@ -6,10 +6,10 @@ hfutils.utils.type\_ .. automodule:: hfutils.utils.type_ -ListItemType +FileItemType -------------------------- -.. autoenum:: ListItemType +.. autoenum:: FileItemType :members: render_color diff --git a/hfutils/entry/ls.py b/hfutils/entry/ls.py index eae9bb95df..13c9c92f7f 100644 --- a/hfutils/entry/ls.py +++ b/hfutils/entry/ls.py @@ -9,7 +9,7 @@ from .base import CONTEXT_SETTINGS from ..operate.base import REPO_TYPES, get_hf_client -from ..utils import get_requests_session, ListItemType, get_file_type +from ..utils import get_requests_session, FileItemType, get_file_type class ListItem: @@ -26,7 +26,7 @@ def __init__(self, item: Union[RepoFolder, RepoFile], base_dir: str): self.item = item self.base_dir = base_dir if isinstance(item, RepoFolder): - self.type = ListItemType.FOLDER + self.type = FileItemType.FOLDER else: self.type = get_file_type(item.path) @@ -101,7 +101,7 @@ def ls(repo_id: str, repo_type: str, dir_in_repo, revision: str, show_all: bool, max_size_length = 0 max_commit_info_length = 0 for item in items: - if item.type == ListItemType.FOLDER: + if item.type == FileItemType.FOLDER: size_text = '-' else: size_text = str(item.item.size) @@ -111,8 +111,8 @@ def ls(repo_id: str, repo_type: str, dir_in_repo, revision: str, show_all: bool, max_commit_info_length = max(max_commit_info_length, len(commit_text)) for item in items: - print('d' if item.type == ListItemType.FOLDER else '-', end='') - print('L' if item.type != ListItemType.FOLDER and item.item.lfs else '-', end='') + print('d' if item.type == FileItemType.FOLDER else '-', end='') + print('L' if item.type != FileItemType.FOLDER and item.item.lfs else '-', end='') commit_text = (item.item.last_commit.title or '').splitlines(keepends=False)[0] print( @@ -121,7 +121,7 @@ def ls(repo_id: str, repo_type: str, dir_in_repo, revision: str, show_all: bool, end='' ) - if item.type == ListItemType.FOLDER: + if item.type == FileItemType.FOLDER: size_text = '-' else: size_text = str(item.item.size) diff --git a/hfutils/entry/tree.py b/hfutils/entry/tree.py index 9e2eba5cfd..2c3ae9a99a 100644 --- a/hfutils/entry/tree.py +++ b/hfutils/entry/tree.py @@ -1,11 +1,83 @@ -import mimetypes +import dataclasses +import os +import re +from typing import Optional, List, Union import click -from huggingface_hub import configure_http_backend from hbutils.string import format_tree +from huggingface_hub import configure_http_backend +from natsort import natsorted + from .base import CONTEXT_SETTINGS -from ..operate.base import REPO_TYPES, get_hf_client, list_files_in_repository, RepoTypeTyping -from ..utils import get_requests_session +from ..operate.base import REPO_TYPES, list_files_in_repository, RepoTypeTyping +from ..utils import get_requests_session, hf_normpath, get_file_type, hf_fs_path, FileItemType + + +@dataclasses.dataclass +class _TreeItem: + name: str + type_: FileItemType + children: Optional[List['_TreeItem']] + + def get_name(self): + return click.style(self.name, fg=self.type_.render_color) + + def get_children(self): + return self.children if self.type_ == FileItemType.FOLDER else [] + + +def _get_tree(repo_id: str, repo_type: RepoTypeTyping, dir_in_repo: str, + revision: Optional[str] = None, show_all: bool = False) -> _TreeItem: + root = {} + for filepath in list_files_in_repository( + repo_id=repo_id, + repo_type=repo_type, + subdir=dir_in_repo, + revision=revision, + ignore_patterns=[], + ): + filename = hf_normpath(os.path.relpath(filepath, dir_in_repo)) + segments = re.split(r'[\\/]+', filename) + if any(segment.startswith('.') for segment in segments) and not show_all: + continue + + current_node = root + for i, segment in enumerate(segments): + if segment not in current_node: + if i == (len(segments) - 1): + current_node[segment] = get_file_type(segment) + else: + current_node[segment] = {} + current_node = current_node[segment] + + root_name = hf_fs_path( + repo_id=repo_id, + repo_type=repo_type, + filename=dir_in_repo, + revision=revision, + ) + + def _recursion(cur_node: Union[dict, FileItemType], parent_name: str): + if isinstance(cur_node, dict): + return _TreeItem( + name=parent_name, + type_=FileItemType.FOLDER, + children=[ + _recursion(cur_node=value, parent_name=name) + for name, value in natsorted(cur_node.items()) + ] + ) + else: + return _TreeItem( + name=parent_name, + type_=cur_node, + children=[], + ) + + return _recursion( + cur_node=root, + parent_name=root_name, + ) def _add_tree_subcommand(cli: click.Group) -> click.Group: @@ -25,15 +97,17 @@ def _add_tree_subcommand(cli: click.Group) -> click.Group: def tree(repo_id: str, repo_type: RepoTypeTyping, dir_in_repo, revision: str, show_all: bool): configure_http_backend(get_requests_session) - hf_client = get_hf_client() - - list_files_in_repository( + _tree = _get_tree( repo_id=repo_id, repo_type=repo_type, - subdir='.', + dir_in_repo=dir_in_repo or '.', revision=revision, - ignore_patterns=[], - + show_all=show_all, ) + print(format_tree( + _tree, + format_node=_TreeItem.get_name, + get_children=_TreeItem.get_children, + )) return cli diff --git a/hfutils/utils/__init__.py b/hfutils/utils/__init__.py index 7979f30e4d..375f00eb07 100644 --- a/hfutils/utils/__init__.py +++ b/hfutils/utils/__init__.py @@ -8,5 +8,5 @@ from .session import TimeoutHTTPAdapter, get_requests_session, get_random_ua from .temp import TemporaryDirectory from .tqdm_ import tqdm -from .type_ import ListItemType, get_file_type +from .type_ import FileItemType, get_file_type from .walk import walk_files diff --git a/hfutils/utils/type_.py b/hfutils/utils/type_.py index a56a83e9e4..16931acc1e 100644 --- a/hfutils/utils/type_.py +++ b/hfutils/utils/type_.py @@ -21,7 +21,7 @@ @unique -class ListItemType(Enum): +class FileItemType(Enum): """ Enum class representing different types of list items. @@ -29,9 +29,9 @@ class ListItemType(Enum): in a file system or list of items. Each type is associated with a unique integer value. Usage: - >>> item_type = ListItemType.FILE + >>> item_type = FileItemType.FILE >>> print(item_type) - ListItemType.FILE + FileItemType.FILE >>> print(item_type.value) 1 """ @@ -57,65 +57,65 @@ def render_color(self): :raises ValueError: If an unknown item type is encountered. Usage: - >>> item_type = ListItemType.FOLDER + >>> item_type = FileItemType.FOLDER >>> print(item_type.render_color) blue """ - if self == ListItemType.FILE: + if self == FileItemType.FILE: return None - elif self == ListItemType.FOLDER: + elif self == FileItemType.FOLDER: return 'blue' - elif self == ListItemType.IMAGE: + elif self == FileItemType.IMAGE: return 'magenta' - elif self == ListItemType.ARCHIVE: + elif self == FileItemType.ARCHIVE: return 'red' - elif self == ListItemType.MODEL: + elif self == FileItemType.MODEL: return 'green' - elif self == ListItemType.DATA: + elif self == FileItemType.DATA: return 'yellow' else: raise ValueError(f'Unknown type - {self!r}') # pragma: no cover -def get_file_type(filename: Union[str, os.PathLike]) -> ListItemType: +def get_file_type(filename: Union[str, os.PathLike]) -> FileItemType: """ Determine the type of a given file. - This function analyzes the provided filename and returns the corresponding ListItemType. + This function analyzes the provided filename and returns the corresponding FileItemType. It uses various methods to determine the file type, including checking for archives, model files, data files, and image files based on MIME types. :param filename: The name or path of the file to analyze. :type filename: Union[str, os.PathLike] - :return: The determined ListItemType for the given file. - :rtype: ListItemType + :return: The determined FileItemType for the given file. + :rtype: FileItemType :raises TypeError: If the provided filename is not a string or PathLike object. Usage: >>> file_type = get_file_type('image.jpg') >>> print(file_type) - ListItemType.IMAGE + FileItemType.IMAGE >>> file_type = get_file_type('data.csv') >>> print(file_type) - ListItemType.DATA + FileItemType.DATA """ if not isinstance(filename, (str, os.PathLike)): raise TypeError(f'Unknown file name type - {filename!r}') filename = os.path.basename(os.path.normcase(str(filename))) mimetype, _ = mimetypes.guess_type(filename) - type_ = ListItemType.FILE + type_ = FileItemType.FILE if is_archive_or_compressed(filename): - type_ = ListItemType.ARCHIVE + type_ = FileItemType.ARCHIVE elif is_model_file(filename): - type_ = ListItemType.MODEL + type_ = FileItemType.MODEL elif is_data_file(filename): - type_ = ListItemType.DATA + type_ = FileItemType.DATA elif mimetype: if 'image' in mimetype: - type_ = ListItemType.IMAGE + type_ = FileItemType.IMAGE return type_ diff --git a/test/utils/test_type.py b/test/utils/test_type.py index 80b9059c20..db236d6925 100644 --- a/test/utils/test_type.py +++ b/test/utils/test_type.py @@ -4,27 +4,27 @@ import pytest -from hfutils.utils import ListItemType, get_file_type +from hfutils.utils import FileItemType, get_file_type @pytest.mark.unittest class TestUtilsType: def test_list_item_type_enum(self): - assert isinstance(ListItemType.FILE, Enum) - assert ListItemType.FILE.value == 0x1 - assert ListItemType.FOLDER.value == 0x2 - assert ListItemType.IMAGE.value == 0x3 - assert ListItemType.ARCHIVE.value == 0x4 - assert ListItemType.MODEL.value == 0x5 - assert ListItemType.DATA.value == 0x6 + assert isinstance(FileItemType.FILE, Enum) + assert FileItemType.FILE.value == 0x1 + assert FileItemType.FOLDER.value == 0x2 + assert FileItemType.IMAGE.value == 0x3 + assert FileItemType.ARCHIVE.value == 0x4 + assert FileItemType.MODEL.value == 0x5 + assert FileItemType.DATA.value == 0x6 @pytest.mark.parametrize("item_type, expected_color", [ - (ListItemType.FILE, None), - (ListItemType.FOLDER, 'blue'), - (ListItemType.IMAGE, 'magenta'), - (ListItemType.ARCHIVE, 'red'), - (ListItemType.MODEL, 'green'), - (ListItemType.DATA, 'yellow'), + (FileItemType.FILE, None), + (FileItemType.FOLDER, 'blue'), + (FileItemType.IMAGE, 'magenta'), + (FileItemType.ARCHIVE, 'red'), + (FileItemType.MODEL, 'green'), + (FileItemType.DATA, 'yellow'), ]) def test_render_color(self, item_type, expected_color): assert item_type.render_color == expected_color @@ -34,15 +34,15 @@ class UnknownType(Enum): UNKNOWN = 0x7 with pytest.raises(ValueError, match='Unknown type'): - ListItemType.render_color.fget(UnknownType.UNKNOWN) + FileItemType.render_color.fget(UnknownType.UNKNOWN) @pytest.mark.parametrize("filename, expected_type", [ - ('file.txt', ListItemType.FILE), - ('image.jpg', ListItemType.IMAGE), - ('archive.zip', ListItemType.ARCHIVE), - ('model.pkl', ListItemType.MODEL), - ('data.csv', ListItemType.DATA), - ('folder', ListItemType.FILE), # Assuming folders are not detected by filename + ('file.txt', FileItemType.FILE), + ('image.jpg', FileItemType.IMAGE), + ('archive.zip', FileItemType.ARCHIVE), + ('model.pkl', FileItemType.MODEL), + ('data.csv', FileItemType.DATA), + ('folder', FileItemType.FILE), # Assuming folders are not detected by filename ]) def test_get_file_type(self, filename, expected_type): with patch('mimetypes.guess_type') as mock_guess_type, \ @@ -50,11 +50,11 @@ def test_get_file_type(self, filename, expected_type): patch('hfutils.utils.type_.is_model_file') as mock_is_model, \ patch('hfutils.utils.type_.is_data_file') as mock_is_data: mock_guess_type.return_value = (None, None) - mock_is_archive.return_value = expected_type == ListItemType.ARCHIVE - mock_is_model.return_value = expected_type == ListItemType.MODEL - mock_is_data.return_value = expected_type == ListItemType.DATA + mock_is_archive.return_value = expected_type == FileItemType.ARCHIVE + mock_is_model.return_value = expected_type == FileItemType.MODEL + mock_is_data.return_value = expected_type == FileItemType.DATA - if expected_type == ListItemType.IMAGE: + if expected_type == FileItemType.IMAGE: mock_guess_type.return_value = ('image/jpeg', None) assert get_file_type(filename) == expected_type @@ -62,7 +62,7 @@ def test_get_file_type(self, filename, expected_type): def test_get_file_type_with_path(self): with patch('mimetypes.guess_type') as mock_guess_type: mock_guess_type.return_value = (None, None) - assert get_file_type('/path/to/file.txt') == ListItemType.FILE + assert get_file_type('/path/to/file.txt') == FileItemType.FILE def test_get_file_type_invalid_input(self): with pytest.raises(TypeError, match='Unknown file name type'): @@ -71,7 +71,7 @@ def test_get_file_type_invalid_input(self): def test_get_file_type_empty_string(self): with patch('mimetypes.guess_type') as mock_guess_type: mock_guess_type.return_value = (None, None) - assert get_file_type('') == ListItemType.FILE + assert get_file_type('') == FileItemType.FILE @pytest.mark.parametrize("filename", [ 'file.txt', 'FILE.TXT', 'FiLe.TxT', @@ -81,10 +81,10 @@ def test_get_file_type_empty_string(self): def test_get_file_type_case_insensitive(self, filename): with patch('mimetypes.guess_type') as mock_guess_type: mock_guess_type.return_value = (None, None) - assert get_file_type(filename) == ListItemType.FILE + assert get_file_type(filename) == FileItemType.FILE def test_get_file_type_with_pathlike_object(self): with patch('mimetypes.guess_type') as mock_guess_type: mock_guess_type.return_value = (None, None) path = os.path.join('path', 'to', 'file.txt') - assert get_file_type(os.fspath(path)) == ListItemType.FILE + assert get_file_type(os.fspath(path)) == FileItemType.FILE From 1e03e04cdb36686e2d8ed65fb2f747ae53bb31bd Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 20 Aug 2024 01:07:41 +0800 Subject: [PATCH 12/17] dev(narugo): add unittest for tree cli --- hfutils/entry/download.py | 8 +- test/entry/test_tree.py | 183 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+), 2 deletions(-) create mode 100644 test/entry/test_tree.py diff --git a/hfutils/entry/download.py b/hfutils/entry/download.py index 0ffabed4f0..7069af0aec 100644 --- a/hfutils/entry/download.py +++ b/hfutils/entry/download.py @@ -7,7 +7,7 @@ from .base import CONTEXT_SETTINGS, command_wrap, ClickErrorException from ..operate import download_file_to_file, download_archive_as_directory, download_directory_as_directory -from ..operate.base import REPO_TYPES, RepoTypeTyping +from ..operate.base import REPO_TYPES, RepoTypeTyping, _IGNORE_PATTERN_UNSET from ..utils import get_requests_session @@ -55,12 +55,15 @@ def _add_download_subcommand(cli: click.Group) -> click.Group: help='Just check the file size when validating the downloaded files.', show_default=True) @click.option('--tmpdir', 'tmpdir', type=str, default=None, help='Use custom temporary Directory.', show_default=True) + @click.option('--all', 'show_all', is_flag=True, type=bool, default=False, + help='Show all files, including hidden files.', show_default=True) @command_wrap() def download( repo_id: str, repo_type: RepoTypeTyping, file_in_repo: Optional[str], archive_in_repo: Optional[str], dir_in_repo: Optional[str], output_path: str, revision: str, max_workers: int, - password: Optional[str], wildcard: Optional[str], soft_mode_when_check: bool, tmpdir: Optional[str] + password: Optional[str], wildcard: Optional[str], soft_mode_when_check: bool, tmpdir: Optional[str], + show_all: bool = False, ): """ Download data from HuggingFace repositories. @@ -135,6 +138,7 @@ def download( silent=False, max_workers=max_workers, soft_mode_when_check=soft_mode_when_check, + ignore_patterns=_IGNORE_PATTERN_UNSET if not show_all else [], ) else: diff --git a/test/entry/test_tree.py b/test/entry/test_tree.py new file mode 100644 index 0000000000..8d8b9f961c --- /dev/null +++ b/test/entry/test_tree.py @@ -0,0 +1,183 @@ +import click +import pytest +from hbutils.testing import simulate_entry + +from hfutils.entry import hfutilscli + + +@pytest.mark.unittest +class TestEntryTree: + def test_simple_tree_1(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'tree', + '-r', 'deepghs/test_nested_dataset', + ]) + assert result.exitcode == 0 + lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + assert lines == [ + 'datasets/deepghs/test_nested_dataset@main/.', + '├── README.md', + '├── data.parquet', + '├── images', + '│ ├── 20240808', + '│ │ ├── 20240808015751528545_642b8ce09a5b1543e88cf95e359d39218d6b3ac5__narugo.json', + '│ │ ├── 20240808015751528545_642b8ce09a5b1543e88cf95e359d39218d6b3ac5__narugo.tar', + '│ │ ├── 20240808091226009067_be359abd170ee3e1a37d3bda7cdf9ff2490f5380__narugo.json', + '│ │ └── 20240808091226009067_be359abd170ee3e1a37d3bda7cdf9ff2490f5380__narugo.tar', + '│ ├── 20240810', + '│ │ ├── 20240810025407329132_7fbe690d6dca73e971036fbb884eba67d11c68d7__narugo.json', + '│ │ ├── 20240810025407329132_7fbe690d6dca73e971036fbb884eba67d11c68d7__narugo.tar', + '│ │ ├── 20240810025642281532_4c13dc63689d93e25a5de44bc9add04ea7d56162__narugo.json', + '│ │ ├── 20240810025642281532_4c13dc63689d93e25a5de44bc9add04ea7d56162__narugo.tar', + '│ │ ├── 20240810220450715507_f95017bb0ff97ee35cd878ba11e6c3d5b4eb6e1f__narugo.json', + '│ │ ├── 20240810220450715507_f95017bb0ff97ee35cd878ba11e6c3d5b4eb6e1f__narugo.tar', + '│ │ ├── 20240810222438167877_c60911f8922933991a20d190175c9eada582af7b__narugo.json', + '│ │ └── 20240810222438167877_c60911f8922933991a20d190175c9eada582af7b__narugo.tar', + '│ └── 20240811', + '│ ├── 20240811011334412620_ce548cb70673e563ad46a37a75b6c1f933b17292__narugo.json', + '│ └── 20240811011334412620_ce548cb70673e563ad46a37a75b6c1f933b17292__narugo.tar', + '├── meta.json', + '├── samples', + '│ ├── colored', + '│ │ ├── 0.webp', + '│ │ ├── 1.webp', + '│ │ ├── 2.webp', + '│ │ ├── 3.webp', + '│ │ ├── 4.webp', + '│ │ ├── 5.webp', + '│ │ ├── 6.webp', + '│ │ └── 7.webp', + '│ └── monochrome', + '│ ├── 0.webp', + '│ ├── 1.webp', + '│ ├── 2.webp', + '│ └── 3.webp', + '└── unarchived', + ' └── 20240810222438167877_c60911f8922933991a20d190175c9eada582af7b__narugo.parquet', + ] + + def test_simple_tree_all(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'tree', + '-r', 'deepghs/test_nested_dataset', + '--all' + ]) + assert result.exitcode == 0 + lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + assert lines == [ + 'datasets/deepghs/test_nested_dataset@main/.', + '├── .gitattributes', + '├── README.md', + '├── data.parquet', + '├── images', + '│ ├── 20240808', + '│ │ ├── 20240808015751528545_642b8ce09a5b1543e88cf95e359d39218d6b3ac5__narugo.json', + '│ │ ├── 20240808015751528545_642b8ce09a5b1543e88cf95e359d39218d6b3ac5__narugo.tar', + '│ │ ├── 20240808091226009067_be359abd170ee3e1a37d3bda7cdf9ff2490f5380__narugo.json', + '│ │ └── 20240808091226009067_be359abd170ee3e1a37d3bda7cdf9ff2490f5380__narugo.tar', + '│ ├── 20240810', + '│ │ ├── 20240810025407329132_7fbe690d6dca73e971036fbb884eba67d11c68d7__narugo.json', + '│ │ ├── 20240810025407329132_7fbe690d6dca73e971036fbb884eba67d11c68d7__narugo.tar', + '│ │ ├── 20240810025642281532_4c13dc63689d93e25a5de44bc9add04ea7d56162__narugo.json', + '│ │ ├── 20240810025642281532_4c13dc63689d93e25a5de44bc9add04ea7d56162__narugo.tar', + '│ │ ├── 20240810220450715507_f95017bb0ff97ee35cd878ba11e6c3d5b4eb6e1f__narugo.json', + '│ │ ├── 20240810220450715507_f95017bb0ff97ee35cd878ba11e6c3d5b4eb6e1f__narugo.tar', + '│ │ ├── 20240810222438167877_c60911f8922933991a20d190175c9eada582af7b__narugo.json', + '│ │ └── 20240810222438167877_c60911f8922933991a20d190175c9eada582af7b__narugo.tar', + '│ └── 20240811', + '│ ├── 20240811011334412620_ce548cb70673e563ad46a37a75b6c1f933b17292__narugo.json', + '│ └── 20240811011334412620_ce548cb70673e563ad46a37a75b6c1f933b17292__narugo.tar', + '├── meta.json', + '├── samples', + '│ ├── colored', + '│ │ ├── 0.webp', + '│ │ ├── 1.webp', + '│ │ ├── 2.webp', + '│ │ ├── 3.webp', + '│ │ ├── 4.webp', + '│ │ ├── 5.webp', + '│ │ ├── 6.webp', + '│ │ └── 7.webp', + '│ └── monochrome', + '│ ├── 0.webp', + '│ ├── 1.webp', + '│ ├── 2.webp', + '│ └── 3.webp', + '└── unarchived', + ' └── 20240810222438167877_c60911f8922933991a20d190175c9eada582af7b__narugo.parquet' + ] + + def test_tree_subdir_1(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'tree', + '-r', 'deepghs/test_nested_dataset', + '-d', 'images' + ]) + assert result.exitcode == 0 + lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + assert lines == [ + "datasets/deepghs/test_nested_dataset@main/images", + "├── 20240808", + "│ ├── 20240808015751528545_642b8ce09a5b1543e88cf95e359d39218d6b3ac5__narugo.json", + "│ ├── 20240808015751528545_642b8ce09a5b1543e88cf95e359d39218d6b3ac5__narugo.tar", + "│ ├── 20240808091226009067_be359abd170ee3e1a37d3bda7cdf9ff2490f5380__narugo.json", + "│ └── 20240808091226009067_be359abd170ee3e1a37d3bda7cdf9ff2490f5380__narugo.tar", + "├── 20240810", + "│ ├── 20240810025407329132_7fbe690d6dca73e971036fbb884eba67d11c68d7__narugo.json", + "│ ├── 20240810025407329132_7fbe690d6dca73e971036fbb884eba67d11c68d7__narugo.tar", + "│ ├── 20240810025642281532_4c13dc63689d93e25a5de44bc9add04ea7d56162__narugo.json", + "│ ├── 20240810025642281532_4c13dc63689d93e25a5de44bc9add04ea7d56162__narugo.tar", + "│ ├── 20240810220450715507_f95017bb0ff97ee35cd878ba11e6c3d5b4eb6e1f__narugo.json", + "│ ├── 20240810220450715507_f95017bb0ff97ee35cd878ba11e6c3d5b4eb6e1f__narugo.tar", + "│ ├── 20240810222438167877_c60911f8922933991a20d190175c9eada582af7b__narugo.json", + "│ └── 20240810222438167877_c60911f8922933991a20d190175c9eada582af7b__narugo.tar", + "└── 20240811", + " ├── 20240811011334412620_ce548cb70673e563ad46a37a75b6c1f933b17292__narugo.json", + " └── 20240811011334412620_ce548cb70673e563ad46a37a75b6c1f933b17292__narugo.tar" + ] + + def test_tree_subdir_2(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'tree', + '-r', 'deepghs/test_nested_dataset', + '-d', 'samples' + ]) + assert result.exitcode == 0 + lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + assert lines == [ + "datasets/deepghs/test_nested_dataset@main/samples", + "├── colored", + "│ ├── 0.webp", + "│ ├── 1.webp", + "│ ├── 2.webp", + "│ ├── 3.webp", + "│ ├── 4.webp", + "│ ├── 5.webp", + "│ ├── 6.webp", + "│ └── 7.webp", + "└── monochrome", + " ├── 0.webp", + " ├── 1.webp", + " ├── 2.webp", + " └── 3.webp" + ] + + def test_tree_subdir_3(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'tree', + '-r', 'deepghs/test_nested_dataset', + '-d', 'samples/colored' + ]) + assert result.exitcode == 0 + lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + assert lines == [ + "datasets/deepghs/test_nested_dataset@main/samples/colored", + "├── 0.webp", + "├── 1.webp", + "├── 2.webp", + "├── 3.webp", + "├── 4.webp", + "├── 5.webp", + "├── 6.webp", + "└── 7.webp" + ] From 28f1d5c7a50c267d6fc974a89ac725f8a81e4d89 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 20 Aug 2024 01:17:32 +0800 Subject: [PATCH 13/17] dev(narugo): add unittest for tree cli --- test/entry/test_tree.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/entry/test_tree.py b/test/entry/test_tree.py index 8d8b9f961c..82d576754f 100644 --- a/test/entry/test_tree.py +++ b/test/entry/test_tree.py @@ -1,3 +1,5 @@ +from pprint import pprint + import click import pytest from hbutils.testing import simulate_entry @@ -169,6 +171,8 @@ def test_tree_subdir_3(self): '-d', 'samples/colored' ]) assert result.exitcode == 0 + text = click.unstyle(result.stdout).strip() + pprint(text) lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) assert lines == [ "datasets/deepghs/test_nested_dataset@main/samples/colored", From daf43a7959b4ef9d314efca50944e02d131fcdac Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 20 Aug 2024 01:27:35 +0800 Subject: [PATCH 14/17] dev(narugo): add unittest for tree cli --- test/entry/test_tree.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/entry/test_tree.py b/test/entry/test_tree.py index 82d576754f..6d8e6b4d7e 100644 --- a/test/entry/test_tree.py +++ b/test/entry/test_tree.py @@ -1,5 +1,3 @@ -from pprint import pprint - import click import pytest from hbutils.testing import simulate_entry @@ -16,6 +14,7 @@ def test_simple_tree_1(self): ]) assert result.exitcode == 0 lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + lines = list(filter(bool, lines)) assert lines == [ 'datasets/deepghs/test_nested_dataset@main/.', '├── README.md', @@ -66,6 +65,7 @@ def test_simple_tree_all(self): ]) assert result.exitcode == 0 lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + lines = list(filter(bool, lines)) assert lines == [ 'datasets/deepghs/test_nested_dataset@main/.', '├── .gitattributes', @@ -117,6 +117,7 @@ def test_tree_subdir_1(self): ]) assert result.exitcode == 0 lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + lines = list(filter(bool, lines)) assert lines == [ "datasets/deepghs/test_nested_dataset@main/images", "├── 20240808", @@ -146,6 +147,7 @@ def test_tree_subdir_2(self): ]) assert result.exitcode == 0 lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + lines = list(filter(bool, lines)) assert lines == [ "datasets/deepghs/test_nested_dataset@main/samples", "├── colored", @@ -171,9 +173,8 @@ def test_tree_subdir_3(self): '-d', 'samples/colored' ]) assert result.exitcode == 0 - text = click.unstyle(result.stdout).strip() - pprint(text) lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + lines = list(filter(bool, lines)) assert lines == [ "datasets/deepghs/test_nested_dataset@main/samples/colored", "├── 0.webp", From 0f94b68f068ab8b40fbf30efed8693bf80c5dd54 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 20 Aug 2024 01:49:26 +0800 Subject: [PATCH 15/17] dev(narugo): add better layout --- hfutils/entry/tree.py | 41 +++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/hfutils/entry/tree.py b/hfutils/entry/tree.py index 2c3ae9a99a..ce4d3fbd4f 100644 --- a/hfutils/entry/tree.py +++ b/hfutils/entry/tree.py @@ -6,10 +6,11 @@ import click from hbutils.string import format_tree from huggingface_hub import configure_http_backend +from huggingface_hub.hf_api import RepoFile from natsort import natsorted from .base import CONTEXT_SETTINGS -from ..operate.base import REPO_TYPES, list_files_in_repository, RepoTypeTyping +from ..operate.base import REPO_TYPES, list_files_in_repository, RepoTypeTyping, get_hf_client from ..utils import get_requests_session, hf_normpath, get_file_type, hf_fs_path, FileItemType @@ -18,9 +19,14 @@ class _TreeItem: name: str type_: FileItemType children: Optional[List['_TreeItem']] + exist: bool = True def get_name(self): - return click.style(self.name, fg=self.type_.render_color) + return click.style( + self.name, + fg=self.type_.render_color if self.exist else None, + strikethrough=not self.exist, + ) def get_children(self): return self.children if self.type_ == FileItemType.FOLDER else [] @@ -38,7 +44,7 @@ def _get_tree(repo_id: str, repo_type: RepoTypeTyping, dir_in_repo: str, ): filename = hf_normpath(os.path.relpath(filepath, dir_in_repo)) segments = re.split(r'[\\/]+', filename) - if any(segment.startswith('.') for segment in segments) and not show_all: + if any(segment.startswith('.') and segment != '.' for segment in segments) and not show_all: continue current_node = root @@ -57,26 +63,49 @@ def _get_tree(repo_id: str, repo_type: RepoTypeTyping, dir_in_repo: str, revision=revision, ) - def _recursion(cur_node: Union[dict, FileItemType], parent_name: str): + def _recursion(cur_node: Union[dict, FileItemType], parent_name: str, is_exist: bool = False): if isinstance(cur_node, dict): return _TreeItem( name=parent_name, type_=FileItemType.FOLDER, children=[ - _recursion(cur_node=value, parent_name=name) + _recursion(cur_node=value, parent_name=name, is_exist=is_exist) for name, value in natsorted(cur_node.items()) - ] + ], + exist=is_exist, ) else: return _TreeItem( name=parent_name, type_=cur_node, children=[], + exist=is_exist, ) + exist = True + if not root: + hf_client = get_hf_client() + paths = hf_client.get_paths_info( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + paths=[dir_in_repo], + ) + if len(paths) == 0: + exist = False + elif len(paths) == 1: + pathobj = paths[0] + if isinstance(pathobj, RepoFile): # the subdir is a file + root = get_file_type(dir_in_repo) + else: + assert len(paths) == 1, \ + f'Multiple path {dir_in_repo!r} found in repo {root_name!r}, ' \ + f'this must be caused by HuggingFace API.' # pragma: no cover + return _recursion( cur_node=root, parent_name=root_name, + is_exist=exist, ) From 48ca86ce253633ed30e743c0305b048e541142d6 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 20 Aug 2024 02:02:46 +0800 Subject: [PATCH 16/17] dev(narugo): better layout test --- hfutils/entry/tree.py | 2 +- test/entry/test_tree.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/hfutils/entry/tree.py b/hfutils/entry/tree.py index ce4d3fbd4f..d5519642d7 100644 --- a/hfutils/entry/tree.py +++ b/hfutils/entry/tree.py @@ -26,7 +26,7 @@ def get_name(self): self.name, fg=self.type_.render_color if self.exist else None, strikethrough=not self.exist, - ) + ) + ('' if self.exist else ' ') def get_children(self): return self.children if self.type_ == FileItemType.FOLDER else [] diff --git a/test/entry/test_tree.py b/test/entry/test_tree.py index 6d8e6b4d7e..301cc75795 100644 --- a/test/entry/test_tree.py +++ b/test/entry/test_tree.py @@ -186,3 +186,29 @@ def test_tree_subdir_3(self): "├── 6.webp", "└── 7.webp" ] + + def test_tree_file(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'tree', + '-r', 'deepghs/test_nested_dataset', + '-d', 'meta.json' + ]) + assert result.exitcode == 0 + lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + lines = list(filter(bool, lines)) + assert lines == [ + "datasets/deepghs/test_nested_dataset@main/meta.json" + ] + + def test_tree_not_exist(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'tree', + '-r', 'deepghs/test_nested_dataset', + '-d', 'not_exist' + ]) + assert result.exitcode == 0 + lines = click.unstyle(result.stdout).strip().splitlines(keepends=False) + lines = list(filter(bool, lines)) + assert lines == [ + "datasets/deepghs/test_nested_dataset@main/not_exist " + ] From 1389ce920ba81df266dfe1f735dc2511676e6521 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 20 Aug 2024 02:07:11 +0800 Subject: [PATCH 17/17] dev(narugo): add docs for item --- docs/source/api_doc/entry/index.rst | 3 +- docs/source/api_doc/entry/tree.rst | 15 ++++ hfutils/entry/tree.py | 113 ++++++++++++++++++++++++++-- 3 files changed, 122 insertions(+), 9 deletions(-) create mode 100644 docs/source/api_doc/entry/tree.rst diff --git a/docs/source/api_doc/entry/index.rst b/docs/source/api_doc/entry/index.rst index b1bef7074f..00d33070f7 100644 --- a/docs/source/api_doc/entry/index.rst +++ b/docs/source/api_doc/entry/index.rst @@ -17,6 +17,7 @@ hfutils.entry index_ ls ls_repo + rollback + tree upload whoami - rollback diff --git a/docs/source/api_doc/entry/tree.rst b/docs/source/api_doc/entry/tree.rst new file mode 100644 index 0000000000..7e7dc776b0 --- /dev/null +++ b/docs/source/api_doc/entry/tree.rst @@ -0,0 +1,15 @@ +hfutils.entry.tree +================================ + +.. currentmodule:: hfutils.entry.tree + +.. automodule:: hfutils.entry.tree + + +TreeItem +---------------------------------- + +.. autoclass:: TreeItem + + + diff --git a/hfutils/entry/tree.py b/hfutils/entry/tree.py index d5519642d7..3ad7778461 100644 --- a/hfutils/entry/tree.py +++ b/hfutils/entry/tree.py @@ -1,3 +1,20 @@ +""" +This module provides functionality for listing and displaying files from a HuggingFace repository in a tree-like structure. + +It includes functions for parsing repository paths, retrieving file information, and formatting the output as a tree. +The module also defines a CLI command for easy interaction with the tree functionality. + +Key components: + +- TreeItem: A dataclass representing an item (file or folder) in the tree structure. +- _get_tree: Function to retrieve the tree structure of files in a HuggingFace repository. +- _add_tree_subcommand: Function to add the 'tree' subcommand to a Click CLI group. + +Usage: +This module is typically used as part of a larger CLI application for interacting with HuggingFace repositories. +The 'tree' command can be used to visualize the structure of files in a repository. +""" + import dataclasses import os import re @@ -15,13 +32,37 @@ @dataclasses.dataclass -class _TreeItem: +class TreeItem: + """ + Represents an item (file or folder) in the tree structure. + + :param name: The name of the item. + :type name: str + :param type_: The type of the item (file or folder). + :type type_: FileItemType + :param children: List of child items if this is a folder. + :type children: Optional[List[TreeItem]] + :param exist: Whether the item exists in the repository. + :type exist: bool + + :ivar name: The name of the item. + :ivar type_: The type of the item. + :ivar children: List of child items. + :ivar exist: Existence status of the item. + """ + name: str type_: FileItemType - children: Optional[List['_TreeItem']] + children: Optional[List['TreeItem']] exist: bool = True def get_name(self): + """ + Get the formatted name of the item for display. + + :return: Formatted name string with color and strike-through if applicable. + :rtype: str + """ return click.style( self.name, fg=self.type_.render_color if self.exist else None, @@ -29,11 +70,34 @@ def get_name(self): ) + ('' if self.exist else ' ') def get_children(self): + """ + Get the children of this item if it's a folder. + + :return: List of child items if folder, empty list otherwise. + :rtype: List[TreeItem] + """ return self.children if self.type_ == FileItemType.FOLDER else [] def _get_tree(repo_id: str, repo_type: RepoTypeTyping, dir_in_repo: str, - revision: Optional[str] = None, show_all: bool = False) -> _TreeItem: + revision: Optional[str] = None, show_all: bool = False) -> TreeItem: + """ + Retrieve the tree structure of files in a HuggingFace repository. + + :param repo_id: The ID of the repository. + :type repo_id: str + :param repo_type: The type of the repository. + :type repo_type: RepoTypeTyping + :param dir_in_repo: The directory in the repository to start from. + :type dir_in_repo: str + :param revision: The revision of the repository to use. + :type revision: Optional[str] + :param show_all: Whether to show hidden files. + :type show_all: bool + + :return: The root TreeItem representing the directory structure. + :rtype: TreeItem + """ root = {} for filepath in list_files_in_repository( repo_id=repo_id, @@ -65,7 +129,7 @@ def _get_tree(repo_id: str, repo_type: RepoTypeTyping, dir_in_repo: str, def _recursion(cur_node: Union[dict, FileItemType], parent_name: str, is_exist: bool = False): if isinstance(cur_node, dict): - return _TreeItem( + return TreeItem( name=parent_name, type_=FileItemType.FOLDER, children=[ @@ -75,7 +139,7 @@ def _recursion(cur_node: Union[dict, FileItemType], parent_name: str, is_exist: exist=is_exist, ) else: - return _TreeItem( + return TreeItem( name=parent_name, type_=cur_node, children=[], @@ -110,7 +174,26 @@ def _recursion(cur_node: Union[dict, FileItemType], parent_name: str, is_exist: def _add_tree_subcommand(cli: click.Group) -> click.Group: - @cli.command('tree', help='List files from HuggingFace repository.\n\n' + """ + Add the 'tree' subcommand to a Click CLI group. + + This function defines a new 'tree' command that lists files from a HuggingFace repository + in a tree-like structure. + + :param cli: The Click CLI group to add the command to. + :type cli: click.Group + + :return: The modified CLI group with the 'tree' command added. + :rtype: click.Group + + Usage: + This function is typically called when setting up a CLI application: + + cli = click.Group() + cli = _add_tree_subcommand(cli) + """ + + @cli.command('tree', help='List files as a tree from HuggingFace repository.\n\n' 'Set environment $HF_TOKEN to use your own access token.', context_settings=CONTEXT_SETTINGS) @click.option('-r', '--repository', 'repo_id', type=str, required=True, @@ -124,6 +207,20 @@ def _add_tree_subcommand(cli: click.Group) -> click.Group: @click.option('-a', '--all', 'show_all', is_flag=True, type=bool, default=False, help='Show all files, including hidden files.', show_default=True) def tree(repo_id: str, repo_type: RepoTypeTyping, dir_in_repo, revision: str, show_all: bool): + """ + List files as a tree from a HuggingFace repository in a tree-like structure. + + :param repo_id: The ID of the repository. + :type repo_id: str + :param repo_type: The type of the repository. + :type repo_type: RepoTypeTyping + :param dir_in_repo: The directory in the repository to start from. + :type dir_in_repo: str + :param revision: The revision of the repository to use. + :type revision: str + :param show_all: Whether to show hidden files. + :type show_all: bool + """ configure_http_backend(get_requests_session) _tree = _get_tree( @@ -135,8 +232,8 @@ def tree(repo_id: str, repo_type: RepoTypeTyping, dir_in_repo, revision: str, sh ) print(format_tree( _tree, - format_node=_TreeItem.get_name, - get_children=_TreeItem.get_children, + format_node=TreeItem.get_name, + get_children=TreeItem.get_children, )) return cli