From 131abe4a6d676109c524f5af939f17386943a8d5 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Thu, 16 May 2024 19:48:23 +0800 Subject: [PATCH] dev(narugo): add ls_repo into the commands --- .github/workflows/ir_repos.yml | 2 +- hfutils/__main__.py | 4 ++ hfutils/entry/cli.py | 2 + hfutils/entry/ls_repo.py | 74 ++++++++++++++++++++++++++++++++++ test/entry/test_ls_repo.py | 60 +++++++++++++++++++++++++++ 5 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 hfutils/__main__.py create mode 100644 hfutils/entry/ls_repo.py create mode 100644 test/entry/test_ls_repo.py diff --git a/.github/workflows/ir_repos.yml b/.github/workflows/ir_repos.yml index c3438df2bc..da21d740fe 100644 --- a/.github/workflows/ir_repos.yml +++ b/.github/workflows/ir_repos.yml @@ -4,7 +4,7 @@ name: Irregular Repos on: workflow_dispatch: schedule: - - cron: '0 * * * 0' + - cron: '0 12 * * 0' jobs: check_irregular_repo: diff --git a/hfutils/__main__.py b/hfutils/__main__.py new file mode 100644 index 0000000000..7062ebf8f6 --- /dev/null +++ b/hfutils/__main__.py @@ -0,0 +1,4 @@ +from .entry import hfutilscli + +if __name__ == '__main__': + hfutilscli() diff --git a/hfutils/entry/cli.py b/hfutils/entry/cli.py index 00f4d929d9..98030edc19 100644 --- a/hfutils/entry/cli.py +++ b/hfutils/entry/cli.py @@ -1,6 +1,7 @@ from .dispatch import hfutilcli from .download import _add_download_subcommand from .ls import _add_ls_subcommand +from .ls_repo import _add_ls_repo_subcommand from .upload import _add_upload_subcommand from .whoami import _add_whoami_subcommand @@ -9,6 +10,7 @@ _add_upload_subcommand, _add_ls_subcommand, _add_whoami_subcommand, + _add_ls_repo_subcommand, ] cli = hfutilcli diff --git a/hfutils/entry/ls_repo.py b/hfutils/entry/ls_repo.py new file mode 100644 index 0000000000..fc108ce00b --- /dev/null +++ b/hfutils/entry/ls_repo.py @@ -0,0 +1,74 @@ +import fnmatch +from typing import Optional + +import click +from huggingface_hub.utils import LocalTokenNotFoundError + +from .base import CONTEXT_SETTINGS, ClickErrorException +from ..operate.base import REPO_TYPES, get_hf_client + + +class NoLocalAuthentication(ClickErrorException): + """ + Exception raised when there is no local authentication token. + """ + exit_code = 0x31 + + +def _add_ls_repo_subcommand(cli: click.Group) -> click.Group: + """ + Add the ls_repo subcommand to the CLI. + + :param cli: The click Group object. + :type cli: click.Group + + :return: The updated click Group object. + :rtype: click.Group + """ + + @cli.command('ls_repo', help='List repositories from HuggingFace.\n\n' + 'Set environment $HF_TOKEN to use your own access token.', + context_settings=CONTEXT_SETTINGS) + @click.option('-a', '--author', 'author', type=str, default=None, + help='Author of the repositories. Search my repositories when not given.') + @click.option('-t', '--type', 'repo_type', type=click.Choice(REPO_TYPES), default='dataset', + help='Type of the HuggingFace repository.', show_default=True) + @click.option('-p', '--pattern', 'pattern', type=str, default='*', + help='Pattern of the repository names.', show_default=True) + def ls(author: Optional[str], repo_type: str, pattern: str): + """ + List repositories from HuggingFace. + + :param author: Author of the repositories. + :type author: Optional[str] + :param repo_type: Type of the HuggingFace repository. + :type repo_type: str + :param pattern: Pattern of the repository names. + :type pattern: str + """ + hf_client = get_hf_client() + if not author: + try: + info = hf_client.whoami() + author = author or info['name'] + except LocalTokenNotFoundError: + raise NoLocalAuthentication( + 'Authentication failed.\n' + 'Make sure you have set the correct Huggingface token.\n' + 'Or if need to use this with guest mode, please explicitly set the `-a` option.' + ) + + if repo_type == 'model': + r = hf_client.list_models(author=author) + elif repo_type == 'dataset': + r = hf_client.list_datasets(author=author) + elif repo_type == 'space': + r = hf_client.list_spaces(author=author) + else: + raise ValueError(f'Unknown repository type - {repo_type!r}.') # pragma: no cover + + for repo_item in r: + if fnmatch.fnmatch(repo_item.id, pattern): + print(repo_item.id) + + return cli diff --git a/test/entry/test_ls_repo.py b/test/entry/test_ls_repo.py new file mode 100644 index 0000000000..9b27a7282b --- /dev/null +++ b/test/entry/test_ls_repo.py @@ -0,0 +1,60 @@ +import os +from unittest.mock import patch + +import click +import pytest +from hbutils.testing import simulate_entry +from huggingface_hub import HfApi + +from hfutils.entry import hfutilscli + + +@pytest.fixture() +def no_hf_token(): + def _get_hf_client(): + return HfApi(token='') + + with patch('hfutils.entry.ls_repo.get_hf_client', _get_hf_client), \ + patch.dict(os.environ, {'HF_TOKEN': ''}): + yield + + +@pytest.mark.unittest +class TestEntryLsRepo: + def test_ls_repo(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'ls_repo', + ]) + assert result.exitcode == 0 + repos = click.unstyle(result.stdout).splitlines(keepends=False) + assert 'narugo/manual_packs' in repos + assert 'narugo/csip_v1_info' in repos + + def test_ls_repo_space(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'ls_repo', '-t', 'space', + ]) + assert result.exitcode == 0 + repos = click.unstyle(result.stdout).splitlines(keepends=False) + assert 'narugo/jupyterlab' in repos + assert 'narugo/CDC_anime_demo' in repos + + def test_ls_repo_model(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'ls_repo', '-t', 'model', + ]) + assert result.exitcode == 0 + repos = click.unstyle(result.stdout).splitlines(keepends=False) + assert 'narugo/gchar_models' in repos + assert 'narugo/test_v1.5_kristen' in repos + assert 'narugo/test_v1.5_nian' in repos + + def test_ls_repo_anonymous(self, no_hf_token): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'ls_repo', + ]) + assert result.exitcode == 0x31 + stdout_lines = click.unstyle(result.stdout).splitlines(keepends=False) + assert len(stdout_lines) == 0 + stderr_lines = click.unstyle(result.stderr).splitlines(keepends=False) + assert 'Authentication failed.' in stderr_lines