diff --git a/gitalong/batch.py b/gitalong/batch.py index f7ff073..81e5cda 100644 --- a/gitalong/batch.py +++ b/gitalong/batch.py @@ -1,35 +1,14 @@ import os import asyncio -from typing import Coroutine, Optional, List +from typing import List, Coroutine import git import git.exc -from git import Repo - from .enums import CommitSpread from .repository import Repository -from .exceptions import RepositoryNotSetup -from .functions import pulled_within, set_read_only - - -def get_repository_safe( - filename: str, -) -> Optional[Repository]: - """ - Args: - filename (str): A path that belong to the repository including itself. - - Returns: - Optional[Repository]: The repository or None. - """ - try: - return Repository(repository=filename, use_cached_instances=True) - except git.exc.InvalidGitRepositoryError: - return None - except RepositoryNotSetup: - return None +from .functions import set_read_only async def get_files_last_commits( @@ -47,7 +26,7 @@ async def get_files_last_commits( for filename in filenames: last_commit = {} - repository = get_repository_safe(filename) + repository = Repository.from_filename(filename) if not repository: last_commits.append(last_commit) continue @@ -92,19 +71,17 @@ async def get_files_last_commits( if not last_commit: pull_threshold = repository.config.get("pull_threshold", 60) - repo = Repo(repository.working_dir) - if not pulled_within(repo, pull_threshold): + if not repository.pulled_within(pull_threshold): try: - git.Repo(repository.working_dir).remotes.origin.fetch(prune=prune) repository.remote.fetch(prune=prune) except git.exc.GitCommandError: pass args = ["--all", "--remotes", '--pretty=format:"%H"', "--", filename] - output = repo.git.log(*args) + output = repository.git.log(*args) file_commits = output.replace('"', "").split("\n") if output else [] sha = file_commits[0] if file_commits else "" - last_commit = repo.commit(sha) or {} + last_commit = repository.get_commit(sha) or {} last_commits.append(last_commit) @@ -257,7 +234,7 @@ async def claim_files( for filename in filenames: last_commits = await get_files_last_commits([filename], prune=prune) last_commit = last_commits[0] - repository = get_repository_safe(filename) + repository = Repository.from_filename(filename) config = repository.config if repository else {} modify_permissions = config.get("modify_permissions") spread = repository.get_commit_spread(last_commit) if repository else 0 diff --git a/gitalong/cli.py b/gitalong/cli.py index 2048f56..05c02b1 100644 --- a/gitalong/cli.py +++ b/gitalong/cli.py @@ -12,29 +12,8 @@ from .__info__ import __version__ from .enums import CommitSpread -from .exceptions import RepositoryNotSetup from .repository import Repository -from .batch import get_files_last_commits, claim_files, get_repository_safe - - -def get_repository(filename: str) -> Repository: - """ - Args: - filename (str): A path that belong to the repository including itself. - - Returns: - Repository: The repository. - """ - try: - # Initializing Gitalong for each file allows to handle files from multiple - # repository. This is especially import to support submodules. - return Repository(repository=filename, use_cached_instances=True) - except git.exc.InvalidGitRepositoryError: - click.echo("fatal: not a git repository") - raise click.Abort() # pylint: disable=raise-missing-from - except RepositoryNotSetup: - click.echo("fatal: not a gitalong repository") - raise click.Abort() # pylint: disable=raise-missing-from +from .batch import get_files_last_commits, claim_files def get_status_string(filename: str, commit: dict, spread: int) -> str: @@ -80,8 +59,8 @@ def version(): # pylint: disable=missing-function-docstring ) @click.pass_context def config(ctx, prop): # pylint: disable=missing-function-docstring - repository = get_repository(ctx.obj.get("REPOSITORY", "")) - repository_config = repository.config + repository = Repository.from_filename(ctx.obj.get("REPOSITORY", "")) + repository_config = repository.config if repository else {} prop = prop.replace("-", "_") if prop in repository_config: value = repository_config[prop] @@ -103,19 +82,22 @@ def config(ctx, prop): # pylint: disable=missing-function-docstring @click.pass_context def update(ctx, repository): """TODO: Improve error handling.""" - repository = get_repository(ctx.obj.get("REPOSITORY", "")) - root = repository.working_dir if repository else "" + repository = Repository.from_filename(ctx.obj.get("REPOSITORY", "")) + if not repository: + return + working_dir = repository.working_dir repository.update_tracked_commits() locally_changed = {} permission_changes = [] if repository.config.get("modify_permissions"): - # TODO: This is a very expensive operation and needs to be optimized. + # TODO: This is an expensive operation and needs to be optimized. + # Also probably should not be done here at the CLI level. for filename in repository.files: if os.path.isfile(repository.get_absolute_path(filename)): - if root not in locally_changed: - locally_changed[root] = repository.locally_changed_files + if working_dir not in locally_changed: + locally_changed[working_dir] = repository.locally_changed_files perm_change = repository.update_file_permissions( - filename, locally_changed[root] + filename, locally_changed[working_dir] ) if perm_change: permission_changes.append(f"{' '.join(perm_change)}") @@ -157,7 +139,7 @@ def run_status(ctx, filename): # pylint: disable=missing-function-docstring file_status = [] commits = asyncio.run(get_files_last_commits(filename)) for _filename, commit in zip(filename, commits): - repository = get_repository_safe(ctx.obj.get("REPOSITORY", _filename)) + repository = Repository.from_filename(ctx.obj.get("REPOSITORY", _filename)) absolute_filename = ( repository.get_absolute_path(_filename) if repository else _filename ) @@ -182,7 +164,7 @@ def claim(ctx, filename): # pylint: disable=missing-function-docstring statuses = [] blocking_commits = asyncio.run(claim_files(filename)) for _filename, commit in zip(filename, blocking_commits): - repository = get_repository_safe(ctx.obj.get("REPOSITORY", _filename)) + repository = Repository.from_filename(ctx.obj.get("REPOSITORY", _filename)) absolute_filename = ( repository.get_absolute_path(_filename) if repository else _filename ) diff --git a/gitalong/functions.py b/gitalong/functions.py index bab4f01..f43c0c3 100644 --- a/gitalong/functions.py +++ b/gitalong/functions.py @@ -6,6 +6,7 @@ from git.repo import Repo + MOVE_STRING_REGEX = re.compile("{(.*)}") diff --git a/gitalong/repository.py b/gitalong/repository.py index c453056..8e66984 100644 --- a/gitalong/repository.py +++ b/gitalong/repository.py @@ -6,6 +6,7 @@ import os import shutil import socket +import asyncio from typing import Optional, List @@ -15,6 +16,9 @@ from git.repo import Repo +# Deliberatedly import the module to avoid circular imports. +from . import batch + from .store import Store from .enums import CommitSpread from .stores.git_store import GitStore @@ -24,7 +28,7 @@ get_real_path, is_binary_file, set_read_only, - get_filenames_from_move_string, + pulled_within, ) @@ -157,6 +161,22 @@ def setup( gitalong.install_hooks() return gitalong + @classmethod + def from_filename(cls, filename: str) -> Optional["Repository"]: + """ + Args: + filename (str): A path that belong to the repository including itself. + + Returns: + Optional[Repository]: The repository or None. + """ + try: + return cls(repository=filename, use_cached_instances=True) + except git.exc.InvalidGitRepositoryError: + return None + except RepositoryNotSetup: + return None + @staticmethod def _write_config_file(config: dict, path: str): with open(path, "w", encoding="utf8") as config_file: @@ -410,17 +430,18 @@ def _accumulate_local_only_commits(self, start: git.Commit, local_commits: list) start (git.objects.Commit): The commit that we start peeling from last commit. """ - # TODO: Maybe there is a way to get this information using pure Python. if self._managed_repository.git.branch("--remotes", "--contains", start.hexsha): return - # TODO: This is an expensive thing to call inividualy. - commit_dict = self.get_commit_dict(start) - commit_dict.update(self.context_dict) - # TODO: This is an expensive thing to call inividualy. - commit_dict["branches"] = {"local": self.get_commit_branches(start.hexsha)} - # TODO: Maybe we should compare the SHA here. - if commit_dict not in local_commits: - local_commits.append(commit_dict) + # TODO: These call to batch functions are expensive for a single file. + commits = asyncio.run(batch.get_commits_dicts([start])) + commit = commits[0] if commits else {} + commit.update(self.context_dict) + branches_list = asyncio.run(batch.get_commits_branches([commit])) + branches = branches_list[0] if branches_list else [] + commit["branches"] = {"local": branches} + # Maybe we should compare the SHA here. + if commit not in local_commits: + local_commits.append(commit) for parent in start.parents: self._accumulate_local_only_commits(parent, local_commits) @@ -469,7 +490,6 @@ def _uncommitted_changes(self) -> list: Returns: list: A list of unique relative filenames that feature uncommitted changes. """ - # TODO: Maybe there is a way to get this information using pure Python. git_cmd = self._managed_repository.git output = git_cmd.ls_files("--exclude-standard", "--others") untracked_changes = output.split("\n") if output else [] @@ -505,8 +525,7 @@ def files(self) -> list: """ git_cmd = self._managed_repository.git try: - # TODO: HEAD might not be safe here since user could checkout an earlier - # commit. + # TODO: HEAD might not be safe. The user could checkout an earlier commit. filenames = git_cmd.ls_tree(full_tree=True, name_only=True, r="HEAD") return filenames.split("\n") except git.exc.GitCommandError: @@ -604,45 +623,38 @@ def _get_updated_tracked_commits(self, claims: Optional[List[str]] = None) -> li tracked_commits.append(commit) return tracked_commits - def get_commit_dict(self, commit: git.Commit) -> dict: + def pulled_within(self, seconds: float) -> bool: """ Args: - commit (git.objects.Commit): The commit to get as a dict. + seconds (float): Time in seconds since last push. Returns: - dict: A simplified JSON serializable dict that represents the commit. + TYPE: Whether the repository pulled within the time provided. """ - changes = [] - for change in list(commit.stats.files.keys()): - changes += get_filenames_from_move_string(str(change)) - return { - "sha": commit.hexsha, - "remote": self._remote.url, - "changes": changes, - "date": str(commit.committed_datetime), - "author": commit.author.name, - } + return pulled_within(self._managed_repository, seconds) - def get_commit_branches(self, sha: str, remote: bool = False) -> list: + def log(self, message: str): + """Logs a message to the managed repository. + + Args: + message (str): The message to log. + """ + self._managed_repository.git.log(message) + + def get_commit(self, sha: str) -> git.Commit: """ Args: - sha (str): The sha of the commit to check for. - remote (bool, optional): Whether we should return local or remote branches. + sha (str): The SHA of the commit to get. Returns: - list: A list of branch names that this commit is living on. + git.Commit: The commit object for the provided SHA. """ - args = ["--remote" if remote else []] - args += ["--contains", sha] - try: - branches = self._managed_repository.git.branch(*args) - # If the commit is not on any branch we get a git.exc.GitCommandError. - except git.exc.GitCommandError: - return [] - branches = branches.replace("*", "") - branches = branches.replace(" ", "") - branches = branches.split("\n") if branches else [] - branch_names = set() - for branch in branches: - branch_names.add(branch.split("/")[-1]) - return list(branch_names) + return self._managed_repository.commit(sha) + + @property + def git(self) -> git.Git: + """ + Returns: + git.cmd.Git: The Git command line interface for the managed repository. + """ + return self._managed_repository.git diff --git a/gitalong/stores/git_store.py b/gitalong/stores/git_store.py index 740b258..467abc5 100644 --- a/gitalong/stores/git_store.py +++ b/gitalong/stores/git_store.py @@ -53,8 +53,8 @@ def commits(self) -> typing.List[dict]: remote = store_repository.remote() pull_threshold = self._managed_repository.config.get("pull_threshold", 60) if not pulled_within(store_repository, pull_threshold) and remote.refs: - # TODO: If we could check that a pull is already happening then we could - # avoid this try except and save time. + # TODO: We could check that a pull is already happening thus avoiding this + # try except and save time. try: remote.pull( ff=True,