diff --git a/github_app_geo_project/module/pull_request/checks.py b/github_app_geo_project/module/pull_request/checks.py index 50bb92c043e..224fc832aee 100644 --- a/github_app_geo_project/module/pull_request/checks.py +++ b/github_app_geo_project/module/pull_request/checks.py @@ -5,41 +5,60 @@ import os import re import subprocess # nosec -from tempfile import NamedTemporaryFile +import tempfile +import typing from typing import Any, cast import github import github.Commit import github.PullRequest -from github_app_geo_project import configuration, module +from github_app_geo_project import module from github_app_geo_project.module import utils as module_utils from github_app_geo_project.module.pull_request import checks_configuration _LOGGER = logging.getLogger(__name__) +if typing.TYPE_CHECKING: + NamedTemporaryFileStr = tempfile._TemporaryFileWrapper[str] # pylint: disable=protected-access +else: + NamedTemporaryFileStr = tempfile._TemporaryFileWrapper # pylint: disable=protected-access -def _get_codespell_command(config: checks_configuration.PullRequestChecksConfiguration) -> list[str]: + +def _get_code_spell_command( + context: module.ProcessContext[ + checks_configuration.PullRequestChecksConfiguration, dict[str, Any], dict[str, Any] + ], + ignore_file: NamedTemporaryFileStr, +) -> list[str]: """ Get the codespell command. """ - codespell_config = config.get("codespell", {}) - codespell_config = codespell_config if isinstance(codespell_config, dict) else {} + config = context.module_config + code_spell_config = config.get("codespell", {}) + code_spell_config = code_spell_config if isinstance(code_spell_config, dict) else {} command = ["codespell"] for spell_ignore_file in ( ".github/spell-ignore-words.txt", "spell-ignore-words.txt", ".spell-ignore-words.txt", ): - if os.path.exists(spell_ignore_file): - command.append(f"--ignore-words={spell_ignore_file}") - break - dictionaries = codespell_config.get( + try: + content = context.github_project.repo.get_contents(spell_ignore_file) + if isinstance(content, github.ContentFile.ContentFile): + ignore_file.write(content.decoded_content.decode("utf-8")) + ignore_file.cloase() + command.append(f"--ignore-words={ignore_file.name}") + break + except github.GithubException as exc: + if exc.status != 404: + raise + dictionaries = code_spell_config.get( "internal-dictionaries", checks_configuration.CODESPELL_DICTIONARIES_DEFAULT ) if dictionaries: command.append("--builtin=" + ",".join(dictionaries)) - command += codespell_config.get("arguments", checks_configuration.CODESPELL_ARGUMENTS_DEFAULT) + command += code_spell_config.get("arguments", checks_configuration.CODESPELL_ARGUMENTS_DEFAULT) return command @@ -151,14 +170,13 @@ def _commits_messages( def _commits_spell( config: checks_configuration.PullRequestChecksConfiguration, commits: list[github.Commit.Commit], + spellcheck_cmd: list[str], ) -> tuple[bool, list[str]]: """Check the spelling of the commits body.""" - spellcheck_cmd = _get_codespell_command(config) - messages = [] success = True for commit in commits: - with NamedTemporaryFile("w+t", encoding="utf-8", suffix=".yaml") as temp_file: + with tempfile.NamedTemporaryFile("w+t", encoding="utf-8", suffix=".yaml") as temp_file: if config.get( "only-head", checks_configuration.PULL_REQUEST_CHECKS_COMMITS_MESSAGES_ONLY_HEAD_DEFAULT ): @@ -189,13 +207,13 @@ def _commits_spell( def _pull_request_spell( - config: checks_configuration.PullRequestChecksConfiguration, pull_request: github.PullRequest.PullRequest + config: checks_configuration.PullRequestChecksConfiguration, + pull_request: github.PullRequest.PullRequest, + spellcheck_cmd: list[str], ) -> tuple[bool, list[str]]: """Check the spelling of the pull request title and message.""" - spellcheck_cmd = _get_codespell_command(config) - messages = [] - with NamedTemporaryFile("w+t") as temp_file: + with tempfile.NamedTemporaryFile("w+t") as temp_file: temp_file.write(pull_request.title) temp_file.write("\n") if ( @@ -293,18 +311,18 @@ async def process( ], ) -> module.ProcessOutput[dict[str, Any], dict[str, Any]]: """Process the module.""" - repo = context.github_project.github.get_repo( - context.github_project.owner + "/" + context.github_project.repository - ) + repo = context.github_project.repo pull_request = repo.get_pull(number=context.module_event_data["pull-request-number"]) commits = [ # pylint: disable=unnecessary-comprehension commit for commit in pull_request.get_commits() ] - success_1, messages_1 = _commits_messages(context.module_config, commits) - success_2, messages_2 = _commits_spell(context.module_config, commits) - success_3, messages_3 = _pull_request_spell(context.module_config, pull_request) + with tempfile.NamedTemporaryFile("w+t", encoding="utf-8") as ignore_file: + spellcheck_cmd = _get_code_spell_command(context, ignore_file) + success_1, messages_1 = _commits_messages(context.module_config, commits) + success_2, messages_2 = _commits_spell(context.module_config, commits, spellcheck_cmd) + success_3, messages_3 = _pull_request_spell(context.module_config, pull_request, spellcheck_cmd) success = success_1 and success_2 and success_3 message = "\n".join([*messages_1, *messages_2, *messages_3])