diff --git a/annif/cli.py b/annif/cli.py index 29e873e0..0563c8aa 100644 --- a/annif/cli.py +++ b/annif/cli.py @@ -633,7 +633,9 @@ def run_upload( that match the given `project_ids_pattern` to archive files, and uploads the archives along with the project configurations to the specified Hugging Face Hub repository. An authentication token and commit message can be given with - options. + options. If the README.md does not exist in the repository it is + created with default contents and metadata of the uploaded projects, if it exists, + its metadata are updated as necessary. """ from huggingface_hub import HfApi from huggingface_hub.utils import HfHubHTTPError, HFValidationError @@ -690,8 +692,14 @@ def run_upload( is_flag=True, help="Replace an existing project/vocabulary/config with the downloaded one", ) +@click.option( + "--trust-repo", + default=False, + is_flag=True, + help="Allow download from the repository even when it has no entries in the cache", +) @cli_util.common_options -def run_download(project_ids_pattern, repo_id, token, revision, force): +def run_download(project_ids_pattern, repo_id, token, revision, force, trust_repo): """ Download selected projects and their vocabularies from a Hugging Face Hub repository. @@ -700,12 +708,14 @@ def run_download(project_ids_pattern, repo_id, token, revision, force): configuration files of the projects that match the given `project_ids_pattern` from the specified Hugging Face Hub repository and unzips the archives to `data/` directory and places the configuration files - to `projects.d/` directory. An authentication token and revision can - be given with options. If the README.md does not exist in the repository it is - created with default contents and metadata of the uploaded projects, if it exists, - its metadata are updated as necessary. + to `projects.d/` directory. An authentication token and revision can be given with + options. If the repository hasn’t been used for downloads previously + (i.e., it doesn’t appear in the Hugging Face Hub cache on local system), the + `--trust-repo` option needs to be used. """ + hfh_util.check_is_download_allowed(trust_repo, repo_id, token) + project_ids = hfh_util.get_matching_project_ids_from_hf_hub( project_ids_pattern, repo_id, token, revision ) diff --git a/annif/hfh_util.py b/annif/hfh_util.py index a99050be..bc421aa3 100644 --- a/annif/hfh_util.py +++ b/annif/hfh_util.py @@ -24,6 +24,35 @@ logger = annif.logger +def check_is_download_allowed(trust_repo, repo_id, token): + """Check if downloading from the specified repository is allowed based on the trust + option and cache status.""" + if trust_repo: + logger.warning( + f'Download allowed from "{repo_id}" because "--trust-repo" flag is used.' + ) + return + if _is_repo_in_cache(repo_id, token): + logger.debug( + f'Download allowed from "{repo_id}" because repo is already in cache.' + ) + return + raise OperationFailedException( + f'Cannot download projects from untrusted repo "{repo_id}"' + ) + + +def _is_repo_in_cache(repo_id, token): + from huggingface_hub import CacheNotFound, scan_cache_dir + + try: + cache = scan_cache_dir() + except CacheNotFound as err: + logger.debug(str(err) + "\nNo HFH cache found.") + return False + return repo_id in [info.repo_id for info in cache.repos] + + def get_matching_projects(pattern: str) -> list[AnnifProject]: """ Get projects that match the given pattern. diff --git a/tests/test_cli.py b/tests/test_cli.py index 98b6f26a..5690be48 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1149,10 +1149,30 @@ def test_upload_nonexistent_repo(): assert "Repository Not Found for url:" in failed_result.output +@mock.patch("annif.hfh_util._is_repo_in_cache", return_value=False) +def test_download_not_allowed_default(mock_is_repo_in_cache): + # Default of --trust-repo is False + failed_result = runner.invoke( + annif.cli.cli, + [ + "download", + "dummy-fi", + "dummy-repo", + ], + ) + assert failed_result.exception + assert failed_result.exit_code != 0 + assert ( + 'Cannot download projects from untrusted repo "dummy-repo"' + in failed_result.output + ) + + def hf_hub_download_mock_side_effect(filename, repo_id, token, revision): return "tests/huggingface-cache/" + filename # Mocks the downloaded file paths +@mock.patch("annif.hfh_util.check_is_download_allowed", return_value=True) @mock.patch( "huggingface_hub.list_repo_files", return_value=[ # Mocks the filenames in repo @@ -1170,7 +1190,11 @@ def hf_hub_download_mock_side_effect(filename, repo_id, token, revision): ) @mock.patch("annif.hfh_util.copy_project_config") def test_download_dummy_fi( - copy_project_config, hf_hub_download, list_repo_files, testdatadir + copy_project_config, + hf_hub_download, + list_repo_files, + check_is_download_allowed, + testdatadir, ): result = runner.invoke( annif.cli.cli, @@ -1211,6 +1235,7 @@ def test_download_dummy_fi( ] +@mock.patch("annif.hfh_util.check_is_download_allowed", return_value=True) @mock.patch( "huggingface_hub.list_repo_files", return_value=[ # Mock filenames in repo @@ -1228,7 +1253,11 @@ def test_download_dummy_fi( ) @mock.patch("annif.hfh_util.copy_project_config") def test_download_dummy_fi_and_en( - copy_project_config, hf_hub_download, list_repo_files, testdatadir + copy_project_config, + hf_hub_download, + list_repo_files, + check_is_download_allowed, + testdatadir, ): result = runner.invoke( annif.cli.cli, @@ -1285,6 +1314,7 @@ def test_download_dummy_fi_and_en( ] +@mock.patch("annif.hfh_util.check_is_download_allowed", return_value=True) @mock.patch( "huggingface_hub.list_repo_files", side_effect=HFValidationError, @@ -1293,8 +1323,7 @@ def test_download_dummy_fi_and_en( "huggingface_hub.hf_hub_download", ) def test_download_list_repo_files_failed( - hf_hub_download, - list_repo_files, + hf_hub_download, list_repo_files, check_is_download_allowed ): failed_result = runner.invoke( annif.cli.cli, @@ -1311,6 +1340,7 @@ def test_download_list_repo_files_failed( assert not hf_hub_download.called +@mock.patch("annif.hfh_util.check_is_download_allowed", return_value=True) @mock.patch( "huggingface_hub.list_repo_files", return_value=[ # Mock filenames in repo @@ -1326,6 +1356,7 @@ def test_download_list_repo_files_failed( def test_download_hf_hub_download_failed( hf_hub_download, list_repo_files, + check_is_download_allowed, ): failed_result = runner.invoke( annif.cli.cli, diff --git a/tests/test_hfh_util.py b/tests/test_hfh_util.py index 6b5f3774..6a783163 100644 --- a/tests/test_hfh_util.py +++ b/tests/test_hfh_util.py @@ -1,16 +1,61 @@ """Unit test module for Hugging Face Hub utilities.""" import io +import logging import os.path import zipfile from datetime import datetime, timezone from unittest import mock import huggingface_hub +import pytest from huggingface_hub.utils import EntryNotFoundError import annif.hfh_util from annif.config import AnnifConfigCFG +from annif.exception import OperationFailedException + + +@mock.patch("annif.hfh_util._is_repo_in_cache", return_value=False) +def test_download_allowed_trust_repo(mock_is_repo_in_cache, caplog): + trust_repo = True + repo_id = "dummy-repo" + token = "dummy-token" + + with caplog.at_level(logging.WARNING, logger="annif"): + annif.hfh_util.check_is_download_allowed(trust_repo, repo_id, token) + assert ( + 'Download allowed from "dummy-repo" because "--trust-repo" flag is used.' + in caplog.text + ) + + +@mock.patch("annif.hfh_util._is_repo_in_cache", return_value=True) +def test_download_allowed_repo_in_cache(mock_is_repo_in_cache, caplog): + trust_repo = False + repo_id = "dummy-repo" + token = "dummy-token" + + with caplog.at_level(logging.DEBUG, logger="annif"): + annif.hfh_util.check_is_download_allowed(trust_repo, repo_id, token) + assert ( + 'Download allowed from "dummy-repo" because repo is already in cache.' + in caplog.text + ) + + +@mock.patch("annif.hfh_util._is_repo_in_cache", return_value=False) +def test_download_not_allowed(mock_is_repo_in_cache): + trust_repo = False + repo_id = "dummy-repo" + token = "dummy-token" + + with pytest.raises(OperationFailedException) as excinfo: + annif.hfh_util.check_is_download_allowed(trust_repo, repo_id, token) + assert ( + str(excinfo.value) + == 'Cannot download projects from untrusted repo "dummy-repo"' + ) def test_archive_dir(testdatadir):