From 00f2e6c67fd2fe4c4accd65af98c2c99c024a656 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 11 Dec 2024 14:33:36 +0800 Subject: [PATCH] dev(narugo): save this half completed code --- hfutils/repository/size.py | 86 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 hfutils/repository/size.py diff --git a/hfutils/repository/size.py b/hfutils/repository/size.py new file mode 100644 index 00000000000..9c47afe5172 --- /dev/null +++ b/hfutils/repository/size.py @@ -0,0 +1,86 @@ +from collections.abc import Sequence +from typing import Optional, List + +from huggingface_hub.hf_api import RepoFile + +from ..operate import list_all_with_pattern +from ..operate.base import RepoTypeTyping +from ..utils import hf_normpath + + +class RepoFileItem: + def __init__(self, repo_file: RepoFile): + self.file: RepoFile = repo_file + + @property + def size(self) -> int: + return self.file.lfs.size if self.file.lfs else self.file.size + + @property + def is_lfs(self) -> bool: + return bool(self.file.lfs) + + @property + def lfs_sha256(self) -> Optional[str]: + return self.file.lfs.sha256 if self.file.lfs else None + + @property + def blob_id(self) -> str: + return self.file.blob_id + + @property + def path(self) -> str: + return hf_normpath(self.file.path) + + def _value(self): + return self.size, self.is_lfs, self.lfs_sha256, self.blob_id, self.path + + def __eq__(self, other): + return isinstance(other, RepoFileItem) and self._value() == other._value() + + +class RepoFileList(Sequence[RepoFileItem]): + def __init__(self, repo_id: str, items: List[RepoFileItem], + repo_type: RepoTypeTyping = 'dataset', revision: str = 'main'): + self.repo_id = repo_id + self.repo_type = repo_type + self.revision = revision + self._file_items = list(items) + self._total_size = 0 + for item in self._file_items: + self._total_size += item.size + + def __getitem__(self, index): + return self._file_items[index] + + def __len__(self) -> int: + return len(self._file_items) + + +def hf_hub_repo_analysis( + repo_id: str, pattern: str = '**/*', repo_type: RepoTypeTyping = 'dataset', + revision: str = 'main', hf_token: Optional[str] = None, silent: bool = False, + subdir: str = '', **kwargs, +) -> RepoFileList: + if subdir and subdir != '.': + pattern = f'{subdir}/{pattern}' + + result = [] + for item in list_all_with_pattern( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + pattern=pattern, + hf_token=hf_token, + silent=silent, + **kwargs + ): + if isinstance(item, RepoFile): + result.append(RepoFileItem(item)) + + return RepoFileList( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + items=result + )