diff --git a/.github/workflows/ir_repos.yml b/.github/workflows/ir_repos.yml index c3438df2bc..da21d740fe 100644 --- a/.github/workflows/ir_repos.yml +++ b/.github/workflows/ir_repos.yml @@ -4,7 +4,7 @@ name: Irregular Repos on: workflow_dispatch: schedule: - - cron: '0 * * * 0' + - cron: '0 12 * * 0' jobs: check_irregular_repo: diff --git a/docs/source/api_doc/entry/index.rst b/docs/source/api_doc/entry/index.rst index fa19907297..3e230abab8 100644 --- a/docs/source/api_doc/entry/index.rst +++ b/docs/source/api_doc/entry/index.rst @@ -14,6 +14,7 @@ hfutils.entry dispatch download ls + ls_repo upload whoami diff --git a/docs/source/api_doc/entry/ls_repo.rst b/docs/source/api_doc/entry/ls_repo.rst new file mode 100644 index 0000000000..d8cf5e702d --- /dev/null +++ b/docs/source/api_doc/entry/ls_repo.rst @@ -0,0 +1,14 @@ +hfutils.entry.ls_repo +================================ + +.. currentmodule:: hfutils.entry.ls_repo + +.. automodule:: hfutils.entry.ls_repo + + +NoLocalAuthentication +---------------------------------- + +.. autoclass:: NoLocalAuthentication + + diff --git a/hfutils/__main__.py b/hfutils/__main__.py new file mode 100644 index 0000000000..7062ebf8f6 --- /dev/null +++ b/hfutils/__main__.py @@ -0,0 +1,4 @@ +from .entry import hfutilscli + +if __name__ == '__main__': + hfutilscli() diff --git a/hfutils/entry/cli.py b/hfutils/entry/cli.py index 00f4d929d9..98030edc19 100644 --- a/hfutils/entry/cli.py +++ b/hfutils/entry/cli.py @@ -1,6 +1,7 @@ from .dispatch import hfutilcli from .download import _add_download_subcommand from .ls import _add_ls_subcommand +from .ls_repo import _add_ls_repo_subcommand from .upload import _add_upload_subcommand from .whoami import _add_whoami_subcommand @@ -9,6 +10,7 @@ _add_upload_subcommand, _add_ls_subcommand, _add_whoami_subcommand, + _add_ls_repo_subcommand, ] cli = hfutilcli diff --git a/hfutils/entry/ls_repo.py b/hfutils/entry/ls_repo.py new file mode 100644 index 0000000000..fc108ce00b --- /dev/null +++ b/hfutils/entry/ls_repo.py @@ -0,0 +1,74 @@ +import fnmatch +from typing import Optional + +import click +from huggingface_hub.utils import LocalTokenNotFoundError + +from .base import CONTEXT_SETTINGS, ClickErrorException +from ..operate.base import REPO_TYPES, get_hf_client + + +class NoLocalAuthentication(ClickErrorException): + """ + Exception raised when there is no local authentication token. + """ + exit_code = 0x31 + + +def _add_ls_repo_subcommand(cli: click.Group) -> click.Group: + """ + Add the ls_repo subcommand to the CLI. + + :param cli: The click Group object. + :type cli: click.Group + + :return: The updated click Group object. + :rtype: click.Group + """ + + @cli.command('ls_repo', help='List repositories from HuggingFace.\n\n' + 'Set environment $HF_TOKEN to use your own access token.', + context_settings=CONTEXT_SETTINGS) + @click.option('-a', '--author', 'author', type=str, default=None, + help='Author of the repositories. Search my repositories when not given.') + @click.option('-t', '--type', 'repo_type', type=click.Choice(REPO_TYPES), default='dataset', + help='Type of the HuggingFace repository.', show_default=True) + @click.option('-p', '--pattern', 'pattern', type=str, default='*', + help='Pattern of the repository names.', show_default=True) + def ls(author: Optional[str], repo_type: str, pattern: str): + """ + List repositories from HuggingFace. + + :param author: Author of the repositories. + :type author: Optional[str] + :param repo_type: Type of the HuggingFace repository. + :type repo_type: str + :param pattern: Pattern of the repository names. + :type pattern: str + """ + hf_client = get_hf_client() + if not author: + try: + info = hf_client.whoami() + author = author or info['name'] + except LocalTokenNotFoundError: + raise NoLocalAuthentication( + 'Authentication failed.\n' + 'Make sure you have set the correct Huggingface token.\n' + 'Or if need to use this with guest mode, please explicitly set the `-a` option.' + ) + + if repo_type == 'model': + r = hf_client.list_models(author=author) + elif repo_type == 'dataset': + r = hf_client.list_datasets(author=author) + elif repo_type == 'space': + r = hf_client.list_spaces(author=author) + else: + raise ValueError(f'Unknown repository type - {repo_type!r}.') # pragma: no cover + + for repo_item in r: + if fnmatch.fnmatch(repo_item.id, pattern): + print(repo_item.id) + + return cli diff --git a/test/entry/test_ls_repo.py b/test/entry/test_ls_repo.py new file mode 100644 index 0000000000..9b27a7282b --- /dev/null +++ b/test/entry/test_ls_repo.py @@ -0,0 +1,60 @@ +import os +from unittest.mock import patch + +import click +import pytest +from hbutils.testing import simulate_entry +from huggingface_hub import HfApi + +from hfutils.entry import hfutilscli + + +@pytest.fixture() +def no_hf_token(): + def _get_hf_client(): + return HfApi(token='') + + with patch('hfutils.entry.ls_repo.get_hf_client', _get_hf_client), \ + patch.dict(os.environ, {'HF_TOKEN': ''}): + yield + + +@pytest.mark.unittest +class TestEntryLsRepo: + def test_ls_repo(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'ls_repo', + ]) + assert result.exitcode == 0 + repos = click.unstyle(result.stdout).splitlines(keepends=False) + assert 'narugo/manual_packs' in repos + assert 'narugo/csip_v1_info' in repos + + def test_ls_repo_space(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'ls_repo', '-t', 'space', + ]) + assert result.exitcode == 0 + repos = click.unstyle(result.stdout).splitlines(keepends=False) + assert 'narugo/jupyterlab' in repos + assert 'narugo/CDC_anime_demo' in repos + + def test_ls_repo_model(self): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'ls_repo', '-t', 'model', + ]) + assert result.exitcode == 0 + repos = click.unstyle(result.stdout).splitlines(keepends=False) + assert 'narugo/gchar_models' in repos + assert 'narugo/test_v1.5_kristen' in repos + assert 'narugo/test_v1.5_nian' in repos + + def test_ls_repo_anonymous(self, no_hf_token): + result = simulate_entry(hfutilscli, [ + 'hfutils', 'ls_repo', + ]) + assert result.exitcode == 0x31 + stdout_lines = click.unstyle(result.stdout).splitlines(keepends=False) + assert len(stdout_lines) == 0 + stderr_lines = click.unstyle(result.stderr).splitlines(keepends=False) + assert 'Authentication failed.' in stderr_lines